1package state
2
3import (
4 "context"
5 "log"
6 "net/http"
7 "strconv"
8 "strings"
9 "time"
10
11 comatproto "github.com/bluesky-social/indigo/api/atproto"
12 "github.com/bluesky-social/indigo/atproto/identity"
13 "github.com/bluesky-social/indigo/xrpc"
14 "github.com/go-chi/chi/v5"
15 "github.com/sotangled/tangled/appview"
16 "github.com/sotangled/tangled/appview/auth"
17 "github.com/sotangled/tangled/appview/db"
18)
19
20type Middleware func(http.Handler) http.Handler
21
22func AuthMiddleware(s *State) Middleware {
23 return func(next http.Handler) http.Handler {
24 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
25 session, err := s.auth.GetSession(r)
26 if session.IsNew || err != nil {
27 log.Printf("not logged in, redirecting")
28 http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
29 return
30 }
31
32 authorized, ok := session.Values[appview.SessionAuthenticated].(bool)
33 if !ok || !authorized {
34 log.Printf("not logged in, redirecting")
35 http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
36 return
37 }
38
39 // refresh if nearing expiry
40 // TODO: dedup with /login
41 expiryStr := session.Values[appview.SessionExpiry].(string)
42 expiry, err := time.Parse(time.RFC3339, expiryStr)
43 if err != nil {
44 log.Println("invalid expiry time", err)
45 http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
46 return
47 }
48 pdsUrl, ok1 := session.Values[appview.SessionPds].(string)
49 did, ok2 := session.Values[appview.SessionDid].(string)
50 refreshJwt, ok3 := session.Values[appview.SessionRefreshJwt].(string)
51
52 if !ok1 || !ok2 || !ok3 {
53 log.Println("invalid expiry time", err)
54 http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
55 return
56 }
57
58 if time.Now().After(expiry) {
59 log.Println("token expired, refreshing ...")
60
61 client := xrpc.Client{
62 Host: pdsUrl,
63 Auth: &xrpc.AuthInfo{
64 Did: did,
65 AccessJwt: refreshJwt,
66 RefreshJwt: refreshJwt,
67 },
68 }
69 atSession, err := comatproto.ServerRefreshSession(r.Context(), &client)
70 if err != nil {
71 log.Println("failed to refresh session", err)
72 http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
73 return
74 }
75
76 sessionish := auth.RefreshSessionWrapper{atSession}
77
78 err = s.auth.StoreSession(r, w, &sessionish, pdsUrl)
79 if err != nil {
80 log.Printf("failed to store session for did: %s\n: %s", atSession.Did, err)
81 return
82 }
83
84 log.Println("successfully refreshed token")
85 }
86
87 next.ServeHTTP(w, r)
88 })
89 }
90}
91
92func knotRoleMiddleware(s *State, group string) Middleware {
93 return func(next http.Handler) http.Handler {
94 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
95 // requires auth also
96 actor := s.auth.GetUser(r)
97 if actor == nil {
98 // we need a logged in user
99 log.Printf("not logged in, redirecting")
100 http.Error(w, "Forbiden", http.StatusUnauthorized)
101 return
102 }
103 domain := chi.URLParam(r, "domain")
104 if domain == "" {
105 http.Error(w, "malformed url", http.StatusBadRequest)
106 return
107 }
108
109 ok, err := s.enforcer.E.HasGroupingPolicy(actor.Did, group, domain)
110 if err != nil || !ok {
111 // we need a logged in user
112 log.Printf("%s does not have perms of a %s in domain %s", actor.Did, group, domain)
113 http.Error(w, "Forbiden", http.StatusUnauthorized)
114 return
115 }
116
117 next.ServeHTTP(w, r)
118 })
119 }
120}
121
122func KnotOwner(s *State) Middleware {
123 return knotRoleMiddleware(s, "server:owner")
124}
125
126func RepoPermissionMiddleware(s *State, requiredPerm string) Middleware {
127 return func(next http.Handler) http.Handler {
128 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
129 // requires auth also
130 actor := s.auth.GetUser(r)
131 if actor == nil {
132 // we need a logged in user
133 log.Printf("not logged in, redirecting")
134 http.Error(w, "Forbiden", http.StatusUnauthorized)
135 return
136 }
137 f, err := fullyResolvedRepo(r)
138 if err != nil {
139 http.Error(w, "malformed url", http.StatusBadRequest)
140 return
141 }
142
143 ok, err := s.enforcer.E.Enforce(actor.Did, f.Knot, f.OwnerSlashRepo(), requiredPerm)
144 if err != nil || !ok {
145 // we need a logged in user
146 log.Printf("%s does not have perms of a %s in repo %s", actor.Did, requiredPerm, f.OwnerSlashRepo())
147 http.Error(w, "Forbiden", http.StatusUnauthorized)
148 return
149 }
150
151 next.ServeHTTP(w, r)
152 })
153 }
154}
155
156func StripLeadingAt(next http.Handler) http.Handler {
157 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
158 path := req.URL.Path
159 if strings.HasPrefix(path, "/@") {
160 req.URL.Path = "/" + strings.TrimPrefix(path, "/@")
161 }
162 next.ServeHTTP(w, req)
163 })
164}
165
166func ResolveIdent(s *State) Middleware {
167 return func(next http.Handler) http.Handler {
168 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
169 didOrHandle := chi.URLParam(req, "user")
170
171 id, err := s.resolver.ResolveIdent(req.Context(), didOrHandle)
172 if err != nil {
173 // invalid did or handle
174 log.Println("failed to resolve did/handle:", err)
175 w.WriteHeader(http.StatusNotFound)
176 return
177 }
178
179 ctx := context.WithValue(req.Context(), "resolvedId", *id)
180
181 next.ServeHTTP(w, req.WithContext(ctx))
182 })
183 }
184}
185
186func ResolveRepo(s *State) Middleware {
187 return func(next http.Handler) http.Handler {
188 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
189 repoName := chi.URLParam(req, "repo")
190 id, ok := req.Context().Value("resolvedId").(identity.Identity)
191 if !ok {
192 log.Println("malformed middleware")
193 w.WriteHeader(http.StatusInternalServerError)
194 return
195 }
196
197 repo, err := db.GetRepo(s.db, id.DID.String(), repoName)
198 if err != nil {
199 // invalid did or handle
200 log.Println("failed to resolve repo")
201 w.WriteHeader(http.StatusNotFound)
202 return
203 }
204
205 ctx := context.WithValue(req.Context(), "knot", repo.Knot)
206 ctx = context.WithValue(ctx, "repoAt", repo.AtUri)
207 ctx = context.WithValue(ctx, "repoDescription", repo.Description)
208 ctx = context.WithValue(ctx, "repoAddedAt", repo.Created.Format(time.RFC3339))
209 next.ServeHTTP(w, req.WithContext(ctx))
210 })
211 }
212}
213
214// middleware that is tacked on top of /{user}/{repo}/pulls/{pull}
215func ResolvePull(s *State) Middleware {
216 return func(next http.Handler) http.Handler {
217 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
218 f, err := fullyResolvedRepo(r)
219 if err != nil {
220 log.Println("failed to fully resolve repo", err)
221 http.Error(w, "invalid repo url", http.StatusNotFound)
222 return
223 }
224
225 prId := chi.URLParam(r, "pull")
226 prIdInt, err := strconv.Atoi(prId)
227 if err != nil {
228 http.Error(w, "bad pr id", http.StatusBadRequest)
229 log.Println("failed to parse pr id", err)
230 return
231 }
232
233 pr, err := db.GetPull(s.db, f.RepoAt, prIdInt)
234 if err != nil {
235 log.Println("failed to get pull and comments", err)
236 return
237 }
238
239 ctx := context.WithValue(r.Context(), "pull", pr)
240
241 next.ServeHTTP(w, r.WithContext(ctx))
242 })
243 }
244}