1package oauth
2
3import (
4 "encoding/json"
5 "fmt"
6 "log"
7 "net/http"
8 "net/url"
9 "strings"
10
11 "github.com/go-chi/chi/v5"
12 "github.com/gorilla/sessions"
13 "github.com/haileyok/atproto-oauth-golang/helpers"
14 "github.com/lestrrat-go/jwx/v2/jwk"
15 "github.com/posthog/posthog-go"
16 "tangled.sh/tangled.sh/core/appview"
17 "tangled.sh/tangled.sh/core/appview/db"
18 "tangled.sh/tangled.sh/core/appview/middleware"
19 "tangled.sh/tangled.sh/core/appview/oauth"
20 "tangled.sh/tangled.sh/core/appview/oauth/client"
21 "tangled.sh/tangled.sh/core/appview/pages"
22 "tangled.sh/tangled.sh/core/knotclient"
23 "tangled.sh/tangled.sh/core/rbac"
24)
25
26const (
27 oauthScope = "atproto transition:generic"
28)
29
30type OAuthHandler struct {
31 Config *appview.Config
32 Pages *pages.Pages
33 Resolver *appview.Resolver
34 Db *db.DB
35 Store *sessions.CookieStore
36 OAuth *oauth.OAuth
37 Enforcer *rbac.Enforcer
38 Posthog posthog.Client
39}
40
41func (o *OAuthHandler) Router() http.Handler {
42 r := chi.NewRouter()
43
44 r.Get("/login", o.login)
45 r.Post("/login", o.login)
46
47 r.With(middleware.AuthMiddleware(o.OAuth)).Post("/logout", o.logout)
48
49 r.Get("/oauth/client-metadata.json", o.clientMetadata)
50 r.Get("/oauth/jwks.json", o.jwks)
51 r.Get("/oauth/callback", o.callback)
52 return r
53}
54
55func (o *OAuthHandler) clientMetadata(w http.ResponseWriter, r *http.Request) {
56 w.Header().Set("Content-Type", "application/json")
57 w.WriteHeader(http.StatusOK)
58 json.NewEncoder(w).Encode(o.OAuth.ClientMetadata())
59}
60
61func (o *OAuthHandler) jwks(w http.ResponseWriter, r *http.Request) {
62 jwks := o.Config.OAuth.Jwks
63 pubKey, err := pubKeyFromJwk(jwks)
64 if err != nil {
65 log.Printf("error parsing public key: %v", err)
66 http.Error(w, err.Error(), http.StatusInternalServerError)
67 return
68 }
69
70 response := helpers.CreateJwksResponseObject(pubKey)
71
72 w.Header().Set("Content-Type", "application/json")
73 w.WriteHeader(http.StatusOK)
74 json.NewEncoder(w).Encode(response)
75}
76
77func (o *OAuthHandler) login(w http.ResponseWriter, r *http.Request) {
78 switch r.Method {
79 case http.MethodGet:
80 o.Pages.Login(w, pages.LoginParams{})
81 case http.MethodPost:
82 handle := strings.TrimPrefix(r.FormValue("handle"), "@")
83
84 resolved, err := o.Resolver.ResolveIdent(r.Context(), handle)
85 if err != nil {
86 log.Println("failed to resolve handle:", err)
87 o.Pages.Notice(w, "login-msg", fmt.Sprintf("\"%s\" is an invalid handle.", handle))
88 return
89 }
90 self := o.OAuth.ClientMetadata()
91 oauthClient, err := client.NewClient(
92 self.ClientID,
93 o.Config.OAuth.Jwks,
94 self.RedirectURIs[0],
95 )
96
97 if err != nil {
98 log.Println("failed to create oauth client:", err)
99 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
100 return
101 }
102
103 authServer, err := oauthClient.ResolvePdsAuthServer(r.Context(), resolved.PDSEndpoint())
104 if err != nil {
105 log.Println("failed to resolve auth server:", err)
106 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
107 return
108 }
109
110 authMeta, err := oauthClient.FetchAuthServerMetadata(r.Context(), authServer)
111 if err != nil {
112 log.Println("failed to fetch auth server metadata:", err)
113 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
114 return
115 }
116
117 dpopKey, err := helpers.GenerateKey(nil)
118 if err != nil {
119 log.Println("failed to generate dpop key:", err)
120 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
121 return
122 }
123
124 dpopKeyJson, err := json.Marshal(dpopKey)
125 if err != nil {
126 log.Println("failed to marshal dpop key:", err)
127 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
128 return
129 }
130
131 parResp, err := oauthClient.SendParAuthRequest(r.Context(), authServer, authMeta, handle, oauthScope, dpopKey)
132 if err != nil {
133 log.Println("failed to send par auth request:", err)
134 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
135 return
136 }
137
138 err = db.SaveOAuthRequest(o.Db, db.OAuthRequest{
139 Did: resolved.DID.String(),
140 PdsUrl: resolved.PDSEndpoint(),
141 Handle: handle,
142 AuthserverIss: authMeta.Issuer,
143 PkceVerifier: parResp.PkceVerifier,
144 DpopAuthserverNonce: parResp.DpopAuthserverNonce,
145 DpopPrivateJwk: string(dpopKeyJson),
146 State: parResp.State,
147 })
148 if err != nil {
149 log.Println("failed to save oauth request:", err)
150 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
151 return
152 }
153
154 u, _ := url.Parse(authMeta.AuthorizationEndpoint)
155 query := url.Values{}
156 query.Add("client_id", self.ClientID)
157 query.Add("request_uri", parResp.RequestUri)
158 u.RawQuery = query.Encode()
159 o.Pages.HxRedirect(w, u.String())
160 }
161}
162
163func (o *OAuthHandler) callback(w http.ResponseWriter, r *http.Request) {
164 state := r.FormValue("state")
165
166 oauthRequest, err := db.GetOAuthRequestByState(o.Db, state)
167 if err != nil {
168 log.Println("failed to get oauth request:", err)
169 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
170 return
171 }
172
173 defer func() {
174 err := db.DeleteOAuthRequestByState(o.Db, state)
175 if err != nil {
176 log.Println("failed to delete oauth request for state:", state, err)
177 }
178 }()
179
180 error := r.FormValue("error")
181 errorDescription := r.FormValue("error_description")
182 if error != "" || errorDescription != "" {
183 log.Printf("error: %s, %s", error, errorDescription)
184 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
185 return
186 }
187
188 code := r.FormValue("code")
189 if code == "" {
190 log.Println("missing code for state: ", state)
191 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
192 return
193 }
194
195 iss := r.FormValue("iss")
196 if iss == "" {
197 log.Println("missing iss for state: ", state)
198 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
199 return
200 }
201
202 self := o.OAuth.ClientMetadata()
203
204 oauthClient, err := client.NewClient(
205 self.ClientID,
206 o.Config.OAuth.Jwks,
207 self.RedirectURIs[0],
208 )
209
210 if err != nil {
211 log.Println("failed to create oauth client:", err)
212 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
213 return
214 }
215
216 jwk, err := helpers.ParseJWKFromBytes([]byte(oauthRequest.DpopPrivateJwk))
217 if err != nil {
218 log.Println("failed to parse jwk:", err)
219 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
220 return
221 }
222
223 tokenResp, err := oauthClient.InitialTokenRequest(
224 r.Context(),
225 code,
226 oauthRequest.AuthserverIss,
227 oauthRequest.PkceVerifier,
228 oauthRequest.DpopAuthserverNonce,
229 jwk,
230 )
231 if err != nil {
232 log.Println("failed to get token:", err)
233 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
234 return
235 }
236
237 if tokenResp.Scope != oauthScope {
238 log.Println("scope doesn't match:", tokenResp.Scope)
239 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
240 return
241 }
242
243 err = o.OAuth.SaveSession(w, r, oauthRequest, tokenResp)
244 if err != nil {
245 log.Println("failed to save session:", err)
246 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
247 return
248 }
249
250 log.Println("session saved successfully")
251 go o.addToDefaultKnot(oauthRequest.Did)
252
253 if !o.Config.Core.Dev {
254 err = o.Posthog.Enqueue(posthog.Capture{
255 DistinctId: oauthRequest.Did,
256 Event: "signin",
257 })
258 if err != nil {
259 log.Println("failed to enqueue posthog event:", err)
260 }
261 }
262
263 http.Redirect(w, r, "/", http.StatusFound)
264}
265
266func (o *OAuthHandler) logout(w http.ResponseWriter, r *http.Request) {
267 err := o.OAuth.ClearSession(r, w)
268 if err != nil {
269 log.Println("failed to clear session:", err)
270 http.Redirect(w, r, "/", http.StatusFound)
271 return
272 }
273
274 log.Println("session cleared successfully")
275 http.Redirect(w, r, "/", http.StatusFound)
276}
277
278func pubKeyFromJwk(jwks string) (jwk.Key, error) {
279 k, err := helpers.ParseJWKFromBytes([]byte(jwks))
280 if err != nil {
281 return nil, err
282 }
283 pubKey, err := k.PublicKey()
284 if err != nil {
285 return nil, err
286 }
287 return pubKey, nil
288}
289
290func (o *OAuthHandler) addToDefaultKnot(did string) {
291 defaultKnot := "knot1.tangled.sh"
292
293 log.Printf("adding %s to default knot", did)
294 err := o.Enforcer.AddMember(defaultKnot, did)
295 if err != nil {
296 log.Println("failed to add user to knot1.tangled.sh: ", err)
297 return
298 }
299 err = o.Enforcer.E.SavePolicy()
300 if err != nil {
301 log.Println("failed to add user to knot1.tangled.sh: ", err)
302 return
303 }
304
305 secret, err := db.GetRegistrationKey(o.Db, defaultKnot)
306 if err != nil {
307 log.Println("failed to get registration key for knot1.tangled.sh")
308 return
309 }
310 signedClient, err := knotclient.NewSignedClient(defaultKnot, secret, o.Config.Core.Dev)
311 resp, err := signedClient.AddMember(did)
312 if err != nil {
313 log.Println("failed to add user to knot1.tangled.sh: ", err)
314 return
315 }
316
317 if resp.StatusCode != http.StatusNoContent {
318 log.Println("failed to add user to knot1.tangled.sh: ", resp.StatusCode)
319 return
320 }
321}