this repo has no description
1package main 2 3import ( 4 "context" 5 "encoding/json" 6 "fmt" 7 "html/template" 8 "io" 9 "log/slog" 10 "net" 11 "net/http" 12 "net/url" 13 "os" 14 "strings" 15 16 "github.com/bluesky-social/indigo/api/atproto" 17 "github.com/bluesky-social/indigo/api/bsky" 18 "github.com/bluesky-social/indigo/atproto/syntax" 19 "github.com/bluesky-social/indigo/lex/util" 20 "github.com/bluesky-social/indigo/xrpc" 21 "github.com/gorilla/sessions" 22 oauth "github.com/haileyok/atproto-oauth-golang" 23 _ "github.com/joho/godotenv/autoload" 24 "github.com/labstack/echo-contrib/session" 25 "github.com/labstack/echo/v4" 26 "github.com/lestrrat-go/jwx/v2/jwk" 27 slogecho "github.com/samber/slog-echo" 28 "github.com/urfave/cli/v2" 29 "gorm.io/driver/sqlite" 30 "gorm.io/gorm" 31 "gorm.io/gorm/clause" 32) 33 34var ( 35 ctx = context.Background() 36 serverAddr = os.Getenv("OAUTH_TEST_SERVER_ADDR") 37 serverUrlRoot = os.Getenv("OAUTH_TEST_SERVER_URL_ROOT") 38 staticFilePath = os.Getenv("OAUTH_TEST_SERVER_STATIC_PATH") 39 sessionSecret = os.Getenv("OAUTH_TEST_SESSION_SECRET") 40 serverMetadataUrl = fmt.Sprintf("%s/oauth/client-metadata.json", serverUrlRoot) 41 serverCallbackUrl = fmt.Sprintf("%s/callback", serverUrlRoot) 42 pdsUrl = os.Getenv("OAUTH_TEST_PDS_URL") 43 scope = "atproto transition:generic" 44) 45 46func main() { 47 app := &cli.App{ 48 Name: "atproto-oauth-golang-tester", 49 Action: run, 50 } 51 52 if serverUrlRoot == "" { 53 panic(fmt.Errorf("no server url root set in env file")) 54 } 55 56 app.RunAndExitOnError() 57} 58 59type TestServer struct { 60 httpd *http.Server 61 e *echo.Echo 62 db *gorm.DB 63 oauthClient *oauth.OauthClient 64 xrpcCli *oauth.XrpcClient 65 jwksResponse *oauth.JwksResponseObject 66} 67 68type TemplateRenderer struct { 69 templates *template.Template 70} 71 72func (t *TemplateRenderer) Render(w io.Writer, name string, data interface{}, c echo.Context) error { 73 if viewContext, isMap := data.(map[string]interface{}); isMap { 74 viewContext["reverse"] = c.Echo().Reverse 75 } 76 77 return t.templates.ExecuteTemplate(w, name, data) 78} 79 80func run(cmd *cli.Context) error { 81 s, err := NewServer() 82 if err != nil { 83 panic(err) 84 } 85 86 s.run() 87 88 return nil 89} 90 91func NewServer() (*TestServer, error) { 92 e := echo.New() 93 94 e.Use(slogecho.New(slog.Default())) 95 e.Use(session.Middleware(sessions.NewCookieStore([]byte(sessionSecret)))) 96 97 renderer := &TemplateRenderer{ 98 templates: template.Must(template.ParseGlob(getFilePath("*.html"))), 99 } 100 e.Renderer = renderer 101 102 fmt.Println("atproto oauth golang tester server") 103 104 b, err := os.ReadFile("./jwks.json") 105 if err != nil { 106 if os.IsNotExist(err) { 107 return nil, fmt.Errorf( 108 "could not find jwks.json. does it exist? hint: run `go run ./cmd/cmd generate-jwks --prefix demo` to create one.", 109 ) 110 } 111 return nil, err 112 } 113 114 k, err := jwk.ParseKey(b) 115 if err != nil { 116 return nil, err 117 } 118 119 pubKey, err := k.PublicKey() 120 if err != nil { 121 return nil, err 122 } 123 124 c, err := oauth.NewOauthClient(oauth.OauthClientArgs{ 125 ClientJwk: k, 126 ClientId: serverMetadataUrl, 127 RedirectUri: serverCallbackUrl, 128 }) 129 if err != nil { 130 return nil, err 131 } 132 133 httpd := &http.Server{ 134 Addr: serverAddr, 135 Handler: e, 136 } 137 138 db, err := gorm.Open(sqlite.Open("oauth.db"), &gorm.Config{}) 139 if err != nil { 140 return nil, err 141 } 142 143 db.AutoMigrate(&OauthRequest{}, &OauthSession{}) 144 145 xrpcCli := &oauth.XrpcClient{ 146 OnDPoPNonceChanged: func(did, newNonce string) { 147 if err := db.Exec("UPDATE oauth_sessions SET dpop_pds_nonce = ? WHERE did = ?", newNonce, did).Error; err != nil { 148 slog.Default().Error("error updating pds nonce", "err", err) 149 } 150 }, 151 } 152 153 return &TestServer{ 154 httpd: httpd, 155 e: e, 156 db: db, 157 oauthClient: c, 158 xrpcCli: xrpcCli, 159 jwksResponse: oauth.CreateJwksResponseObject(pubKey), 160 }, nil 161} 162 163func (s *TestServer) run() error { 164 s.e.GET("/", s.handleHome) 165 s.e.File("/login", getFilePath("login.html")) 166 s.e.POST("/login", s.handleLoginSubmit) 167 s.e.GET("/logout", s.handleLogout) 168 s.e.GET("/make-post", s.handleMakePost) 169 s.e.GET("/callback", s.handleCallback) 170 s.e.GET("/oauth/client-metadata.json", s.handleClientMetadata) 171 s.e.GET("/oauth/jwks.json", s.handleJwks) 172 173 if err := s.httpd.ListenAndServe(); err != nil { 174 return err 175 } 176 177 return nil 178} 179 180func (s *TestServer) handleHome(e echo.Context) error { 181 sess, err := session.Get("session", e) 182 if err != nil { 183 return err 184 } 185 186 return e.Render(200, "index.html", map[string]any{ 187 "Did": sess.Values["did"], 188 }) 189} 190 191func (s *TestServer) handleClientMetadata(e echo.Context) error { 192 metadata := map[string]any{ 193 "client_id": serverMetadataUrl, 194 "client_name": "Atproto Oauth Golang Tester", 195 "client_uri": serverUrlRoot, 196 "logo_uri": fmt.Sprintf("%s/logo.png", serverUrlRoot), 197 "tos_uri": fmt.Sprintf("%s/tos", serverUrlRoot), 198 "policy_url": fmt.Sprintf("%s/policy", serverUrlRoot), 199 "redirect_uris": []string{serverCallbackUrl}, 200 "grant_types": []string{"authorization_code", "refresh_token"}, 201 "response_types": []string{"code"}, 202 "application_type": "web", 203 "dpop_bound_access_tokens": true, 204 "jwks_uri": fmt.Sprintf("%s/oauth/jwks.json", serverUrlRoot), 205 "scope": "atproto transition:generic", 206 "token_endpoint_auth_method": "private_key_jwt", 207 "token_endpoint_auth_signing_alg": "ES256", 208 } 209 210 return e.JSON(200, metadata) 211} 212 213func (s *TestServer) handleJwks(e echo.Context) error { 214 return e.JSON(200, s.jwksResponse) 215} 216 217func (s *TestServer) handleLoginSubmit(e echo.Context) error { 218 handle := e.FormValue("handle") 219 if handle == "" { 220 return e.Redirect(302, "/login?e=handle-empty") 221 } 222 223 _, herr := syntax.ParseHandle(handle) 224 _, derr := syntax.ParseDID(handle) 225 226 if herr != nil && derr != nil { 227 return e.Redirect(302, "/login?e=handle-invalid") 228 } 229 230 var did string 231 232 if derr == nil { 233 did = handle 234 } else { 235 maybeDid, err := resolveHandle(e.Request().Context(), handle) 236 if err != nil { 237 return err 238 } 239 240 did = maybeDid 241 } 242 243 service, err := resolveService(ctx, did) 244 if err != nil { 245 return err 246 } 247 248 authserver, err := s.oauthClient.ResolvePDSAuthServer(ctx, service) 249 if err != nil { 250 return err 251 } 252 253 meta, err := s.oauthClient.FetchAuthServerMetadata(ctx, authserver) 254 if err != nil { 255 return err 256 } 257 258 dpopPrivateKey, err := oauth.GenerateKey(nil) 259 if err != nil { 260 return err 261 } 262 263 dpopPrivateKeyJson, err := json.Marshal(dpopPrivateKey) 264 if err != nil { 265 return err 266 } 267 268 parResp, err := s.oauthClient.SendParAuthRequest( 269 ctx, 270 authserver, 271 meta, 272 "", 273 scope, 274 dpopPrivateKey, 275 ) 276 277 oauthRequest := &OauthRequest{ 278 State: parResp.State, 279 AuthserverIss: meta.Issuer, 280 Did: did, 281 PdsUrl: service, 282 PkceVerifier: parResp.PkceVerifier, 283 DpopAuthserverNonce: parResp.DpopAuthserverNonce, 284 DpopPrivateJwk: string(dpopPrivateKeyJson), 285 } 286 287 if err := s.db.Create(oauthRequest).Error; err != nil { 288 return err 289 } 290 291 u, _ := url.Parse(meta.AuthorizationEndpoint) 292 u.RawQuery = fmt.Sprintf( 293 "client_id=%s&request_uri=%s", 294 url.QueryEscape(serverMetadataUrl), 295 parResp.Resp["request_uri"].(string), 296 ) 297 298 sess, err := session.Get("session", e) 299 if err != nil { 300 return err 301 } 302 303 sess.Options = &sessions.Options{ 304 Path: "/", 305 MaxAge: 300, // save for five minutes 306 HttpOnly: true, 307 } 308 309 // make sure the session is empty 310 sess.Values = map[interface{}]interface{}{} 311 sess.Values["oauth_state"] = parResp.State 312 sess.Values["oauth_did"] = did 313 314 if err := sess.Save(e.Request(), e.Response()); err != nil { 315 return err 316 } 317 318 return e.Redirect(302, u.String()) 319} 320 321func (s *TestServer) handleCallback(e echo.Context) error { 322 resState := e.QueryParam("state") 323 resIss := e.QueryParam("iss") 324 resCode := e.QueryParam("code") 325 326 sess, err := session.Get("session", e) 327 if err != nil { 328 return err 329 } 330 331 sessState := sess.Values["oauth_state"] 332 sessDid := sess.Values["oauth_did"] 333 334 if resState == "" || resIss == "" || resCode == "" || sessState == "" || sessDid == "" { 335 return fmt.Errorf("request missing needed parameters") 336 } 337 338 if resState != sessState { 339 return fmt.Errorf("session state does not match response state") 340 } 341 342 var oauthRequest OauthRequest 343 if err := s.db.Raw("SELECT * FROM oauth_requests WHERE state = ? AND did = ?", sessState, sessDid).Scan(&oauthRequest).Error; err != nil { 344 return err 345 } 346 347 if err := s.db.Exec("DELETE FROM oauth_requests WHERE state = ? AND did = ?", sessState, sessDid).Error; err != nil { 348 return err 349 } 350 351 if resIss != oauthRequest.AuthserverIss { 352 return fmt.Errorf("incoming iss did not match authserver iss") 353 } 354 355 jwk, err := oauth.ParseKeyFromBytes([]byte(oauthRequest.DpopPrivateJwk)) 356 if err != nil { 357 return err 358 } 359 360 initialTokenResp, err := s.oauthClient.InitialTokenRequest( 361 e.Request().Context(), 362 resCode, 363 resIss, 364 resIss, 365 oauthRequest.PkceVerifier, 366 oauthRequest.DpopAuthserverNonce, 367 jwk, 368 ) 369 if err != nil { 370 return err 371 } 372 373 // TODO: resolve if needed 374 375 if initialTokenResp.Resp["scope"] != scope { 376 return fmt.Errorf("did not receive correct scopes from token request") 377 } 378 379 oauthSession := &OauthSession{ 380 Did: oauthRequest.Did, 381 PdsUrl: oauthRequest.PdsUrl, 382 AuthserverIss: oauthRequest.AuthserverIss, 383 AccessToken: initialTokenResp.Resp["access_token"].(string), 384 RefreshToken: initialTokenResp.Resp["refresh_token"].(string), 385 DpopAuthserverNonce: initialTokenResp.DpopAuthserverNonce, 386 DpopPrivateJwk: oauthRequest.DpopPrivateJwk, 387 } 388 389 if err := s.db.Clauses(clause.OnConflict{ 390 Columns: []clause.Column{{Name: "did"}}, 391 UpdateAll: true, 392 }).Create(oauthSession).Error; err != nil { 393 return err 394 } 395 396 sess.Options = &sessions.Options{ 397 Path: "/", 398 MaxAge: 86400 * 7, 399 HttpOnly: true, 400 } 401 402 // make sure the session is empty 403 sess.Values = map[interface{}]interface{}{} 404 sess.Values["did"] = oauthRequest.Did 405 406 if err := sess.Save(e.Request(), e.Response()); err != nil { 407 return err 408 } 409 410 return e.Redirect(302, "/") 411} 412 413func (s *TestServer) handleLogout(e echo.Context) error { 414 sess, err := session.Get("session", e) 415 if err != nil { 416 return err 417 } 418 419 sess.Options = &sessions.Options{ 420 Path: "/", 421 MaxAge: -1, 422 HttpOnly: true, 423 } 424 425 if err := sess.Save(e.Request(), e.Response()); err != nil { 426 return err 427 } 428 429 return e.Redirect(302, "/") 430} 431 432func (s *TestServer) handleMakePost(e echo.Context) error { 433 sess, err := session.Get("session", e) 434 if err != nil { 435 return err 436 } 437 438 did, ok := sess.Values["did"] 439 if !ok { 440 return e.Redirect(302, "/login") 441 } 442 443 var oauthSession OauthSession 444 if err := s.db.Raw("SELECT * FROM oauth_sessions WHERE did = ?", did).Scan(&oauthSession).Error; err != nil { 445 return err 446 } 447 448 args, err := authedReqArgsFromSession(&oauthSession) 449 if err != nil { 450 return err 451 } 452 453 post := bsky.FeedPost{ 454 Text: "hello from atproto golang oauth client", 455 CreatedAt: syntax.DatetimeNow().String(), 456 } 457 458 input := atproto.RepoCreateRecord_Input{ 459 Collection: "app.bsky.feed.post", 460 Repo: oauthSession.Did, 461 Record: &util.LexiconTypeDecoder{Val: &post}, 462 } 463 464 var out atproto.RepoCreateRecord_Output 465 if err := s.xrpcCli.Do(e.Request().Context(), args, xrpc.Procedure, "application/json", "com.atproto.repo.createRecord", nil, input, &out); err != nil { 466 return err 467 } 468 469 return e.File(getFilePath("make-post.html")) 470} 471 472func authedReqArgsFromSession(session *OauthSession) (*oauth.XrpcAuthedRequestArgs, error) { 473 privateJwk, err := oauth.ParseKeyFromBytes([]byte(session.DpopPrivateJwk)) 474 if err != nil { 475 return nil, err 476 } 477 478 return &oauth.XrpcAuthedRequestArgs{ 479 Did: session.Did, 480 AccessToken: session.AccessToken, 481 PdsUrl: session.PdsUrl, 482 Issuer: session.AuthserverIss, 483 DpopPdsNonce: session.DpopPdsNonce, 484 DpopPrivateJwk: privateJwk, 485 }, nil 486} 487 488func resolveHandle(ctx context.Context, handle string) (string, error) { 489 var did string 490 491 _, err := syntax.ParseHandle(handle) 492 if err != nil { 493 return "", err 494 } 495 496 recs, err := net.LookupTXT(fmt.Sprintf("_atproto.%s", handle)) 497 if err != nil { 498 return "", err 499 } 500 501 for _, rec := range recs { 502 if strings.HasPrefix(rec, "did=") { 503 did = strings.Split(rec, "did=")[1] 504 break 505 } 506 } 507 508 if did == "" { 509 req, err := http.NewRequestWithContext( 510 ctx, 511 "GET", 512 fmt.Sprintf("https://%s/.well-known/atproto-did", handle), 513 nil, 514 ) 515 if err != nil { 516 return "", err 517 } 518 519 resp, err := http.DefaultClient.Do(req) 520 if err != nil { 521 return "", err 522 } 523 defer resp.Body.Close() 524 525 if resp.StatusCode != http.StatusOK { 526 io.Copy(io.Discard, resp.Body) 527 return "", fmt.Errorf("unable to resolve handle") 528 } 529 530 b, err := io.ReadAll(resp.Body) 531 if err != nil { 532 return "", err 533 } 534 535 maybeDid := string(b) 536 537 if _, err := syntax.ParseDID(maybeDid); err != nil { 538 return "", fmt.Errorf("unable to resolve handle") 539 } 540 541 did = maybeDid 542 } 543 544 return did, nil 545} 546 547func resolveService(ctx context.Context, did string) (string, error) { 548 type Identity struct { 549 Service []struct { 550 ID string `json:"id"` 551 Type string `json:"type"` 552 ServiceEndpoint string `json:"serviceEndpoint"` 553 } `json:"service"` 554 } 555 556 var ustr string 557 if strings.HasPrefix(did, "did:plc:") { 558 ustr = fmt.Sprintf("https://plc.directory/%s", did) 559 } else if strings.HasPrefix(did, "did:web:") { 560 ustr = fmt.Sprintf("https://%s/.well-known/did.json", did) 561 } else { 562 return "", fmt.Errorf("did was not a supported did type") 563 } 564 565 req, err := http.NewRequestWithContext(ctx, "GET", ustr, nil) 566 if err != nil { 567 return "", err 568 } 569 570 resp, err := http.DefaultClient.Do(req) 571 if err != nil { 572 return "", err 573 } 574 defer resp.Body.Close() 575 576 if resp.StatusCode != 200 { 577 io.Copy(io.Discard, resp.Body) 578 return "", fmt.Errorf("could not find identity in plc registry") 579 } 580 581 var identity Identity 582 if err := json.NewDecoder(resp.Body).Decode(&identity); err != nil { 583 return "", err 584 } 585 586 var service string 587 for _, svc := range identity.Service { 588 if svc.ID == "#atproto_pds" { 589 service = svc.ServiceEndpoint 590 } 591 } 592 593 if service == "" { 594 return "", fmt.Errorf("could not find atproto_pds service in identity services") 595 } 596 597 return service, nil 598} 599 600func getFilePath(file string) string { 601 return fmt.Sprintf("%s/%s", staticFilePath, file) 602}