1package state
2
3import (
4 "context"
5 "log"
6 "net/http"
7 "strings"
8 "time"
9
10 comatproto "github.com/bluesky-social/indigo/api/atproto"
11 "github.com/bluesky-social/indigo/atproto/identity"
12 "github.com/bluesky-social/indigo/xrpc"
13 "github.com/go-chi/chi/v5"
14 "github.com/sotangled/tangled/appview"
15 "github.com/sotangled/tangled/appview/auth"
16 "github.com/sotangled/tangled/appview/db"
17)
18
19type Middleware func(http.Handler) http.Handler
20
21func AuthMiddleware(s *State) Middleware {
22 return func(next http.Handler) http.Handler {
23 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
24 session, _ := s.auth.Store.Get(r, appview.SessionName)
25 authorized, ok := session.Values[appview.SessionAuthenticated].(bool)
26 if !ok || !authorized {
27 log.Printf("not logged in, redirecting")
28 http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
29 return
30 }
31
32 // refresh if nearing expiry
33 // TODO: dedup with /login
34 expiryStr := session.Values[appview.SessionExpiry].(string)
35 expiry, err := time.Parse(time.RFC3339, expiryStr)
36 if err != nil {
37 log.Println("invalid expiry time", err)
38 http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
39 return
40 }
41 pdsUrl := session.Values[appview.SessionPds].(string)
42 did := session.Values[appview.SessionDid].(string)
43 refreshJwt := session.Values[appview.SessionRefreshJwt].(string)
44
45 if time.Now().After(expiry) {
46 log.Println("token expired, refreshing ...")
47
48 client := xrpc.Client{
49 Host: pdsUrl,
50 Auth: &xrpc.AuthInfo{
51 Did: did,
52 AccessJwt: refreshJwt,
53 RefreshJwt: refreshJwt,
54 },
55 }
56 atSession, err := comatproto.ServerRefreshSession(r.Context(), &client)
57 if err != nil {
58 log.Println("failed to refresh session", err)
59 http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
60 return
61 }
62
63 sessionish := auth.RefreshSessionWrapper{atSession}
64
65 err = s.auth.StoreSession(r, w, &sessionish, pdsUrl)
66 if err != nil {
67 log.Printf("failed to store session for did: %s\n: %s", atSession.Did, err)
68 return
69 }
70
71 log.Println("successfully refreshed token")
72 }
73
74 next.ServeHTTP(w, r)
75 })
76 }
77}
78
79func RoleMiddleware(s *State, group string) Middleware {
80 return func(next http.Handler) http.Handler {
81 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
82 // requires auth also
83 actor := s.auth.GetUser(r)
84 if actor == nil {
85 // we need a logged in user
86 log.Printf("not logged in, redirecting")
87 http.Error(w, "Forbiden", http.StatusUnauthorized)
88 return
89 }
90 domain := chi.URLParam(r, "domain")
91 if domain == "" {
92 http.Error(w, "malformed url", http.StatusBadRequest)
93 return
94 }
95
96 ok, err := s.enforcer.E.HasGroupingPolicy(actor.Did, group, domain)
97 if err != nil || !ok {
98 // we need a logged in user
99 log.Printf("%s does not have perms of a %s in domain %s", actor.Did, group, domain)
100 http.Error(w, "Forbiden", http.StatusUnauthorized)
101 return
102 }
103
104 next.ServeHTTP(w, r)
105 })
106 }
107}
108
109func RepoPermissionMiddleware(s *State, requiredPerm string) Middleware {
110 return func(next http.Handler) http.Handler {
111 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
112 // requires auth also
113 actor := s.auth.GetUser(r)
114 if actor == nil {
115 // we need a logged in user
116 log.Printf("not logged in, redirecting")
117 http.Error(w, "Forbiden", http.StatusUnauthorized)
118 return
119 }
120 f, err := fullyResolvedRepo(r)
121 if err != nil {
122 http.Error(w, "malformed url", http.StatusBadRequest)
123 return
124 }
125
126 ok, err := s.enforcer.E.Enforce(actor.Did, f.Knot, f.OwnerSlashRepo(), requiredPerm)
127 if err != nil || !ok {
128 // we need a logged in user
129 log.Printf("%s does not have perms of a %s in repo %s", actor.Did, requiredPerm, f.OwnerSlashRepo())
130 http.Error(w, "Forbiden", http.StatusUnauthorized)
131 return
132 }
133
134 next.ServeHTTP(w, r)
135 })
136 }
137}
138
139func StripLeadingAt(next http.Handler) http.Handler {
140 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
141 path := req.URL.Path
142 if strings.HasPrefix(path, "/@") {
143 req.URL.Path = "/" + strings.TrimPrefix(path, "/@")
144 }
145 next.ServeHTTP(w, req)
146 })
147}
148
149func ResolveIdent(s *State) Middleware {
150 return func(next http.Handler) http.Handler {
151 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
152 didOrHandle := chi.URLParam(req, "user")
153
154 id, err := s.resolver.ResolveIdent(req.Context(), didOrHandle)
155 if err != nil {
156 // invalid did or handle
157 log.Println("failed to resolve did/handle:", err)
158 w.WriteHeader(http.StatusNotFound)
159 return
160 }
161
162 ctx := context.WithValue(req.Context(), "resolvedId", *id)
163
164 next.ServeHTTP(w, req.WithContext(ctx))
165 })
166 }
167}
168
169func ResolveRepoKnot(s *State) Middleware {
170 return func(next http.Handler) http.Handler {
171 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
172 repoName := chi.URLParam(req, "repo")
173 id, ok := req.Context().Value("resolvedId").(identity.Identity)
174 if !ok {
175 log.Println("malformed middleware")
176 w.WriteHeader(http.StatusInternalServerError)
177 return
178 }
179
180 repo, err := db.GetRepo(s.db, id.DID.String(), repoName)
181 if err != nil {
182 // invalid did or handle
183 log.Println("failed to resolve repo")
184 w.WriteHeader(http.StatusNotFound)
185 return
186 }
187
188 ctx := context.WithValue(req.Context(), "knot", repo.Knot)
189 ctx = context.WithValue(ctx, "repoAt", repo.AtUri)
190 next.ServeHTTP(w, req.WithContext(ctx))
191 })
192 }
193}