1package state
2
3import (
4 "context"
5 "log"
6 "net/http"
7 "strconv"
8 "strings"
9 "time"
10
11 "slices"
12
13 comatproto "github.com/bluesky-social/indigo/api/atproto"
14 "github.com/bluesky-social/indigo/atproto/identity"
15 "github.com/bluesky-social/indigo/xrpc"
16 "github.com/go-chi/chi/v5"
17 "tangled.sh/tangled.sh/core/appview"
18 "tangled.sh/tangled.sh/core/appview/auth"
19 "tangled.sh/tangled.sh/core/appview/db"
20)
21
22type Middleware func(http.Handler) http.Handler
23
24func AuthMiddleware(s *State) Middleware {
25 return func(next http.Handler) http.Handler {
26 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
27 redirectFunc := func(w http.ResponseWriter, r *http.Request) {
28 http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
29 }
30 if r.Header.Get("HX-Request") == "true" {
31 redirectFunc = func(w http.ResponseWriter, _ *http.Request) {
32 w.Header().Set("HX-Redirect", "/login")
33 w.WriteHeader(http.StatusOK)
34 }
35 }
36
37 session, err := s.auth.GetSession(r)
38 if session.IsNew || err != nil {
39 log.Printf("not logged in, redirecting")
40 redirectFunc(w, r)
41 return
42 }
43
44 authorized, ok := session.Values[appview.SessionAuthenticated].(bool)
45 if !ok || !authorized {
46 log.Printf("not logged in, redirecting")
47 redirectFunc(w, r)
48 return
49 }
50
51 // refresh if nearing expiry
52 // TODO: dedup with /login
53 expiryStr := session.Values[appview.SessionExpiry].(string)
54 expiry, err := time.Parse(time.RFC3339, expiryStr)
55 if err != nil {
56 log.Println("invalid expiry time", err)
57 redirectFunc(w, r)
58 return
59 }
60 pdsUrl, ok1 := session.Values[appview.SessionPds].(string)
61 did, ok2 := session.Values[appview.SessionDid].(string)
62 refreshJwt, ok3 := session.Values[appview.SessionRefreshJwt].(string)
63
64 if !ok1 || !ok2 || !ok3 {
65 log.Println("invalid expiry time", err)
66 redirectFunc(w, r)
67 return
68 }
69
70 if time.Now().After(expiry) {
71 log.Println("token expired, refreshing ...")
72
73 client := xrpc.Client{
74 Host: pdsUrl,
75 Auth: &xrpc.AuthInfo{
76 Did: did,
77 AccessJwt: refreshJwt,
78 RefreshJwt: refreshJwt,
79 },
80 }
81 atSession, err := comatproto.ServerRefreshSession(r.Context(), &client)
82 if err != nil {
83 log.Println("failed to refresh session", err)
84 redirectFunc(w, r)
85 return
86 }
87
88 sessionish := auth.RefreshSessionWrapper{atSession}
89
90 err = s.auth.StoreSession(r, w, &sessionish, pdsUrl)
91 if err != nil {
92 log.Printf("failed to store session for did: %s\n: %s", atSession.Did, err)
93 return
94 }
95
96 log.Println("successfully refreshed token")
97 }
98
99 next.ServeHTTP(w, r)
100 })
101 }
102}
103
104func knotRoleMiddleware(s *State, group string) Middleware {
105 return func(next http.Handler) http.Handler {
106 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
107 // requires auth also
108 actor := s.auth.GetUser(r)
109 if actor == nil {
110 // we need a logged in user
111 log.Printf("not logged in, redirecting")
112 http.Error(w, "Forbiden", http.StatusUnauthorized)
113 return
114 }
115 domain := chi.URLParam(r, "domain")
116 if domain == "" {
117 http.Error(w, "malformed url", http.StatusBadRequest)
118 return
119 }
120
121 ok, err := s.enforcer.E.HasGroupingPolicy(actor.Did, group, domain)
122 if err != nil || !ok {
123 // we need a logged in user
124 log.Printf("%s does not have perms of a %s in domain %s", actor.Did, group, domain)
125 http.Error(w, "Forbiden", http.StatusUnauthorized)
126 return
127 }
128
129 next.ServeHTTP(w, r)
130 })
131 }
132}
133
134func KnotOwner(s *State) Middleware {
135 return knotRoleMiddleware(s, "server:owner")
136}
137
138func RepoPermissionMiddleware(s *State, requiredPerm string) Middleware {
139 return func(next http.Handler) http.Handler {
140 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
141 // requires auth also
142 actor := s.auth.GetUser(r)
143 if actor == nil {
144 // we need a logged in user
145 log.Printf("not logged in, redirecting")
146 http.Error(w, "Forbiden", http.StatusUnauthorized)
147 return
148 }
149 f, err := fullyResolvedRepo(r)
150 if err != nil {
151 http.Error(w, "malformed url", http.StatusBadRequest)
152 return
153 }
154
155 ok, err := s.enforcer.E.Enforce(actor.Did, f.Knot, f.DidSlashRepo(), requiredPerm)
156 if err != nil || !ok {
157 // we need a logged in user
158 log.Printf("%s does not have perms of a %s in repo %s", actor.Did, requiredPerm, f.OwnerSlashRepo())
159 http.Error(w, "Forbiden", http.StatusUnauthorized)
160 return
161 }
162
163 next.ServeHTTP(w, r)
164 })
165 }
166}
167
168func StripLeadingAt(next http.Handler) http.Handler {
169 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
170 path := req.URL.EscapedPath()
171 if strings.HasPrefix(path, "/@") {
172 req.URL.RawPath = "/" + strings.TrimPrefix(path, "/@")
173 }
174 next.ServeHTTP(w, req)
175 })
176}
177
178func ResolveIdent(s *State) Middleware {
179 excluded := []string{"favicon.ico"}
180
181 return func(next http.Handler) http.Handler {
182 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
183 didOrHandle := chi.URLParam(req, "user")
184 if slices.Contains(excluded, didOrHandle) {
185 next.ServeHTTP(w, req)
186 return
187 }
188
189 id, err := s.resolver.ResolveIdent(req.Context(), didOrHandle)
190 if err != nil {
191 // invalid did or handle
192 log.Println("failed to resolve did/handle:", err)
193 w.WriteHeader(http.StatusNotFound)
194 return
195 }
196
197 ctx := context.WithValue(req.Context(), "resolvedId", *id)
198
199 next.ServeHTTP(w, req.WithContext(ctx))
200 })
201 }
202}
203
204func ResolveRepo(s *State) Middleware {
205 return func(next http.Handler) http.Handler {
206 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
207 repoName := chi.URLParam(req, "repo")
208 id, ok := req.Context().Value("resolvedId").(identity.Identity)
209 if !ok {
210 log.Println("malformed middleware")
211 w.WriteHeader(http.StatusInternalServerError)
212 return
213 }
214
215 repo, err := db.GetRepo(s.db, id.DID.String(), repoName)
216 if err != nil {
217 // invalid did or handle
218 log.Println("failed to resolve repo")
219 w.WriteHeader(http.StatusNotFound)
220 return
221 }
222
223 ctx := context.WithValue(req.Context(), "knot", repo.Knot)
224 ctx = context.WithValue(ctx, "repoAt", repo.AtUri)
225 ctx = context.WithValue(ctx, "repoDescription", repo.Description)
226 ctx = context.WithValue(ctx, "repoAddedAt", repo.Created.Format(time.RFC3339))
227 next.ServeHTTP(w, req.WithContext(ctx))
228 })
229 }
230}
231
232// middleware that is tacked on top of /{user}/{repo}/pulls/{pull}
233func ResolvePull(s *State) Middleware {
234 return func(next http.Handler) http.Handler {
235 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
236 f, err := fullyResolvedRepo(r)
237 if err != nil {
238 log.Println("failed to fully resolve repo", err)
239 http.Error(w, "invalid repo url", http.StatusNotFound)
240 return
241 }
242
243 prId := chi.URLParam(r, "pull")
244 prIdInt, err := strconv.Atoi(prId)
245 if err != nil {
246 http.Error(w, "bad pr id", http.StatusBadRequest)
247 log.Println("failed to parse pr id", err)
248 return
249 }
250
251 pr, err := db.GetPull(s.db, f.RepoAt, prIdInt)
252 if err != nil {
253 log.Println("failed to get pull and comments", err)
254 return
255 }
256
257 ctx := context.WithValue(r.Context(), "pull", pr)
258
259 next.ServeHTTP(w, r.WithContext(ctx))
260 })
261 }
262}