this repo has no description
1package main
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "html/template"
8 "io"
9 "log/slog"
10 "net"
11 "net/http"
12 "net/url"
13 "os"
14 "strings"
15
16 "github.com/bluesky-social/indigo/api/atproto"
17 "github.com/bluesky-social/indigo/api/bsky"
18 "github.com/bluesky-social/indigo/atproto/syntax"
19 "github.com/bluesky-social/indigo/lex/util"
20 "github.com/bluesky-social/indigo/xrpc"
21 "github.com/gorilla/sessions"
22 oauth "github.com/haileyok/atproto-oauth-golang"
23 _ "github.com/joho/godotenv/autoload"
24 "github.com/labstack/echo-contrib/session"
25 "github.com/labstack/echo/v4"
26 "github.com/lestrrat-go/jwx/v2/jwk"
27 slogecho "github.com/samber/slog-echo"
28 "github.com/urfave/cli/v2"
29 "gorm.io/driver/sqlite"
30 "gorm.io/gorm"
31 "gorm.io/gorm/clause"
32)
33
34var (
35 ctx = context.Background()
36 serverAddr = os.Getenv("OAUTH_TEST_SERVER_ADDR")
37 serverUrlRoot = os.Getenv("OAUTH_TEST_SERVER_URL_ROOT")
38 staticFilePath = os.Getenv("OAUTH_TEST_SERVER_STATIC_PATH")
39 sessionSecret = os.Getenv("OAUTH_TEST_SESSION_SECRET")
40 serverMetadataUrl = fmt.Sprintf("%s/oauth/client-metadata.json", serverUrlRoot)
41 serverCallbackUrl = fmt.Sprintf("%s/callback", serverUrlRoot)
42 pdsUrl = os.Getenv("OAUTH_TEST_PDS_URL")
43 scope = "atproto transition:generic"
44)
45
46func main() {
47 app := &cli.App{
48 Name: "atproto-oauth-golang-tester",
49 Action: run,
50 }
51
52 if serverUrlRoot == "" {
53 panic(fmt.Errorf("no server url root set in env file"))
54 }
55
56 app.RunAndExitOnError()
57}
58
59type TestServer struct {
60 httpd *http.Server
61 e *echo.Echo
62 db *gorm.DB
63 oauthClient *oauth.OauthClient
64 xrpcCli *oauth.XrpcClient
65 jwksResponse *oauth.JwksResponseObject
66}
67
68type TemplateRenderer struct {
69 templates *template.Template
70}
71
72func (t *TemplateRenderer) Render(w io.Writer, name string, data interface{}, c echo.Context) error {
73 if viewContext, isMap := data.(map[string]interface{}); isMap {
74 viewContext["reverse"] = c.Echo().Reverse
75 }
76
77 return t.templates.ExecuteTemplate(w, name, data)
78}
79
80func run(cmd *cli.Context) error {
81 s, err := NewServer()
82 if err != nil {
83 panic(err)
84 }
85
86 s.run()
87
88 return nil
89}
90
91func NewServer() (*TestServer, error) {
92 e := echo.New()
93
94 e.Use(slogecho.New(slog.Default()))
95 e.Use(session.Middleware(sessions.NewCookieStore([]byte(sessionSecret))))
96
97 renderer := &TemplateRenderer{
98 templates: template.Must(template.ParseGlob(getFilePath("*.html"))),
99 }
100 e.Renderer = renderer
101
102 fmt.Println("atproto oauth golang tester server")
103
104 b, err := os.ReadFile("./jwks.json")
105 if err != nil {
106 if os.IsNotExist(err) {
107 return nil, fmt.Errorf(
108 "could not find jwks.json. does it exist? hint: run `go run ./cmd/cmd generate-jwks --prefix demo` to create one.",
109 )
110 }
111 return nil, err
112 }
113
114 k, err := jwk.ParseKey(b)
115 if err != nil {
116 return nil, err
117 }
118
119 pubKey, err := k.PublicKey()
120 if err != nil {
121 return nil, err
122 }
123
124 c, err := oauth.NewOauthClient(oauth.OauthClientArgs{
125 ClientJwk: k,
126 ClientId: serverMetadataUrl,
127 RedirectUri: serverCallbackUrl,
128 })
129 if err != nil {
130 return nil, err
131 }
132
133 httpd := &http.Server{
134 Addr: serverAddr,
135 Handler: e,
136 }
137
138 db, err := gorm.Open(sqlite.Open("oauth.db"), &gorm.Config{})
139 if err != nil {
140 return nil, err
141 }
142
143 db.AutoMigrate(&OauthRequest{}, &OauthSession{})
144
145 xrpcCli := &oauth.XrpcClient{
146 OnDPoPNonceChanged: func(did, newNonce string) {
147 if err := db.Exec("UPDATE oauth_sessions SET dpop_pds_nonce = ? WHERE did = ?", newNonce, did).Error; err != nil {
148 slog.Default().Error("error updating pds nonce", "err", err)
149 }
150 },
151 }
152
153 return &TestServer{
154 httpd: httpd,
155 e: e,
156 db: db,
157 oauthClient: c,
158 xrpcCli: xrpcCli,
159 jwksResponse: oauth.CreateJwksResponseObject(pubKey),
160 }, nil
161}
162
163func (s *TestServer) run() error {
164 s.e.GET("/", s.handleHome)
165 s.e.File("/login", getFilePath("login.html"))
166 s.e.POST("/login", s.handleLoginSubmit)
167 s.e.GET("/logout", s.handleLogout)
168 s.e.GET("/make-post", s.handleMakePost)
169 s.e.GET("/callback", s.handleCallback)
170 s.e.GET("/oauth/client-metadata.json", s.handleClientMetadata)
171 s.e.GET("/oauth/jwks.json", s.handleJwks)
172
173 if err := s.httpd.ListenAndServe(); err != nil {
174 return err
175 }
176
177 return nil
178}
179
180func (s *TestServer) handleHome(e echo.Context) error {
181 sess, err := session.Get("session", e)
182 if err != nil {
183 return err
184 }
185
186 return e.Render(200, "index.html", map[string]any{
187 "Did": sess.Values["did"],
188 })
189}
190
191func (s *TestServer) handleClientMetadata(e echo.Context) error {
192 metadata := map[string]any{
193 "client_id": serverMetadataUrl,
194 "client_name": "Atproto Oauth Golang Tester",
195 "client_uri": serverUrlRoot,
196 "logo_uri": fmt.Sprintf("%s/logo.png", serverUrlRoot),
197 "tos_uri": fmt.Sprintf("%s/tos", serverUrlRoot),
198 "policy_url": fmt.Sprintf("%s/policy", serverUrlRoot),
199 "redirect_uris": []string{serverCallbackUrl},
200 "grant_types": []string{"authorization_code", "refresh_token"},
201 "response_types": []string{"code"},
202 "application_type": "web",
203 "dpop_bound_access_tokens": true,
204 "jwks_uri": fmt.Sprintf("%s/oauth/jwks.json", serverUrlRoot),
205 "scope": "atproto transition:generic",
206 "token_endpoint_auth_method": "private_key_jwt",
207 "token_endpoint_auth_signing_alg": "ES256",
208 }
209
210 return e.JSON(200, metadata)
211}
212
213func (s *TestServer) handleJwks(e echo.Context) error {
214 return e.JSON(200, s.jwksResponse)
215}
216
217func (s *TestServer) handleLoginSubmit(e echo.Context) error {
218 handle := e.FormValue("handle")
219 if handle == "" {
220 return e.Redirect(302, "/login?e=handle-empty")
221 }
222
223 _, herr := syntax.ParseHandle(handle)
224 _, derr := syntax.ParseDID(handle)
225
226 if herr != nil && derr != nil {
227 return e.Redirect(302, "/login?e=handle-invalid")
228 }
229
230 var did string
231
232 if derr == nil {
233 did = handle
234 } else {
235 maybeDid, err := resolveHandle(e.Request().Context(), handle)
236 if err != nil {
237 return err
238 }
239
240 did = maybeDid
241 }
242
243 service, err := resolveService(ctx, did)
244 if err != nil {
245 return err
246 }
247
248 authserver, err := s.oauthClient.ResolvePDSAuthServer(ctx, service)
249 if err != nil {
250 return err
251 }
252
253 meta, err := s.oauthClient.FetchAuthServerMetadata(ctx, authserver)
254 if err != nil {
255 return err
256 }
257
258 dpopPrivateKey, err := oauth.GenerateKey(nil)
259 if err != nil {
260 return err
261 }
262
263 dpopPrivateKeyJson, err := json.Marshal(dpopPrivateKey)
264 if err != nil {
265 return err
266 }
267
268 parResp, err := s.oauthClient.SendParAuthRequest(
269 ctx,
270 authserver,
271 meta,
272 "",
273 scope,
274 dpopPrivateKey,
275 )
276
277 oauthRequest := &OauthRequest{
278 State: parResp.State,
279 AuthserverIss: meta.Issuer,
280 Did: did,
281 PdsUrl: service,
282 PkceVerifier: parResp.PkceVerifier,
283 DpopAuthserverNonce: parResp.DpopAuthserverNonce,
284 DpopPrivateJwk: string(dpopPrivateKeyJson),
285 }
286
287 if err := s.db.Create(oauthRequest).Error; err != nil {
288 return err
289 }
290
291 u, _ := url.Parse(meta.AuthorizationEndpoint)
292 u.RawQuery = fmt.Sprintf(
293 "client_id=%s&request_uri=%s",
294 url.QueryEscape(serverMetadataUrl),
295 parResp.Resp["request_uri"].(string),
296 )
297
298 sess, err := session.Get("session", e)
299 if err != nil {
300 return err
301 }
302
303 sess.Options = &sessions.Options{
304 Path: "/",
305 MaxAge: 300, // save for five minutes
306 HttpOnly: true,
307 }
308
309 // make sure the session is empty
310 sess.Values = map[interface{}]interface{}{}
311 sess.Values["oauth_state"] = parResp.State
312 sess.Values["oauth_did"] = did
313
314 if err := sess.Save(e.Request(), e.Response()); err != nil {
315 return err
316 }
317
318 return e.Redirect(302, u.String())
319}
320
321func (s *TestServer) handleCallback(e echo.Context) error {
322 resState := e.QueryParam("state")
323 resIss := e.QueryParam("iss")
324 resCode := e.QueryParam("code")
325
326 sess, err := session.Get("session", e)
327 if err != nil {
328 return err
329 }
330
331 sessState := sess.Values["oauth_state"]
332 sessDid := sess.Values["oauth_did"]
333
334 if resState == "" || resIss == "" || resCode == "" || sessState == "" || sessDid == "" {
335 return fmt.Errorf("request missing needed parameters")
336 }
337
338 if resState != sessState {
339 return fmt.Errorf("session state does not match response state")
340 }
341
342 var oauthRequest OauthRequest
343 if err := s.db.Raw("SELECT * FROM oauth_requests WHERE state = ? AND did = ?", sessState, sessDid).Scan(&oauthRequest).Error; err != nil {
344 return err
345 }
346
347 if err := s.db.Exec("DELETE FROM oauth_requests WHERE state = ? AND did = ?", sessState, sessDid).Error; err != nil {
348 return err
349 }
350
351 if resIss != oauthRequest.AuthserverIss {
352 return fmt.Errorf("incoming iss did not match authserver iss")
353 }
354
355 jwk, err := oauth.ParseKeyFromBytes([]byte(oauthRequest.DpopPrivateJwk))
356 if err != nil {
357 return err
358 }
359
360 initialTokenResp, err := s.oauthClient.InitialTokenRequest(
361 e.Request().Context(),
362 resCode,
363 resIss,
364 resIss,
365 oauthRequest.PkceVerifier,
366 oauthRequest.DpopAuthserverNonce,
367 jwk,
368 )
369 if err != nil {
370 return err
371 }
372
373 // TODO: resolve if needed
374
375 if initialTokenResp.Resp["scope"] != scope {
376 return fmt.Errorf("did not receive correct scopes from token request")
377 }
378
379 oauthSession := &OauthSession{
380 Did: oauthRequest.Did,
381 PdsUrl: oauthRequest.PdsUrl,
382 AuthserverIss: oauthRequest.AuthserverIss,
383 AccessToken: initialTokenResp.Resp["access_token"].(string),
384 RefreshToken: initialTokenResp.Resp["refresh_token"].(string),
385 DpopAuthserverNonce: initialTokenResp.DpopAuthserverNonce,
386 DpopPrivateJwk: oauthRequest.DpopPrivateJwk,
387 }
388
389 if err := s.db.Clauses(clause.OnConflict{
390 Columns: []clause.Column{{Name: "did"}},
391 UpdateAll: true,
392 }).Create(oauthSession).Error; err != nil {
393 return err
394 }
395
396 sess.Options = &sessions.Options{
397 Path: "/",
398 MaxAge: 86400 * 7,
399 HttpOnly: true,
400 }
401
402 // make sure the session is empty
403 sess.Values = map[interface{}]interface{}{}
404 sess.Values["did"] = oauthRequest.Did
405
406 if err := sess.Save(e.Request(), e.Response()); err != nil {
407 return err
408 }
409
410 return e.Redirect(302, "/")
411}
412
413func (s *TestServer) handleLogout(e echo.Context) error {
414 sess, err := session.Get("session", e)
415 if err != nil {
416 return err
417 }
418
419 sess.Options = &sessions.Options{
420 Path: "/",
421 MaxAge: -1,
422 HttpOnly: true,
423 }
424
425 if err := sess.Save(e.Request(), e.Response()); err != nil {
426 return err
427 }
428
429 return e.Redirect(302, "/")
430}
431
432func (s *TestServer) handleMakePost(e echo.Context) error {
433 sess, err := session.Get("session", e)
434 if err != nil {
435 return err
436 }
437
438 did, ok := sess.Values["did"]
439 if !ok {
440 return e.Redirect(302, "/login")
441 }
442
443 var oauthSession OauthSession
444 if err := s.db.Raw("SELECT * FROM oauth_sessions WHERE did = ?", did).Scan(&oauthSession).Error; err != nil {
445 return err
446 }
447
448 args, err := authedReqArgsFromSession(&oauthSession)
449 if err != nil {
450 return err
451 }
452
453 post := bsky.FeedPost{
454 Text: "hello from atproto golang oauth client",
455 CreatedAt: syntax.DatetimeNow().String(),
456 }
457
458 input := atproto.RepoCreateRecord_Input{
459 Collection: "app.bsky.feed.post",
460 Repo: oauthSession.Did,
461 Record: &util.LexiconTypeDecoder{Val: &post},
462 }
463
464 var out atproto.RepoCreateRecord_Output
465 if err := s.xrpcCli.Do(e.Request().Context(), args, xrpc.Procedure, "application/json", "com.atproto.repo.createRecord", nil, input, &out); err != nil {
466 return err
467 }
468
469 return e.File(getFilePath("make-post.html"))
470}
471
472func authedReqArgsFromSession(session *OauthSession) (*oauth.XrpcAuthedRequestArgs, error) {
473 privateJwk, err := oauth.ParseKeyFromBytes([]byte(session.DpopPrivateJwk))
474 if err != nil {
475 return nil, err
476 }
477
478 return &oauth.XrpcAuthedRequestArgs{
479 Did: session.Did,
480 AccessToken: session.AccessToken,
481 PdsUrl: session.PdsUrl,
482 Issuer: session.AuthserverIss,
483 DpopPdsNonce: session.DpopPdsNonce,
484 DpopPrivateJwk: privateJwk,
485 }, nil
486}
487
488func resolveHandle(ctx context.Context, handle string) (string, error) {
489 var did string
490
491 _, err := syntax.ParseHandle(handle)
492 if err != nil {
493 return "", err
494 }
495
496 recs, err := net.LookupTXT(fmt.Sprintf("_atproto.%s", handle))
497 if err != nil {
498 return "", err
499 }
500
501 for _, rec := range recs {
502 if strings.HasPrefix(rec, "did=") {
503 did = strings.Split(rec, "did=")[1]
504 break
505 }
506 }
507
508 if did == "" {
509 req, err := http.NewRequestWithContext(
510 ctx,
511 "GET",
512 fmt.Sprintf("https://%s/.well-known/atproto-did", handle),
513 nil,
514 )
515 if err != nil {
516 return "", err
517 }
518
519 resp, err := http.DefaultClient.Do(req)
520 if err != nil {
521 return "", err
522 }
523 defer resp.Body.Close()
524
525 if resp.StatusCode != http.StatusOK {
526 io.Copy(io.Discard, resp.Body)
527 return "", fmt.Errorf("unable to resolve handle")
528 }
529
530 b, err := io.ReadAll(resp.Body)
531 if err != nil {
532 return "", err
533 }
534
535 maybeDid := string(b)
536
537 if _, err := syntax.ParseDID(maybeDid); err != nil {
538 return "", fmt.Errorf("unable to resolve handle")
539 }
540
541 did = maybeDid
542 }
543
544 return did, nil
545}
546
547func resolveService(ctx context.Context, did string) (string, error) {
548 type Identity struct {
549 Service []struct {
550 ID string `json:"id"`
551 Type string `json:"type"`
552 ServiceEndpoint string `json:"serviceEndpoint"`
553 } `json:"service"`
554 }
555
556 var ustr string
557 if strings.HasPrefix(did, "did:plc:") {
558 ustr = fmt.Sprintf("https://plc.directory/%s", did)
559 } else if strings.HasPrefix(did, "did:web:") {
560 ustr = fmt.Sprintf("https://%s/.well-known/did.json", did)
561 } else {
562 return "", fmt.Errorf("did was not a supported did type")
563 }
564
565 req, err := http.NewRequestWithContext(ctx, "GET", ustr, nil)
566 if err != nil {
567 return "", err
568 }
569
570 resp, err := http.DefaultClient.Do(req)
571 if err != nil {
572 return "", err
573 }
574 defer resp.Body.Close()
575
576 if resp.StatusCode != 200 {
577 io.Copy(io.Discard, resp.Body)
578 return "", fmt.Errorf("could not find identity in plc registry")
579 }
580
581 var identity Identity
582 if err := json.NewDecoder(resp.Body).Decode(&identity); err != nil {
583 return "", err
584 }
585
586 var service string
587 for _, svc := range identity.Service {
588 if svc.ID == "#atproto_pds" {
589 service = svc.ServiceEndpoint
590 }
591 }
592
593 if service == "" {
594 return "", fmt.Errorf("could not find atproto_pds service in identity services")
595 }
596
597 return service, nil
598}
599
600func getFilePath(file string) string {
601 return fmt.Sprintf("%s/%s", staticFilePath, file)
602}