1package state
2
3import (
4 "context"
5 "log"
6 "net/http"
7 "strconv"
8 "strings"
9 "time"
10
11 comatproto "github.com/bluesky-social/indigo/api/atproto"
12 "github.com/bluesky-social/indigo/atproto/identity"
13 "github.com/bluesky-social/indigo/xrpc"
14 "github.com/go-chi/chi/v5"
15 "tangled.sh/tangled.sh/core/appview"
16 "tangled.sh/tangled.sh/core/appview/auth"
17 "tangled.sh/tangled.sh/core/appview/db"
18)
19
20type Middleware func(http.Handler) http.Handler
21
22func AuthMiddleware(s *State) Middleware {
23 return func(next http.Handler) http.Handler {
24 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
25 redirectFunc := func(w http.ResponseWriter, r *http.Request) {
26 http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
27 }
28 if r.Header.Get("HX-Request") == "true" {
29 redirectFunc = func(w http.ResponseWriter, _ *http.Request) {
30 w.Header().Set("HX-Redirect", "/login")
31 w.WriteHeader(http.StatusOK)
32 }
33 }
34
35 session, err := s.auth.GetSession(r)
36 if session.IsNew || err != nil {
37 log.Printf("not logged in, redirecting")
38 redirectFunc(w, r)
39 return
40 }
41
42 authorized, ok := session.Values[appview.SessionAuthenticated].(bool)
43 if !ok || !authorized {
44 log.Printf("not logged in, redirecting")
45 redirectFunc(w, r)
46 return
47 }
48
49 // refresh if nearing expiry
50 // TODO: dedup with /login
51 expiryStr := session.Values[appview.SessionExpiry].(string)
52 expiry, err := time.Parse(time.RFC3339, expiryStr)
53 if err != nil {
54 log.Println("invalid expiry time", err)
55 redirectFunc(w, r)
56 return
57 }
58 pdsUrl, ok1 := session.Values[appview.SessionPds].(string)
59 did, ok2 := session.Values[appview.SessionDid].(string)
60 refreshJwt, ok3 := session.Values[appview.SessionRefreshJwt].(string)
61
62 if !ok1 || !ok2 || !ok3 {
63 log.Println("invalid expiry time", err)
64 redirectFunc(w, r)
65 return
66 }
67
68 if time.Now().After(expiry) {
69 log.Println("token expired, refreshing ...")
70
71 client := xrpc.Client{
72 Host: pdsUrl,
73 Auth: &xrpc.AuthInfo{
74 Did: did,
75 AccessJwt: refreshJwt,
76 RefreshJwt: refreshJwt,
77 },
78 }
79 atSession, err := comatproto.ServerRefreshSession(r.Context(), &client)
80 if err != nil {
81 log.Println("failed to refresh session", err)
82 redirectFunc(w, r)
83 return
84 }
85
86 sessionish := auth.RefreshSessionWrapper{atSession}
87
88 err = s.auth.StoreSession(r, w, &sessionish, pdsUrl)
89 if err != nil {
90 log.Printf("failed to store session for did: %s\n: %s", atSession.Did, err)
91 return
92 }
93
94 log.Println("successfully refreshed token")
95 }
96
97 next.ServeHTTP(w, r)
98 })
99 }
100}
101
102func knotRoleMiddleware(s *State, group string) Middleware {
103 return func(next http.Handler) http.Handler {
104 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
105 // requires auth also
106 actor := s.auth.GetUser(r)
107 if actor == nil {
108 // we need a logged in user
109 log.Printf("not logged in, redirecting")
110 http.Error(w, "Forbiden", http.StatusUnauthorized)
111 return
112 }
113 domain := chi.URLParam(r, "domain")
114 if domain == "" {
115 http.Error(w, "malformed url", http.StatusBadRequest)
116 return
117 }
118
119 ok, err := s.enforcer.E.HasGroupingPolicy(actor.Did, group, domain)
120 if err != nil || !ok {
121 // we need a logged in user
122 log.Printf("%s does not have perms of a %s in domain %s", actor.Did, group, domain)
123 http.Error(w, "Forbiden", http.StatusUnauthorized)
124 return
125 }
126
127 next.ServeHTTP(w, r)
128 })
129 }
130}
131
132func KnotOwner(s *State) Middleware {
133 return knotRoleMiddleware(s, "server:owner")
134}
135
136func RepoPermissionMiddleware(s *State, requiredPerm string) Middleware {
137 return func(next http.Handler) http.Handler {
138 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
139 // requires auth also
140 actor := s.auth.GetUser(r)
141 if actor == nil {
142 // we need a logged in user
143 log.Printf("not logged in, redirecting")
144 http.Error(w, "Forbiden", http.StatusUnauthorized)
145 return
146 }
147 f, err := fullyResolvedRepo(r)
148 if err != nil {
149 http.Error(w, "malformed url", http.StatusBadRequest)
150 return
151 }
152
153 ok, err := s.enforcer.E.Enforce(actor.Did, f.Knot, f.OwnerSlashRepo(), requiredPerm)
154 if err != nil || !ok {
155 // we need a logged in user
156 log.Printf("%s does not have perms of a %s in repo %s", actor.Did, requiredPerm, f.OwnerSlashRepo())
157 http.Error(w, "Forbiden", http.StatusUnauthorized)
158 return
159 }
160
161 next.ServeHTTP(w, r)
162 })
163 }
164}
165
166func StripLeadingAt(next http.Handler) http.Handler {
167 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
168 path := req.URL.EscapedPath()
169 if strings.HasPrefix(path, "/@") {
170 req.URL.RawPath = "/" + strings.TrimPrefix(path, "/@")
171 }
172 next.ServeHTTP(w, req)
173 })
174}
175
176func ResolveIdent(s *State) Middleware {
177 return func(next http.Handler) http.Handler {
178 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
179 didOrHandle := chi.URLParam(req, "user")
180
181 id, err := s.resolver.ResolveIdent(req.Context(), didOrHandle)
182 if err != nil {
183 // invalid did or handle
184 log.Println("failed to resolve did/handle:", err)
185 w.WriteHeader(http.StatusNotFound)
186 return
187 }
188
189 ctx := context.WithValue(req.Context(), "resolvedId", *id)
190
191 next.ServeHTTP(w, req.WithContext(ctx))
192 })
193 }
194}
195
196func ResolveRepo(s *State) Middleware {
197 return func(next http.Handler) http.Handler {
198 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
199 repoName := chi.URLParam(req, "repo")
200 id, ok := req.Context().Value("resolvedId").(identity.Identity)
201 if !ok {
202 log.Println("malformed middleware")
203 w.WriteHeader(http.StatusInternalServerError)
204 return
205 }
206
207 repo, err := db.GetRepo(s.db, id.DID.String(), repoName)
208 if err != nil {
209 // invalid did or handle
210 log.Println("failed to resolve repo")
211 w.WriteHeader(http.StatusNotFound)
212 return
213 }
214
215 ctx := context.WithValue(req.Context(), "knot", repo.Knot)
216 ctx = context.WithValue(ctx, "repoAt", repo.AtUri)
217 ctx = context.WithValue(ctx, "repoDescription", repo.Description)
218 ctx = context.WithValue(ctx, "repoAddedAt", repo.Created.Format(time.RFC3339))
219 next.ServeHTTP(w, req.WithContext(ctx))
220 })
221 }
222}
223
224// middleware that is tacked on top of /{user}/{repo}/pulls/{pull}
225func ResolvePull(s *State) Middleware {
226 return func(next http.Handler) http.Handler {
227 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
228 f, err := fullyResolvedRepo(r)
229 if err != nil {
230 log.Println("failed to fully resolve repo", err)
231 http.Error(w, "invalid repo url", http.StatusNotFound)
232 return
233 }
234
235 prId := chi.URLParam(r, "pull")
236 prIdInt, err := strconv.Atoi(prId)
237 if err != nil {
238 http.Error(w, "bad pr id", http.StatusBadRequest)
239 log.Println("failed to parse pr id", err)
240 return
241 }
242
243 pr, err := db.GetPull(s.db, f.RepoAt, prIdInt)
244 if err != nil {
245 log.Println("failed to get pull and comments", err)
246 return
247 }
248
249 ctx := context.WithValue(r.Context(), "pull", pr)
250
251 next.ServeHTTP(w, r.WithContext(ctx))
252 })
253 }
254}