1package state
2
3import (
4 "context"
5 "log"
6 "net/http"
7 "strings"
8
9 "github.com/bluesky-social/indigo/atproto/identity"
10 "github.com/go-chi/chi/v5"
11 "github.com/sotangled/tangled/appview"
12 "github.com/sotangled/tangled/appview/db"
13)
14
15type Middleware func(http.Handler) http.Handler
16
17func AuthMiddleware(s *State) Middleware {
18 return func(next http.Handler) http.Handler {
19 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
20 if s.auth == nil {
21 http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
22 return
23 }
24 err := s.RestoreSessionIfNeeded(r, w)
25 if err != nil {
26 http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
27 return
28 }
29
30 session, _ := s.auth.Store.Get(r, appview.SessionName)
31 authorized, ok := session.Values[appview.SessionAuthenticated].(bool)
32 if !ok || !authorized {
33 log.Printf("not logged in, redirecting")
34 http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
35 return
36 }
37
38 // refresh if nearing expiry
39 next.ServeHTTP(w, r)
40 })
41 }
42}
43
44func RoleMiddleware(s *State, group string) Middleware {
45 return func(next http.Handler) http.Handler {
46 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
47 // requires auth also
48 actor := s.auth.GetUser(r)
49 if actor == nil {
50 // we need a logged in user
51 log.Printf("not logged in, redirecting")
52 http.Error(w, "Forbiden", http.StatusUnauthorized)
53 return
54 }
55 domain := chi.URLParam(r, "domain")
56 if domain == "" {
57 http.Error(w, "malformed url", http.StatusBadRequest)
58 return
59 }
60
61 ok, err := s.enforcer.E.HasGroupingPolicy(actor.Did, group, domain)
62 if err != nil || !ok {
63 // we need a logged in user
64 log.Printf("%s does not have perms of a %s in domain %s", actor.Did, group, domain)
65 http.Error(w, "Forbiden", http.StatusUnauthorized)
66 return
67 }
68
69 next.ServeHTTP(w, r)
70 })
71 }
72}
73
74func RepoPermissionMiddleware(s *State, requiredPerm string) Middleware {
75 return func(next http.Handler) http.Handler {
76 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
77 // requires auth also
78 actor := s.auth.GetUser(r)
79 if actor == nil {
80 // we need a logged in user
81 log.Printf("not logged in, redirecting")
82 http.Error(w, "Forbiden", http.StatusUnauthorized)
83 return
84 }
85 f, err := fullyResolvedRepo(r)
86 if err != nil {
87 http.Error(w, "malformed url", http.StatusBadRequest)
88 return
89 }
90
91 ok, err := s.enforcer.E.Enforce(actor.Did, f.Knot, f.OwnerSlashRepo(), requiredPerm)
92 if err != nil || !ok {
93 // we need a logged in user
94 log.Printf("%s does not have perms of a %s in repo %s", actor.Did, requiredPerm, f.OwnerSlashRepo())
95 http.Error(w, "Forbiden", http.StatusUnauthorized)
96 return
97 }
98
99 next.ServeHTTP(w, r)
100 })
101 }
102}
103
104func StripLeadingAt(next http.Handler) http.Handler {
105 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
106 path := req.URL.Path
107 if strings.HasPrefix(path, "/@") {
108 req.URL.Path = "/" + strings.TrimPrefix(path, "/@")
109 }
110 next.ServeHTTP(w, req)
111 })
112}
113
114func ResolveIdent(s *State) Middleware {
115 return func(next http.Handler) http.Handler {
116 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
117 didOrHandle := chi.URLParam(req, "user")
118
119 id, err := s.resolver.ResolveIdent(req.Context(), didOrHandle)
120 if err != nil {
121 // invalid did or handle
122 log.Println("failed to resolve did/handle:", err)
123 w.WriteHeader(http.StatusNotFound)
124 return
125 }
126
127 ctx := context.WithValue(req.Context(), "resolvedId", *id)
128
129 next.ServeHTTP(w, req.WithContext(ctx))
130 })
131 }
132}
133
134func ResolveRepoKnot(s *State) Middleware {
135 return func(next http.Handler) http.Handler {
136 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
137 repoName := chi.URLParam(req, "repo")
138 id, ok := req.Context().Value("resolvedId").(identity.Identity)
139 if !ok {
140 log.Println("malformed middleware")
141 w.WriteHeader(http.StatusInternalServerError)
142 return
143 }
144
145 repo, err := db.GetRepo(s.db, id.DID.String(), repoName)
146 if err != nil {
147 // invalid did or handle
148 log.Println("failed to resolve repo")
149 w.WriteHeader(http.StatusNotFound)
150 return
151 }
152
153 ctx := context.WithValue(req.Context(), "knot", repo.Knot)
154 ctx = context.WithValue(ctx, "repoAt", repo.AtUri)
155 next.ServeHTTP(w, req.WithContext(ctx))
156 })
157 }
158}