forked from tangled.org/core
this repo has no description
1package state 2 3import ( 4 "context" 5 "log" 6 "net/http" 7 "strconv" 8 "strings" 9 "time" 10 11 "slices" 12 13 comatproto "github.com/bluesky-social/indigo/api/atproto" 14 "github.com/bluesky-social/indigo/atproto/identity" 15 "github.com/bluesky-social/indigo/xrpc" 16 "github.com/go-chi/chi/v5" 17 "tangled.sh/tangled.sh/core/appview" 18 "tangled.sh/tangled.sh/core/appview/auth" 19 "tangled.sh/tangled.sh/core/appview/db" 20) 21 22type Middleware func(http.Handler) http.Handler 23 24func AuthMiddleware(s *State) Middleware { 25 return func(next http.Handler) http.Handler { 26 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 27 redirectFunc := func(w http.ResponseWriter, r *http.Request) { 28 http.Redirect(w, r, "/login", http.StatusTemporaryRedirect) 29 } 30 if r.Header.Get("HX-Request") == "true" { 31 redirectFunc = func(w http.ResponseWriter, _ *http.Request) { 32 w.Header().Set("HX-Redirect", "/login") 33 w.WriteHeader(http.StatusOK) 34 } 35 } 36 37 session, err := s.auth.GetSession(r) 38 if session.IsNew || err != nil { 39 log.Printf("not logged in, redirecting") 40 redirectFunc(w, r) 41 return 42 } 43 44 authorized, ok := session.Values[appview.SessionAuthenticated].(bool) 45 if !ok || !authorized { 46 log.Printf("not logged in, redirecting") 47 redirectFunc(w, r) 48 return 49 } 50 51 // refresh if nearing expiry 52 // TODO: dedup with /login 53 expiryStr := session.Values[appview.SessionExpiry].(string) 54 expiry, err := time.Parse(time.RFC3339, expiryStr) 55 if err != nil { 56 log.Println("invalid expiry time", err) 57 redirectFunc(w, r) 58 return 59 } 60 pdsUrl, ok1 := session.Values[appview.SessionPds].(string) 61 did, ok2 := session.Values[appview.SessionDid].(string) 62 refreshJwt, ok3 := session.Values[appview.SessionRefreshJwt].(string) 63 64 if !ok1 || !ok2 || !ok3 { 65 log.Println("invalid expiry time", err) 66 redirectFunc(w, r) 67 return 68 } 69 70 if time.Now().After(expiry) { 71 log.Println("token expired, refreshing ...") 72 73 client := xrpc.Client{ 74 Host: pdsUrl, 75 Auth: &xrpc.AuthInfo{ 76 Did: did, 77 AccessJwt: refreshJwt, 78 RefreshJwt: refreshJwt, 79 }, 80 } 81 atSession, err := comatproto.ServerRefreshSession(r.Context(), &client) 82 if err != nil { 83 log.Println("failed to refresh session", err) 84 redirectFunc(w, r) 85 return 86 } 87 88 sessionish := auth.RefreshSessionWrapper{atSession} 89 90 err = s.auth.StoreSession(r, w, &sessionish, pdsUrl) 91 if err != nil { 92 log.Printf("failed to store session for did: %s\n: %s", atSession.Did, err) 93 return 94 } 95 96 log.Println("successfully refreshed token") 97 } 98 99 next.ServeHTTP(w, r) 100 }) 101 } 102} 103 104func knotRoleMiddleware(s *State, group string) Middleware { 105 return func(next http.Handler) http.Handler { 106 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 107 // requires auth also 108 actor := s.auth.GetUser(r) 109 if actor == nil { 110 // we need a logged in user 111 log.Printf("not logged in, redirecting") 112 http.Error(w, "Forbiden", http.StatusUnauthorized) 113 return 114 } 115 domain := chi.URLParam(r, "domain") 116 if domain == "" { 117 http.Error(w, "malformed url", http.StatusBadRequest) 118 return 119 } 120 121 ok, err := s.enforcer.E.HasGroupingPolicy(actor.Did, group, domain) 122 if err != nil || !ok { 123 // we need a logged in user 124 log.Printf("%s does not have perms of a %s in domain %s", actor.Did, group, domain) 125 http.Error(w, "Forbiden", http.StatusUnauthorized) 126 return 127 } 128 129 next.ServeHTTP(w, r) 130 }) 131 } 132} 133 134func KnotOwner(s *State) Middleware { 135 return knotRoleMiddleware(s, "server:owner") 136} 137 138func RepoPermissionMiddleware(s *State, requiredPerm string) Middleware { 139 return func(next http.Handler) http.Handler { 140 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 141 // requires auth also 142 actor := s.auth.GetUser(r) 143 if actor == nil { 144 // we need a logged in user 145 log.Printf("not logged in, redirecting") 146 http.Error(w, "Forbiden", http.StatusUnauthorized) 147 return 148 } 149 f, err := fullyResolvedRepo(r) 150 if err != nil { 151 http.Error(w, "malformed url", http.StatusBadRequest) 152 return 153 } 154 155 ok, err := s.enforcer.E.Enforce(actor.Did, f.Knot, f.DidSlashRepo(), requiredPerm) 156 if err != nil || !ok { 157 // we need a logged in user 158 log.Printf("%s does not have perms of a %s in repo %s", actor.Did, requiredPerm, f.OwnerSlashRepo()) 159 http.Error(w, "Forbiden", http.StatusUnauthorized) 160 return 161 } 162 163 next.ServeHTTP(w, r) 164 }) 165 } 166} 167 168func StripLeadingAt(next http.Handler) http.Handler { 169 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 170 path := req.URL.EscapedPath() 171 if strings.HasPrefix(path, "/@") { 172 req.URL.RawPath = "/" + strings.TrimPrefix(path, "/@") 173 } 174 next.ServeHTTP(w, req) 175 }) 176} 177 178func ResolveIdent(s *State) Middleware { 179 excluded := []string{"favicon.ico"} 180 181 return func(next http.Handler) http.Handler { 182 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 183 didOrHandle := chi.URLParam(req, "user") 184 if slices.Contains(excluded, didOrHandle) { 185 next.ServeHTTP(w, req) 186 return 187 } 188 189 id, err := s.resolver.ResolveIdent(req.Context(), didOrHandle) 190 if err != nil { 191 // invalid did or handle 192 log.Println("failed to resolve did/handle:", err) 193 w.WriteHeader(http.StatusNotFound) 194 return 195 } 196 197 ctx := context.WithValue(req.Context(), "resolvedId", *id) 198 199 next.ServeHTTP(w, req.WithContext(ctx)) 200 }) 201 } 202} 203 204func ResolveRepo(s *State) Middleware { 205 return func(next http.Handler) http.Handler { 206 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 207 repoName := chi.URLParam(req, "repo") 208 id, ok := req.Context().Value("resolvedId").(identity.Identity) 209 if !ok { 210 log.Println("malformed middleware") 211 w.WriteHeader(http.StatusInternalServerError) 212 return 213 } 214 215 repo, err := db.GetRepo(s.db, id.DID.String(), repoName) 216 if err != nil { 217 // invalid did or handle 218 log.Println("failed to resolve repo") 219 w.WriteHeader(http.StatusNotFound) 220 return 221 } 222 223 ctx := context.WithValue(req.Context(), "knot", repo.Knot) 224 ctx = context.WithValue(ctx, "repoAt", repo.AtUri) 225 ctx = context.WithValue(ctx, "repoDescription", repo.Description) 226 ctx = context.WithValue(ctx, "repoAddedAt", repo.Created.Format(time.RFC3339)) 227 next.ServeHTTP(w, req.WithContext(ctx)) 228 }) 229 } 230} 231 232// middleware that is tacked on top of /{user}/{repo}/pulls/{pull} 233func ResolvePull(s *State) Middleware { 234 return func(next http.Handler) http.Handler { 235 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 236 f, err := fullyResolvedRepo(r) 237 if err != nil { 238 log.Println("failed to fully resolve repo", err) 239 http.Error(w, "invalid repo url", http.StatusNotFound) 240 return 241 } 242 243 prId := chi.URLParam(r, "pull") 244 prIdInt, err := strconv.Atoi(prId) 245 if err != nil { 246 http.Error(w, "bad pr id", http.StatusBadRequest) 247 log.Println("failed to parse pr id", err) 248 return 249 } 250 251 pr, err := db.GetPull(s.db, f.RepoAt, prIdInt) 252 if err != nil { 253 log.Println("failed to get pull and comments", err) 254 return 255 } 256 257 ctx := context.WithValue(r.Context(), "pull", pr) 258 259 next.ServeHTTP(w, r.WithContext(ctx)) 260 }) 261 } 262}