1package middleware
2
3import (
4 "context"
5 "fmt"
6 "log"
7 "net/http"
8 "net/url"
9 "slices"
10 "strconv"
11 "strings"
12
13 "github.com/bluesky-social/indigo/atproto/identity"
14 "github.com/go-chi/chi/v5"
15 "tangled.org/core/appview/db"
16 "tangled.org/core/appview/oauth"
17 "tangled.org/core/appview/pages"
18 "tangled.org/core/appview/pagination"
19 "tangled.org/core/appview/reporesolver"
20 "tangled.org/core/idresolver"
21 "tangled.org/core/rbac"
22)
23
24type Middleware struct {
25 oauth *oauth.OAuth
26 db *db.DB
27 enforcer *rbac.Enforcer
28 repoResolver *reporesolver.RepoResolver
29 idResolver *idresolver.Resolver
30 pages *pages.Pages
31}
32
33func New(oauth *oauth.OAuth, db *db.DB, enforcer *rbac.Enforcer, repoResolver *reporesolver.RepoResolver, idResolver *idresolver.Resolver, pages *pages.Pages) Middleware {
34 return Middleware{
35 oauth: oauth,
36 db: db,
37 enforcer: enforcer,
38 repoResolver: repoResolver,
39 idResolver: idResolver,
40 pages: pages,
41 }
42}
43
44type middlewareFunc func(http.Handler) http.Handler
45
46func AuthMiddleware(a *oauth.OAuth) middlewareFunc {
47 return func(next http.Handler) http.Handler {
48 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
49 returnURL := "/"
50 if u, err := url.Parse(r.Header.Get("Referer")); err == nil {
51 returnURL = u.RequestURI()
52 }
53
54 loginURL := fmt.Sprintf("/login?return_url=%s", url.QueryEscape(returnURL))
55
56 redirectFunc := func(w http.ResponseWriter, r *http.Request) {
57 http.Redirect(w, r, loginURL, http.StatusTemporaryRedirect)
58 }
59 if r.Header.Get("HX-Request") == "true" {
60 redirectFunc = func(w http.ResponseWriter, _ *http.Request) {
61 w.Header().Set("HX-Redirect", loginURL)
62 w.WriteHeader(http.StatusOK)
63 }
64 }
65
66 _, auth, err := a.GetSession(r)
67 if err != nil {
68 log.Println("not logged in, redirecting", "err", err)
69 redirectFunc(w, r)
70 return
71 }
72
73 if !auth {
74 log.Printf("not logged in, redirecting")
75 redirectFunc(w, r)
76 return
77 }
78
79 next.ServeHTTP(w, r)
80 })
81 }
82}
83
84func Paginate(next http.Handler) http.Handler {
85 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
86 page := pagination.FirstPage()
87
88 offsetVal := r.URL.Query().Get("offset")
89 if offsetVal != "" {
90 offset, err := strconv.Atoi(offsetVal)
91 if err != nil {
92 log.Println("invalid offset")
93 } else {
94 page.Offset = offset
95 }
96 }
97
98 limitVal := r.URL.Query().Get("limit")
99 if limitVal != "" {
100 limit, err := strconv.Atoi(limitVal)
101 if err != nil {
102 log.Println("invalid limit")
103 } else {
104 page.Limit = limit
105 }
106 }
107
108 ctx := context.WithValue(r.Context(), "page", page)
109 next.ServeHTTP(w, r.WithContext(ctx))
110 })
111}
112
113func (mw Middleware) knotRoleMiddleware(group string) middlewareFunc {
114 return func(next http.Handler) http.Handler {
115 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
116 // requires auth also
117 actor := mw.oauth.GetUser(r)
118 if actor == nil {
119 // we need a logged in user
120 log.Printf("not logged in, redirecting")
121 http.Error(w, "Forbiden", http.StatusUnauthorized)
122 return
123 }
124 domain := chi.URLParam(r, "domain")
125 if domain == "" {
126 http.Error(w, "malformed url", http.StatusBadRequest)
127 return
128 }
129
130 ok, err := mw.enforcer.E.HasGroupingPolicy(actor.Did, group, domain)
131 if err != nil || !ok {
132 // we need a logged in user
133 log.Printf("%s does not have perms of a %s in domain %s", actor.Did, group, domain)
134 http.Error(w, "Forbiden", http.StatusUnauthorized)
135 return
136 }
137
138 next.ServeHTTP(w, r)
139 })
140 }
141}
142
143func (mw Middleware) KnotOwner() middlewareFunc {
144 return mw.knotRoleMiddleware("server:owner")
145}
146
147func (mw Middleware) RepoPermissionMiddleware(requiredPerm string) middlewareFunc {
148 return func(next http.Handler) http.Handler {
149 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
150 // requires auth also
151 actor := mw.oauth.GetUser(r)
152 if actor == nil {
153 // we need a logged in user
154 log.Printf("not logged in, redirecting")
155 http.Error(w, "Forbiden", http.StatusUnauthorized)
156 return
157 }
158 f, err := mw.repoResolver.Resolve(r)
159 if err != nil {
160 http.Error(w, "malformed url", http.StatusBadRequest)
161 return
162 }
163
164 ok, err := mw.enforcer.E.Enforce(actor.Did, f.Knot, f.DidSlashRepo(), requiredPerm)
165 if err != nil || !ok {
166 // we need a logged in user
167 log.Printf("%s does not have perms of a %s in repo %s", actor.Did, requiredPerm, f.OwnerSlashRepo())
168 http.Error(w, "Forbiden", http.StatusUnauthorized)
169 return
170 }
171
172 next.ServeHTTP(w, r)
173 })
174 }
175}
176
177func (mw Middleware) ResolveIdent() middlewareFunc {
178 excluded := []string{"favicon.ico"}
179
180 return func(next http.Handler) http.Handler {
181 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
182 didOrHandle := chi.URLParam(req, "user")
183 if slices.Contains(excluded, didOrHandle) {
184 next.ServeHTTP(w, req)
185 return
186 }
187
188 didOrHandle = strings.TrimPrefix(didOrHandle, "@")
189
190 id, err := mw.idResolver.ResolveIdent(req.Context(), didOrHandle)
191 if err != nil {
192 // invalid did or handle
193 log.Printf("failed to resolve did/handle '%s': %s\n", didOrHandle, err)
194 mw.pages.Error404(w)
195 return
196 }
197
198 ctx := context.WithValue(req.Context(), "resolvedId", *id)
199
200 next.ServeHTTP(w, req.WithContext(ctx))
201 })
202 }
203}
204
205func (mw Middleware) ResolveRepo() middlewareFunc {
206 return func(next http.Handler) http.Handler {
207 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
208 repoName := chi.URLParam(req, "repo")
209 id, ok := req.Context().Value("resolvedId").(identity.Identity)
210 if !ok {
211 log.Println("malformed middleware")
212 w.WriteHeader(http.StatusInternalServerError)
213 return
214 }
215
216 repo, err := db.GetRepo(
217 mw.db,
218 db.FilterEq("did", id.DID.String()),
219 db.FilterEq("name", repoName),
220 )
221 if err != nil {
222 log.Println("failed to resolve repo", "err", err)
223 mw.pages.ErrorKnot404(w)
224 return
225 }
226
227 ctx := context.WithValue(req.Context(), "repo", repo)
228 next.ServeHTTP(w, req.WithContext(ctx))
229 })
230 }
231}
232
233// middleware that is tacked on top of /{user}/{repo}/pulls/{pull}
234func (mw Middleware) ResolvePull() middlewareFunc {
235 return func(next http.Handler) http.Handler {
236 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
237 f, err := mw.repoResolver.Resolve(r)
238 if err != nil {
239 log.Println("failed to fully resolve repo", err)
240 mw.pages.ErrorKnot404(w)
241 return
242 }
243
244 prId := chi.URLParam(r, "pull")
245 prIdInt, err := strconv.Atoi(prId)
246 if err != nil {
247 http.Error(w, "bad pr id", http.StatusBadRequest)
248 log.Println("failed to parse pr id", err)
249 return
250 }
251
252 pr, err := db.GetPull(mw.db, f.RepoAt(), prIdInt)
253 if err != nil {
254 log.Println("failed to get pull and comments", err)
255 return
256 }
257
258 ctx := context.WithValue(r.Context(), "pull", pr)
259
260 if pr.IsStacked() {
261 stack, err := db.GetStack(mw.db, pr.StackId)
262 if err != nil {
263 log.Println("failed to get stack", err)
264 return
265 }
266 abandonedPulls, err := db.GetAbandonedPulls(mw.db, pr.StackId)
267 if err != nil {
268 log.Println("failed to get abandoned pulls", err)
269 return
270 }
271
272 ctx = context.WithValue(ctx, "stack", stack)
273 ctx = context.WithValue(ctx, "abandonedPulls", abandonedPulls)
274 }
275
276 next.ServeHTTP(w, r.WithContext(ctx))
277 })
278 }
279}
280
281// middleware that is tacked on top of /{user}/{repo}/issues/{issue}
282func (mw Middleware) ResolveIssue(next http.Handler) http.Handler {
283 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
284 f, err := mw.repoResolver.Resolve(r)
285 if err != nil {
286 log.Println("failed to fully resolve repo", err)
287 mw.pages.ErrorKnot404(w)
288 return
289 }
290
291 issueIdStr := chi.URLParam(r, "issue")
292 issueId, err := strconv.Atoi(issueIdStr)
293 if err != nil {
294 log.Println("failed to fully resolve issue ID", err)
295 mw.pages.ErrorKnot404(w)
296 return
297 }
298
299 issues, err := db.GetIssues(
300 mw.db,
301 db.FilterEq("repo_at", f.RepoAt()),
302 db.FilterEq("issue_id", issueId),
303 )
304 if err != nil {
305 log.Println("failed to get issues", "err", err)
306 return
307 }
308 if len(issues) != 1 {
309 log.Println("got incorrect number of issues", "len(issuse)", len(issues))
310 return
311 }
312 issue := issues[0]
313
314 ctx := context.WithValue(r.Context(), "issue", &issue)
315 next.ServeHTTP(w, r.WithContext(ctx))
316 })
317}
318
319// this should serve the go-import meta tag even if the path is technically
320// a 404 like tangled.sh/oppi.li/go-git/v5
321//
322// we're keeping the tangled.sh go-import tag too to maintain backward
323// compatiblity for modules that still point there. they will be redirected
324// to fetch source from tangled.org
325func (mw Middleware) GoImport() middlewareFunc {
326 return func(next http.Handler) http.Handler {
327 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
328 f, err := mw.repoResolver.Resolve(r)
329 if err != nil {
330 log.Println("failed to fully resolve repo", err)
331 mw.pages.ErrorKnot404(w)
332 return
333 }
334
335 fullName := f.OwnerHandle() + "/" + f.Name
336
337 if r.Header.Get("User-Agent") == "Go-http-client/1.1" {
338 if r.URL.Query().Get("go-get") == "1" {
339 html := fmt.Sprintf(
340 `<meta name="go-import" content="tangled.sh/%s git https://tangled.sh/%s"/>
341<meta name="go-import" content="tangled.org/%s git https://tangled.org/%s"/>`,
342 fullName, fullName,
343 fullName, fullName,
344 )
345 w.Header().Set("Content-Type", "text/html")
346 w.Write([]byte(html))
347 return
348 }
349 }
350
351 next.ServeHTTP(w, r)
352 })
353 }
354}