1package middleware
2
3import (
4 "context"
5 "fmt"
6 "log"
7 "net/http"
8 "net/url"
9 "slices"
10 "strconv"
11 "strings"
12
13 "github.com/bluesky-social/indigo/atproto/identity"
14 "github.com/go-chi/chi/v5"
15 "tangled.org/core/appview/db"
16 "tangled.org/core/appview/oauth"
17 "tangled.org/core/appview/pages"
18 "tangled.org/core/appview/pagination"
19 "tangled.org/core/appview/reporesolver"
20 "tangled.org/core/idresolver"
21 "tangled.org/core/rbac"
22)
23
24type Middleware struct {
25 oauth *oauth.OAuth
26 db *db.DB
27 enforcer *rbac.Enforcer
28 repoResolver *reporesolver.RepoResolver
29 idResolver *idresolver.Resolver
30 pages *pages.Pages
31}
32
33func New(oauth *oauth.OAuth, db *db.DB, enforcer *rbac.Enforcer, repoResolver *reporesolver.RepoResolver, idResolver *idresolver.Resolver, pages *pages.Pages) Middleware {
34 return Middleware{
35 oauth: oauth,
36 db: db,
37 enforcer: enforcer,
38 repoResolver: repoResolver,
39 idResolver: idResolver,
40 pages: pages,
41 }
42}
43
44type middlewareFunc func(http.Handler) http.Handler
45
46func (mw *Middleware) TryRefreshSession() middlewareFunc {
47 return func(next http.Handler) http.Handler {
48 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
49 _, _, _ = mw.oauth.GetSession(r)
50 next.ServeHTTP(w, r)
51 })
52 }
53}
54
55func AuthMiddleware(a *oauth.OAuth) middlewareFunc {
56 return func(next http.Handler) http.Handler {
57 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
58 returnURL := "/"
59 if u, err := url.Parse(r.Header.Get("Referer")); err == nil {
60 returnURL = u.RequestURI()
61 }
62
63 loginURL := fmt.Sprintf("/login?return_url=%s", url.QueryEscape(returnURL))
64
65 redirectFunc := func(w http.ResponseWriter, r *http.Request) {
66 http.Redirect(w, r, loginURL, http.StatusTemporaryRedirect)
67 }
68 if r.Header.Get("HX-Request") == "true" {
69 redirectFunc = func(w http.ResponseWriter, _ *http.Request) {
70 w.Header().Set("HX-Redirect", loginURL)
71 w.WriteHeader(http.StatusOK)
72 }
73 }
74
75 _, auth, err := a.GetSession(r)
76 if err != nil {
77 log.Println("not logged in, redirecting", "err", err)
78 redirectFunc(w, r)
79 return
80 }
81
82 if !auth {
83 log.Printf("not logged in, redirecting")
84 redirectFunc(w, r)
85 return
86 }
87
88 next.ServeHTTP(w, r)
89 })
90 }
91}
92
93func Paginate(next http.Handler) http.Handler {
94 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
95 page := pagination.FirstPage()
96
97 offsetVal := r.URL.Query().Get("offset")
98 if offsetVal != "" {
99 offset, err := strconv.Atoi(offsetVal)
100 if err != nil {
101 log.Println("invalid offset")
102 } else {
103 page.Offset = offset
104 }
105 }
106
107 limitVal := r.URL.Query().Get("limit")
108 if limitVal != "" {
109 limit, err := strconv.Atoi(limitVal)
110 if err != nil {
111 log.Println("invalid limit")
112 } else {
113 page.Limit = limit
114 }
115 }
116
117 ctx := context.WithValue(r.Context(), "page", page)
118 next.ServeHTTP(w, r.WithContext(ctx))
119 })
120}
121
122func (mw Middleware) knotRoleMiddleware(group string) middlewareFunc {
123 return func(next http.Handler) http.Handler {
124 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
125 // requires auth also
126 actor := mw.oauth.GetUser(r)
127 if actor == nil {
128 // we need a logged in user
129 log.Printf("not logged in, redirecting")
130 http.Error(w, "Forbiden", http.StatusUnauthorized)
131 return
132 }
133 domain := chi.URLParam(r, "domain")
134 if domain == "" {
135 http.Error(w, "malformed url", http.StatusBadRequest)
136 return
137 }
138
139 ok, err := mw.enforcer.E.HasGroupingPolicy(actor.Did, group, domain)
140 if err != nil || !ok {
141 // we need a logged in user
142 log.Printf("%s does not have perms of a %s in domain %s", actor.Did, group, domain)
143 http.Error(w, "Forbiden", http.StatusUnauthorized)
144 return
145 }
146
147 next.ServeHTTP(w, r)
148 })
149 }
150}
151
152func (mw Middleware) KnotOwner() middlewareFunc {
153 return mw.knotRoleMiddleware("server:owner")
154}
155
156func (mw Middleware) RepoPermissionMiddleware(requiredPerm string) middlewareFunc {
157 return func(next http.Handler) http.Handler {
158 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
159 // requires auth also
160 actor := mw.oauth.GetUser(r)
161 if actor == nil {
162 // we need a logged in user
163 log.Printf("not logged in, redirecting")
164 http.Error(w, "Forbiden", http.StatusUnauthorized)
165 return
166 }
167 f, err := mw.repoResolver.Resolve(r)
168 if err != nil {
169 http.Error(w, "malformed url", http.StatusBadRequest)
170 return
171 }
172
173 ok, err := mw.enforcer.E.Enforce(actor.Did, f.Knot, f.DidSlashRepo(), requiredPerm)
174 if err != nil || !ok {
175 // we need a logged in user
176 log.Printf("%s does not have perms of a %s in repo %s", actor.Did, requiredPerm, f.OwnerSlashRepo())
177 http.Error(w, "Forbiden", http.StatusUnauthorized)
178 return
179 }
180
181 next.ServeHTTP(w, r)
182 })
183 }
184}
185
186func (mw Middleware) ResolveIdent() middlewareFunc {
187 excluded := []string{"favicon.ico"}
188
189 return func(next http.Handler) http.Handler {
190 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
191 didOrHandle := chi.URLParam(req, "user")
192 if slices.Contains(excluded, didOrHandle) {
193 next.ServeHTTP(w, req)
194 return
195 }
196
197 didOrHandle = strings.TrimPrefix(didOrHandle, "@")
198
199 id, err := mw.idResolver.ResolveIdent(req.Context(), didOrHandle)
200 if err != nil {
201 // invalid did or handle
202 log.Printf("failed to resolve did/handle '%s': %s\n", didOrHandle, err)
203 mw.pages.Error404(w)
204 return
205 }
206
207 ctx := context.WithValue(req.Context(), "resolvedId", *id)
208
209 next.ServeHTTP(w, req.WithContext(ctx))
210 })
211 }
212}
213
214func (mw Middleware) ResolveRepo() middlewareFunc {
215 return func(next http.Handler) http.Handler {
216 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
217 repoName := chi.URLParam(req, "repo")
218 id, ok := req.Context().Value("resolvedId").(identity.Identity)
219 if !ok {
220 log.Println("malformed middleware")
221 w.WriteHeader(http.StatusInternalServerError)
222 return
223 }
224
225 repo, err := db.GetRepo(
226 mw.db,
227 db.FilterEq("did", id.DID.String()),
228 db.FilterEq("name", repoName),
229 )
230 if err != nil {
231 log.Println("failed to resolve repo", "err", err)
232 mw.pages.ErrorKnot404(w)
233 return
234 }
235
236 ctx := context.WithValue(req.Context(), "repo", repo)
237 next.ServeHTTP(w, req.WithContext(ctx))
238 })
239 }
240}
241
242// middleware that is tacked on top of /{user}/{repo}/pulls/{pull}
243func (mw Middleware) ResolvePull() middlewareFunc {
244 return func(next http.Handler) http.Handler {
245 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
246 f, err := mw.repoResolver.Resolve(r)
247 if err != nil {
248 log.Println("failed to fully resolve repo", err)
249 mw.pages.ErrorKnot404(w)
250 return
251 }
252
253 prId := chi.URLParam(r, "pull")
254 prIdInt, err := strconv.Atoi(prId)
255 if err != nil {
256 http.Error(w, "bad pr id", http.StatusBadRequest)
257 log.Println("failed to parse pr id", err)
258 return
259 }
260
261 pr, err := db.GetPull(mw.db, f.RepoAt(), prIdInt)
262 if err != nil {
263 log.Println("failed to get pull and comments", err)
264 return
265 }
266
267 ctx := context.WithValue(r.Context(), "pull", pr)
268
269 if pr.IsStacked() {
270 stack, err := db.GetStack(mw.db, pr.StackId)
271 if err != nil {
272 log.Println("failed to get stack", err)
273 return
274 }
275 abandonedPulls, err := db.GetAbandonedPulls(mw.db, pr.StackId)
276 if err != nil {
277 log.Println("failed to get abandoned pulls", err)
278 return
279 }
280
281 ctx = context.WithValue(ctx, "stack", stack)
282 ctx = context.WithValue(ctx, "abandonedPulls", abandonedPulls)
283 }
284
285 next.ServeHTTP(w, r.WithContext(ctx))
286 })
287 }
288}
289
290// middleware that is tacked on top of /{user}/{repo}/issues/{issue}
291func (mw Middleware) ResolveIssue(next http.Handler) http.Handler {
292 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
293 f, err := mw.repoResolver.Resolve(r)
294 if err != nil {
295 log.Println("failed to fully resolve repo", err)
296 mw.pages.ErrorKnot404(w)
297 return
298 }
299
300 issueIdStr := chi.URLParam(r, "issue")
301 issueId, err := strconv.Atoi(issueIdStr)
302 if err != nil {
303 log.Println("failed to fully resolve issue ID", err)
304 mw.pages.ErrorKnot404(w)
305 return
306 }
307
308 issues, err := db.GetIssues(
309 mw.db,
310 db.FilterEq("repo_at", f.RepoAt()),
311 db.FilterEq("issue_id", issueId),
312 )
313 if err != nil {
314 log.Println("failed to get issues", "err", err)
315 return
316 }
317 if len(issues) != 1 {
318 log.Println("got incorrect number of issues", "len(issuse)", len(issues))
319 return
320 }
321 issue := issues[0]
322
323 ctx := context.WithValue(r.Context(), "issue", &issue)
324 next.ServeHTTP(w, r.WithContext(ctx))
325 })
326}
327
328// this should serve the go-import meta tag even if the path is technically
329// a 404 like tangled.sh/oppi.li/go-git/v5
330//
331// we're keeping the tangled.sh go-import tag too to maintain backward
332// compatiblity for modules that still point there. they will be redirected
333// to fetch source from tangled.org
334func (mw Middleware) GoImport() middlewareFunc {
335 return func(next http.Handler) http.Handler {
336 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
337 f, err := mw.repoResolver.Resolve(r)
338 if err != nil {
339 log.Println("failed to fully resolve repo", err)
340 mw.pages.ErrorKnot404(w)
341 return
342 }
343
344 fullName := f.OwnerHandle() + "/" + f.Name
345
346 if r.Header.Get("User-Agent") == "Go-http-client/1.1" {
347 if r.URL.Query().Get("go-get") == "1" {
348 html := fmt.Sprintf(
349 `<meta name="go-import" content="tangled.sh/%s git https://tangled.sh/%s"/>
350<meta name="go-import" content="tangled.org/%s git https://tangled.org/%s"/>`,
351 fullName, fullName,
352 fullName, fullName,
353 )
354 w.Header().Set("Content-Type", "text/html")
355 w.Write([]byte(html))
356 return
357 }
358 }
359
360 next.ServeHTTP(w, r)
361 })
362 }
363}