1package oauth
2
3import (
4 "encoding/json"
5 "fmt"
6 "log"
7 "net/http"
8 "net/url"
9 "strings"
10
11 "github.com/go-chi/chi/v5"
12 "github.com/gorilla/sessions"
13 "github.com/lestrrat-go/jwx/v2/jwk"
14 "github.com/posthog/posthog-go"
15 "tangled.sh/icyphox.sh/atproto-oauth/helpers"
16 sessioncache "tangled.sh/tangled.sh/core/appview/cache/session"
17 "tangled.sh/tangled.sh/core/appview/config"
18 "tangled.sh/tangled.sh/core/appview/db"
19 "tangled.sh/tangled.sh/core/appview/middleware"
20 "tangled.sh/tangled.sh/core/appview/oauth"
21 "tangled.sh/tangled.sh/core/appview/oauth/client"
22 "tangled.sh/tangled.sh/core/appview/pages"
23 "tangled.sh/tangled.sh/core/idresolver"
24 "tangled.sh/tangled.sh/core/knotclient"
25 "tangled.sh/tangled.sh/core/rbac"
26)
27
28const (
29 oauthScope = "atproto transition:generic"
30)
31
32type OAuthHandler struct {
33 config *config.Config
34 pages *pages.Pages
35 idResolver *idresolver.Resolver
36 sess *sessioncache.SessionStore
37 db *db.DB
38 store *sessions.CookieStore
39 oauth *oauth.OAuth
40 enforcer *rbac.Enforcer
41 posthog posthog.Client
42}
43
44func New(
45 config *config.Config,
46 pages *pages.Pages,
47 idResolver *idresolver.Resolver,
48 db *db.DB,
49 sess *sessioncache.SessionStore,
50 store *sessions.CookieStore,
51 oauth *oauth.OAuth,
52 enforcer *rbac.Enforcer,
53 posthog posthog.Client,
54) *OAuthHandler {
55 return &OAuthHandler{
56 config: config,
57 pages: pages,
58 idResolver: idResolver,
59 db: db,
60 sess: sess,
61 store: store,
62 oauth: oauth,
63 enforcer: enforcer,
64 posthog: posthog,
65 }
66}
67
68func (o *OAuthHandler) Router() http.Handler {
69 r := chi.NewRouter()
70
71 r.Get("/login", o.login)
72 r.Post("/login", o.login)
73
74 r.With(middleware.AuthMiddleware(o.oauth)).Post("/logout", o.logout)
75
76 r.Get("/oauth/client-metadata.json", o.clientMetadata)
77 r.Get("/oauth/jwks.json", o.jwks)
78 r.Get("/oauth/callback", o.callback)
79 return r
80}
81
82func (o *OAuthHandler) clientMetadata(w http.ResponseWriter, r *http.Request) {
83 w.Header().Set("Content-Type", "application/json")
84 w.WriteHeader(http.StatusOK)
85 json.NewEncoder(w).Encode(o.oauth.ClientMetadata())
86}
87
88func (o *OAuthHandler) jwks(w http.ResponseWriter, r *http.Request) {
89 jwks := o.config.OAuth.Jwks
90 pubKey, err := pubKeyFromJwk(jwks)
91 if err != nil {
92 log.Printf("error parsing public key: %v", err)
93 http.Error(w, err.Error(), http.StatusInternalServerError)
94 return
95 }
96
97 response := helpers.CreateJwksResponseObject(pubKey)
98
99 w.Header().Set("Content-Type", "application/json")
100 w.WriteHeader(http.StatusOK)
101 json.NewEncoder(w).Encode(response)
102}
103
104func (o *OAuthHandler) login(w http.ResponseWriter, r *http.Request) {
105 switch r.Method {
106 case http.MethodGet:
107 o.pages.Login(w, pages.LoginParams{})
108 case http.MethodPost:
109 handle := r.FormValue("handle")
110
111 // when users copy their handle from bsky.app, it tends to have these characters around it:
112 //
113 // @nelind.dk:
114 // \u202a ensures that the handle is always rendered left to right and
115 // \u202c reverts that so the rest of the page renders however it should
116 handle = strings.TrimPrefix(handle, "\u202a")
117 handle = strings.TrimSuffix(handle, "\u202c")
118
119 // `@` is harmless
120 handle = strings.TrimPrefix(handle, "@")
121
122 // basic handle validation
123 if !strings.Contains(handle, ".") {
124 log.Println("invalid handle format", "raw", handle)
125 o.pages.Notice(w, "login-msg", fmt.Sprintf("\"%s\" is an invalid handle. Did you mean %s.bsky.social?", handle, handle))
126 return
127 }
128
129 resolved, err := o.idResolver.ResolveIdent(r.Context(), handle)
130 if err != nil {
131 log.Println("failed to resolve handle:", err)
132 o.pages.Notice(w, "login-msg", fmt.Sprintf("\"%s\" is an invalid handle.", handle))
133 return
134 }
135 self := o.oauth.ClientMetadata()
136 oauthClient, err := client.NewClient(
137 self.ClientID,
138 o.config.OAuth.Jwks,
139 self.RedirectURIs[0],
140 )
141
142 if err != nil {
143 log.Println("failed to create oauth client:", err)
144 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
145 return
146 }
147
148 authServer, err := oauthClient.ResolvePdsAuthServer(r.Context(), resolved.PDSEndpoint())
149 if err != nil {
150 log.Println("failed to resolve auth server:", err)
151 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
152 return
153 }
154
155 authMeta, err := oauthClient.FetchAuthServerMetadata(r.Context(), authServer)
156 if err != nil {
157 log.Println("failed to fetch auth server metadata:", err)
158 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
159 return
160 }
161
162 dpopKey, err := helpers.GenerateKey(nil)
163 if err != nil {
164 log.Println("failed to generate dpop key:", err)
165 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
166 return
167 }
168
169 dpopKeyJson, err := json.Marshal(dpopKey)
170 if err != nil {
171 log.Println("failed to marshal dpop key:", err)
172 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
173 return
174 }
175
176 parResp, err := oauthClient.SendParAuthRequest(r.Context(), authServer, authMeta, handle, oauthScope, dpopKey)
177 if err != nil {
178 log.Println("failed to send par auth request:", err)
179 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
180 return
181 }
182
183 err = o.sess.SaveRequest(r.Context(), sessioncache.OAuthRequest{
184 Did: resolved.DID.String(),
185 PdsUrl: resolved.PDSEndpoint(),
186 Handle: handle,
187 AuthserverIss: authMeta.Issuer,
188 PkceVerifier: parResp.PkceVerifier,
189 DpopAuthserverNonce: parResp.DpopAuthserverNonce,
190 DpopPrivateJwk: string(dpopKeyJson),
191 State: parResp.State,
192 })
193 if err != nil {
194 log.Println("failed to save oauth request:", err)
195 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
196 return
197 }
198
199 u, _ := url.Parse(authMeta.AuthorizationEndpoint)
200 query := url.Values{}
201 query.Add("client_id", self.ClientID)
202 query.Add("request_uri", parResp.RequestUri)
203 u.RawQuery = query.Encode()
204 o.pages.HxRedirect(w, u.String())
205 }
206}
207
208func (o *OAuthHandler) callback(w http.ResponseWriter, r *http.Request) {
209 state := r.FormValue("state")
210
211 oauthRequest, err := o.sess.GetRequestByState(r.Context(), state)
212 if err != nil {
213 log.Println("failed to get oauth request:", err)
214 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
215 return
216 }
217
218 defer func() {
219 err := o.sess.DeleteRequestByState(r.Context(), state)
220 if err != nil {
221 log.Println("failed to delete oauth request for state:", state, err)
222 }
223 }()
224
225 error := r.FormValue("error")
226 errorDescription := r.FormValue("error_description")
227 if error != "" || errorDescription != "" {
228 log.Printf("error: %s, %s", error, errorDescription)
229 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
230 return
231 }
232
233 code := r.FormValue("code")
234 if code == "" {
235 log.Println("missing code for state: ", state)
236 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
237 return
238 }
239
240 iss := r.FormValue("iss")
241 if iss == "" {
242 log.Println("missing iss for state: ", state)
243 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
244 return
245 }
246
247 self := o.oauth.ClientMetadata()
248
249 oauthClient, err := client.NewClient(
250 self.ClientID,
251 o.config.OAuth.Jwks,
252 self.RedirectURIs[0],
253 )
254
255 if err != nil {
256 log.Println("failed to create oauth client:", err)
257 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
258 return
259 }
260
261 jwk, err := helpers.ParseJWKFromBytes([]byte(oauthRequest.DpopPrivateJwk))
262 if err != nil {
263 log.Println("failed to parse jwk:", err)
264 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
265 return
266 }
267
268 tokenResp, err := oauthClient.InitialTokenRequest(
269 r.Context(),
270 code,
271 oauthRequest.AuthserverIss,
272 oauthRequest.PkceVerifier,
273 oauthRequest.DpopAuthserverNonce,
274 jwk,
275 )
276 if err != nil {
277 log.Println("failed to get token:", err)
278 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
279 return
280 }
281
282 if tokenResp.Scope != oauthScope {
283 log.Println("scope doesn't match:", tokenResp.Scope)
284 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
285 return
286 }
287
288 err = o.oauth.SaveSession(w, r, *oauthRequest, tokenResp)
289 if err != nil {
290 log.Println("failed to save session:", err)
291 o.pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
292 return
293 }
294
295 log.Println("session saved successfully")
296 go o.addToDefaultKnot(oauthRequest.Did)
297
298 if !o.config.Core.Dev {
299 err = o.posthog.Enqueue(posthog.Capture{
300 DistinctId: oauthRequest.Did,
301 Event: "signin",
302 })
303 if err != nil {
304 log.Println("failed to enqueue posthog event:", err)
305 }
306 }
307
308 http.Redirect(w, r, "/", http.StatusFound)
309}
310
311func (o *OAuthHandler) logout(w http.ResponseWriter, r *http.Request) {
312 err := o.oauth.ClearSession(r, w)
313 if err != nil {
314 log.Println("failed to clear session:", err)
315 http.Redirect(w, r, "/", http.StatusFound)
316 return
317 }
318
319 log.Println("session cleared successfully")
320 o.pages.HxRedirect(w, "/login")
321}
322
323func pubKeyFromJwk(jwks string) (jwk.Key, error) {
324 k, err := helpers.ParseJWKFromBytes([]byte(jwks))
325 if err != nil {
326 return nil, err
327 }
328 pubKey, err := k.PublicKey()
329 if err != nil {
330 return nil, err
331 }
332 return pubKey, nil
333}
334
335func (o *OAuthHandler) addToDefaultKnot(did string) {
336 defaultKnot := "knot1.tangled.sh"
337
338 log.Printf("adding %s to default knot", did)
339 err := o.enforcer.AddKnotMember(defaultKnot, did)
340 if err != nil {
341 log.Println("failed to add user to knot1.tangled.sh: ", err)
342 return
343 }
344 err = o.enforcer.E.SavePolicy()
345 if err != nil {
346 log.Println("failed to add user to knot1.tangled.sh: ", err)
347 return
348 }
349
350 secret, err := db.GetRegistrationKey(o.db, defaultKnot)
351 if err != nil {
352 log.Println("failed to get registration key for knot1.tangled.sh")
353 return
354 }
355 signedClient, err := knotclient.NewSignedClient(defaultKnot, secret, o.config.Core.Dev)
356 resp, err := signedClient.AddMember(did)
357 if err != nil {
358 log.Println("failed to add user to knot1.tangled.sh: ", err)
359 return
360 }
361
362 if resp.StatusCode != http.StatusNoContent {
363 log.Println("failed to add user to knot1.tangled.sh: ", resp.StatusCode)
364 return
365 }
366}