1package middleware
2
3import (
4 "context"
5 "fmt"
6 "log"
7 "net/http"
8 "net/url"
9 "slices"
10 "strconv"
11 "strings"
12
13 "github.com/bluesky-social/indigo/atproto/identity"
14 "github.com/go-chi/chi/v5"
15 "tangled.org/core/appview/db"
16 "tangled.org/core/appview/oauth"
17 "tangled.org/core/appview/pages"
18 "tangled.org/core/appview/pagination"
19 "tangled.org/core/appview/reporesolver"
20 "tangled.org/core/idresolver"
21 "tangled.org/core/rbac"
22)
23
24type Middleware struct {
25 oauth *oauth.OAuth
26 db *db.DB
27 enforcer *rbac.Enforcer
28 repoResolver *reporesolver.RepoResolver
29 idResolver *idresolver.Resolver
30 pages *pages.Pages
31}
32
33func New(oauth *oauth.OAuth, db *db.DB, enforcer *rbac.Enforcer, repoResolver *reporesolver.RepoResolver, idResolver *idresolver.Resolver, pages *pages.Pages) Middleware {
34 return Middleware{
35 oauth: oauth,
36 db: db,
37 enforcer: enforcer,
38 repoResolver: repoResolver,
39 idResolver: idResolver,
40 pages: pages,
41 }
42}
43
44type middlewareFunc func(http.Handler) http.Handler
45
46func (mw *Middleware) TryRefreshSession() middlewareFunc {
47 return func(next http.Handler) http.Handler {
48 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
49 _, auth, err := mw.oauth.GetSession(r)
50 if err != nil {
51 log.Println("could not refresh session", "err", err, "auth", auth)
52 }
53
54 next.ServeHTTP(w, r)
55 })
56 }
57}
58
59func AuthMiddleware(a *oauth.OAuth) middlewareFunc {
60 return func(next http.Handler) http.Handler {
61 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
62 returnURL := "/"
63 if u, err := url.Parse(r.Header.Get("Referer")); err == nil {
64 returnURL = u.RequestURI()
65 }
66
67 loginURL := fmt.Sprintf("/login?return_url=%s", url.QueryEscape(returnURL))
68
69 redirectFunc := func(w http.ResponseWriter, r *http.Request) {
70 http.Redirect(w, r, loginURL, http.StatusTemporaryRedirect)
71 }
72 if r.Header.Get("HX-Request") == "true" {
73 redirectFunc = func(w http.ResponseWriter, _ *http.Request) {
74 w.Header().Set("HX-Redirect", loginURL)
75 w.WriteHeader(http.StatusOK)
76 }
77 }
78
79 _, auth, err := a.GetSession(r)
80 if err != nil {
81 log.Println("not logged in, redirecting", "err", err)
82 redirectFunc(w, r)
83 return
84 }
85
86 if !auth {
87 log.Printf("not logged in, redirecting")
88 redirectFunc(w, r)
89 return
90 }
91
92 next.ServeHTTP(w, r)
93 })
94 }
95}
96
97func Paginate(next http.Handler) http.Handler {
98 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
99 page := pagination.FirstPage()
100
101 offsetVal := r.URL.Query().Get("offset")
102 if offsetVal != "" {
103 offset, err := strconv.Atoi(offsetVal)
104 if err != nil {
105 log.Println("invalid offset")
106 } else {
107 page.Offset = offset
108 }
109 }
110
111 limitVal := r.URL.Query().Get("limit")
112 if limitVal != "" {
113 limit, err := strconv.Atoi(limitVal)
114 if err != nil {
115 log.Println("invalid limit")
116 } else {
117 page.Limit = limit
118 }
119 }
120
121 ctx := context.WithValue(r.Context(), "page", page)
122 next.ServeHTTP(w, r.WithContext(ctx))
123 })
124}
125
126func (mw Middleware) knotRoleMiddleware(group string) middlewareFunc {
127 return func(next http.Handler) http.Handler {
128 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
129 // requires auth also
130 actor := mw.oauth.GetUser(r)
131 if actor == nil {
132 // we need a logged in user
133 log.Printf("not logged in, redirecting")
134 http.Error(w, "Forbiden", http.StatusUnauthorized)
135 return
136 }
137 domain := chi.URLParam(r, "domain")
138 if domain == "" {
139 http.Error(w, "malformed url", http.StatusBadRequest)
140 return
141 }
142
143 ok, err := mw.enforcer.E.HasGroupingPolicy(actor.Did, group, domain)
144 if err != nil || !ok {
145 // we need a logged in user
146 log.Printf("%s does not have perms of a %s in domain %s", actor.Did, group, domain)
147 http.Error(w, "Forbiden", http.StatusUnauthorized)
148 return
149 }
150
151 next.ServeHTTP(w, r)
152 })
153 }
154}
155
156func (mw Middleware) KnotOwner() middlewareFunc {
157 return mw.knotRoleMiddleware("server:owner")
158}
159
160func (mw Middleware) RepoPermissionMiddleware(requiredPerm string) middlewareFunc {
161 return func(next http.Handler) http.Handler {
162 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
163 // requires auth also
164 actor := mw.oauth.GetUser(r)
165 if actor == nil {
166 // we need a logged in user
167 log.Printf("not logged in, redirecting")
168 http.Error(w, "Forbiden", http.StatusUnauthorized)
169 return
170 }
171 f, err := mw.repoResolver.Resolve(r)
172 if err != nil {
173 http.Error(w, "malformed url", http.StatusBadRequest)
174 return
175 }
176
177 ok, err := mw.enforcer.E.Enforce(actor.Did, f.Knot, f.DidSlashRepo(), requiredPerm)
178 if err != nil || !ok {
179 // we need a logged in user
180 log.Printf("%s does not have perms of a %s in repo %s", actor.Did, requiredPerm, f.OwnerSlashRepo())
181 http.Error(w, "Forbiden", http.StatusUnauthorized)
182 return
183 }
184
185 next.ServeHTTP(w, r)
186 })
187 }
188}
189
190func (mw Middleware) ResolveIdent() middlewareFunc {
191 excluded := []string{"favicon.ico"}
192
193 return func(next http.Handler) http.Handler {
194 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
195 didOrHandle := chi.URLParam(req, "user")
196 if slices.Contains(excluded, didOrHandle) {
197 next.ServeHTTP(w, req)
198 return
199 }
200
201 didOrHandle = strings.TrimPrefix(didOrHandle, "@")
202
203 id, err := mw.idResolver.ResolveIdent(req.Context(), didOrHandle)
204 if err != nil {
205 // invalid did or handle
206 log.Printf("failed to resolve did/handle '%s': %s\n", didOrHandle, err)
207 mw.pages.Error404(w)
208 return
209 }
210
211 ctx := context.WithValue(req.Context(), "resolvedId", *id)
212
213 next.ServeHTTP(w, req.WithContext(ctx))
214 })
215 }
216}
217
218func (mw Middleware) ResolveRepo() middlewareFunc {
219 return func(next http.Handler) http.Handler {
220 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
221 repoName := chi.URLParam(req, "repo")
222 id, ok := req.Context().Value("resolvedId").(identity.Identity)
223 if !ok {
224 log.Println("malformed middleware")
225 w.WriteHeader(http.StatusInternalServerError)
226 return
227 }
228
229 repo, err := db.GetRepo(
230 mw.db,
231 db.FilterEq("did", id.DID.String()),
232 db.FilterEq("name", repoName),
233 )
234 if err != nil {
235 log.Println("failed to resolve repo", "err", err)
236 mw.pages.ErrorKnot404(w)
237 return
238 }
239
240 ctx := context.WithValue(req.Context(), "repo", repo)
241 next.ServeHTTP(w, req.WithContext(ctx))
242 })
243 }
244}
245
246// middleware that is tacked on top of /{user}/{repo}/pulls/{pull}
247func (mw Middleware) ResolvePull() middlewareFunc {
248 return func(next http.Handler) http.Handler {
249 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
250 f, err := mw.repoResolver.Resolve(r)
251 if err != nil {
252 log.Println("failed to fully resolve repo", err)
253 mw.pages.ErrorKnot404(w)
254 return
255 }
256
257 prId := chi.URLParam(r, "pull")
258 prIdInt, err := strconv.Atoi(prId)
259 if err != nil {
260 http.Error(w, "bad pr id", http.StatusBadRequest)
261 log.Println("failed to parse pr id", err)
262 return
263 }
264
265 pr, err := db.GetPull(mw.db, f.RepoAt(), prIdInt)
266 if err != nil {
267 log.Println("failed to get pull and comments", err)
268 return
269 }
270
271 ctx := context.WithValue(r.Context(), "pull", pr)
272
273 if pr.IsStacked() {
274 stack, err := db.GetStack(mw.db, pr.StackId)
275 if err != nil {
276 log.Println("failed to get stack", err)
277 return
278 }
279 abandonedPulls, err := db.GetAbandonedPulls(mw.db, pr.StackId)
280 if err != nil {
281 log.Println("failed to get abandoned pulls", err)
282 return
283 }
284
285 ctx = context.WithValue(ctx, "stack", stack)
286 ctx = context.WithValue(ctx, "abandonedPulls", abandonedPulls)
287 }
288
289 next.ServeHTTP(w, r.WithContext(ctx))
290 })
291 }
292}
293
294// middleware that is tacked on top of /{user}/{repo}/issues/{issue}
295func (mw Middleware) ResolveIssue(next http.Handler) http.Handler {
296 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
297 f, err := mw.repoResolver.Resolve(r)
298 if err != nil {
299 log.Println("failed to fully resolve repo", err)
300 mw.pages.ErrorKnot404(w)
301 return
302 }
303
304 issueIdStr := chi.URLParam(r, "issue")
305 issueId, err := strconv.Atoi(issueIdStr)
306 if err != nil {
307 log.Println("failed to fully resolve issue ID", err)
308 mw.pages.ErrorKnot404(w)
309 return
310 }
311
312 issues, err := db.GetIssues(
313 mw.db,
314 db.FilterEq("repo_at", f.RepoAt()),
315 db.FilterEq("issue_id", issueId),
316 )
317 if err != nil {
318 log.Println("failed to get issues", "err", err)
319 return
320 }
321 if len(issues) != 1 {
322 log.Println("got incorrect number of issues", "len(issuse)", len(issues))
323 return
324 }
325 issue := issues[0]
326
327 ctx := context.WithValue(r.Context(), "issue", &issue)
328 next.ServeHTTP(w, r.WithContext(ctx))
329 })
330}
331
332// this should serve the go-import meta tag even if the path is technically
333// a 404 like tangled.sh/oppi.li/go-git/v5
334//
335// we're keeping the tangled.sh go-import tag too to maintain backward
336// compatiblity for modules that still point there. they will be redirected
337// to fetch source from tangled.org
338func (mw Middleware) GoImport() middlewareFunc {
339 return func(next http.Handler) http.Handler {
340 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
341 f, err := mw.repoResolver.Resolve(r)
342 if err != nil {
343 log.Println("failed to fully resolve repo", err)
344 mw.pages.ErrorKnot404(w)
345 return
346 }
347
348 fullName := f.OwnerHandle() + "/" + f.Name
349
350 if r.Header.Get("User-Agent") == "Go-http-client/1.1" {
351 if r.URL.Query().Get("go-get") == "1" {
352 html := fmt.Sprintf(
353 `<meta name="go-import" content="tangled.sh/%s git https://tangled.sh/%s"/>
354<meta name="go-import" content="tangled.org/%s git https://tangled.org/%s"/>`,
355 fullName, fullName,
356 fullName, fullName,
357 )
358 w.Header().Set("Content-Type", "text/html")
359 w.Write([]byte(html))
360 return
361 }
362 }
363
364 next.ServeHTTP(w, r)
365 })
366 }
367}