forked from tangled.org/core
Monorepo for Tangled — https://tangled.org
1package oauth 2 3import ( 4 "bytes" 5 "context" 6 "encoding/json" 7 "errors" 8 "fmt" 9 "net/http" 10 "slices" 11 "time" 12 13 "github.com/bluesky-social/indigo/atproto/auth/oauth" 14 "github.com/go-chi/chi/v5" 15 "github.com/posthog/posthog-go" 16 "tangled.org/core/api/tangled" 17 "tangled.org/core/appview/db" 18 "tangled.org/core/consts" 19 "tangled.org/core/tid" 20) 21 22func (o *OAuth) Router() http.Handler { 23 r := chi.NewRouter() 24 25 r.Get("/oauth/client-metadata.json", o.clientMetadata) 26 r.Get("/oauth/jwks.json", o.jwks) 27 r.Get("/oauth/callback", o.callback) 28 return r 29} 30 31func (o *OAuth) clientMetadata(w http.ResponseWriter, r *http.Request) { 32 doc := o.ClientApp.Config.ClientMetadata() 33 doc.JWKSURI = &o.JwksUri 34 doc.ClientName = &o.ClientName 35 doc.ClientURI = &o.ClientUri 36 37 w.Header().Set("Content-Type", "application/json") 38 if err := json.NewEncoder(w).Encode(doc); err != nil { 39 http.Error(w, err.Error(), http.StatusInternalServerError) 40 return 41 } 42} 43 44func (o *OAuth) jwks(w http.ResponseWriter, r *http.Request) { 45 w.Header().Set("Content-Type", "application/json") 46 body := o.ClientApp.Config.PublicJWKS() 47 if err := json.NewEncoder(w).Encode(body); err != nil { 48 http.Error(w, err.Error(), http.StatusInternalServerError) 49 return 50 } 51} 52 53func (o *OAuth) callback(w http.ResponseWriter, r *http.Request) { 54 ctx := r.Context() 55 l := o.Logger.With("query", r.URL.Query()) 56 57 sessData, err := o.ClientApp.ProcessCallback(ctx, r.URL.Query()) 58 if err != nil { 59 var callbackErr *oauth.AuthRequestCallbackError 60 if errors.As(err, &callbackErr) { 61 l.Debug("callback error", "err", callbackErr) 62 http.Redirect(w, r, fmt.Sprintf("/login?error=%s", callbackErr.ErrorCode), http.StatusFound) 63 return 64 } 65 l.Error("failed to process callback", "err", err) 66 http.Redirect(w, r, "/login?error=oauth", http.StatusFound) 67 return 68 } 69 70 if err := o.SaveSession(w, r, sessData); err != nil { 71 l.Error("failed to save session", "data", sessData, "err", err) 72 http.Redirect(w, r, "/login?error=session", http.StatusFound) 73 return 74 } 75 76 o.Logger.Debug("session saved successfully") 77 go o.addToDefaultKnot(sessData.AccountDID.String()) 78 go o.addToDefaultSpindle(sessData.AccountDID.String()) 79 80 if !o.Config.Core.Dev { 81 err = o.Posthog.Enqueue(posthog.Capture{ 82 DistinctId: sessData.AccountDID.String(), 83 Event: "signin", 84 }) 85 if err != nil { 86 o.Logger.Error("failed to enqueue posthog event", "err", err) 87 } 88 } 89 90 http.Redirect(w, r, "/", http.StatusFound) 91} 92 93func (o *OAuth) addToDefaultSpindle(did string) { 94 l := o.Logger.With("subject", did) 95 96 // use the tangled.sh app password to get an accessJwt 97 // and create an sh.tangled.spindle.member record with that 98 spindleMembers, err := db.GetSpindleMembers( 99 o.Db, 100 db.FilterEq("instance", "spindle.tangled.sh"), 101 db.FilterEq("subject", did), 102 ) 103 if err != nil { 104 l.Error("failed to get spindle members", "err", err) 105 return 106 } 107 108 if len(spindleMembers) != 0 { 109 l.Warn("already a member of the default spindle") 110 return 111 } 112 113 l.Debug("adding to default spindle") 114 session, err := o.createAppPasswordSession(o.Config.Core.AppPassword, consts.TangledDid) 115 if err != nil { 116 l.Error("failed to create session", "err", err) 117 return 118 } 119 120 record := tangled.SpindleMember{ 121 LexiconTypeID: "sh.tangled.spindle.member", 122 Subject: did, 123 Instance: consts.DefaultSpindle, 124 CreatedAt: time.Now().Format(time.RFC3339), 125 } 126 127 if err := session.putRecord(record, tangled.SpindleMemberNSID); err != nil { 128 l.Error("failed to add to default spindle", "err", err) 129 return 130 } 131 132 l.Debug("successfully added to default spindle", "did", did) 133} 134 135func (o *OAuth) addToDefaultKnot(did string) { 136 l := o.Logger.With("subject", did) 137 138 // use the tangled.sh app password to get an accessJwt 139 // and create an sh.tangled.spindle.member record with that 140 141 allKnots, err := o.Enforcer.GetKnotsForUser(did) 142 if err != nil { 143 l.Error("failed to get knot members for did", "err", err) 144 return 145 } 146 147 if slices.Contains(allKnots, consts.DefaultKnot) { 148 l.Warn("already a member of the default knot") 149 return 150 } 151 152 l.Debug("addings to default knot") 153 session, err := o.createAppPasswordSession(o.Config.Core.TmpAltAppPassword, consts.IcyDid) 154 if err != nil { 155 l.Error("failed to create session", "err", err) 156 return 157 } 158 159 record := tangled.KnotMember{ 160 LexiconTypeID: "sh.tangled.knot.member", 161 Subject: did, 162 Domain: consts.DefaultKnot, 163 CreatedAt: time.Now().Format(time.RFC3339), 164 } 165 166 if err := session.putRecord(record, tangled.KnotMemberNSID); err != nil { 167 l.Error("failed to add to default knot", "err", err) 168 return 169 } 170 171 if err := o.Enforcer.AddKnotMember(consts.DefaultKnot, did); err != nil { 172 l.Error("failed to set up enforcer rules", "err", err) 173 return 174 } 175 176 l.Debug("successfully addeds to default Knot") 177} 178 179// create a session using apppasswords 180type session struct { 181 AccessJwt string `json:"accessJwt"` 182 PdsEndpoint string 183 Did string 184} 185 186func (o *OAuth) createAppPasswordSession(appPassword, did string) (*session, error) { 187 if appPassword == "" { 188 return nil, fmt.Errorf("no app password configured, skipping member addition") 189 } 190 191 resolved, err := o.IdResolver.ResolveIdent(context.Background(), did) 192 if err != nil { 193 return nil, fmt.Errorf("failed to resolve tangled.sh DID %s: %v", did, err) 194 } 195 196 pdsEndpoint := resolved.PDSEndpoint() 197 if pdsEndpoint == "" { 198 return nil, fmt.Errorf("no PDS endpoint found for tangled.sh DID %s", did) 199 } 200 201 sessionPayload := map[string]string{ 202 "identifier": did, 203 "password": appPassword, 204 } 205 sessionBytes, err := json.Marshal(sessionPayload) 206 if err != nil { 207 return nil, fmt.Errorf("failed to marshal session payload: %v", err) 208 } 209 210 sessionURL := pdsEndpoint + "/xrpc/com.atproto.server.createSession" 211 sessionReq, err := http.NewRequestWithContext(context.Background(), "POST", sessionURL, bytes.NewBuffer(sessionBytes)) 212 if err != nil { 213 return nil, fmt.Errorf("failed to create session request: %v", err) 214 } 215 sessionReq.Header.Set("Content-Type", "application/json") 216 217 client := &http.Client{Timeout: 30 * time.Second} 218 sessionResp, err := client.Do(sessionReq) 219 if err != nil { 220 return nil, fmt.Errorf("failed to create session: %v", err) 221 } 222 defer sessionResp.Body.Close() 223 224 if sessionResp.StatusCode != http.StatusOK { 225 return nil, fmt.Errorf("failed to create session: HTTP %d", sessionResp.StatusCode) 226 } 227 228 var session session 229 if err := json.NewDecoder(sessionResp.Body).Decode(&session); err != nil { 230 return nil, fmt.Errorf("failed to decode session response: %v", err) 231 } 232 233 session.PdsEndpoint = pdsEndpoint 234 session.Did = did 235 236 return &session, nil 237} 238 239func (s *session) putRecord(record any, collection string) error { 240 recordBytes, err := json.Marshal(record) 241 if err != nil { 242 return fmt.Errorf("failed to marshal knot member record: %w", err) 243 } 244 245 payload := map[string]any{ 246 "repo": s.Did, 247 "collection": collection, 248 "rkey": tid.TID(), 249 "record": json.RawMessage(recordBytes), 250 } 251 252 payloadBytes, err := json.Marshal(payload) 253 if err != nil { 254 return fmt.Errorf("failed to marshal request payload: %w", err) 255 } 256 257 url := s.PdsEndpoint + "/xrpc/com.atproto.repo.putRecord" 258 req, err := http.NewRequestWithContext(context.Background(), "POST", url, bytes.NewBuffer(payloadBytes)) 259 if err != nil { 260 return fmt.Errorf("failed to create HTTP request: %w", err) 261 } 262 263 req.Header.Set("Content-Type", "application/json") 264 req.Header.Set("Authorization", "Bearer "+s.AccessJwt) 265 266 client := &http.Client{Timeout: 30 * time.Second} 267 resp, err := client.Do(req) 268 if err != nil { 269 return fmt.Errorf("failed to add user to default service: %w", err) 270 } 271 defer resp.Body.Close() 272 273 if resp.StatusCode != http.StatusOK { 274 return fmt.Errorf("failed to add user to default service: HTTP %d", resp.StatusCode) 275 } 276 277 return nil 278}