1package state
2
3import (
4 "context"
5 "log"
6 "net/http"
7 "strings"
8 "time"
9
10 comatproto "github.com/bluesky-social/indigo/api/atproto"
11 "github.com/bluesky-social/indigo/atproto/identity"
12 "github.com/bluesky-social/indigo/xrpc"
13 "github.com/go-chi/chi/v5"
14 "github.com/sotangled/tangled/appview"
15 "github.com/sotangled/tangled/appview/auth"
16)
17
18type Middleware func(http.Handler) http.Handler
19
20func AuthMiddleware(s *State) Middleware {
21 return func(next http.Handler) http.Handler {
22 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
23 session, _ := s.auth.Store.Get(r, appview.SessionName)
24 authorized, ok := session.Values[appview.SessionAuthenticated].(bool)
25 if !ok || !authorized {
26 log.Printf("not logged in, redirecting")
27 http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
28 return
29 }
30
31 // refresh if nearing expiry
32 // TODO: dedup with /login
33 expiryStr := session.Values[appview.SessionExpiry].(string)
34 expiry, err := time.Parse(time.RFC3339, expiryStr)
35 if err != nil {
36 log.Println("invalid expiry time", err)
37 http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
38 return
39 }
40 pdsUrl := session.Values[appview.SessionPds].(string)
41 did := session.Values[appview.SessionDid].(string)
42 refreshJwt := session.Values[appview.SessionRefreshJwt].(string)
43
44 if time.Now().After(expiry) {
45 log.Println("token expired, refreshing ...")
46
47 client := xrpc.Client{
48 Host: pdsUrl,
49 Auth: &xrpc.AuthInfo{
50 Did: did,
51 AccessJwt: refreshJwt,
52 RefreshJwt: refreshJwt,
53 },
54 }
55 atSession, err := comatproto.ServerRefreshSession(r.Context(), &client)
56 if err != nil {
57 log.Println(err)
58 return
59 }
60
61 sessionish := auth.RefreshSessionWrapper{atSession}
62
63 err = s.auth.StoreSession(r, w, &sessionish, pdsUrl)
64 if err != nil {
65 log.Printf("failed to store session for did: %s\n: %s", atSession.Did, err)
66 return
67 }
68
69 log.Println("successfully refreshed token")
70 }
71
72 next.ServeHTTP(w, r)
73 })
74 }
75}
76
77func RoleMiddleware(s *State, group string) Middleware {
78 return func(next http.Handler) http.Handler {
79 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
80 // requires auth also
81 actor := s.auth.GetUser(r)
82 if actor == nil {
83 // we need a logged in user
84 log.Printf("not logged in, redirecting")
85 http.Error(w, "Forbiden", http.StatusUnauthorized)
86 return
87 }
88 domain := chi.URLParam(r, "domain")
89 if domain == "" {
90 http.Error(w, "malformed url", http.StatusBadRequest)
91 return
92 }
93
94 ok, err := s.enforcer.E.HasGroupingPolicy(actor.Did, group, domain)
95 if err != nil || !ok {
96 // we need a logged in user
97 log.Printf("%s does not have perms of a %s in domain %s", actor.Did, group, domain)
98 http.Error(w, "Forbiden", http.StatusUnauthorized)
99 return
100 }
101
102 next.ServeHTTP(w, r)
103 })
104 }
105}
106
107func RepoPermissionMiddleware(s *State, requiredPerm string) Middleware {
108 return func(next http.Handler) http.Handler {
109 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
110 // requires auth also
111 actor := s.auth.GetUser(r)
112 if actor == nil {
113 // we need a logged in user
114 log.Printf("not logged in, redirecting")
115 http.Error(w, "Forbiden", http.StatusUnauthorized)
116 return
117 }
118 f, err := fullyResolvedRepo(r)
119 if err != nil {
120 http.Error(w, "malformed url", http.StatusBadRequest)
121 return
122 }
123
124 ok, err := s.enforcer.E.Enforce(actor.Did, f.Knot, f.OwnerSlashRepo(), requiredPerm)
125 if err != nil || !ok {
126 // we need a logged in user
127 log.Printf("%s does not have perms of a %s in repo %s", actor.Did, requiredPerm, f.OwnerSlashRepo())
128 http.Error(w, "Forbiden", http.StatusUnauthorized)
129 return
130 }
131
132 next.ServeHTTP(w, r)
133 })
134 }
135}
136
137func StripLeadingAt(next http.Handler) http.Handler {
138 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
139 path := req.URL.Path
140 if strings.HasPrefix(path, "/@") {
141 req.URL.Path = "/" + strings.TrimPrefix(path, "/@")
142 }
143 next.ServeHTTP(w, req)
144 })
145}
146
147func ResolveIdent(s *State) Middleware {
148 return func(next http.Handler) http.Handler {
149 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
150 start := time.Now()
151 didOrHandle := chi.URLParam(req, "user")
152
153 id, err := s.resolver.ResolveIdent(req.Context(), didOrHandle)
154 if err != nil {
155 // invalid did or handle
156 log.Println("failed to resolve did/handle:", err)
157 w.WriteHeader(http.StatusNotFound)
158 return
159 }
160
161 ctx := context.WithValue(req.Context(), "resolvedId", *id)
162
163 elapsed := time.Since(start)
164 log.Println("Execution time:", elapsed)
165 next.ServeHTTP(w, req.WithContext(ctx))
166 })
167 }
168}
169
170func ResolveRepoKnot(s *State) Middleware {
171 return func(next http.Handler) http.Handler {
172 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
173 repoName := chi.URLParam(req, "repo")
174 id, ok := req.Context().Value("resolvedId").(identity.Identity)
175 if !ok {
176 log.Println("malformed middleware")
177 w.WriteHeader(http.StatusInternalServerError)
178 return
179 }
180
181 repo, err := s.db.GetRepo(id.DID.String(), repoName)
182 if err != nil {
183 // invalid did or handle
184 log.Println("failed to resolve repo")
185 w.WriteHeader(http.StatusNotFound)
186 return
187 }
188
189 ctx := context.WithValue(req.Context(), "knot", repo.Knot)
190 ctx = context.WithValue(ctx, "repoAt", repo.AtUri)
191 next.ServeHTTP(w, req.WithContext(ctx))
192 })
193 }
194}