A community based topic aggregation platform built on atproto
1package pds 2 3import ( 4 "context" 5 "encoding/json" 6 "errors" 7 "net/http" 8 "net/http/httptest" 9 "strings" 10 "testing" 11 12 "github.com/bluesky-social/indigo/atproto/atclient" 13 "github.com/bluesky-social/indigo/atproto/auth/oauth" 14 "github.com/bluesky-social/indigo/atproto/syntax" 15) 16 17// This test suite provides comprehensive unit tests for the PDS client package. 18// 19// Coverage: 20// - All Client interface methods: 100% 21// - bearerAuth implementation: 100% 22// - Factory function input validation: 100% 23// - NewFromAccessToken: 100% 24// 25// Not covered (requires integration tests with real infrastructure): 26// - NewFromPasswordAuth success path (requires live PDS server) 27// - NewFromOAuthSession success path (requires OAuth infrastructure) 28// 29// The untested code paths involve external dependencies (PDS authentication, 30// OAuth session resumption) which are appropriately tested in E2E/integration tests. 31 32// TestClientImplementsInterface verifies that client implements the Client interface. 33func TestClientImplementsInterface(t *testing.T) { 34 var _ Client = (*client)(nil) 35} 36 37// TestBearerAuth_DoWithAuth verifies that bearerAuth correctly adds Authorization header. 38func TestBearerAuth_DoWithAuth(t *testing.T) { 39 tests := []struct { 40 name string 41 token string 42 }{ 43 { 44 name: "standard token", 45 token: "test-access-token-12345", 46 }, 47 { 48 name: "token with special characters", 49 token: "token.with.dots_and-dashes", 50 }, 51 } 52 53 for _, tt := range tests { 54 t.Run(tt.name, func(t *testing.T) { 55 // Create a test server that captures the Authorization header 56 var capturedHeader string 57 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 58 capturedHeader = r.Header.Get("Authorization") 59 w.WriteHeader(http.StatusOK) 60 })) 61 defer server.Close() 62 63 // Create bearerAuth instance 64 auth := &bearerAuth{token: tt.token} 65 66 // Create request 67 req, err := http.NewRequest(http.MethodGet, server.URL, nil) 68 if err != nil { 69 t.Fatalf("failed to create request: %v", err) 70 } 71 72 // Execute with auth 73 client := &http.Client{} 74 nsid := syntax.NSID("com.atproto.test") 75 _, err = auth.DoWithAuth(client, req, nsid) 76 if err != nil { 77 t.Fatalf("DoWithAuth failed: %v", err) 78 } 79 80 // Verify Authorization header 81 expectedHeader := "Bearer " + tt.token 82 if capturedHeader != expectedHeader { 83 t.Errorf("Authorization header = %q, want %q", capturedHeader, expectedHeader) 84 } 85 }) 86 } 87} 88 89// TestBearerAuth_ImplementsAuthMethod verifies bearerAuth implements atclient.AuthMethod. 90func TestBearerAuth_ImplementsAuthMethod(t *testing.T) { 91 var _ atclient.AuthMethod = (*bearerAuth)(nil) 92} 93 94// TestNewFromAccessToken validates factory function input validation. 95func TestNewFromAccessToken(t *testing.T) { 96 tests := []struct { 97 name string 98 host string 99 did string 100 accessToken string 101 wantErr bool 102 errContains string 103 }{ 104 { 105 name: "valid inputs", 106 host: "https://pds.example.com", 107 did: "did:plc:12345", 108 accessToken: "test-token", 109 wantErr: false, 110 }, 111 { 112 name: "empty host", 113 host: "", 114 did: "did:plc:12345", 115 accessToken: "test-token", 116 wantErr: true, 117 errContains: "host is required", 118 }, 119 { 120 name: "empty did", 121 host: "https://pds.example.com", 122 did: "", 123 accessToken: "test-token", 124 wantErr: true, 125 errContains: "did is required", 126 }, 127 { 128 name: "empty access token", 129 host: "https://pds.example.com", 130 did: "did:plc:12345", 131 accessToken: "", 132 wantErr: true, 133 errContains: "accessToken is required", 134 }, 135 { 136 name: "all empty", 137 host: "", 138 did: "", 139 accessToken: "", 140 wantErr: true, 141 errContains: "host is required", 142 }, 143 } 144 145 for _, tt := range tests { 146 t.Run(tt.name, func(t *testing.T) { 147 client, err := NewFromAccessToken(tt.host, tt.did, tt.accessToken) 148 149 if tt.wantErr { 150 if err == nil { 151 t.Fatal("expected error, got nil") 152 } 153 if !strings.Contains(err.Error(), tt.errContains) { 154 t.Errorf("error = %q, want contains %q", err.Error(), tt.errContains) 155 } 156 return 157 } 158 159 if err != nil { 160 t.Fatalf("unexpected error: %v", err) 161 } 162 163 if client == nil { 164 t.Fatal("expected client, got nil") 165 } 166 167 // Verify DID and HostURL methods 168 if client.DID() != tt.did { 169 t.Errorf("DID() = %q, want %q", client.DID(), tt.did) 170 } 171 if client.HostURL() != tt.host { 172 t.Errorf("HostURL() = %q, want %q", client.HostURL(), tt.host) 173 } 174 }) 175 } 176} 177 178// TestNewFromPasswordAuth validates factory function input validation. 179func TestNewFromPasswordAuth(t *testing.T) { 180 tests := []struct { 181 name string 182 host string 183 handle string 184 password string 185 wantErr bool 186 errContains string 187 }{ 188 { 189 name: "empty host", 190 host: "", 191 handle: "user.bsky.social", 192 password: "password", 193 wantErr: true, 194 errContains: "host is required", 195 }, 196 { 197 name: "empty handle", 198 host: "https://pds.example.com", 199 handle: "", 200 password: "password", 201 wantErr: true, 202 errContains: "handle is required", 203 }, 204 { 205 name: "empty password", 206 host: "https://pds.example.com", 207 handle: "user.bsky.social", 208 password: "", 209 wantErr: true, 210 errContains: "password is required", 211 }, 212 { 213 name: "all empty", 214 host: "", 215 handle: "", 216 password: "", 217 wantErr: true, 218 errContains: "host is required", 219 }, 220 } 221 222 for _, tt := range tests { 223 t.Run(tt.name, func(t *testing.T) { 224 ctx := context.Background() 225 _, err := NewFromPasswordAuth(ctx, tt.host, tt.handle, tt.password) 226 227 if tt.wantErr { 228 if err == nil { 229 t.Fatal("expected error, got nil") 230 } 231 if !strings.Contains(err.Error(), tt.errContains) { 232 t.Errorf("error = %q, want contains %q", err.Error(), tt.errContains) 233 } 234 return 235 } 236 237 // Note: We don't test success case here because it requires a real PDS 238 // Those are covered in integration tests 239 }) 240 } 241} 242 243// TestNewFromOAuthSession validates factory function input validation. 244func TestNewFromOAuthSession(t *testing.T) { 245 ctx := context.Background() 246 247 tests := []struct { 248 name string 249 oauthClient *oauth.ClientApp 250 sessionData *oauth.ClientSessionData 251 wantErr bool 252 errContains string 253 }{ 254 { 255 name: "nil oauth client", 256 oauthClient: nil, 257 sessionData: &oauth.ClientSessionData{}, 258 wantErr: true, 259 errContains: "oauthClient is required", 260 }, 261 { 262 name: "nil session data", 263 oauthClient: &oauth.ClientApp{}, 264 sessionData: nil, 265 wantErr: true, 266 errContains: "sessionData is required", 267 }, 268 { 269 name: "both nil", 270 oauthClient: nil, 271 sessionData: nil, 272 wantErr: true, 273 errContains: "oauthClient is required", 274 }, 275 } 276 277 for _, tt := range tests { 278 t.Run(tt.name, func(t *testing.T) { 279 _, err := NewFromOAuthSession(ctx, tt.oauthClient, tt.sessionData) 280 281 if tt.wantErr { 282 if err == nil { 283 t.Fatal("expected error, got nil") 284 } 285 if !strings.Contains(err.Error(), tt.errContains) { 286 t.Errorf("error = %q, want contains %q", err.Error(), tt.errContains) 287 } 288 return 289 } 290 291 // Note: Success case requires proper OAuth setup, tested in integration tests 292 }) 293 } 294} 295 296// TestClient_DIDAndHostURL verifies DID() and HostURL() return correct values. 297func TestClient_DIDAndHostURL(t *testing.T) { 298 expectedDID := "did:plc:test123" 299 expectedHost := "https://pds.test.com" 300 301 c := &client{ 302 did: expectedDID, 303 host: expectedHost, 304 } 305 306 if got := c.DID(); got != expectedDID { 307 t.Errorf("DID() = %q, want %q", got, expectedDID) 308 } 309 310 if got := c.HostURL(); got != expectedHost { 311 t.Errorf("HostURL() = %q, want %q", got, expectedHost) 312 } 313} 314 315// TestClient_CreateRecord tests the CreateRecord method with a mock server. 316func TestClient_CreateRecord(t *testing.T) { 317 tests := []struct { 318 name string 319 collection string 320 rkey string 321 record map[string]any 322 serverResponse map[string]any 323 serverStatus int 324 wantURI string 325 wantCID string 326 wantErr bool 327 }{ 328 { 329 name: "successful creation with rkey", 330 collection: "social.coves.vote", 331 rkey: "3kjzl5kcb2s2v", 332 record: map[string]any{ 333 "$type": "social.coves.vote", 334 "subject": "at://did:plc:abc123/social.coves.post/3kjzl5kc", 335 "direction": "up", 336 }, 337 serverResponse: map[string]any{ 338 "uri": "at://did:plc:test/social.coves.vote/3kjzl5kcb2s2v", 339 "cid": "bafyreigbtj4x7ip5legnfznufuopl4sg4knzc2cof6duas4b3q2fy6swua", 340 }, 341 serverStatus: http.StatusOK, 342 wantURI: "at://did:plc:test/social.coves.vote/3kjzl5kcb2s2v", 343 wantCID: "bafyreigbtj4x7ip5legnfznufuopl4sg4knzc2cof6duas4b3q2fy6swua", 344 wantErr: false, 345 }, 346 { 347 name: "successful creation without rkey (TID generated)", 348 collection: "social.coves.vote", 349 rkey: "", 350 record: map[string]any{ 351 "$type": "social.coves.vote", 352 "subject": "at://did:plc:abc123/social.coves.post/3kjzl5kc", 353 "direction": "down", 354 }, 355 serverResponse: map[string]any{ 356 "uri": "at://did:plc:test/social.coves.vote/3kjzl5kcc2a1b", 357 "cid": "bafyreihd4q3yqcfvnv5zlp6n4fqzh6z4p4m3mwc7vvr6k2j6y6v2a3b4c5", 358 }, 359 serverStatus: http.StatusOK, 360 wantURI: "at://did:plc:test/social.coves.vote/3kjzl5kcc2a1b", 361 wantCID: "bafyreihd4q3yqcfvnv5zlp6n4fqzh6z4p4m3mwc7vvr6k2j6y6v2a3b4c5", 362 wantErr: false, 363 }, 364 { 365 name: "server error", 366 collection: "social.coves.vote", 367 rkey: "test", 368 record: map[string]any{"$type": "social.coves.vote"}, 369 serverResponse: map[string]any{ 370 "error": "InvalidRequest", 371 "message": "Invalid record", 372 }, 373 serverStatus: http.StatusBadRequest, 374 wantErr: true, 375 }, 376 } 377 378 for _, tt := range tests { 379 t.Run(tt.name, func(t *testing.T) { 380 // Create mock server 381 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 382 // Verify method 383 if r.Method != http.MethodPost { 384 t.Errorf("expected POST request, got %s", r.Method) 385 } 386 387 // Verify path 388 expectedPath := "/xrpc/com.atproto.repo.createRecord" 389 if r.URL.Path != expectedPath { 390 t.Errorf("path = %q, want %q", r.URL.Path, expectedPath) 391 } 392 393 // Verify request body 394 var payload map[string]any 395 if err := json.NewDecoder(r.Body).Decode(&payload); err != nil { 396 t.Fatalf("failed to decode request body: %v", err) 397 } 398 399 // Check required fields 400 if payload["collection"] != tt.collection { 401 t.Errorf("collection = %v, want %v", payload["collection"], tt.collection) 402 } 403 404 // Check rkey inclusion 405 if tt.rkey != "" { 406 if payload["rkey"] != tt.rkey { 407 t.Errorf("rkey = %v, want %v", payload["rkey"], tt.rkey) 408 } 409 } else { 410 if _, exists := payload["rkey"]; exists { 411 t.Error("rkey should not be included when empty") 412 } 413 } 414 415 // Send response 416 w.Header().Set("Content-Type", "application/json") 417 w.WriteHeader(tt.serverStatus) 418 json.NewEncoder(w).Encode(tt.serverResponse) 419 })) 420 defer server.Close() 421 422 // Create client 423 apiClient := atclient.NewAPIClient(server.URL) 424 apiClient.Auth = &bearerAuth{token: "test-token"} 425 426 c := &client{ 427 apiClient: apiClient, 428 did: "did:plc:test", 429 host: server.URL, 430 } 431 432 // Execute CreateRecord 433 ctx := context.Background() 434 uri, cid, err := c.CreateRecord(ctx, tt.collection, tt.rkey, tt.record) 435 436 if tt.wantErr { 437 if err == nil { 438 t.Fatal("expected error, got nil") 439 } 440 return 441 } 442 443 if err != nil { 444 t.Fatalf("unexpected error: %v", err) 445 } 446 447 if uri != tt.wantURI { 448 t.Errorf("uri = %q, want %q", uri, tt.wantURI) 449 } 450 451 if cid != tt.wantCID { 452 t.Errorf("cid = %q, want %q", cid, tt.wantCID) 453 } 454 }) 455 } 456} 457 458// TestClient_DeleteRecord tests the DeleteRecord method with a mock server. 459func TestClient_DeleteRecord(t *testing.T) { 460 tests := []struct { 461 name string 462 collection string 463 rkey string 464 serverStatus int 465 wantErr bool 466 }{ 467 { 468 name: "successful deletion", 469 collection: "social.coves.vote", 470 rkey: "3kjzl5kcb2s2v", 471 serverStatus: http.StatusOK, 472 wantErr: false, 473 }, 474 { 475 name: "not found error", 476 collection: "social.coves.vote", 477 rkey: "nonexistent", 478 serverStatus: http.StatusNotFound, 479 wantErr: true, 480 }, 481 { 482 name: "server error", 483 collection: "social.coves.vote", 484 rkey: "test", 485 serverStatus: http.StatusInternalServerError, 486 wantErr: true, 487 }, 488 } 489 490 for _, tt := range tests { 491 t.Run(tt.name, func(t *testing.T) { 492 // Create mock server 493 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 494 // Verify method 495 if r.Method != http.MethodPost { 496 t.Errorf("expected POST request, got %s", r.Method) 497 } 498 499 // Verify path 500 expectedPath := "/xrpc/com.atproto.repo.deleteRecord" 501 if r.URL.Path != expectedPath { 502 t.Errorf("path = %q, want %q", r.URL.Path, expectedPath) 503 } 504 505 // Verify request body 506 var payload map[string]any 507 if err := json.NewDecoder(r.Body).Decode(&payload); err != nil { 508 t.Fatalf("failed to decode request body: %v", err) 509 } 510 511 if payload["collection"] != tt.collection { 512 t.Errorf("collection = %v, want %v", payload["collection"], tt.collection) 513 } 514 if payload["rkey"] != tt.rkey { 515 t.Errorf("rkey = %v, want %v", payload["rkey"], tt.rkey) 516 } 517 518 // Send response 519 w.WriteHeader(tt.serverStatus) 520 if tt.serverStatus != http.StatusOK { 521 w.Header().Set("Content-Type", "application/json") 522 json.NewEncoder(w).Encode(map[string]any{ 523 "error": "Error", 524 "message": "Operation failed", 525 }) 526 } 527 })) 528 defer server.Close() 529 530 // Create client 531 apiClient := atclient.NewAPIClient(server.URL) 532 apiClient.Auth = &bearerAuth{token: "test-token"} 533 534 c := &client{ 535 apiClient: apiClient, 536 did: "did:plc:test", 537 host: server.URL, 538 } 539 540 // Execute DeleteRecord 541 ctx := context.Background() 542 err := c.DeleteRecord(ctx, tt.collection, tt.rkey) 543 544 if tt.wantErr { 545 if err == nil { 546 t.Fatal("expected error, got nil") 547 } 548 return 549 } 550 551 if err != nil { 552 t.Fatalf("unexpected error: %v", err) 553 } 554 }) 555 } 556} 557 558// TestClient_ListRecords tests the ListRecords method with pagination. 559func TestClient_ListRecords(t *testing.T) { 560 tests := []struct { 561 name string 562 collection string 563 limit int 564 cursor string 565 serverResponse map[string]any 566 serverStatus int 567 wantRecords int 568 wantCursor string 569 wantErr bool 570 }{ 571 { 572 name: "successful list with records", 573 collection: "social.coves.vote", 574 limit: 10, 575 cursor: "", 576 serverResponse: map[string]any{ 577 "cursor": "next-cursor-123", 578 "records": []map[string]any{ 579 { 580 "uri": "at://did:plc:test/social.coves.vote/1", 581 "cid": "bafyreiabc123", 582 "value": map[string]any{"$type": "social.coves.vote", "direction": "up"}, 583 }, 584 { 585 "uri": "at://did:plc:test/social.coves.vote/2", 586 "cid": "bafyreiabc456", 587 "value": map[string]any{"$type": "social.coves.vote", "direction": "down"}, 588 }, 589 }, 590 }, 591 serverStatus: http.StatusOK, 592 wantRecords: 2, 593 wantCursor: "next-cursor-123", 594 wantErr: false, 595 }, 596 { 597 name: "empty list", 598 collection: "social.coves.vote", 599 limit: 10, 600 cursor: "", 601 serverResponse: map[string]any{ 602 "cursor": "", 603 "records": []map[string]any{}, 604 }, 605 serverStatus: http.StatusOK, 606 wantRecords: 0, 607 wantCursor: "", 608 wantErr: false, 609 }, 610 { 611 name: "with cursor pagination", 612 collection: "social.coves.vote", 613 limit: 5, 614 cursor: "existing-cursor", 615 serverResponse: map[string]any{ 616 "cursor": "final-cursor", 617 "records": []map[string]any{ 618 { 619 "uri": "at://did:plc:test/social.coves.vote/3", 620 "cid": "bafyreiabc789", 621 "value": map[string]any{"$type": "social.coves.vote", "direction": "up"}, 622 }, 623 }, 624 }, 625 serverStatus: http.StatusOK, 626 wantRecords: 1, 627 wantCursor: "final-cursor", 628 wantErr: false, 629 }, 630 { 631 name: "server error", 632 collection: "social.coves.vote", 633 limit: 10, 634 cursor: "", 635 serverResponse: map[string]any{"error": "Internal error"}, 636 serverStatus: http.StatusInternalServerError, 637 wantErr: true, 638 }, 639 } 640 641 for _, tt := range tests { 642 t.Run(tt.name, func(t *testing.T) { 643 // Create mock server 644 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 645 // Verify method 646 if r.Method != http.MethodGet { 647 t.Errorf("expected GET request, got %s", r.Method) 648 } 649 650 // Verify path 651 expectedPath := "/xrpc/com.atproto.repo.listRecords" 652 if r.URL.Path != expectedPath { 653 t.Errorf("path = %q, want %q", r.URL.Path, expectedPath) 654 } 655 656 // Verify query parameters 657 query := r.URL.Query() 658 if query.Get("collection") != tt.collection { 659 t.Errorf("collection param = %q, want %q", query.Get("collection"), tt.collection) 660 } 661 662 if tt.cursor != "" { 663 if query.Get("cursor") != tt.cursor { 664 t.Errorf("cursor param = %q, want %q", query.Get("cursor"), tt.cursor) 665 } 666 } 667 668 // Send response 669 w.Header().Set("Content-Type", "application/json") 670 w.WriteHeader(tt.serverStatus) 671 json.NewEncoder(w).Encode(tt.serverResponse) 672 })) 673 defer server.Close() 674 675 // Create client 676 apiClient := atclient.NewAPIClient(server.URL) 677 apiClient.Auth = &bearerAuth{token: "test-token"} 678 679 c := &client{ 680 apiClient: apiClient, 681 did: "did:plc:test", 682 host: server.URL, 683 } 684 685 // Execute ListRecords 686 ctx := context.Background() 687 resp, err := c.ListRecords(ctx, tt.collection, tt.limit, tt.cursor) 688 689 if tt.wantErr { 690 if err == nil { 691 t.Fatal("expected error, got nil") 692 } 693 return 694 } 695 696 if err != nil { 697 t.Fatalf("unexpected error: %v", err) 698 } 699 700 if resp == nil { 701 t.Fatal("expected response, got nil") 702 } 703 704 if len(resp.Records) != tt.wantRecords { 705 t.Errorf("records count = %d, want %d", len(resp.Records), tt.wantRecords) 706 } 707 708 if resp.Cursor != tt.wantCursor { 709 t.Errorf("cursor = %q, want %q", resp.Cursor, tt.wantCursor) 710 } 711 712 // Verify record structure if we have records 713 if tt.wantRecords > 0 { 714 for i, record := range resp.Records { 715 if record.URI == "" { 716 t.Errorf("record[%d].URI is empty", i) 717 } 718 if record.CID == "" { 719 t.Errorf("record[%d].CID is empty", i) 720 } 721 if record.Value == nil { 722 t.Errorf("record[%d].Value is nil", i) 723 } 724 } 725 } 726 }) 727 } 728} 729 730// TestClient_GetRecord tests the GetRecord method with a mock server. 731func TestClient_GetRecord(t *testing.T) { 732 tests := []struct { 733 name string 734 collection string 735 rkey string 736 serverResponse map[string]any 737 serverStatus int 738 wantURI string 739 wantCID string 740 wantErr bool 741 }{ 742 { 743 name: "successful get", 744 collection: "social.coves.vote", 745 rkey: "3kjzl5kcb2s2v", 746 serverResponse: map[string]any{ 747 "uri": "at://did:plc:test/social.coves.vote/3kjzl5kcb2s2v", 748 "cid": "bafyreigbtj4x7ip5legnfznufuopl4sg4knzc2cof6duas4b3q2fy6swua", 749 "value": map[string]any{ 750 "$type": "social.coves.vote", 751 "subject": "at://did:plc:abc/social.coves.post/123", 752 "direction": "up", 753 }, 754 }, 755 serverStatus: http.StatusOK, 756 wantURI: "at://did:plc:test/social.coves.vote/3kjzl5kcb2s2v", 757 wantCID: "bafyreigbtj4x7ip5legnfznufuopl4sg4knzc2cof6duas4b3q2fy6swua", 758 wantErr: false, 759 }, 760 { 761 name: "record not found", 762 collection: "social.coves.vote", 763 rkey: "nonexistent", 764 serverResponse: map[string]any{ 765 "error": "RecordNotFound", 766 "message": "Record not found", 767 }, 768 serverStatus: http.StatusNotFound, 769 wantErr: true, 770 }, 771 { 772 name: "server error", 773 collection: "social.coves.vote", 774 rkey: "test", 775 serverResponse: map[string]any{ 776 "error": "Internal error", 777 }, 778 serverStatus: http.StatusInternalServerError, 779 wantErr: true, 780 }, 781 } 782 783 for _, tt := range tests { 784 t.Run(tt.name, func(t *testing.T) { 785 // Create mock server 786 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 787 // Verify method 788 if r.Method != http.MethodGet { 789 t.Errorf("expected GET request, got %s", r.Method) 790 } 791 792 // Verify path 793 expectedPath := "/xrpc/com.atproto.repo.getRecord" 794 if r.URL.Path != expectedPath { 795 t.Errorf("path = %q, want %q", r.URL.Path, expectedPath) 796 } 797 798 // Verify query parameters 799 query := r.URL.Query() 800 if query.Get("collection") != tt.collection { 801 t.Errorf("collection param = %q, want %q", query.Get("collection"), tt.collection) 802 } 803 if query.Get("rkey") != tt.rkey { 804 t.Errorf("rkey param = %q, want %q", query.Get("rkey"), tt.rkey) 805 } 806 807 // Send response 808 w.Header().Set("Content-Type", "application/json") 809 w.WriteHeader(tt.serverStatus) 810 json.NewEncoder(w).Encode(tt.serverResponse) 811 })) 812 defer server.Close() 813 814 // Create client 815 apiClient := atclient.NewAPIClient(server.URL) 816 apiClient.Auth = &bearerAuth{token: "test-token"} 817 818 c := &client{ 819 apiClient: apiClient, 820 did: "did:plc:test", 821 host: server.URL, 822 } 823 824 // Execute GetRecord 825 ctx := context.Background() 826 resp, err := c.GetRecord(ctx, tt.collection, tt.rkey) 827 828 if tt.wantErr { 829 if err == nil { 830 t.Fatal("expected error, got nil") 831 } 832 return 833 } 834 835 if err != nil { 836 t.Fatalf("unexpected error: %v", err) 837 } 838 839 if resp == nil { 840 t.Fatal("expected response, got nil") 841 } 842 843 if resp.URI != tt.wantURI { 844 t.Errorf("URI = %q, want %q", resp.URI, tt.wantURI) 845 } 846 847 if resp.CID != tt.wantCID { 848 t.Errorf("CID = %q, want %q", resp.CID, tt.wantCID) 849 } 850 851 if resp.Value == nil { 852 t.Error("Value is nil") 853 } 854 }) 855 } 856} 857 858// TestTypedErrors_IsAuthError tests the IsAuthError helper function. 859func TestTypedErrors_IsAuthError(t *testing.T) { 860 tests := []struct { 861 name string 862 err error 863 wantAuth bool 864 }{ 865 { 866 name: "ErrUnauthorized is auth error", 867 err: ErrUnauthorized, 868 wantAuth: true, 869 }, 870 { 871 name: "ErrForbidden is auth error", 872 err: ErrForbidden, 873 wantAuth: true, 874 }, 875 { 876 name: "ErrNotFound is not auth error", 877 err: ErrNotFound, 878 wantAuth: false, 879 }, 880 { 881 name: "ErrBadRequest is not auth error", 882 err: ErrBadRequest, 883 wantAuth: false, 884 }, 885 { 886 name: "wrapped ErrUnauthorized is auth error", 887 err: errors.New("outer: " + ErrUnauthorized.Error()), 888 wantAuth: false, // Plain string wrap doesn't work 889 }, 890 { 891 name: "fmt.Errorf wrapped ErrUnauthorized is auth error", 892 err: wrapAPIError(&atclient.APIError{StatusCode: 401, Message: "test"}, "op"), 893 wantAuth: true, 894 }, 895 { 896 name: "fmt.Errorf wrapped ErrForbidden is auth error", 897 err: wrapAPIError(&atclient.APIError{StatusCode: 403, Message: "test"}, "op"), 898 wantAuth: true, 899 }, 900 { 901 name: "nil error", 902 err: nil, 903 wantAuth: false, 904 }, 905 { 906 name: "generic error", 907 err: errors.New("something else"), 908 wantAuth: false, 909 }, 910 } 911 912 for _, tt := range tests { 913 t.Run(tt.name, func(t *testing.T) { 914 got := IsAuthError(tt.err) 915 if got != tt.wantAuth { 916 t.Errorf("IsAuthError() = %v, want %v", got, tt.wantAuth) 917 } 918 }) 919 } 920} 921 922// TestWrapAPIError tests error wrapping for HTTP status codes. 923func TestWrapAPIError(t *testing.T) { 924 tests := []struct { 925 name string 926 err error 927 operation string 928 wantTyped error 929 wantNil bool 930 }{ 931 { 932 name: "nil error returns nil", 933 err: nil, 934 operation: "test", 935 wantNil: true, 936 }, 937 { 938 name: "401 maps to ErrUnauthorized", 939 err: &atclient.APIError{StatusCode: 401, Name: "AuthRequired", Message: "Not logged in"}, 940 operation: "createRecord", 941 wantTyped: ErrUnauthorized, 942 }, 943 { 944 name: "403 maps to ErrForbidden", 945 err: &atclient.APIError{StatusCode: 403, Name: "Forbidden", Message: "Access denied"}, 946 operation: "deleteRecord", 947 wantTyped: ErrForbidden, 948 }, 949 { 950 name: "404 maps to ErrNotFound", 951 err: &atclient.APIError{StatusCode: 404, Name: "NotFound", Message: "Record not found"}, 952 operation: "getRecord", 953 wantTyped: ErrNotFound, 954 }, 955 { 956 name: "400 maps to ErrBadRequest", 957 err: &atclient.APIError{StatusCode: 400, Name: "InvalidRequest", Message: "Bad input"}, 958 operation: "createRecord", 959 wantTyped: ErrBadRequest, 960 }, 961 { 962 name: "500 wraps without typed error", 963 err: &atclient.APIError{StatusCode: 500, Name: "InternalError", Message: "Server error"}, 964 operation: "listRecords", 965 wantTyped: nil, // No typed error for 500 966 }, 967 { 968 name: "non-APIError wraps normally", 969 err: errors.New("network timeout"), 970 operation: "createRecord", 971 wantTyped: nil, 972 }, 973 } 974 975 for _, tt := range tests { 976 t.Run(tt.name, func(t *testing.T) { 977 result := wrapAPIError(tt.err, tt.operation) 978 979 if tt.wantNil { 980 if result != nil { 981 t.Errorf("expected nil, got %v", result) 982 } 983 return 984 } 985 986 if result == nil { 987 t.Fatal("expected error, got nil") 988 } 989 990 if tt.wantTyped != nil { 991 if !errors.Is(result, tt.wantTyped) { 992 t.Errorf("expected errors.Is(%v, %v) to be true", result, tt.wantTyped) 993 } 994 } 995 996 // Verify operation is included in error message 997 if !strings.Contains(result.Error(), tt.operation) { 998 t.Errorf("error message %q should contain operation %q", result.Error(), tt.operation) 999 } 1000 }) 1001 } 1002} 1003 1004// TestClient_TypedErrors_CreateRecord tests that CreateRecord returns typed errors. 1005func TestClient_TypedErrors_CreateRecord(t *testing.T) { 1006 tests := []struct { 1007 name string 1008 serverStatus int 1009 wantErr error 1010 }{ 1011 { 1012 name: "401 returns ErrUnauthorized", 1013 serverStatus: http.StatusUnauthorized, 1014 wantErr: ErrUnauthorized, 1015 }, 1016 { 1017 name: "403 returns ErrForbidden", 1018 serverStatus: http.StatusForbidden, 1019 wantErr: ErrForbidden, 1020 }, 1021 { 1022 name: "400 returns ErrBadRequest", 1023 serverStatus: http.StatusBadRequest, 1024 wantErr: ErrBadRequest, 1025 }, 1026 } 1027 1028 for _, tt := range tests { 1029 t.Run(tt.name, func(t *testing.T) { 1030 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1031 w.Header().Set("Content-Type", "application/json") 1032 w.WriteHeader(tt.serverStatus) 1033 json.NewEncoder(w).Encode(map[string]any{ 1034 "error": "TestError", 1035 "message": "Test error message", 1036 }) 1037 })) 1038 defer server.Close() 1039 1040 apiClient := atclient.NewAPIClient(server.URL) 1041 apiClient.Auth = &bearerAuth{token: "test-token"} 1042 1043 c := &client{ 1044 apiClient: apiClient, 1045 did: "did:plc:test", 1046 host: server.URL, 1047 } 1048 1049 ctx := context.Background() 1050 _, _, err := c.CreateRecord(ctx, "test.collection", "rkey", map[string]any{}) 1051 1052 if err == nil { 1053 t.Fatal("expected error, got nil") 1054 } 1055 1056 if !errors.Is(err, tt.wantErr) { 1057 t.Errorf("expected errors.Is(%v, %v) to be true", err, tt.wantErr) 1058 } 1059 }) 1060 } 1061} 1062 1063// TestClient_TypedErrors_DeleteRecord tests that DeleteRecord returns typed errors. 1064func TestClient_TypedErrors_DeleteRecord(t *testing.T) { 1065 tests := []struct { 1066 name string 1067 serverStatus int 1068 wantErr error 1069 }{ 1070 { 1071 name: "401 returns ErrUnauthorized", 1072 serverStatus: http.StatusUnauthorized, 1073 wantErr: ErrUnauthorized, 1074 }, 1075 { 1076 name: "403 returns ErrForbidden", 1077 serverStatus: http.StatusForbidden, 1078 wantErr: ErrForbidden, 1079 }, 1080 { 1081 name: "404 returns ErrNotFound", 1082 serverStatus: http.StatusNotFound, 1083 wantErr: ErrNotFound, 1084 }, 1085 } 1086 1087 for _, tt := range tests { 1088 t.Run(tt.name, func(t *testing.T) { 1089 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1090 w.Header().Set("Content-Type", "application/json") 1091 w.WriteHeader(tt.serverStatus) 1092 json.NewEncoder(w).Encode(map[string]any{ 1093 "error": "TestError", 1094 "message": "Test error message", 1095 }) 1096 })) 1097 defer server.Close() 1098 1099 apiClient := atclient.NewAPIClient(server.URL) 1100 apiClient.Auth = &bearerAuth{token: "test-token"} 1101 1102 c := &client{ 1103 apiClient: apiClient, 1104 did: "did:plc:test", 1105 host: server.URL, 1106 } 1107 1108 ctx := context.Background() 1109 err := c.DeleteRecord(ctx, "test.collection", "rkey") 1110 1111 if err == nil { 1112 t.Fatal("expected error, got nil") 1113 } 1114 1115 if !errors.Is(err, tt.wantErr) { 1116 t.Errorf("expected errors.Is(%v, %v) to be true", err, tt.wantErr) 1117 } 1118 }) 1119 } 1120} 1121 1122// TestClient_TypedErrors_ListRecords tests that ListRecords returns typed errors. 1123func TestClient_TypedErrors_ListRecords(t *testing.T) { 1124 tests := []struct { 1125 name string 1126 serverStatus int 1127 wantErr error 1128 }{ 1129 { 1130 name: "401 returns ErrUnauthorized", 1131 serverStatus: http.StatusUnauthorized, 1132 wantErr: ErrUnauthorized, 1133 }, 1134 { 1135 name: "403 returns ErrForbidden", 1136 serverStatus: http.StatusForbidden, 1137 wantErr: ErrForbidden, 1138 }, 1139 } 1140 1141 for _, tt := range tests { 1142 t.Run(tt.name, func(t *testing.T) { 1143 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1144 w.Header().Set("Content-Type", "application/json") 1145 w.WriteHeader(tt.serverStatus) 1146 json.NewEncoder(w).Encode(map[string]any{ 1147 "error": "TestError", 1148 "message": "Test error message", 1149 }) 1150 })) 1151 defer server.Close() 1152 1153 apiClient := atclient.NewAPIClient(server.URL) 1154 apiClient.Auth = &bearerAuth{token: "test-token"} 1155 1156 c := &client{ 1157 apiClient: apiClient, 1158 did: "did:plc:test", 1159 host: server.URL, 1160 } 1161 1162 ctx := context.Background() 1163 _, err := c.ListRecords(ctx, "test.collection", 10, "") 1164 1165 if err == nil { 1166 t.Fatal("expected error, got nil") 1167 } 1168 1169 if !errors.Is(err, tt.wantErr) { 1170 t.Errorf("expected errors.Is(%v, %v) to be true", err, tt.wantErr) 1171 } 1172 }) 1173 } 1174}