forked from tangled.org/core
Monorepo for Tangled — https://tangled.org
1package oauth 2 3import ( 4 "bytes" 5 "context" 6 "encoding/json" 7 "fmt" 8 "log" 9 "net/http" 10 "slices" 11 "time" 12 13 "github.com/go-chi/chi/v5" 14 "github.com/lestrrat-go/jwx/v2/jwk" 15 "github.com/posthog/posthog-go" 16 "tangled.org/core/api/tangled" 17 "tangled.org/core/appview/db" 18 "tangled.org/core/consts" 19 "tangled.org/core/tid" 20) 21 22func (o *OAuth) Router() http.Handler { 23 r := chi.NewRouter() 24 25 r.Get("/oauth/client-metadata.json", o.clientMetadata) 26 r.Get("/oauth/jwks.json", o.jwks) 27 r.Get("/oauth/callback", o.callback) 28 return r 29} 30 31func (o *OAuth) clientMetadata(w http.ResponseWriter, r *http.Request) { 32 doc := o.ClientApp.Config.ClientMetadata() 33 doc.JWKSURI = &o.JwksUri 34 35 w.Header().Set("Content-Type", "application/json") 36 if err := json.NewEncoder(w).Encode(doc); err != nil { 37 http.Error(w, err.Error(), http.StatusInternalServerError) 38 return 39 } 40} 41 42func (o *OAuth) jwks(w http.ResponseWriter, r *http.Request) { 43 jwks := o.Config.OAuth.Jwks 44 pubKey, err := pubKeyFromJwk(jwks) 45 if err != nil { 46 log.Printf("error parsing public key: %v", err) 47 http.Error(w, err.Error(), http.StatusInternalServerError) 48 return 49 } 50 51 response := map[string]any{ 52 "keys": []jwk.Key{pubKey}, 53 } 54 55 w.Header().Set("Content-Type", "application/json") 56 w.WriteHeader(http.StatusOK) 57 json.NewEncoder(w).Encode(response) 58} 59 60func (o *OAuth) callback(w http.ResponseWriter, r *http.Request) { 61 ctx := r.Context() 62 63 sessData, err := o.ClientApp.ProcessCallback(ctx, r.URL.Query()) 64 if err != nil { 65 http.Error(w, err.Error(), http.StatusInternalServerError) 66 return 67 } 68 69 if err := o.SaveSession(w, r, sessData); err != nil { 70 http.Error(w, err.Error(), http.StatusInternalServerError) 71 return 72 } 73 74 log.Println("session saved successfully") 75 go o.addToDefaultKnot(sessData.AccountDID.String()) 76 go o.addToDefaultSpindle(sessData.AccountDID.String()) 77 78 if !o.Config.Core.Dev { 79 err = o.Posthog.Enqueue(posthog.Capture{ 80 DistinctId: sessData.AccountDID.String(), 81 Event: "signin", 82 }) 83 if err != nil { 84 log.Println("failed to enqueue posthog event:", err) 85 } 86 } 87 88 http.Redirect(w, r, "/", http.StatusFound) 89} 90 91func (o *OAuth) addToDefaultSpindle(did string) { 92 // use the tangled.sh app password to get an accessJwt 93 // and create an sh.tangled.spindle.member record with that 94 spindleMembers, err := db.GetSpindleMembers( 95 o.Db, 96 db.FilterEq("instance", "spindle.tangled.sh"), 97 db.FilterEq("subject", did), 98 ) 99 if err != nil { 100 log.Printf("failed to get spindle members for did %s: %v", did, err) 101 return 102 } 103 104 if len(spindleMembers) != 0 { 105 log.Printf("did %s is already a member of the default spindle", did) 106 return 107 } 108 109 log.Printf("adding %s to default spindle", did) 110 session, err := o.createAppPasswordSession(o.Config.Core.AppPassword, consts.TangledDid) 111 if err != nil { 112 log.Printf("failed to create session: %s", err) 113 return 114 } 115 116 record := tangled.SpindleMember{ 117 LexiconTypeID: "sh.tangled.spindle.member", 118 Subject: did, 119 Instance: consts.DefaultSpindle, 120 CreatedAt: time.Now().Format(time.RFC3339), 121 } 122 123 if err := session.putRecord(record, tangled.SpindleMemberNSID); err != nil { 124 log.Printf("failed to add member to default spindle: %s", err) 125 return 126 } 127 128 log.Printf("successfully added %s to default spindle", did) 129} 130 131func (o *OAuth) addToDefaultKnot(did string) { 132 // use the tangled.sh app password to get an accessJwt 133 // and create an sh.tangled.spindle.member record with that 134 135 allKnots, err := o.Enforcer.GetKnotsForUser(did) 136 if err != nil { 137 log.Printf("failed to get knot members for did %s: %v", did, err) 138 return 139 } 140 141 if slices.Contains(allKnots, consts.DefaultKnot) { 142 log.Printf("did %s is already a member of the default knot", did) 143 return 144 } 145 146 log.Printf("adding %s to default knot", did) 147 session, err := o.createAppPasswordSession(o.Config.Core.TmpAltAppPassword, consts.IcyDid) 148 if err != nil { 149 log.Printf("failed to create session: %s", err) 150 return 151 } 152 153 record := tangled.KnotMember{ 154 LexiconTypeID: "sh.tangled.knot.member", 155 Subject: did, 156 Domain: consts.DefaultKnot, 157 CreatedAt: time.Now().Format(time.RFC3339), 158 } 159 160 if err := session.putRecord(record, tangled.KnotMemberNSID); err != nil { 161 log.Printf("failed to add member to default knot: %s", err) 162 return 163 } 164 165 if err := o.Enforcer.AddKnotMember(consts.DefaultKnot, did); err != nil { 166 log.Printf("failed to set up enforcer rules: %s", err) 167 return 168 } 169 170 log.Printf("successfully added %s to default Knot", did) 171} 172 173// create a session using apppasswords 174type session struct { 175 AccessJwt string `json:"accessJwt"` 176 PdsEndpoint string 177 Did string 178} 179 180func (o *OAuth) createAppPasswordSession(appPassword, did string) (*session, error) { 181 if appPassword == "" { 182 return nil, fmt.Errorf("no app password configured, skipping member addition") 183 } 184 185 resolved, err := o.IdResolver.ResolveIdent(context.Background(), did) 186 if err != nil { 187 return nil, fmt.Errorf("failed to resolve tangled.sh DID %s: %v", did, err) 188 } 189 190 pdsEndpoint := resolved.PDSEndpoint() 191 if pdsEndpoint == "" { 192 return nil, fmt.Errorf("no PDS endpoint found for tangled.sh DID %s", did) 193 } 194 195 sessionPayload := map[string]string{ 196 "identifier": did, 197 "password": appPassword, 198 } 199 sessionBytes, err := json.Marshal(sessionPayload) 200 if err != nil { 201 return nil, fmt.Errorf("failed to marshal session payload: %v", err) 202 } 203 204 sessionURL := pdsEndpoint + "/xrpc/com.atproto.server.createSession" 205 sessionReq, err := http.NewRequestWithContext(context.Background(), "POST", sessionURL, bytes.NewBuffer(sessionBytes)) 206 if err != nil { 207 return nil, fmt.Errorf("failed to create session request: %v", err) 208 } 209 sessionReq.Header.Set("Content-Type", "application/json") 210 211 client := &http.Client{Timeout: 30 * time.Second} 212 sessionResp, err := client.Do(sessionReq) 213 if err != nil { 214 return nil, fmt.Errorf("failed to create session: %v", err) 215 } 216 defer sessionResp.Body.Close() 217 218 if sessionResp.StatusCode != http.StatusOK { 219 return nil, fmt.Errorf("failed to create session: HTTP %d", sessionResp.StatusCode) 220 } 221 222 var session session 223 if err := json.NewDecoder(sessionResp.Body).Decode(&session); err != nil { 224 return nil, fmt.Errorf("failed to decode session response: %v", err) 225 } 226 227 session.PdsEndpoint = pdsEndpoint 228 session.Did = did 229 230 return &session, nil 231} 232 233func (s *session) putRecord(record any, collection string) error { 234 recordBytes, err := json.Marshal(record) 235 if err != nil { 236 return fmt.Errorf("failed to marshal knot member record: %w", err) 237 } 238 239 payload := map[string]any{ 240 "repo": s.Did, 241 "collection": collection, 242 "rkey": tid.TID(), 243 "record": json.RawMessage(recordBytes), 244 } 245 246 payloadBytes, err := json.Marshal(payload) 247 if err != nil { 248 return fmt.Errorf("failed to marshal request payload: %w", err) 249 } 250 251 url := s.PdsEndpoint + "/xrpc/com.atproto.repo.putRecord" 252 req, err := http.NewRequestWithContext(context.Background(), "POST", url, bytes.NewBuffer(payloadBytes)) 253 if err != nil { 254 return fmt.Errorf("failed to create HTTP request: %w", err) 255 } 256 257 req.Header.Set("Content-Type", "application/json") 258 req.Header.Set("Authorization", "Bearer "+s.AccessJwt) 259 260 client := &http.Client{Timeout: 30 * time.Second} 261 resp, err := client.Do(req) 262 if err != nil { 263 return fmt.Errorf("failed to add user to default service: %w", err) 264 } 265 defer resp.Body.Close() 266 267 if resp.StatusCode != http.StatusOK { 268 return fmt.Errorf("failed to add user to default service: HTTP %d", resp.StatusCode) 269 } 270 271 return nil 272}