1package middleware
2
3import (
4 "context"
5 "fmt"
6 "log"
7 "net/http"
8 "slices"
9 "strconv"
10 "strings"
11 "time"
12
13 "github.com/bluesky-social/indigo/atproto/identity"
14 "github.com/go-chi/chi/v5"
15 "tangled.sh/tangled.sh/core/appview/db"
16 "tangled.sh/tangled.sh/core/appview/oauth"
17 "tangled.sh/tangled.sh/core/appview/pages"
18 "tangled.sh/tangled.sh/core/appview/pagination"
19 "tangled.sh/tangled.sh/core/appview/reporesolver"
20 "tangled.sh/tangled.sh/core/idresolver"
21 "tangled.sh/tangled.sh/core/rbac"
22)
23
24type Middleware struct {
25 oauth *oauth.OAuth
26 db *db.DB
27 enforcer *rbac.Enforcer
28 repoResolver *reporesolver.RepoResolver
29 idResolver *idresolver.Resolver
30 pages *pages.Pages
31}
32
33func New(oauth *oauth.OAuth, db *db.DB, enforcer *rbac.Enforcer, repoResolver *reporesolver.RepoResolver, idResolver *idresolver.Resolver, pages *pages.Pages) Middleware {
34 return Middleware{
35 oauth: oauth,
36 db: db,
37 enforcer: enforcer,
38 repoResolver: repoResolver,
39 idResolver: idResolver,
40 pages: pages,
41 }
42}
43
44type middlewareFunc func(http.Handler) http.Handler
45
46func AuthMiddleware(a *oauth.OAuth) middlewareFunc {
47 return func(next http.Handler) http.Handler {
48 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
49 redirectFunc := func(w http.ResponseWriter, r *http.Request) {
50 http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
51 }
52 if r.Header.Get("HX-Request") == "true" {
53 redirectFunc = func(w http.ResponseWriter, _ *http.Request) {
54 w.Header().Set("HX-Redirect", "/login")
55 w.WriteHeader(http.StatusOK)
56 }
57 }
58
59 _, auth, err := a.GetSession(r)
60 if err != nil {
61 log.Println("not logged in, redirecting", "err", err)
62 redirectFunc(w, r)
63 return
64 }
65
66 if !auth {
67 log.Printf("not logged in, redirecting")
68 redirectFunc(w, r)
69 return
70 }
71
72 next.ServeHTTP(w, r)
73 })
74 }
75}
76
77func Paginate(next http.Handler) http.Handler {
78 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
79 page := pagination.FirstPage()
80
81 offsetVal := r.URL.Query().Get("offset")
82 if offsetVal != "" {
83 offset, err := strconv.Atoi(offsetVal)
84 if err != nil {
85 log.Println("invalid offset")
86 } else {
87 page.Offset = offset
88 }
89 }
90
91 limitVal := r.URL.Query().Get("limit")
92 if limitVal != "" {
93 limit, err := strconv.Atoi(limitVal)
94 if err != nil {
95 log.Println("invalid limit")
96 } else {
97 page.Limit = limit
98 }
99 }
100
101 ctx := context.WithValue(r.Context(), "page", page)
102 next.ServeHTTP(w, r.WithContext(ctx))
103 })
104}
105
106func (mw Middleware) knotRoleMiddleware(group string) middlewareFunc {
107 return func(next http.Handler) http.Handler {
108 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
109 // requires auth also
110 actor := mw.oauth.GetUser(r)
111 if actor == nil {
112 // we need a logged in user
113 log.Printf("not logged in, redirecting")
114 http.Error(w, "Forbiden", http.StatusUnauthorized)
115 return
116 }
117 domain := chi.URLParam(r, "domain")
118 if domain == "" {
119 http.Error(w, "malformed url", http.StatusBadRequest)
120 return
121 }
122
123 ok, err := mw.enforcer.E.HasGroupingPolicy(actor.Did, group, domain)
124 if err != nil || !ok {
125 // we need a logged in user
126 log.Printf("%s does not have perms of a %s in domain %s", actor.Did, group, domain)
127 http.Error(w, "Forbiden", http.StatusUnauthorized)
128 return
129 }
130
131 next.ServeHTTP(w, r)
132 })
133 }
134}
135
136func (mw Middleware) KnotOwner() middlewareFunc {
137 return mw.knotRoleMiddleware("server:owner")
138}
139
140func (mw Middleware) RepoPermissionMiddleware(requiredPerm string) middlewareFunc {
141 return func(next http.Handler) http.Handler {
142 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
143 // requires auth also
144 actor := mw.oauth.GetUser(r)
145 if actor == nil {
146 // we need a logged in user
147 log.Printf("not logged in, redirecting")
148 http.Error(w, "Forbiden", http.StatusUnauthorized)
149 return
150 }
151 f, err := mw.repoResolver.Resolve(r)
152 if err != nil {
153 http.Error(w, "malformed url", http.StatusBadRequest)
154 return
155 }
156
157 ok, err := mw.enforcer.E.Enforce(actor.Did, f.Knot, f.DidSlashRepo(), requiredPerm)
158 if err != nil || !ok {
159 // we need a logged in user
160 log.Printf("%s does not have perms of a %s in repo %s", actor.Did, requiredPerm, f.OwnerSlashRepo())
161 http.Error(w, "Forbiden", http.StatusUnauthorized)
162 return
163 }
164
165 next.ServeHTTP(w, r)
166 })
167 }
168}
169
170func (mw Middleware) ResolveIdent() middlewareFunc {
171 excluded := []string{"favicon.ico"}
172
173 return func(next http.Handler) http.Handler {
174 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
175 didOrHandle := chi.URLParam(req, "user")
176 if slices.Contains(excluded, didOrHandle) {
177 next.ServeHTTP(w, req)
178 return
179 }
180
181 didOrHandle = strings.TrimPrefix(didOrHandle, "@")
182
183 id, err := mw.idResolver.ResolveIdent(req.Context(), didOrHandle)
184 if err != nil {
185 // invalid did or handle
186 log.Printf("failed to resolve did/handle '%s': %s\n", didOrHandle, err)
187 mw.pages.Error404(w)
188 return
189 }
190
191 ctx := context.WithValue(req.Context(), "resolvedId", *id)
192
193 next.ServeHTTP(w, req.WithContext(ctx))
194 })
195 }
196}
197
198func (mw Middleware) ResolveRepo() middlewareFunc {
199 return func(next http.Handler) http.Handler {
200 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
201 repoName := chi.URLParam(req, "repo")
202 id, ok := req.Context().Value("resolvedId").(identity.Identity)
203 if !ok {
204 log.Println("malformed middleware")
205 w.WriteHeader(http.StatusInternalServerError)
206 return
207 }
208
209 repo, err := db.GetRepo(mw.db, id.DID.String(), repoName)
210 if err != nil {
211 // invalid did or handle
212 log.Println("failed to resolve repo")
213 mw.pages.Error404(w)
214 return
215 }
216
217 ctx := context.WithValue(req.Context(), "knot", repo.Knot)
218 ctx = context.WithValue(ctx, "repoAt", repo.AtUri)
219 ctx = context.WithValue(ctx, "repoDescription", repo.Description)
220 ctx = context.WithValue(ctx, "repoSpindle", repo.Spindle)
221 ctx = context.WithValue(ctx, "repoAddedAt", repo.Created.Format(time.RFC3339))
222 next.ServeHTTP(w, req.WithContext(ctx))
223 })
224 }
225}
226
227// middleware that is tacked on top of /{user}/{repo}/pulls/{pull}
228func (mw Middleware) ResolvePull() middlewareFunc {
229 return func(next http.Handler) http.Handler {
230 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
231 f, err := mw.repoResolver.Resolve(r)
232 if err != nil {
233 log.Println("failed to fully resolve repo", err)
234 http.Error(w, "invalid repo url", http.StatusNotFound)
235 return
236 }
237
238 prId := chi.URLParam(r, "pull")
239 prIdInt, err := strconv.Atoi(prId)
240 if err != nil {
241 http.Error(w, "bad pr id", http.StatusBadRequest)
242 log.Println("failed to parse pr id", err)
243 return
244 }
245
246 pr, err := db.GetPull(mw.db, f.RepoAt, prIdInt)
247 if err != nil {
248 log.Println("failed to get pull and comments", err)
249 return
250 }
251
252 ctx := context.WithValue(r.Context(), "pull", pr)
253
254 if pr.IsStacked() {
255 stack, err := db.GetStack(mw.db, pr.StackId)
256 if err != nil {
257 log.Println("failed to get stack", err)
258 return
259 }
260 abandonedPulls, err := db.GetAbandonedPulls(mw.db, pr.StackId)
261 if err != nil {
262 log.Println("failed to get abandoned pulls", err)
263 return
264 }
265
266 ctx = context.WithValue(ctx, "stack", stack)
267 ctx = context.WithValue(ctx, "abandonedPulls", abandonedPulls)
268 }
269
270 next.ServeHTTP(w, r.WithContext(ctx))
271 })
272 }
273}
274
275// this should serve the go-import meta tag even if the path is technically
276// a 404 like tangled.sh/oppi.li/go-git/v5
277func (mw Middleware) GoImport() middlewareFunc {
278 return func(next http.Handler) http.Handler {
279 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
280 f, err := mw.repoResolver.Resolve(r)
281 if err != nil {
282 log.Println("failed to fully resolve repo", err)
283 http.Error(w, "invalid repo url", http.StatusNotFound)
284 return
285 }
286
287 fullName := f.OwnerHandle() + "/" + f.RepoName
288
289 if r.Header.Get("User-Agent") == "Go-http-client/1.1" {
290 if r.URL.Query().Get("go-get") == "1" {
291 html := fmt.Sprintf(
292 `<meta name="go-import" content="tangled.sh/%s git https://tangled.sh/%s"/>`,
293 fullName,
294 fullName,
295 )
296 w.Header().Set("Content-Type", "text/html")
297 w.Write([]byte(html))
298 return
299 }
300 }
301
302 next.ServeHTTP(w, r)
303 })
304 }
305}