1package oauth
2
3import (
4 "encoding/json"
5 "fmt"
6 "log"
7 "net/http"
8 "net/url"
9 "strings"
10
11 "github.com/go-chi/chi/v5"
12 "github.com/gorilla/sessions"
13 "github.com/haileyok/atproto-oauth-golang/helpers"
14 "github.com/lestrrat-go/jwx/v2/jwk"
15 "github.com/posthog/posthog-go"
16 "tangled.sh/tangled.sh/core/appview"
17 "tangled.sh/tangled.sh/core/appview/db"
18 "tangled.sh/tangled.sh/core/appview/middleware"
19 "tangled.sh/tangled.sh/core/appview/oauth"
20 "tangled.sh/tangled.sh/core/appview/oauth/client"
21 "tangled.sh/tangled.sh/core/appview/pages"
22 "tangled.sh/tangled.sh/core/knotclient"
23 "tangled.sh/tangled.sh/core/rbac"
24 "tangled.sh/tangled.sh/core/resolver"
25)
26
27const (
28 oauthScope = "atproto transition:generic"
29)
30
31type OAuthHandler struct {
32 Config *appview.Config
33 Pages *pages.Pages
34 Resolver *resolver.Resolver
35 Db *db.DB
36 Store *sessions.CookieStore
37 OAuth *oauth.OAuth
38 Enforcer *rbac.Enforcer
39 Posthog posthog.Client
40}
41
42func (o *OAuthHandler) Router() http.Handler {
43 r := chi.NewRouter()
44
45 r.Get("/login", o.login)
46 r.Post("/login", o.login)
47
48 r.With(middleware.AuthMiddleware(o.OAuth)).Post("/logout", o.logout)
49
50 r.Get("/oauth/client-metadata.json", o.clientMetadata)
51 r.Get("/oauth/jwks.json", o.jwks)
52 r.Get("/oauth/callback", o.callback)
53 return r
54}
55
56func (o *OAuthHandler) clientMetadata(w http.ResponseWriter, r *http.Request) {
57 w.Header().Set("Content-Type", "application/json")
58 w.WriteHeader(http.StatusOK)
59 json.NewEncoder(w).Encode(o.OAuth.ClientMetadata())
60}
61
62func (o *OAuthHandler) jwks(w http.ResponseWriter, r *http.Request) {
63 jwks := o.Config.OAuth.Jwks
64 pubKey, err := pubKeyFromJwk(jwks)
65 if err != nil {
66 log.Printf("error parsing public key: %v", err)
67 http.Error(w, err.Error(), http.StatusInternalServerError)
68 return
69 }
70
71 response := helpers.CreateJwksResponseObject(pubKey)
72
73 w.Header().Set("Content-Type", "application/json")
74 w.WriteHeader(http.StatusOK)
75 json.NewEncoder(w).Encode(response)
76}
77
78func (o *OAuthHandler) login(w http.ResponseWriter, r *http.Request) {
79 switch r.Method {
80 case http.MethodGet:
81 o.Pages.Login(w, pages.LoginParams{})
82 case http.MethodPost:
83 handle := strings.TrimPrefix(r.FormValue("handle"), "@")
84
85 resolved, err := o.Resolver.ResolveIdent(r.Context(), handle)
86 if err != nil {
87 log.Println("failed to resolve handle:", err)
88 o.Pages.Notice(w, "login-msg", fmt.Sprintf("\"%s\" is an invalid handle.", handle))
89 return
90 }
91 self := o.OAuth.ClientMetadata()
92 oauthClient, err := client.NewClient(
93 self.ClientID,
94 o.Config.OAuth.Jwks,
95 self.RedirectURIs[0],
96 )
97
98 if err != nil {
99 log.Println("failed to create oauth client:", err)
100 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
101 return
102 }
103
104 authServer, err := oauthClient.ResolvePdsAuthServer(r.Context(), resolved.PDSEndpoint())
105 if err != nil {
106 log.Println("failed to resolve auth server:", err)
107 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
108 return
109 }
110
111 authMeta, err := oauthClient.FetchAuthServerMetadata(r.Context(), authServer)
112 if err != nil {
113 log.Println("failed to fetch auth server metadata:", err)
114 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
115 return
116 }
117
118 dpopKey, err := helpers.GenerateKey(nil)
119 if err != nil {
120 log.Println("failed to generate dpop key:", err)
121 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
122 return
123 }
124
125 dpopKeyJson, err := json.Marshal(dpopKey)
126 if err != nil {
127 log.Println("failed to marshal dpop key:", err)
128 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
129 return
130 }
131
132 parResp, err := oauthClient.SendParAuthRequest(r.Context(), authServer, authMeta, handle, oauthScope, dpopKey)
133 if err != nil {
134 log.Println("failed to send par auth request:", err)
135 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
136 return
137 }
138
139 err = db.SaveOAuthRequest(o.Db, db.OAuthRequest{
140 Did: resolved.DID.String(),
141 PdsUrl: resolved.PDSEndpoint(),
142 Handle: handle,
143 AuthserverIss: authMeta.Issuer,
144 PkceVerifier: parResp.PkceVerifier,
145 DpopAuthserverNonce: parResp.DpopAuthserverNonce,
146 DpopPrivateJwk: string(dpopKeyJson),
147 State: parResp.State,
148 })
149 if err != nil {
150 log.Println("failed to save oauth request:", err)
151 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
152 return
153 }
154
155 u, _ := url.Parse(authMeta.AuthorizationEndpoint)
156 query := url.Values{}
157 query.Add("client_id", self.ClientID)
158 query.Add("request_uri", parResp.RequestUri)
159 u.RawQuery = query.Encode()
160 o.Pages.HxRedirect(w, u.String())
161 }
162}
163
164func (o *OAuthHandler) callback(w http.ResponseWriter, r *http.Request) {
165 state := r.FormValue("state")
166
167 oauthRequest, err := db.GetOAuthRequestByState(o.Db, state)
168 if err != nil {
169 log.Println("failed to get oauth request:", err)
170 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
171 return
172 }
173
174 defer func() {
175 err := db.DeleteOAuthRequestByState(o.Db, state)
176 if err != nil {
177 log.Println("failed to delete oauth request for state:", state, err)
178 }
179 }()
180
181 error := r.FormValue("error")
182 errorDescription := r.FormValue("error_description")
183 if error != "" || errorDescription != "" {
184 log.Printf("error: %s, %s", error, errorDescription)
185 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
186 return
187 }
188
189 code := r.FormValue("code")
190 if code == "" {
191 log.Println("missing code for state: ", state)
192 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
193 return
194 }
195
196 iss := r.FormValue("iss")
197 if iss == "" {
198 log.Println("missing iss for state: ", state)
199 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
200 return
201 }
202
203 self := o.OAuth.ClientMetadata()
204
205 oauthClient, err := client.NewClient(
206 self.ClientID,
207 o.Config.OAuth.Jwks,
208 self.RedirectURIs[0],
209 )
210
211 if err != nil {
212 log.Println("failed to create oauth client:", err)
213 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
214 return
215 }
216
217 jwk, err := helpers.ParseJWKFromBytes([]byte(oauthRequest.DpopPrivateJwk))
218 if err != nil {
219 log.Println("failed to parse jwk:", err)
220 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
221 return
222 }
223
224 tokenResp, err := oauthClient.InitialTokenRequest(
225 r.Context(),
226 code,
227 oauthRequest.AuthserverIss,
228 oauthRequest.PkceVerifier,
229 oauthRequest.DpopAuthserverNonce,
230 jwk,
231 )
232 if err != nil {
233 log.Println("failed to get token:", err)
234 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
235 return
236 }
237
238 if tokenResp.Scope != oauthScope {
239 log.Println("scope doesn't match:", tokenResp.Scope)
240 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
241 return
242 }
243
244 err = o.OAuth.SaveSession(w, r, oauthRequest, tokenResp)
245 if err != nil {
246 log.Println("failed to save session:", err)
247 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
248 return
249 }
250
251 log.Println("session saved successfully")
252 go o.addToDefaultKnot(oauthRequest.Did)
253
254 if !o.Config.Core.Dev {
255 err = o.Posthog.Enqueue(posthog.Capture{
256 DistinctId: oauthRequest.Did,
257 Event: "signin",
258 })
259 if err != nil {
260 log.Println("failed to enqueue posthog event:", err)
261 }
262 }
263
264 http.Redirect(w, r, "/", http.StatusFound)
265}
266
267func (o *OAuthHandler) logout(w http.ResponseWriter, r *http.Request) {
268 err := o.OAuth.ClearSession(r, w)
269 if err != nil {
270 log.Println("failed to clear session:", err)
271 http.Redirect(w, r, "/", http.StatusFound)
272 return
273 }
274
275 log.Println("session cleared successfully")
276 http.Redirect(w, r, "/", http.StatusFound)
277}
278
279func pubKeyFromJwk(jwks string) (jwk.Key, error) {
280 k, err := helpers.ParseJWKFromBytes([]byte(jwks))
281 if err != nil {
282 return nil, err
283 }
284 pubKey, err := k.PublicKey()
285 if err != nil {
286 return nil, err
287 }
288 return pubKey, nil
289}
290
291func (o *OAuthHandler) addToDefaultKnot(did string) {
292 defaultKnot := "knot1.tangled.sh"
293
294 log.Printf("adding %s to default knot", did)
295 err := o.Enforcer.AddMember(defaultKnot, did)
296 if err != nil {
297 log.Println("failed to add user to knot1.tangled.sh: ", err)
298 return
299 }
300 err = o.Enforcer.E.SavePolicy()
301 if err != nil {
302 log.Println("failed to add user to knot1.tangled.sh: ", err)
303 return
304 }
305
306 secret, err := db.GetRegistrationKey(o.Db, defaultKnot)
307 if err != nil {
308 log.Println("failed to get registration key for knot1.tangled.sh")
309 return
310 }
311 signedClient, err := knotclient.NewSignedClient(defaultKnot, secret, o.Config.Core.Dev)
312 resp, err := signedClient.AddMember(did)
313 if err != nil {
314 log.Println("failed to add user to knot1.tangled.sh: ", err)
315 return
316 }
317
318 if resp.StatusCode != http.StatusNoContent {
319 log.Println("failed to add user to knot1.tangled.sh: ", resp.StatusCode)
320 return
321 }
322}