forked from tangled.org/core
this repo has no description
at master 9.8 kB view raw
1package middleware 2 3import ( 4 "context" 5 "fmt" 6 "log" 7 "net/http" 8 "net/url" 9 "slices" 10 "strconv" 11 "strings" 12 13 "github.com/bluesky-social/indigo/atproto/identity" 14 "github.com/go-chi/chi/v5" 15 "tangled.org/core/appview/db" 16 "tangled.org/core/appview/oauth" 17 "tangled.org/core/appview/pages" 18 "tangled.org/core/appview/pagination" 19 "tangled.org/core/appview/reporesolver" 20 "tangled.org/core/idresolver" 21 "tangled.org/core/rbac" 22) 23 24type Middleware struct { 25 oauth *oauth.OAuth 26 db *db.DB 27 enforcer *rbac.Enforcer 28 repoResolver *reporesolver.RepoResolver 29 idResolver *idresolver.Resolver 30 pages *pages.Pages 31} 32 33func New(oauth *oauth.OAuth, db *db.DB, enforcer *rbac.Enforcer, repoResolver *reporesolver.RepoResolver, idResolver *idresolver.Resolver, pages *pages.Pages) Middleware { 34 return Middleware{ 35 oauth: oauth, 36 db: db, 37 enforcer: enforcer, 38 repoResolver: repoResolver, 39 idResolver: idResolver, 40 pages: pages, 41 } 42} 43 44type middlewareFunc func(http.Handler) http.Handler 45 46func (mw *Middleware) TryRefreshSession() middlewareFunc { 47 return func(next http.Handler) http.Handler { 48 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 49 _, _, _ = mw.oauth.GetSession(r) 50 next.ServeHTTP(w, r) 51 }) 52 } 53} 54 55func AuthMiddleware(a *oauth.OAuth) middlewareFunc { 56 return func(next http.Handler) http.Handler { 57 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 58 returnURL := "/" 59 if u, err := url.Parse(r.Header.Get("Referer")); err == nil { 60 returnURL = u.RequestURI() 61 } 62 63 loginURL := fmt.Sprintf("/login?return_url=%s", url.QueryEscape(returnURL)) 64 65 redirectFunc := func(w http.ResponseWriter, r *http.Request) { 66 http.Redirect(w, r, loginURL, http.StatusTemporaryRedirect) 67 } 68 if r.Header.Get("HX-Request") == "true" { 69 redirectFunc = func(w http.ResponseWriter, _ *http.Request) { 70 w.Header().Set("HX-Redirect", loginURL) 71 w.WriteHeader(http.StatusOK) 72 } 73 } 74 75 _, auth, err := a.GetSession(r) 76 if err != nil { 77 log.Println("not logged in, redirecting", "err", err) 78 redirectFunc(w, r) 79 return 80 } 81 82 if !auth { 83 log.Printf("not logged in, redirecting") 84 redirectFunc(w, r) 85 return 86 } 87 88 next.ServeHTTP(w, r) 89 }) 90 } 91} 92 93func Paginate(next http.Handler) http.Handler { 94 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 95 page := pagination.FirstPage() 96 97 offsetVal := r.URL.Query().Get("offset") 98 if offsetVal != "" { 99 offset, err := strconv.Atoi(offsetVal) 100 if err != nil { 101 log.Println("invalid offset") 102 } else { 103 page.Offset = offset 104 } 105 } 106 107 limitVal := r.URL.Query().Get("limit") 108 if limitVal != "" { 109 limit, err := strconv.Atoi(limitVal) 110 if err != nil { 111 log.Println("invalid limit") 112 } else { 113 page.Limit = limit 114 } 115 } 116 117 ctx := context.WithValue(r.Context(), "page", page) 118 next.ServeHTTP(w, r.WithContext(ctx)) 119 }) 120} 121 122func (mw Middleware) knotRoleMiddleware(group string) middlewareFunc { 123 return func(next http.Handler) http.Handler { 124 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 125 // requires auth also 126 actor := mw.oauth.GetUser(r) 127 if actor == nil { 128 // we need a logged in user 129 log.Printf("not logged in, redirecting") 130 http.Error(w, "Forbiden", http.StatusUnauthorized) 131 return 132 } 133 domain := chi.URLParam(r, "domain") 134 if domain == "" { 135 http.Error(w, "malformed url", http.StatusBadRequest) 136 return 137 } 138 139 ok, err := mw.enforcer.E.HasGroupingPolicy(actor.Did, group, domain) 140 if err != nil || !ok { 141 // we need a logged in user 142 log.Printf("%s does not have perms of a %s in domain %s", actor.Did, group, domain) 143 http.Error(w, "Forbiden", http.StatusUnauthorized) 144 return 145 } 146 147 next.ServeHTTP(w, r) 148 }) 149 } 150} 151 152func (mw Middleware) KnotOwner() middlewareFunc { 153 return mw.knotRoleMiddleware("server:owner") 154} 155 156func (mw Middleware) RepoPermissionMiddleware(requiredPerm string) middlewareFunc { 157 return func(next http.Handler) http.Handler { 158 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 159 // requires auth also 160 actor := mw.oauth.GetUser(r) 161 if actor == nil { 162 // we need a logged in user 163 log.Printf("not logged in, redirecting") 164 http.Error(w, "Forbiden", http.StatusUnauthorized) 165 return 166 } 167 f, err := mw.repoResolver.Resolve(r) 168 if err != nil { 169 http.Error(w, "malformed url", http.StatusBadRequest) 170 return 171 } 172 173 ok, err := mw.enforcer.E.Enforce(actor.Did, f.Knot, f.DidSlashRepo(), requiredPerm) 174 if err != nil || !ok { 175 // we need a logged in user 176 log.Printf("%s does not have perms of a %s in repo %s", actor.Did, requiredPerm, f.OwnerSlashRepo()) 177 http.Error(w, "Forbiden", http.StatusUnauthorized) 178 return 179 } 180 181 next.ServeHTTP(w, r) 182 }) 183 } 184} 185 186func (mw Middleware) ResolveIdent() middlewareFunc { 187 excluded := []string{"favicon.ico"} 188 189 return func(next http.Handler) http.Handler { 190 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 191 didOrHandle := chi.URLParam(req, "user") 192 if slices.Contains(excluded, didOrHandle) { 193 next.ServeHTTP(w, req) 194 return 195 } 196 197 didOrHandle = strings.TrimPrefix(didOrHandle, "@") 198 199 id, err := mw.idResolver.ResolveIdent(req.Context(), didOrHandle) 200 if err != nil { 201 // invalid did or handle 202 log.Printf("failed to resolve did/handle '%s': %s\n", didOrHandle, err) 203 mw.pages.Error404(w) 204 return 205 } 206 207 ctx := context.WithValue(req.Context(), "resolvedId", *id) 208 209 next.ServeHTTP(w, req.WithContext(ctx)) 210 }) 211 } 212} 213 214func (mw Middleware) ResolveRepo() middlewareFunc { 215 return func(next http.Handler) http.Handler { 216 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 217 repoName := chi.URLParam(req, "repo") 218 id, ok := req.Context().Value("resolvedId").(identity.Identity) 219 if !ok { 220 log.Println("malformed middleware") 221 w.WriteHeader(http.StatusInternalServerError) 222 return 223 } 224 225 repo, err := db.GetRepo( 226 mw.db, 227 db.FilterEq("did", id.DID.String()), 228 db.FilterEq("name", repoName), 229 ) 230 if err != nil { 231 log.Println("failed to resolve repo", "err", err) 232 mw.pages.ErrorKnot404(w) 233 return 234 } 235 236 ctx := context.WithValue(req.Context(), "repo", repo) 237 next.ServeHTTP(w, req.WithContext(ctx)) 238 }) 239 } 240} 241 242// middleware that is tacked on top of /{user}/{repo}/pulls/{pull} 243func (mw Middleware) ResolvePull() middlewareFunc { 244 return func(next http.Handler) http.Handler { 245 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 246 f, err := mw.repoResolver.Resolve(r) 247 if err != nil { 248 log.Println("failed to fully resolve repo", err) 249 mw.pages.ErrorKnot404(w) 250 return 251 } 252 253 prId := chi.URLParam(r, "pull") 254 prIdInt, err := strconv.Atoi(prId) 255 if err != nil { 256 http.Error(w, "bad pr id", http.StatusBadRequest) 257 log.Println("failed to parse pr id", err) 258 return 259 } 260 261 pr, err := db.GetPull(mw.db, f.RepoAt(), prIdInt) 262 if err != nil { 263 log.Println("failed to get pull and comments", err) 264 return 265 } 266 267 ctx := context.WithValue(r.Context(), "pull", pr) 268 269 if pr.IsStacked() { 270 stack, err := db.GetStack(mw.db, pr.StackId) 271 if err != nil { 272 log.Println("failed to get stack", err) 273 return 274 } 275 abandonedPulls, err := db.GetAbandonedPulls(mw.db, pr.StackId) 276 if err != nil { 277 log.Println("failed to get abandoned pulls", err) 278 return 279 } 280 281 ctx = context.WithValue(ctx, "stack", stack) 282 ctx = context.WithValue(ctx, "abandonedPulls", abandonedPulls) 283 } 284 285 next.ServeHTTP(w, r.WithContext(ctx)) 286 }) 287 } 288} 289 290// middleware that is tacked on top of /{user}/{repo}/issues/{issue} 291func (mw Middleware) ResolveIssue(next http.Handler) http.Handler { 292 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 293 f, err := mw.repoResolver.Resolve(r) 294 if err != nil { 295 log.Println("failed to fully resolve repo", err) 296 mw.pages.ErrorKnot404(w) 297 return 298 } 299 300 issueIdStr := chi.URLParam(r, "issue") 301 issueId, err := strconv.Atoi(issueIdStr) 302 if err != nil { 303 log.Println("failed to fully resolve issue ID", err) 304 mw.pages.ErrorKnot404(w) 305 return 306 } 307 308 issues, err := db.GetIssues( 309 mw.db, 310 db.FilterEq("repo_at", f.RepoAt()), 311 db.FilterEq("issue_id", issueId), 312 ) 313 if err != nil { 314 log.Println("failed to get issues", "err", err) 315 return 316 } 317 if len(issues) != 1 { 318 log.Println("got incorrect number of issues", "len(issuse)", len(issues)) 319 return 320 } 321 issue := issues[0] 322 323 ctx := context.WithValue(r.Context(), "issue", &issue) 324 next.ServeHTTP(w, r.WithContext(ctx)) 325 }) 326} 327 328// this should serve the go-import meta tag even if the path is technically 329// a 404 like tangled.sh/oppi.li/go-git/v5 330// 331// we're keeping the tangled.sh go-import tag too to maintain backward 332// compatiblity for modules that still point there. they will be redirected 333// to fetch source from tangled.org 334func (mw Middleware) GoImport() middlewareFunc { 335 return func(next http.Handler) http.Handler { 336 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 337 f, err := mw.repoResolver.Resolve(r) 338 if err != nil { 339 log.Println("failed to fully resolve repo", err) 340 mw.pages.ErrorKnot404(w) 341 return 342 } 343 344 fullName := f.OwnerHandle() + "/" + f.Name 345 346 if r.Header.Get("User-Agent") == "Go-http-client/1.1" { 347 if r.URL.Query().Get("go-get") == "1" { 348 html := fmt.Sprintf( 349 `<meta name="go-import" content="tangled.sh/%s git https://tangled.sh/%s"/> 350<meta name="go-import" content="tangled.org/%s git https://tangled.org/%s"/>`, 351 fullName, fullName, 352 fullName, fullName, 353 ) 354 w.Header().Set("Content-Type", "text/html") 355 w.Write([]byte(html)) 356 return 357 } 358 } 359 360 next.ServeHTTP(w, r) 361 }) 362 } 363}