this repo has no description
1package main 2 3import ( 4 "context" 5 "encoding/json" 6 "fmt" 7 "io" 8 "log/slog" 9 "net" 10 "net/http" 11 "net/url" 12 "os" 13 "strings" 14 15 "github.com/bluesky-social/indigo/atproto/syntax" 16 oauth "github.com/haileyok/atproto-oauth-golang" 17 _ "github.com/joho/godotenv/autoload" 18 "github.com/labstack/echo/v4" 19 "github.com/lestrrat-go/jwx/v2/jwk" 20 slogecho "github.com/samber/slog-echo" 21 "github.com/urfave/cli/v2" 22 "gorm.io/driver/sqlite" 23 "gorm.io/gorm" 24) 25 26var ( 27 ctx = context.Background() 28 serverAddr = os.Getenv("OAUTH_TEST_SERVER_ADDR") 29 serverUrlRoot = os.Getenv("OAUTH_TEST_SERVER_URL_ROOT") 30 staticFilePath = os.Getenv("OAUTH_TEST_SERVER_STATIC_PATH") 31 serverMetadataUrl = fmt.Sprintf("%s/oauth/client-metadata.json", serverUrlRoot) 32 serverCallbackUrl = fmt.Sprintf("%s/callback", serverUrlRoot) 33 pdsUrl = os.Getenv("OAUTH_TEST_PDS_URL") 34 scope = "atproto transition:generic" 35) 36 37func main() { 38 app := &cli.App{ 39 Name: "atproto-oauth-golang-tester", 40 Action: run, 41 } 42 43 if serverUrlRoot == "" { 44 panic(fmt.Errorf("no server url root set in env file")) 45 } 46 47 app.RunAndExitOnError() 48} 49 50type TestServer struct { 51 httpd *http.Server 52 e *echo.Echo 53 db *gorm.DB 54 oauthClient *oauth.OauthClient 55 jwksResponse *oauth.JwksResponseObject 56} 57 58func run(cmd *cli.Context) error { 59 s, err := NewServer() 60 if err != nil { 61 panic(err) 62 } 63 64 s.run() 65 66 return nil 67} 68 69func NewServer() (*TestServer, error) { 70 e := echo.New() 71 72 e.Use(slogecho.New(slog.Default())) 73 74 fmt.Println("atproto oauth golang tester server") 75 76 b, err := os.ReadFile("./jwks.json") 77 if err != nil { 78 if os.IsNotExist(err) { 79 return nil, fmt.Errorf( 80 "could not find jwks.json. does it exist? hint: run `go run ./cmd/cmd generate-jwks --prefix demo` to create one.", 81 ) 82 } 83 return nil, err 84 } 85 86 k, err := jwk.ParseKey(b) 87 if err != nil { 88 return nil, err 89 } 90 91 pubKey, err := k.PublicKey() 92 if err != nil { 93 return nil, err 94 } 95 96 c, err := oauth.NewOauthClient(oauth.OauthClientArgs{ 97 ClientJwk: k, 98 ClientId: serverMetadataUrl, 99 RedirectUri: serverCallbackUrl, 100 }) 101 if err != nil { 102 return nil, err 103 } 104 105 httpd := &http.Server{ 106 Addr: serverAddr, 107 Handler: e, 108 } 109 110 db, err := gorm.Open(sqlite.Open("oauth.db"), &gorm.Config{}) 111 if err != nil { 112 return nil, err 113 } 114 115 db.AutoMigrate(&OauthRequest{}) 116 117 return &TestServer{ 118 httpd: httpd, 119 e: e, 120 db: db, 121 oauthClient: c, 122 jwksResponse: oauth.CreateJwksResponseObject(pubKey), 123 }, nil 124} 125 126func (s *TestServer) run() error { 127 s.e.File("/", s.getFilePath("index.html")) 128 s.e.File("/login", s.getFilePath("login.html")) 129 s.e.POST("/login", s.handleLoginSubmit) 130 s.e.GET("/oauth/client-metadata.json", s.handleClientMetadata) 131 s.e.GET("/oauth/jwks.json", s.handleJwks) 132 133 if err := s.httpd.ListenAndServe(); err != nil { 134 return err 135 } 136 137 return nil 138} 139 140func (s *TestServer) handleClientMetadata(e echo.Context) error { 141 metadata := map[string]any{ 142 "client_id": serverMetadataUrl, 143 "client_name": "Atproto Oauth Golang Tester", 144 "client_uri": serverUrlRoot, 145 "logo_uri": fmt.Sprintf("%s/logo.png", serverUrlRoot), 146 "tos_uri": fmt.Sprintf("%s/tos", serverUrlRoot), 147 "policy_url": fmt.Sprintf("%s/policy", serverUrlRoot), 148 "redirect_uris": []string{serverCallbackUrl}, 149 "grant_types": []string{"authorization_code", "refresh_token"}, 150 "response_types": []string{"code"}, 151 "application_type": "web", 152 "dpop_bound_access_tokens": true, 153 "jwks_uri": fmt.Sprintf("%s/oauth/jwks.json", serverUrlRoot), 154 "scope": "atproto transition:generic", 155 "token_endpoint_auth_method": "private_key_jwt", 156 "token_endpoint_auth_signing_alg": "ES256", 157 } 158 159 return e.JSON(200, metadata) 160} 161 162func (s *TestServer) handleJwks(e echo.Context) error { 163 return e.JSON(200, s.jwksResponse) 164} 165 166func (s *TestServer) handleLoginSubmit(e echo.Context) error { 167 handle := e.FormValue("handle") 168 if handle == "" { 169 return e.Redirect(302, "/login?e=handle-empty") 170 } 171 172 _, herr := syntax.ParseHandle(handle) 173 _, derr := syntax.ParseDID(handle) 174 175 if herr != nil && derr != nil { 176 return e.Redirect(302, "/login?e=handle-invalid") 177 } 178 179 var did string 180 181 if derr == nil { 182 did = handle 183 } else { 184 maybeDid, err := resolveHandle(e.Request().Context(), handle) 185 if err != nil { 186 return err 187 } 188 189 did = maybeDid 190 } 191 192 service, err := resolveService(ctx, did) 193 if err != nil { 194 return err 195 } 196 197 authserver, err := s.oauthClient.ResolvePDSAuthServer(ctx, service) 198 if err != nil { 199 return err 200 } 201 202 meta, err := s.oauthClient.FetchAuthServerMetadata(ctx, authserver) 203 if err != nil { 204 return err 205 } 206 207 dpopPrivateKey, err := oauth.GenerateKey(nil) 208 if err != nil { 209 return err 210 } 211 212 dpopPrivateKeyJson, err := json.Marshal(dpopPrivateKey) 213 if err != nil { 214 return err 215 } 216 217 parResp, err := s.oauthClient.SendParAuthRequest( 218 ctx, 219 authserver, 220 meta, 221 "", 222 scope, 223 dpopPrivateKey, 224 ) 225 226 oauthRequest := OauthRequest{ 227 State: "", 228 AuthserverIss: meta.Issuer, 229 Did: did, 230 PdsUrl: service, 231 PkceVerifier: parResp.PkceVerifier, 232 DpopAuthserverNonce: parResp.DpopAuthserverNonce, 233 DpopPrivateJwk: string(dpopPrivateKeyJson), 234 } 235 236 if err := s.db.Create(&oauthRequest).Error; err != nil { 237 return err 238 } 239 240 u, _ := url.Parse(meta.AuthorizationEndpoint) 241 u.RawQuery = fmt.Sprintf( 242 "client_id=%s&request_uri=%s", 243 url.QueryEscape(serverMetadataUrl), 244 parResp.Resp["request_uri"].(string), 245 ) 246 247 return e.Redirect(302, u.String()) 248} 249 250func resolveHandle(ctx context.Context, handle string) (string, error) { 251 var did string 252 253 _, err := syntax.ParseHandle(handle) 254 if err != nil { 255 return "", err 256 } 257 258 recs, err := net.LookupTXT(fmt.Sprintf("_atproto.%s", handle)) 259 if err != nil { 260 return "", err 261 } 262 263 for _, rec := range recs { 264 if strings.HasPrefix(rec, "did=") { 265 did = strings.Split(rec, "did=")[1] 266 break 267 } 268 } 269 270 if did == "" { 271 req, err := http.NewRequestWithContext( 272 ctx, 273 "GET", 274 fmt.Sprintf("https://%s/.well-known/atproto-did", handle), 275 nil, 276 ) 277 if err != nil { 278 return "", err 279 } 280 281 resp, err := http.DefaultClient.Do(req) 282 if err != nil { 283 return "", err 284 } 285 defer resp.Body.Close() 286 287 if resp.StatusCode != http.StatusOK { 288 io.Copy(io.Discard, resp.Body) 289 return "", fmt.Errorf("unable to resolve handle") 290 } 291 292 b, err := io.ReadAll(resp.Body) 293 if err != nil { 294 return "", err 295 } 296 297 maybeDid := string(b) 298 299 if _, err := syntax.ParseDID(maybeDid); err != nil { 300 return "", fmt.Errorf("unable to resolve handle") 301 } 302 303 did = maybeDid 304 } 305 306 // TODO: we can also support did:web here 307 308 if did == "" { 309 return "", fmt.Errorf("unable to resolve handle") 310 } 311 312 return did, nil 313} 314 315func resolveService(ctx context.Context, did string) (string, error) { 316 type Identity struct { 317 Service []struct { 318 ID string `json:"id"` 319 Type string `json:"type"` 320 ServiceEndpoint string `json:"serviceEndpoint"` 321 } `json:"service"` 322 } 323 324 if strings.HasPrefix(did, "did:plc:") { 325 req, err := http.NewRequestWithContext( 326 ctx, 327 "GET", 328 fmt.Sprintf("https://plc.directory/%s", did), 329 nil, 330 ) 331 if err != nil { 332 return "", err 333 } 334 335 resp, err := http.DefaultClient.Do(req) 336 if err != nil { 337 return "", err 338 } 339 defer resp.Body.Close() 340 341 if resp.StatusCode != 200 { 342 io.Copy(io.Discard, resp.Body) 343 return "", fmt.Errorf("could not find identity in plc registry") 344 } 345 346 var identity Identity 347 if err := json.NewDecoder(resp.Body).Decode(&identity); err != nil { 348 return "", err 349 } 350 351 var service string 352 for _, svc := range identity.Service { 353 if svc.ID == "#atproto_pds" { 354 service = svc.ServiceEndpoint 355 } 356 } 357 358 if service == "" { 359 return "", fmt.Errorf("could not find atproto_pds service in identity services") 360 } 361 362 return service, nil 363 } else if strings.HasPrefix(did, "did:web:") { 364 // TODO: needs more work 365 req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("https://%s/.well-known/did.json", did), nil) 366 if err != nil { 367 return "", err 368 } 369 370 resp, err := http.DefaultClient.Do(req) 371 if err != nil { 372 return "", err 373 } 374 defer resp.Body.Close() 375 376 if resp.StatusCode != 200 { 377 io.Copy(io.Discard, resp.Body) 378 return "", fmt.Errorf("could not find identity in plc registry") 379 } 380 381 var identity Identity 382 if err := json.NewDecoder(resp.Body).Decode(&identity); err != nil { 383 return "", err 384 } 385 386 var service string 387 for _, svc := range identity.Service { 388 if svc.ID == "#atproto_pds" { 389 service = svc.ServiceEndpoint 390 } 391 } 392 393 if service == "" { 394 return "", fmt.Errorf("could not find atproto_pds service in identity services") 395 } 396 397 return service, nil 398 } else { 399 return "", fmt.Errorf("did was not a supported did type") 400 } 401} 402 403func (s *TestServer) getFilePath(file string) string { 404 return fmt.Sprintf("%s/%s", staticFilePath, file) 405}