1package state
2
3import (
4 "context"
5 "log"
6 "net/http"
7 "strings"
8 "time"
9
10 comatproto "github.com/bluesky-social/indigo/api/atproto"
11 "github.com/bluesky-social/indigo/atproto/identity"
12 "github.com/bluesky-social/indigo/xrpc"
13 "github.com/go-chi/chi/v5"
14 "github.com/sotangled/tangled/appview"
15 "github.com/sotangled/tangled/appview/auth"
16)
17
18type Middleware func(http.Handler) http.Handler
19
20func AuthMiddleware(s *State) Middleware {
21 return func(next http.Handler) http.Handler {
22 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
23 session, _ := s.auth.Store.Get(r, appview.SessionName)
24 authorized, ok := session.Values[appview.SessionAuthenticated].(bool)
25 if !ok || !authorized {
26 log.Printf("not logged in, redirecting")
27 http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
28 return
29 }
30
31 // refresh if nearing expiry
32 // TODO: dedup with /login
33 expiryStr := session.Values[appview.SessionExpiry].(string)
34 expiry, err := time.Parse(time.RFC3339, expiryStr)
35 if err != nil {
36 log.Println("invalid expiry time", err)
37 http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
38 return
39 }
40 pdsUrl := session.Values[appview.SessionPds].(string)
41 did := session.Values[appview.SessionDid].(string)
42 refreshJwt := session.Values[appview.SessionRefreshJwt].(string)
43
44 if time.Now().After(expiry) {
45 log.Println("token expired, refreshing ...")
46
47 client := xrpc.Client{
48 Host: pdsUrl,
49 Auth: &xrpc.AuthInfo{
50 Did: did,
51 AccessJwt: refreshJwt,
52 RefreshJwt: refreshJwt,
53 },
54 }
55 atSession, err := comatproto.ServerRefreshSession(r.Context(), &client)
56 if err != nil {
57 log.Println("failed to refresh session", err)
58 http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
59 return
60 }
61
62 sessionish := auth.RefreshSessionWrapper{atSession}
63
64 err = s.auth.StoreSession(r, w, &sessionish, pdsUrl)
65 if err != nil {
66 log.Printf("failed to store session for did: %s\n: %s", atSession.Did, err)
67 return
68 }
69
70 log.Println("successfully refreshed token")
71 }
72
73 next.ServeHTTP(w, r)
74 })
75 }
76}
77
78func RoleMiddleware(s *State, group string) Middleware {
79 return func(next http.Handler) http.Handler {
80 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
81 // requires auth also
82 actor := s.auth.GetUser(r)
83 if actor == nil {
84 // we need a logged in user
85 log.Printf("not logged in, redirecting")
86 http.Error(w, "Forbiden", http.StatusUnauthorized)
87 return
88 }
89 domain := chi.URLParam(r, "domain")
90 if domain == "" {
91 http.Error(w, "malformed url", http.StatusBadRequest)
92 return
93 }
94
95 ok, err := s.enforcer.E.HasGroupingPolicy(actor.Did, group, domain)
96 if err != nil || !ok {
97 // we need a logged in user
98 log.Printf("%s does not have perms of a %s in domain %s", actor.Did, group, domain)
99 http.Error(w, "Forbiden", http.StatusUnauthorized)
100 return
101 }
102
103 next.ServeHTTP(w, r)
104 })
105 }
106}
107
108func RepoPermissionMiddleware(s *State, requiredPerm string) Middleware {
109 return func(next http.Handler) http.Handler {
110 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
111 // requires auth also
112 actor := s.auth.GetUser(r)
113 if actor == nil {
114 // we need a logged in user
115 log.Printf("not logged in, redirecting")
116 http.Error(w, "Forbiden", http.StatusUnauthorized)
117 return
118 }
119 f, err := fullyResolvedRepo(r)
120 if err != nil {
121 http.Error(w, "malformed url", http.StatusBadRequest)
122 return
123 }
124
125 ok, err := s.enforcer.E.Enforce(actor.Did, f.Knot, f.OwnerSlashRepo(), requiredPerm)
126 if err != nil || !ok {
127 // we need a logged in user
128 log.Printf("%s does not have perms of a %s in repo %s", actor.Did, requiredPerm, f.OwnerSlashRepo())
129 http.Error(w, "Forbiden", http.StatusUnauthorized)
130 return
131 }
132
133 next.ServeHTTP(w, r)
134 })
135 }
136}
137
138func StripLeadingAt(next http.Handler) http.Handler {
139 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
140 path := req.URL.Path
141 if strings.HasPrefix(path, "/@") {
142 req.URL.Path = "/" + strings.TrimPrefix(path, "/@")
143 }
144 next.ServeHTTP(w, req)
145 })
146}
147
148func ResolveIdent(s *State) Middleware {
149 return func(next http.Handler) http.Handler {
150 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
151 didOrHandle := chi.URLParam(req, "user")
152
153 id, err := s.resolver.ResolveIdent(req.Context(), didOrHandle)
154 if err != nil {
155 // invalid did or handle
156 log.Println("failed to resolve did/handle:", err)
157 w.WriteHeader(http.StatusNotFound)
158 return
159 }
160
161 ctx := context.WithValue(req.Context(), "resolvedId", *id)
162
163 next.ServeHTTP(w, req.WithContext(ctx))
164 })
165 }
166}
167
168func ResolveRepoKnot(s *State) Middleware {
169 return func(next http.Handler) http.Handler {
170 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
171 repoName := chi.URLParam(req, "repo")
172 id, ok := req.Context().Value("resolvedId").(identity.Identity)
173 if !ok {
174 log.Println("malformed middleware")
175 w.WriteHeader(http.StatusInternalServerError)
176 return
177 }
178
179 repo, err := s.db.GetRepo(id.DID.String(), repoName)
180 if err != nil {
181 // invalid did or handle
182 log.Println("failed to resolve repo")
183 w.WriteHeader(http.StatusNotFound)
184 return
185 }
186
187 ctx := context.WithValue(req.Context(), "knot", repo.Knot)
188 ctx = context.WithValue(ctx, "repoAt", repo.AtUri)
189 next.ServeHTTP(w, req.WithContext(ctx))
190 })
191 }
192}