1package oauth
2
3import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "fmt"
8 "net/http"
9 "slices"
10 "time"
11
12 "github.com/go-chi/chi/v5"
13 "github.com/lestrrat-go/jwx/v2/jwk"
14 "github.com/posthog/posthog-go"
15 "tangled.org/core/api/tangled"
16 "tangled.org/core/appview/db"
17 "tangled.org/core/consts"
18 "tangled.org/core/tid"
19)
20
21func (o *OAuth) Router() http.Handler {
22 r := chi.NewRouter()
23
24 r.Get("/oauth/client-metadata.json", o.clientMetadata)
25 r.Get("/oauth/jwks.json", o.jwks)
26 r.Get("/oauth/callback", o.callback)
27 return r
28}
29
30func (o *OAuth) clientMetadata(w http.ResponseWriter, r *http.Request) {
31 doc := o.ClientApp.Config.ClientMetadata()
32 doc.JWKSURI = &o.JwksUri
33
34 w.Header().Set("Content-Type", "application/json")
35 if err := json.NewEncoder(w).Encode(doc); err != nil {
36 http.Error(w, err.Error(), http.StatusInternalServerError)
37 return
38 }
39}
40
41func (o *OAuth) jwks(w http.ResponseWriter, r *http.Request) {
42 jwks := o.Config.OAuth.Jwks
43 pubKey, err := pubKeyFromJwk(jwks)
44 if err != nil {
45 o.Logger.Error("error parsing public key", "err", err)
46 http.Error(w, err.Error(), http.StatusInternalServerError)
47 return
48 }
49
50 response := map[string]any{
51 "keys": []jwk.Key{pubKey},
52 }
53
54 w.Header().Set("Content-Type", "application/json")
55 w.WriteHeader(http.StatusOK)
56 json.NewEncoder(w).Encode(response)
57}
58
59func (o *OAuth) callback(w http.ResponseWriter, r *http.Request) {
60 ctx := r.Context()
61
62 sessData, err := o.ClientApp.ProcessCallback(ctx, r.URL.Query())
63 if err != nil {
64 http.Error(w, err.Error(), http.StatusInternalServerError)
65 return
66 }
67
68 if err := o.SaveSession(w, r, sessData); err != nil {
69 http.Error(w, err.Error(), http.StatusInternalServerError)
70 return
71 }
72
73 o.Logger.Debug("session saved successfully")
74 go o.addToDefaultKnot(sessData.AccountDID.String())
75 go o.addToDefaultSpindle(sessData.AccountDID.String())
76
77 if !o.Config.Core.Dev {
78 err = o.Posthog.Enqueue(posthog.Capture{
79 DistinctId: sessData.AccountDID.String(),
80 Event: "signin",
81 })
82 if err != nil {
83 o.Logger.Error("failed to enqueue posthog event", "err", err)
84 }
85 }
86
87 http.Redirect(w, r, "/", http.StatusFound)
88}
89
90func (o *OAuth) addToDefaultSpindle(did string) {
91 l := o.Logger.With("subject", did)
92
93 // use the tangled.sh app password to get an accessJwt
94 // and create an sh.tangled.spindle.member record with that
95 spindleMembers, err := db.GetSpindleMembers(
96 o.Db,
97 db.FilterEq("instance", "spindle.tangled.sh"),
98 db.FilterEq("subject", did),
99 )
100 if err != nil {
101 l.Error("failed to get spindle members", "err", err)
102 return
103 }
104
105 if len(spindleMembers) != 0 {
106 l.Warn("already a member of the default spindle")
107 return
108 }
109
110 l.Debug("adding to default spindle")
111 session, err := o.createAppPasswordSession(o.Config.Core.AppPassword, consts.TangledDid)
112 if err != nil {
113 l.Error("failed to create session", "err", err)
114 return
115 }
116
117 record := tangled.SpindleMember{
118 LexiconTypeID: "sh.tangled.spindle.member",
119 Subject: did,
120 Instance: consts.DefaultSpindle,
121 CreatedAt: time.Now().Format(time.RFC3339),
122 }
123
124 if err := session.putRecord(record, tangled.SpindleMemberNSID); err != nil {
125 l.Error("failed to add to default spindle", "err", err)
126 return
127 }
128
129 l.Debug("successfully added to default spindle", "did", did)
130}
131
132func (o *OAuth) addToDefaultKnot(did string) {
133 l := o.Logger.With("subject", did)
134
135 // use the tangled.sh app password to get an accessJwt
136 // and create an sh.tangled.spindle.member record with that
137
138 allKnots, err := o.Enforcer.GetKnotsForUser(did)
139 if err != nil {
140 l.Error("failed to get knot members for did", "err", err)
141 return
142 }
143
144 if slices.Contains(allKnots, consts.DefaultKnot) {
145 l.Warn("already a member of the default knot")
146 return
147 }
148
149 l.Debug("addings to default knot")
150 session, err := o.createAppPasswordSession(o.Config.Core.TmpAltAppPassword, consts.IcyDid)
151 if err != nil {
152 l.Error("failed to create session", "err", err)
153 return
154 }
155
156 record := tangled.KnotMember{
157 LexiconTypeID: "sh.tangled.knot.member",
158 Subject: did,
159 Domain: consts.DefaultKnot,
160 CreatedAt: time.Now().Format(time.RFC3339),
161 }
162
163 if err := session.putRecord(record, tangled.KnotMemberNSID); err != nil {
164 l.Error("failed to add to default knot", "err", err)
165 return
166 }
167
168 if err := o.Enforcer.AddKnotMember(consts.DefaultKnot, did); err != nil {
169 l.Error("failed to set up enforcer rules", "err", err)
170 return
171 }
172
173 l.Debug("successfully addeds to default Knot")
174}
175
176// create a session using apppasswords
177type session struct {
178 AccessJwt string `json:"accessJwt"`
179 PdsEndpoint string
180 Did string
181}
182
183func (o *OAuth) createAppPasswordSession(appPassword, did string) (*session, error) {
184 if appPassword == "" {
185 return nil, fmt.Errorf("no app password configured, skipping member addition")
186 }
187
188 resolved, err := o.IdResolver.ResolveIdent(context.Background(), did)
189 if err != nil {
190 return nil, fmt.Errorf("failed to resolve tangled.sh DID %s: %v", did, err)
191 }
192
193 pdsEndpoint := resolved.PDSEndpoint()
194 if pdsEndpoint == "" {
195 return nil, fmt.Errorf("no PDS endpoint found for tangled.sh DID %s", did)
196 }
197
198 sessionPayload := map[string]string{
199 "identifier": did,
200 "password": appPassword,
201 }
202 sessionBytes, err := json.Marshal(sessionPayload)
203 if err != nil {
204 return nil, fmt.Errorf("failed to marshal session payload: %v", err)
205 }
206
207 sessionURL := pdsEndpoint + "/xrpc/com.atproto.server.createSession"
208 sessionReq, err := http.NewRequestWithContext(context.Background(), "POST", sessionURL, bytes.NewBuffer(sessionBytes))
209 if err != nil {
210 return nil, fmt.Errorf("failed to create session request: %v", err)
211 }
212 sessionReq.Header.Set("Content-Type", "application/json")
213
214 client := &http.Client{Timeout: 30 * time.Second}
215 sessionResp, err := client.Do(sessionReq)
216 if err != nil {
217 return nil, fmt.Errorf("failed to create session: %v", err)
218 }
219 defer sessionResp.Body.Close()
220
221 if sessionResp.StatusCode != http.StatusOK {
222 return nil, fmt.Errorf("failed to create session: HTTP %d", sessionResp.StatusCode)
223 }
224
225 var session session
226 if err := json.NewDecoder(sessionResp.Body).Decode(&session); err != nil {
227 return nil, fmt.Errorf("failed to decode session response: %v", err)
228 }
229
230 session.PdsEndpoint = pdsEndpoint
231 session.Did = did
232
233 return &session, nil
234}
235
236func (s *session) putRecord(record any, collection string) error {
237 recordBytes, err := json.Marshal(record)
238 if err != nil {
239 return fmt.Errorf("failed to marshal knot member record: %w", err)
240 }
241
242 payload := map[string]any{
243 "repo": s.Did,
244 "collection": collection,
245 "rkey": tid.TID(),
246 "record": json.RawMessage(recordBytes),
247 }
248
249 payloadBytes, err := json.Marshal(payload)
250 if err != nil {
251 return fmt.Errorf("failed to marshal request payload: %w", err)
252 }
253
254 url := s.PdsEndpoint + "/xrpc/com.atproto.repo.putRecord"
255 req, err := http.NewRequestWithContext(context.Background(), "POST", url, bytes.NewBuffer(payloadBytes))
256 if err != nil {
257 return fmt.Errorf("failed to create HTTP request: %w", err)
258 }
259
260 req.Header.Set("Content-Type", "application/json")
261 req.Header.Set("Authorization", "Bearer "+s.AccessJwt)
262
263 client := &http.Client{Timeout: 30 * time.Second}
264 resp, err := client.Do(req)
265 if err != nil {
266 return fmt.Errorf("failed to add user to default service: %w", err)
267 }
268 defer resp.Body.Close()
269
270 if resp.StatusCode != http.StatusOK {
271 return fmt.Errorf("failed to add user to default service: HTTP %d", resp.StatusCode)
272 }
273
274 return nil
275}