forked from
tangled.org/core
Monorepo for Tangled — https://tangled.org
1package oauth
2
3import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "fmt"
8 "log"
9 "net/http"
10 "slices"
11 "time"
12
13 "github.com/go-chi/chi/v5"
14 "github.com/lestrrat-go/jwx/v2/jwk"
15 "github.com/posthog/posthog-go"
16 "tangled.org/core/api/tangled"
17 "tangled.org/core/appview/db"
18 "tangled.org/core/consts"
19 "tangled.org/core/tid"
20)
21
22func (o *OAuth) Router() http.Handler {
23 r := chi.NewRouter()
24
25 r.Get("/oauth/client-metadata.json", o.clientMetadata)
26 r.Get("/oauth/jwks.json", o.jwks)
27 r.Get("/oauth/callback", o.callback)
28 return r
29}
30
31func (o *OAuth) clientMetadata(w http.ResponseWriter, r *http.Request) {
32 doc := o.ClientApp.Config.ClientMetadata()
33 doc.JWKSURI = &o.JwksUri
34
35 w.Header().Set("Content-Type", "application/json")
36 if err := json.NewEncoder(w).Encode(doc); err != nil {
37 http.Error(w, err.Error(), http.StatusInternalServerError)
38 return
39 }
40}
41
42func (o *OAuth) jwks(w http.ResponseWriter, r *http.Request) {
43 jwks := o.Config.OAuth.Jwks
44 pubKey, err := pubKeyFromJwk(jwks)
45 if err != nil {
46 log.Printf("error parsing public key: %v", err)
47 http.Error(w, err.Error(), http.StatusInternalServerError)
48 return
49 }
50
51 response := map[string]any{
52 "keys": []jwk.Key{pubKey},
53 }
54
55 w.Header().Set("Content-Type", "application/json")
56 w.WriteHeader(http.StatusOK)
57 json.NewEncoder(w).Encode(response)
58}
59
60func (o *OAuth) callback(w http.ResponseWriter, r *http.Request) {
61 ctx := r.Context()
62
63 sessData, err := o.ClientApp.ProcessCallback(ctx, r.URL.Query())
64 if err != nil {
65 http.Error(w, err.Error(), http.StatusInternalServerError)
66 return
67 }
68
69 if err := o.SaveSession(w, r, sessData); err != nil {
70 http.Error(w, err.Error(), http.StatusInternalServerError)
71 return
72 }
73
74 log.Println("session saved successfully")
75 go o.addToDefaultKnot(sessData.AccountDID.String())
76 go o.addToDefaultSpindle(sessData.AccountDID.String())
77
78 if !o.Config.Core.Dev {
79 err = o.Posthog.Enqueue(posthog.Capture{
80 DistinctId: sessData.AccountDID.String(),
81 Event: "signin",
82 })
83 if err != nil {
84 log.Println("failed to enqueue posthog event:", err)
85 }
86 }
87
88 http.Redirect(w, r, "/", http.StatusFound)
89}
90
91func (o *OAuth) addToDefaultSpindle(did string) {
92 // use the tangled.sh app password to get an accessJwt
93 // and create an sh.tangled.spindle.member record with that
94 spindleMembers, err := db.GetSpindleMembers(
95 o.Db,
96 db.FilterEq("instance", "spindle.tangled.sh"),
97 db.FilterEq("subject", did),
98 )
99 if err != nil {
100 log.Printf("failed to get spindle members for did %s: %v", did, err)
101 return
102 }
103
104 if len(spindleMembers) != 0 {
105 log.Printf("did %s is already a member of the default spindle", did)
106 return
107 }
108
109 log.Printf("adding %s to default spindle", did)
110 session, err := o.createAppPasswordSession(o.Config.Core.AppPassword, consts.TangledDid)
111 if err != nil {
112 log.Printf("failed to create session: %s", err)
113 return
114 }
115
116 record := tangled.SpindleMember{
117 LexiconTypeID: "sh.tangled.spindle.member",
118 Subject: did,
119 Instance: consts.DefaultSpindle,
120 CreatedAt: time.Now().Format(time.RFC3339),
121 }
122
123 if err := session.putRecord(record, tangled.SpindleMemberNSID); err != nil {
124 log.Printf("failed to add member to default spindle: %s", err)
125 return
126 }
127
128 log.Printf("successfully added %s to default spindle", did)
129}
130
131func (o *OAuth) addToDefaultKnot(did string) {
132 // use the tangled.sh app password to get an accessJwt
133 // and create an sh.tangled.spindle.member record with that
134
135 allKnots, err := o.Enforcer.GetKnotsForUser(did)
136 if err != nil {
137 log.Printf("failed to get knot members for did %s: %v", did, err)
138 return
139 }
140
141 if slices.Contains(allKnots, consts.DefaultKnot) {
142 log.Printf("did %s is already a member of the default knot", did)
143 return
144 }
145
146 log.Printf("adding %s to default knot", did)
147 session, err := o.createAppPasswordSession(o.Config.Core.TmpAltAppPassword, consts.IcyDid)
148 if err != nil {
149 log.Printf("failed to create session: %s", err)
150 return
151 }
152
153 record := tangled.KnotMember{
154 LexiconTypeID: "sh.tangled.knot.member",
155 Subject: did,
156 Domain: consts.DefaultKnot,
157 CreatedAt: time.Now().Format(time.RFC3339),
158 }
159
160 if err := session.putRecord(record, tangled.KnotMemberNSID); err != nil {
161 log.Printf("failed to add member to default knot: %s", err)
162 return
163 }
164
165 if err := o.Enforcer.AddKnotMember(consts.DefaultKnot, did); err != nil {
166 log.Printf("failed to set up enforcer rules: %s", err)
167 return
168 }
169
170 log.Printf("successfully added %s to default Knot", did)
171}
172
173// create a session using apppasswords
174type session struct {
175 AccessJwt string `json:"accessJwt"`
176 PdsEndpoint string
177 Did string
178}
179
180func (o *OAuth) createAppPasswordSession(appPassword, did string) (*session, error) {
181 if appPassword == "" {
182 return nil, fmt.Errorf("no app password configured, skipping member addition")
183 }
184
185 resolved, err := o.IdResolver.ResolveIdent(context.Background(), did)
186 if err != nil {
187 return nil, fmt.Errorf("failed to resolve tangled.sh DID %s: %v", did, err)
188 }
189
190 pdsEndpoint := resolved.PDSEndpoint()
191 if pdsEndpoint == "" {
192 return nil, fmt.Errorf("no PDS endpoint found for tangled.sh DID %s", did)
193 }
194
195 sessionPayload := map[string]string{
196 "identifier": did,
197 "password": appPassword,
198 }
199 sessionBytes, err := json.Marshal(sessionPayload)
200 if err != nil {
201 return nil, fmt.Errorf("failed to marshal session payload: %v", err)
202 }
203
204 sessionURL := pdsEndpoint + "/xrpc/com.atproto.server.createSession"
205 sessionReq, err := http.NewRequestWithContext(context.Background(), "POST", sessionURL, bytes.NewBuffer(sessionBytes))
206 if err != nil {
207 return nil, fmt.Errorf("failed to create session request: %v", err)
208 }
209 sessionReq.Header.Set("Content-Type", "application/json")
210
211 client := &http.Client{Timeout: 30 * time.Second}
212 sessionResp, err := client.Do(sessionReq)
213 if err != nil {
214 return nil, fmt.Errorf("failed to create session: %v", err)
215 }
216 defer sessionResp.Body.Close()
217
218 if sessionResp.StatusCode != http.StatusOK {
219 return nil, fmt.Errorf("failed to create session: HTTP %d", sessionResp.StatusCode)
220 }
221
222 var session session
223 if err := json.NewDecoder(sessionResp.Body).Decode(&session); err != nil {
224 return nil, fmt.Errorf("failed to decode session response: %v", err)
225 }
226
227 session.PdsEndpoint = pdsEndpoint
228 session.Did = did
229
230 return &session, nil
231}
232
233func (s *session) putRecord(record any, collection string) error {
234 recordBytes, err := json.Marshal(record)
235 if err != nil {
236 return fmt.Errorf("failed to marshal knot member record: %w", err)
237 }
238
239 payload := map[string]any{
240 "repo": s.Did,
241 "collection": collection,
242 "rkey": tid.TID(),
243 "record": json.RawMessage(recordBytes),
244 }
245
246 payloadBytes, err := json.Marshal(payload)
247 if err != nil {
248 return fmt.Errorf("failed to marshal request payload: %w", err)
249 }
250
251 url := s.PdsEndpoint + "/xrpc/com.atproto.repo.putRecord"
252 req, err := http.NewRequestWithContext(context.Background(), "POST", url, bytes.NewBuffer(payloadBytes))
253 if err != nil {
254 return fmt.Errorf("failed to create HTTP request: %w", err)
255 }
256
257 req.Header.Set("Content-Type", "application/json")
258 req.Header.Set("Authorization", "Bearer "+s.AccessJwt)
259
260 client := &http.Client{Timeout: 30 * time.Second}
261 resp, err := client.Do(req)
262 if err != nil {
263 return fmt.Errorf("failed to add user to default service: %w", err)
264 }
265 defer resp.Body.Close()
266
267 if resp.StatusCode != http.StatusOK {
268 return fmt.Errorf("failed to add user to default service: HTTP %d", resp.StatusCode)
269 }
270
271 return nil
272}