forked from tangled.org/core
Monorepo for Tangled — https://tangled.org
1package state 2 3import ( 4 "context" 5 "log" 6 "net/http" 7 "strings" 8 "time" 9 10 comatproto "github.com/bluesky-social/indigo/api/atproto" 11 "github.com/bluesky-social/indigo/atproto/identity" 12 "github.com/bluesky-social/indigo/xrpc" 13 "github.com/go-chi/chi/v5" 14 "github.com/sotangled/tangled/appview" 15 "github.com/sotangled/tangled/appview/auth" 16) 17 18type Middleware func(http.Handler) http.Handler 19 20func AuthMiddleware(s *State) Middleware { 21 return func(next http.Handler) http.Handler { 22 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 23 session, _ := s.auth.Store.Get(r, appview.SessionName) 24 authorized, ok := session.Values[appview.SessionAuthenticated].(bool) 25 if !ok || !authorized { 26 log.Printf("not logged in, redirecting") 27 http.Redirect(w, r, "/login", http.StatusTemporaryRedirect) 28 return 29 } 30 31 // refresh if nearing expiry 32 // TODO: dedup with /login 33 expiryStr := session.Values[appview.SessionExpiry].(string) 34 expiry, err := time.Parse(time.RFC3339, expiryStr) 35 if err != nil { 36 log.Println("invalid expiry time", err) 37 http.Redirect(w, r, "/login", http.StatusTemporaryRedirect) 38 return 39 } 40 pdsUrl := session.Values[appview.SessionPds].(string) 41 did := session.Values[appview.SessionDid].(string) 42 refreshJwt := session.Values[appview.SessionRefreshJwt].(string) 43 44 if time.Now().After(expiry) { 45 log.Println("token expired, refreshing ...") 46 47 client := xrpc.Client{ 48 Host: pdsUrl, 49 Auth: &xrpc.AuthInfo{ 50 Did: did, 51 AccessJwt: refreshJwt, 52 RefreshJwt: refreshJwt, 53 }, 54 } 55 atSession, err := comatproto.ServerRefreshSession(r.Context(), &client) 56 if err != nil { 57 log.Println(err) 58 return 59 } 60 61 sessionish := auth.RefreshSessionWrapper{atSession} 62 63 err = s.auth.StoreSession(r, w, &sessionish, pdsUrl) 64 if err != nil { 65 log.Printf("failed to store session for did: %s\n: %s", atSession.Did, err) 66 return 67 } 68 69 log.Println("successfully refreshed token") 70 } 71 72 next.ServeHTTP(w, r) 73 }) 74 } 75} 76 77func RoleMiddleware(s *State, group string) Middleware { 78 return func(next http.Handler) http.Handler { 79 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 80 // requires auth also 81 actor := s.auth.GetUser(r) 82 if actor == nil { 83 // we need a logged in user 84 log.Printf("not logged in, redirecting") 85 http.Error(w, "Forbiden", http.StatusUnauthorized) 86 return 87 } 88 domain := chi.URLParam(r, "domain") 89 if domain == "" { 90 http.Error(w, "malformed url", http.StatusBadRequest) 91 return 92 } 93 94 ok, err := s.enforcer.E.HasGroupingPolicy(actor.Did, group, domain) 95 if err != nil || !ok { 96 // we need a logged in user 97 log.Printf("%s does not have perms of a %s in domain %s", actor.Did, group, domain) 98 http.Error(w, "Forbiden", http.StatusUnauthorized) 99 return 100 } 101 102 next.ServeHTTP(w, r) 103 }) 104 } 105} 106 107func RepoPermissionMiddleware(s *State, requiredPerm string) Middleware { 108 return func(next http.Handler) http.Handler { 109 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 110 // requires auth also 111 actor := s.auth.GetUser(r) 112 if actor == nil { 113 // we need a logged in user 114 log.Printf("not logged in, redirecting") 115 http.Error(w, "Forbiden", http.StatusUnauthorized) 116 return 117 } 118 f, err := fullyResolvedRepo(r) 119 if err != nil { 120 http.Error(w, "malformed url", http.StatusBadRequest) 121 return 122 } 123 124 ok, err := s.enforcer.E.Enforce(actor.Did, f.Knot, f.OwnerSlashRepo(), requiredPerm) 125 if err != nil || !ok { 126 // we need a logged in user 127 log.Printf("%s does not have perms of a %s in repo %s", actor.Did, requiredPerm, f.OwnerSlashRepo()) 128 http.Error(w, "Forbiden", http.StatusUnauthorized) 129 return 130 } 131 132 next.ServeHTTP(w, r) 133 }) 134 } 135} 136 137func StripLeadingAt(next http.Handler) http.Handler { 138 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 139 path := req.URL.Path 140 if strings.HasPrefix(path, "/@") { 141 req.URL.Path = "/" + strings.TrimPrefix(path, "/@") 142 } 143 next.ServeHTTP(w, req) 144 }) 145} 146 147func ResolveIdent(s *State) Middleware { 148 return func(next http.Handler) http.Handler { 149 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 150 start := time.Now() 151 didOrHandle := chi.URLParam(req, "user") 152 153 id, err := s.resolver.ResolveIdent(req.Context(), didOrHandle) 154 if err != nil { 155 // invalid did or handle 156 log.Println("failed to resolve did/handle:", err) 157 w.WriteHeader(http.StatusNotFound) 158 return 159 } 160 161 ctx := context.WithValue(req.Context(), "resolvedId", *id) 162 163 elapsed := time.Since(start) 164 log.Println("Execution time:", elapsed) 165 next.ServeHTTP(w, req.WithContext(ctx)) 166 }) 167 } 168} 169 170func ResolveRepoKnot(s *State) Middleware { 171 return func(next http.Handler) http.Handler { 172 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 173 repoName := chi.URLParam(req, "repo") 174 id, ok := req.Context().Value("resolvedId").(identity.Identity) 175 if !ok { 176 log.Println("malformed middleware") 177 w.WriteHeader(http.StatusInternalServerError) 178 return 179 } 180 181 repo, err := s.db.GetRepo(id.DID.String(), repoName) 182 if err != nil { 183 // invalid did or handle 184 log.Println("failed to resolve repo") 185 w.WriteHeader(http.StatusNotFound) 186 return 187 } 188 189 ctx := context.WithValue(req.Context(), "knot", repo.Knot) 190 ctx = context.WithValue(ctx, "repoAt", repo.AtUri) 191 next.ServeHTTP(w, req.WithContext(ctx)) 192 }) 193 } 194}