forked from tangled.org/core
Monorepo for Tangled — https://tangled.org
1package state 2 3import ( 4 "context" 5 "log" 6 "net/http" 7 "strings" 8 "time" 9 10 comatproto "github.com/bluesky-social/indigo/api/atproto" 11 "github.com/bluesky-social/indigo/atproto/identity" 12 "github.com/bluesky-social/indigo/xrpc" 13 "github.com/go-chi/chi/v5" 14 "github.com/sotangled/tangled/appview" 15 "github.com/sotangled/tangled/appview/auth" 16 "github.com/sotangled/tangled/appview/db" 17) 18 19type Middleware func(http.Handler) http.Handler 20 21func AuthMiddleware(s *State) Middleware { 22 return func(next http.Handler) http.Handler { 23 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 24 session, _ := s.auth.Store.Get(r, appview.SessionName) 25 authorized, ok := session.Values[appview.SessionAuthenticated].(bool) 26 if !ok || !authorized { 27 log.Printf("not logged in, redirecting") 28 http.Redirect(w, r, "/login", http.StatusTemporaryRedirect) 29 return 30 } 31 32 // refresh if nearing expiry 33 // TODO: dedup with /login 34 expiryStr := session.Values[appview.SessionExpiry].(string) 35 expiry, err := time.Parse(time.RFC3339, expiryStr) 36 if err != nil { 37 log.Println("invalid expiry time", err) 38 http.Redirect(w, r, "/login", http.StatusTemporaryRedirect) 39 return 40 } 41 pdsUrl := session.Values[appview.SessionPds].(string) 42 did := session.Values[appview.SessionDid].(string) 43 refreshJwt := session.Values[appview.SessionRefreshJwt].(string) 44 45 if time.Now().After(expiry) { 46 log.Println("token expired, refreshing ...") 47 48 client := xrpc.Client{ 49 Host: pdsUrl, 50 Auth: &xrpc.AuthInfo{ 51 Did: did, 52 AccessJwt: refreshJwt, 53 RefreshJwt: refreshJwt, 54 }, 55 } 56 atSession, err := comatproto.ServerRefreshSession(r.Context(), &client) 57 if err != nil { 58 log.Println("failed to refresh session", err) 59 http.Redirect(w, r, "/login", http.StatusTemporaryRedirect) 60 return 61 } 62 63 sessionish := auth.RefreshSessionWrapper{atSession} 64 65 err = s.auth.StoreSession(r, w, &sessionish, pdsUrl) 66 if err != nil { 67 log.Printf("failed to store session for did: %s\n: %s", atSession.Did, err) 68 return 69 } 70 71 log.Println("successfully refreshed token") 72 } 73 74 next.ServeHTTP(w, r) 75 }) 76 } 77} 78 79func RoleMiddleware(s *State, group string) Middleware { 80 return func(next http.Handler) http.Handler { 81 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 82 // requires auth also 83 actor := s.auth.GetUser(r) 84 if actor == nil { 85 // we need a logged in user 86 log.Printf("not logged in, redirecting") 87 http.Error(w, "Forbiden", http.StatusUnauthorized) 88 return 89 } 90 domain := chi.URLParam(r, "domain") 91 if domain == "" { 92 http.Error(w, "malformed url", http.StatusBadRequest) 93 return 94 } 95 96 ok, err := s.enforcer.E.HasGroupingPolicy(actor.Did, group, domain) 97 if err != nil || !ok { 98 // we need a logged in user 99 log.Printf("%s does not have perms of a %s in domain %s", actor.Did, group, domain) 100 http.Error(w, "Forbiden", http.StatusUnauthorized) 101 return 102 } 103 104 next.ServeHTTP(w, r) 105 }) 106 } 107} 108 109func RepoPermissionMiddleware(s *State, requiredPerm string) Middleware { 110 return func(next http.Handler) http.Handler { 111 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 112 // requires auth also 113 actor := s.auth.GetUser(r) 114 if actor == nil { 115 // we need a logged in user 116 log.Printf("not logged in, redirecting") 117 http.Error(w, "Forbiden", http.StatusUnauthorized) 118 return 119 } 120 f, err := fullyResolvedRepo(r) 121 if err != nil { 122 http.Error(w, "malformed url", http.StatusBadRequest) 123 return 124 } 125 126 ok, err := s.enforcer.E.Enforce(actor.Did, f.Knot, f.OwnerSlashRepo(), requiredPerm) 127 if err != nil || !ok { 128 // we need a logged in user 129 log.Printf("%s does not have perms of a %s in repo %s", actor.Did, requiredPerm, f.OwnerSlashRepo()) 130 http.Error(w, "Forbiden", http.StatusUnauthorized) 131 return 132 } 133 134 next.ServeHTTP(w, r) 135 }) 136 } 137} 138 139func StripLeadingAt(next http.Handler) http.Handler { 140 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 141 path := req.URL.Path 142 if strings.HasPrefix(path, "/@") { 143 req.URL.Path = "/" + strings.TrimPrefix(path, "/@") 144 } 145 next.ServeHTTP(w, req) 146 }) 147} 148 149func ResolveIdent(s *State) Middleware { 150 return func(next http.Handler) http.Handler { 151 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 152 didOrHandle := chi.URLParam(req, "user") 153 154 id, err := s.resolver.ResolveIdent(req.Context(), didOrHandle) 155 if err != nil { 156 // invalid did or handle 157 log.Println("failed to resolve did/handle:", err) 158 w.WriteHeader(http.StatusNotFound) 159 return 160 } 161 162 ctx := context.WithValue(req.Context(), "resolvedId", *id) 163 164 next.ServeHTTP(w, req.WithContext(ctx)) 165 }) 166 } 167} 168 169func ResolveRepoKnot(s *State) Middleware { 170 return func(next http.Handler) http.Handler { 171 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 172 repoName := chi.URLParam(req, "repo") 173 id, ok := req.Context().Value("resolvedId").(identity.Identity) 174 if !ok { 175 log.Println("malformed middleware") 176 w.WriteHeader(http.StatusInternalServerError) 177 return 178 } 179 180 repo, err := db.GetRepo(s.db, id.DID.String(), repoName) 181 if err != nil { 182 // invalid did or handle 183 log.Println("failed to resolve repo") 184 w.WriteHeader(http.StatusNotFound) 185 return 186 } 187 188 ctx := context.WithValue(req.Context(), "knot", repo.Knot) 189 ctx = context.WithValue(ctx, "repoAt", repo.AtUri) 190 next.ServeHTTP(w, req.WithContext(ctx)) 191 }) 192 } 193}