forked from tangled.org/core
this repo has no description
1package state 2 3import ( 4 "context" 5 "log" 6 "net/http" 7 "strconv" 8 "strings" 9 "time" 10 11 comatproto "github.com/bluesky-social/indigo/api/atproto" 12 "github.com/bluesky-social/indigo/atproto/identity" 13 "github.com/bluesky-social/indigo/xrpc" 14 "github.com/go-chi/chi/v5" 15 "tangled.sh/tangled.sh/core/appview" 16 "tangled.sh/tangled.sh/core/appview/auth" 17 "tangled.sh/tangled.sh/core/appview/db" 18) 19 20type Middleware func(http.Handler) http.Handler 21 22func AuthMiddleware(s *State) Middleware { 23 return func(next http.Handler) http.Handler { 24 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 25 redirectFunc := func(w http.ResponseWriter, r *http.Request) { 26 http.Redirect(w, r, "/login", http.StatusTemporaryRedirect) 27 } 28 if r.Header.Get("HX-Request") == "true" { 29 redirectFunc = func(w http.ResponseWriter, _ *http.Request) { 30 w.Header().Set("HX-Redirect", "/login") 31 w.WriteHeader(http.StatusOK) 32 } 33 } 34 35 session, err := s.auth.GetSession(r) 36 if session.IsNew || err != nil { 37 log.Printf("not logged in, redirecting") 38 redirectFunc(w, r) 39 return 40 } 41 42 authorized, ok := session.Values[appview.SessionAuthenticated].(bool) 43 if !ok || !authorized { 44 log.Printf("not logged in, redirecting") 45 redirectFunc(w, r) 46 return 47 } 48 49 // refresh if nearing expiry 50 // TODO: dedup with /login 51 expiryStr := session.Values[appview.SessionExpiry].(string) 52 expiry, err := time.Parse(time.RFC3339, expiryStr) 53 if err != nil { 54 log.Println("invalid expiry time", err) 55 redirectFunc(w, r) 56 return 57 } 58 pdsUrl, ok1 := session.Values[appview.SessionPds].(string) 59 did, ok2 := session.Values[appview.SessionDid].(string) 60 refreshJwt, ok3 := session.Values[appview.SessionRefreshJwt].(string) 61 62 if !ok1 || !ok2 || !ok3 { 63 log.Println("invalid expiry time", err) 64 redirectFunc(w, r) 65 return 66 } 67 68 if time.Now().After(expiry) { 69 log.Println("token expired, refreshing ...") 70 71 client := xrpc.Client{ 72 Host: pdsUrl, 73 Auth: &xrpc.AuthInfo{ 74 Did: did, 75 AccessJwt: refreshJwt, 76 RefreshJwt: refreshJwt, 77 }, 78 } 79 atSession, err := comatproto.ServerRefreshSession(r.Context(), &client) 80 if err != nil { 81 log.Println("failed to refresh session", err) 82 redirectFunc(w, r) 83 return 84 } 85 86 sessionish := auth.RefreshSessionWrapper{atSession} 87 88 err = s.auth.StoreSession(r, w, &sessionish, pdsUrl) 89 if err != nil { 90 log.Printf("failed to store session for did: %s\n: %s", atSession.Did, err) 91 return 92 } 93 94 log.Println("successfully refreshed token") 95 } 96 97 next.ServeHTTP(w, r) 98 }) 99 } 100} 101 102func knotRoleMiddleware(s *State, group string) Middleware { 103 return func(next http.Handler) http.Handler { 104 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 105 // requires auth also 106 actor := s.auth.GetUser(r) 107 if actor == nil { 108 // we need a logged in user 109 log.Printf("not logged in, redirecting") 110 http.Error(w, "Forbiden", http.StatusUnauthorized) 111 return 112 } 113 domain := chi.URLParam(r, "domain") 114 if domain == "" { 115 http.Error(w, "malformed url", http.StatusBadRequest) 116 return 117 } 118 119 ok, err := s.enforcer.E.HasGroupingPolicy(actor.Did, group, domain) 120 if err != nil || !ok { 121 // we need a logged in user 122 log.Printf("%s does not have perms of a %s in domain %s", actor.Did, group, domain) 123 http.Error(w, "Forbiden", http.StatusUnauthorized) 124 return 125 } 126 127 next.ServeHTTP(w, r) 128 }) 129 } 130} 131 132func KnotOwner(s *State) Middleware { 133 return knotRoleMiddleware(s, "server:owner") 134} 135 136func RepoPermissionMiddleware(s *State, requiredPerm string) Middleware { 137 return func(next http.Handler) http.Handler { 138 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 139 // requires auth also 140 actor := s.auth.GetUser(r) 141 if actor == nil { 142 // we need a logged in user 143 log.Printf("not logged in, redirecting") 144 http.Error(w, "Forbiden", http.StatusUnauthorized) 145 return 146 } 147 f, err := fullyResolvedRepo(r) 148 if err != nil { 149 http.Error(w, "malformed url", http.StatusBadRequest) 150 return 151 } 152 153 ok, err := s.enforcer.E.Enforce(actor.Did, f.Knot, f.OwnerSlashRepo(), requiredPerm) 154 if err != nil || !ok { 155 // we need a logged in user 156 log.Printf("%s does not have perms of a %s in repo %s", actor.Did, requiredPerm, f.OwnerSlashRepo()) 157 http.Error(w, "Forbiden", http.StatusUnauthorized) 158 return 159 } 160 161 next.ServeHTTP(w, r) 162 }) 163 } 164} 165 166func StripLeadingAt(next http.Handler) http.Handler { 167 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 168 path := req.URL.EscapedPath() 169 if strings.HasPrefix(path, "/@") { 170 req.URL.RawPath = "/" + strings.TrimPrefix(path, "/@") 171 } 172 next.ServeHTTP(w, req) 173 }) 174} 175 176func ResolveIdent(s *State) Middleware { 177 return func(next http.Handler) http.Handler { 178 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 179 didOrHandle := chi.URLParam(req, "user") 180 181 id, err := s.resolver.ResolveIdent(req.Context(), didOrHandle) 182 if err != nil { 183 // invalid did or handle 184 log.Println("failed to resolve did/handle:", err) 185 w.WriteHeader(http.StatusNotFound) 186 return 187 } 188 189 ctx := context.WithValue(req.Context(), "resolvedId", *id) 190 191 next.ServeHTTP(w, req.WithContext(ctx)) 192 }) 193 } 194} 195 196func ResolveRepo(s *State) Middleware { 197 return func(next http.Handler) http.Handler { 198 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 199 repoName := chi.URLParam(req, "repo") 200 id, ok := req.Context().Value("resolvedId").(identity.Identity) 201 if !ok { 202 log.Println("malformed middleware") 203 w.WriteHeader(http.StatusInternalServerError) 204 return 205 } 206 207 repo, err := db.GetRepo(s.db, id.DID.String(), repoName) 208 if err != nil { 209 // invalid did or handle 210 log.Println("failed to resolve repo") 211 w.WriteHeader(http.StatusNotFound) 212 return 213 } 214 215 ctx := context.WithValue(req.Context(), "knot", repo.Knot) 216 ctx = context.WithValue(ctx, "repoAt", repo.AtUri) 217 ctx = context.WithValue(ctx, "repoDescription", repo.Description) 218 ctx = context.WithValue(ctx, "repoAddedAt", repo.Created.Format(time.RFC3339)) 219 next.ServeHTTP(w, req.WithContext(ctx)) 220 }) 221 } 222} 223 224// middleware that is tacked on top of /{user}/{repo}/pulls/{pull} 225func ResolvePull(s *State) Middleware { 226 return func(next http.Handler) http.Handler { 227 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 228 f, err := fullyResolvedRepo(r) 229 if err != nil { 230 log.Println("failed to fully resolve repo", err) 231 http.Error(w, "invalid repo url", http.StatusNotFound) 232 return 233 } 234 235 prId := chi.URLParam(r, "pull") 236 prIdInt, err := strconv.Atoi(prId) 237 if err != nil { 238 http.Error(w, "bad pr id", http.StatusBadRequest) 239 log.Println("failed to parse pr id", err) 240 return 241 } 242 243 pr, err := db.GetPull(s.db, f.RepoAt, prIdInt) 244 if err != nil { 245 log.Println("failed to get pull and comments", err) 246 return 247 } 248 249 ctx := context.WithValue(r.Context(), "pull", pr) 250 251 next.ServeHTTP(w, r.WithContext(ctx)) 252 }) 253 } 254}