1package state
2
3import (
4 "context"
5 "log"
6 "net/http"
7 "strings"
8 "time"
9
10 comatproto "github.com/bluesky-social/indigo/api/atproto"
11 "github.com/bluesky-social/indigo/atproto/identity"
12 "github.com/bluesky-social/indigo/xrpc"
13 "github.com/go-chi/chi/v5"
14 "github.com/sotangled/tangled/appview"
15 "github.com/sotangled/tangled/appview/auth"
16 "github.com/sotangled/tangled/appview/db"
17)
18
19type Middleware func(http.Handler) http.Handler
20
21func AuthMiddleware(s *State) Middleware {
22 return func(next http.Handler) http.Handler {
23 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
24 session, err := s.auth.GetSession(r)
25 if session.IsNew || err != nil {
26 log.Printf("not logged in, redirecting")
27 http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
28 return
29 }
30
31 authorized, ok := session.Values[appview.SessionAuthenticated].(bool)
32 if !ok || !authorized {
33 log.Printf("not logged in, redirecting")
34 http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
35 return
36 }
37
38 // refresh if nearing expiry
39 // TODO: dedup with /login
40 expiryStr := session.Values[appview.SessionExpiry].(string)
41 expiry, err := time.Parse(time.RFC3339, expiryStr)
42 if err != nil {
43 log.Println("invalid expiry time", err)
44 http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
45 return
46 }
47 pdsUrl, ok1 := session.Values[appview.SessionPds].(string)
48 did, ok2 := session.Values[appview.SessionDid].(string)
49 refreshJwt, ok3 := session.Values[appview.SessionRefreshJwt].(string)
50
51 if !ok1 || !ok2 || !ok3 {
52 log.Println("invalid expiry time", err)
53 http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
54 return
55 }
56
57 if time.Now().After(expiry) {
58 log.Println("token expired, refreshing ...")
59
60 client := xrpc.Client{
61 Host: pdsUrl,
62 Auth: &xrpc.AuthInfo{
63 Did: did,
64 AccessJwt: refreshJwt,
65 RefreshJwt: refreshJwt,
66 },
67 }
68 atSession, err := comatproto.ServerRefreshSession(r.Context(), &client)
69 if err != nil {
70 log.Println("failed to refresh session", err)
71 http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
72 return
73 }
74
75 sessionish := auth.RefreshSessionWrapper{atSession}
76
77 err = s.auth.StoreSession(r, w, &sessionish, pdsUrl)
78 if err != nil {
79 log.Printf("failed to store session for did: %s\n: %s", atSession.Did, err)
80 return
81 }
82
83 log.Println("successfully refreshed token")
84 }
85
86 next.ServeHTTP(w, r)
87 })
88 }
89}
90
91func RoleMiddleware(s *State, group string) Middleware {
92 return func(next http.Handler) http.Handler {
93 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
94 // requires auth also
95 actor := s.auth.GetUser(r)
96 if actor == nil {
97 // we need a logged in user
98 log.Printf("not logged in, redirecting")
99 http.Error(w, "Forbiden", http.StatusUnauthorized)
100 return
101 }
102 domain := chi.URLParam(r, "domain")
103 if domain == "" {
104 http.Error(w, "malformed url", http.StatusBadRequest)
105 return
106 }
107
108 ok, err := s.enforcer.E.HasGroupingPolicy(actor.Did, group, domain)
109 if err != nil || !ok {
110 // we need a logged in user
111 log.Printf("%s does not have perms of a %s in domain %s", actor.Did, group, domain)
112 http.Error(w, "Forbiden", http.StatusUnauthorized)
113 return
114 }
115
116 next.ServeHTTP(w, r)
117 })
118 }
119}
120
121func RepoPermissionMiddleware(s *State, requiredPerm string) Middleware {
122 return func(next http.Handler) http.Handler {
123 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
124 // requires auth also
125 actor := s.auth.GetUser(r)
126 if actor == nil {
127 // we need a logged in user
128 log.Printf("not logged in, redirecting")
129 http.Error(w, "Forbiden", http.StatusUnauthorized)
130 return
131 }
132 f, err := fullyResolvedRepo(r)
133 if err != nil {
134 http.Error(w, "malformed url", http.StatusBadRequest)
135 return
136 }
137
138 ok, err := s.enforcer.E.Enforce(actor.Did, f.Knot, f.OwnerSlashRepo(), requiredPerm)
139 if err != nil || !ok {
140 // we need a logged in user
141 log.Printf("%s does not have perms of a %s in repo %s", actor.Did, requiredPerm, f.OwnerSlashRepo())
142 http.Error(w, "Forbiden", http.StatusUnauthorized)
143 return
144 }
145
146 next.ServeHTTP(w, r)
147 })
148 }
149}
150
151func StripLeadingAt(next http.Handler) http.Handler {
152 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
153 path := req.URL.Path
154 if strings.HasPrefix(path, "/@") {
155 req.URL.Path = "/" + strings.TrimPrefix(path, "/@")
156 }
157 next.ServeHTTP(w, req)
158 })
159}
160
161func ResolveIdent(s *State) Middleware {
162 return func(next http.Handler) http.Handler {
163 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
164 didOrHandle := chi.URLParam(req, "user")
165
166 id, err := s.resolver.ResolveIdent(req.Context(), didOrHandle)
167 if err != nil {
168 // invalid did or handle
169 log.Println("failed to resolve did/handle:", err)
170 w.WriteHeader(http.StatusNotFound)
171 return
172 }
173
174 ctx := context.WithValue(req.Context(), "resolvedId", *id)
175
176 next.ServeHTTP(w, req.WithContext(ctx))
177 })
178 }
179}
180
181func ResolveRepoKnot(s *State) Middleware {
182 return func(next http.Handler) http.Handler {
183 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
184 repoName := chi.URLParam(req, "repo")
185 id, ok := req.Context().Value("resolvedId").(identity.Identity)
186 if !ok {
187 log.Println("malformed middleware")
188 w.WriteHeader(http.StatusInternalServerError)
189 return
190 }
191
192 repo, err := db.GetRepo(s.db, id.DID.String(), repoName)
193 if err != nil {
194 // invalid did or handle
195 log.Println("failed to resolve repo")
196 w.WriteHeader(http.StatusNotFound)
197 return
198 }
199
200 ctx := context.WithValue(req.Context(), "knot", repo.Knot)
201 ctx = context.WithValue(ctx, "repoAt", repo.AtUri)
202 ctx = context.WithValue(ctx, "repoDescription", repo.Description)
203 ctx = context.WithValue(ctx, "repoAddedAt", repo.Created.Format(time.RFC3339))
204 next.ServeHTTP(w, req.WithContext(ctx))
205 })
206 }
207}