A community based topic aggregation platform built on atproto
at main 37 kB view raw
1package pds 2 3import ( 4 "context" 5 "encoding/json" 6 "errors" 7 "net/http" 8 "net/http/httptest" 9 "strings" 10 "testing" 11 12 "github.com/bluesky-social/indigo/atproto/atclient" 13 "github.com/bluesky-social/indigo/atproto/auth/oauth" 14 "github.com/bluesky-social/indigo/atproto/syntax" 15) 16 17// This test suite provides comprehensive unit tests for the PDS client package. 18// 19// Coverage: 20// - All Client interface methods: 100% 21// - bearerAuth implementation: 100% 22// - Factory function input validation: 100% 23// - NewFromAccessToken: 100% 24// 25// Not covered (requires integration tests with real infrastructure): 26// - NewFromPasswordAuth success path (requires live PDS server) 27// - NewFromOAuthSession success path (requires OAuth infrastructure) 28// 29// The untested code paths involve external dependencies (PDS authentication, 30// OAuth session resumption) which are appropriately tested in E2E/integration tests. 31 32// TestClientImplementsInterface verifies that client implements the Client interface. 33func TestClientImplementsInterface(t *testing.T) { 34 var _ Client = (*client)(nil) 35} 36 37// TestBearerAuth_DoWithAuth verifies that bearerAuth correctly adds Authorization header. 38func TestBearerAuth_DoWithAuth(t *testing.T) { 39 tests := []struct { 40 name string 41 token string 42 }{ 43 { 44 name: "standard token", 45 token: "test-access-token-12345", 46 }, 47 { 48 name: "token with special characters", 49 token: "token.with.dots_and-dashes", 50 }, 51 } 52 53 for _, tt := range tests { 54 t.Run(tt.name, func(t *testing.T) { 55 // Create a test server that captures the Authorization header 56 var capturedHeader string 57 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 58 capturedHeader = r.Header.Get("Authorization") 59 w.WriteHeader(http.StatusOK) 60 })) 61 defer server.Close() 62 63 // Create bearerAuth instance 64 auth := &bearerAuth{token: tt.token} 65 66 // Create request 67 req, err := http.NewRequest(http.MethodGet, server.URL, nil) 68 if err != nil { 69 t.Fatalf("failed to create request: %v", err) 70 } 71 72 // Execute with auth 73 client := &http.Client{} 74 nsid := syntax.NSID("com.atproto.test") 75 _, err = auth.DoWithAuth(client, req, nsid) 76 if err != nil { 77 t.Fatalf("DoWithAuth failed: %v", err) 78 } 79 80 // Verify Authorization header 81 expectedHeader := "Bearer " + tt.token 82 if capturedHeader != expectedHeader { 83 t.Errorf("Authorization header = %q, want %q", capturedHeader, expectedHeader) 84 } 85 }) 86 } 87} 88 89// TestBearerAuth_ImplementsAuthMethod verifies bearerAuth implements atclient.AuthMethod. 90func TestBearerAuth_ImplementsAuthMethod(t *testing.T) { 91 var _ atclient.AuthMethod = (*bearerAuth)(nil) 92} 93 94// TestNewFromAccessToken validates factory function input validation. 95func TestNewFromAccessToken(t *testing.T) { 96 tests := []struct { 97 name string 98 host string 99 did string 100 accessToken string 101 wantErr bool 102 errContains string 103 }{ 104 { 105 name: "valid inputs", 106 host: "https://pds.example.com", 107 did: "did:plc:12345", 108 accessToken: "test-token", 109 wantErr: false, 110 }, 111 { 112 name: "empty host", 113 host: "", 114 did: "did:plc:12345", 115 accessToken: "test-token", 116 wantErr: true, 117 errContains: "host is required", 118 }, 119 { 120 name: "empty did", 121 host: "https://pds.example.com", 122 did: "", 123 accessToken: "test-token", 124 wantErr: true, 125 errContains: "did is required", 126 }, 127 { 128 name: "empty access token", 129 host: "https://pds.example.com", 130 did: "did:plc:12345", 131 accessToken: "", 132 wantErr: true, 133 errContains: "accessToken is required", 134 }, 135 { 136 name: "all empty", 137 host: "", 138 did: "", 139 accessToken: "", 140 wantErr: true, 141 errContains: "host is required", 142 }, 143 } 144 145 for _, tt := range tests { 146 t.Run(tt.name, func(t *testing.T) { 147 client, err := NewFromAccessToken(tt.host, tt.did, tt.accessToken) 148 149 if tt.wantErr { 150 if err == nil { 151 t.Fatal("expected error, got nil") 152 } 153 if !strings.Contains(err.Error(), tt.errContains) { 154 t.Errorf("error = %q, want contains %q", err.Error(), tt.errContains) 155 } 156 return 157 } 158 159 if err != nil { 160 t.Fatalf("unexpected error: %v", err) 161 } 162 163 if client == nil { 164 t.Fatal("expected client, got nil") 165 } 166 167 // Verify DID and HostURL methods 168 if client.DID() != tt.did { 169 t.Errorf("DID() = %q, want %q", client.DID(), tt.did) 170 } 171 if client.HostURL() != tt.host { 172 t.Errorf("HostURL() = %q, want %q", client.HostURL(), tt.host) 173 } 174 }) 175 } 176} 177 178// TestNewFromPasswordAuth validates factory function input validation. 179func TestNewFromPasswordAuth(t *testing.T) { 180 tests := []struct { 181 name string 182 host string 183 handle string 184 password string 185 wantErr bool 186 errContains string 187 }{ 188 { 189 name: "empty host", 190 host: "", 191 handle: "user.bsky.social", 192 password: "password", 193 wantErr: true, 194 errContains: "host is required", 195 }, 196 { 197 name: "empty handle", 198 host: "https://pds.example.com", 199 handle: "", 200 password: "password", 201 wantErr: true, 202 errContains: "handle is required", 203 }, 204 { 205 name: "empty password", 206 host: "https://pds.example.com", 207 handle: "user.bsky.social", 208 password: "", 209 wantErr: true, 210 errContains: "password is required", 211 }, 212 { 213 name: "all empty", 214 host: "", 215 handle: "", 216 password: "", 217 wantErr: true, 218 errContains: "host is required", 219 }, 220 } 221 222 for _, tt := range tests { 223 t.Run(tt.name, func(t *testing.T) { 224 ctx := context.Background() 225 _, err := NewFromPasswordAuth(ctx, tt.host, tt.handle, tt.password) 226 227 if tt.wantErr { 228 if err == nil { 229 t.Fatal("expected error, got nil") 230 } 231 if !strings.Contains(err.Error(), tt.errContains) { 232 t.Errorf("error = %q, want contains %q", err.Error(), tt.errContains) 233 } 234 return 235 } 236 237 // Note: We don't test success case here because it requires a real PDS 238 // Those are covered in integration tests 239 }) 240 } 241} 242 243// TestNewFromOAuthSession validates factory function input validation. 244func TestNewFromOAuthSession(t *testing.T) { 245 ctx := context.Background() 246 247 tests := []struct { 248 name string 249 oauthClient *oauth.ClientApp 250 sessionData *oauth.ClientSessionData 251 wantErr bool 252 errContains string 253 }{ 254 { 255 name: "nil oauth client", 256 oauthClient: nil, 257 sessionData: &oauth.ClientSessionData{}, 258 wantErr: true, 259 errContains: "oauthClient is required", 260 }, 261 { 262 name: "nil session data", 263 oauthClient: &oauth.ClientApp{}, 264 sessionData: nil, 265 wantErr: true, 266 errContains: "sessionData is required", 267 }, 268 { 269 name: "both nil", 270 oauthClient: nil, 271 sessionData: nil, 272 wantErr: true, 273 errContains: "oauthClient is required", 274 }, 275 } 276 277 for _, tt := range tests { 278 t.Run(tt.name, func(t *testing.T) { 279 _, err := NewFromOAuthSession(ctx, tt.oauthClient, tt.sessionData) 280 281 if tt.wantErr { 282 if err == nil { 283 t.Fatal("expected error, got nil") 284 } 285 if !strings.Contains(err.Error(), tt.errContains) { 286 t.Errorf("error = %q, want contains %q", err.Error(), tt.errContains) 287 } 288 return 289 } 290 291 // Note: Success case requires proper OAuth setup, tested in integration tests 292 }) 293 } 294} 295 296// TestClient_DIDAndHostURL verifies DID() and HostURL() return correct values. 297func TestClient_DIDAndHostURL(t *testing.T) { 298 expectedDID := "did:plc:test123" 299 expectedHost := "https://pds.test.com" 300 301 c := &client{ 302 did: expectedDID, 303 host: expectedHost, 304 } 305 306 if got := c.DID(); got != expectedDID { 307 t.Errorf("DID() = %q, want %q", got, expectedDID) 308 } 309 310 if got := c.HostURL(); got != expectedHost { 311 t.Errorf("HostURL() = %q, want %q", got, expectedHost) 312 } 313} 314 315// TestClient_CreateRecord tests the CreateRecord method with a mock server. 316func TestClient_CreateRecord(t *testing.T) { 317 tests := []struct { 318 name string 319 collection string 320 rkey string 321 record map[string]any 322 serverResponse map[string]any 323 serverStatus int 324 wantURI string 325 wantCID string 326 wantErr bool 327 }{ 328 { 329 name: "successful creation with rkey", 330 collection: "social.coves.vote", 331 rkey: "3kjzl5kcb2s2v", 332 record: map[string]any{ 333 "$type": "social.coves.vote", 334 "subject": "at://did:plc:abc123/social.coves.post/3kjzl5kc", 335 "direction": "up", 336 }, 337 serverResponse: map[string]any{ 338 "uri": "at://did:plc:test/social.coves.vote/3kjzl5kcb2s2v", 339 "cid": "bafyreigbtj4x7ip5legnfznufuopl4sg4knzc2cof6duas4b3q2fy6swua", 340 }, 341 serverStatus: http.StatusOK, 342 wantURI: "at://did:plc:test/social.coves.vote/3kjzl5kcb2s2v", 343 wantCID: "bafyreigbtj4x7ip5legnfznufuopl4sg4knzc2cof6duas4b3q2fy6swua", 344 wantErr: false, 345 }, 346 { 347 name: "successful creation without rkey (TID generated)", 348 collection: "social.coves.vote", 349 rkey: "", 350 record: map[string]any{ 351 "$type": "social.coves.vote", 352 "subject": "at://did:plc:abc123/social.coves.post/3kjzl5kc", 353 "direction": "down", 354 }, 355 serverResponse: map[string]any{ 356 "uri": "at://did:plc:test/social.coves.vote/3kjzl5kcc2a1b", 357 "cid": "bafyreihd4q3yqcfvnv5zlp6n4fqzh6z4p4m3mwc7vvr6k2j6y6v2a3b4c5", 358 }, 359 serverStatus: http.StatusOK, 360 wantURI: "at://did:plc:test/social.coves.vote/3kjzl5kcc2a1b", 361 wantCID: "bafyreihd4q3yqcfvnv5zlp6n4fqzh6z4p4m3mwc7vvr6k2j6y6v2a3b4c5", 362 wantErr: false, 363 }, 364 { 365 name: "server error", 366 collection: "social.coves.vote", 367 rkey: "test", 368 record: map[string]any{"$type": "social.coves.vote"}, 369 serverResponse: map[string]any{ 370 "error": "InvalidRequest", 371 "message": "Invalid record", 372 }, 373 serverStatus: http.StatusBadRequest, 374 wantErr: true, 375 }, 376 } 377 378 for _, tt := range tests { 379 t.Run(tt.name, func(t *testing.T) { 380 // Create mock server 381 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 382 // Verify method 383 if r.Method != http.MethodPost { 384 t.Errorf("expected POST request, got %s", r.Method) 385 } 386 387 // Verify path 388 expectedPath := "/xrpc/com.atproto.repo.createRecord" 389 if r.URL.Path != expectedPath { 390 t.Errorf("path = %q, want %q", r.URL.Path, expectedPath) 391 } 392 393 // Verify request body 394 var payload map[string]any 395 if err := json.NewDecoder(r.Body).Decode(&payload); err != nil { 396 t.Fatalf("failed to decode request body: %v", err) 397 } 398 399 // Check required fields 400 if payload["collection"] != tt.collection { 401 t.Errorf("collection = %v, want %v", payload["collection"], tt.collection) 402 } 403 404 // Check rkey inclusion 405 if tt.rkey != "" { 406 if payload["rkey"] != tt.rkey { 407 t.Errorf("rkey = %v, want %v", payload["rkey"], tt.rkey) 408 } 409 } else { 410 if _, exists := payload["rkey"]; exists { 411 t.Error("rkey should not be included when empty") 412 } 413 } 414 415 // Send response 416 w.Header().Set("Content-Type", "application/json") 417 w.WriteHeader(tt.serverStatus) 418 json.NewEncoder(w).Encode(tt.serverResponse) 419 })) 420 defer server.Close() 421 422 // Create client 423 apiClient := atclient.NewAPIClient(server.URL) 424 apiClient.Auth = &bearerAuth{token: "test-token"} 425 426 c := &client{ 427 apiClient: apiClient, 428 did: "did:plc:test", 429 host: server.URL, 430 } 431 432 // Execute CreateRecord 433 ctx := context.Background() 434 uri, cid, err := c.CreateRecord(ctx, tt.collection, tt.rkey, tt.record) 435 436 if tt.wantErr { 437 if err == nil { 438 t.Fatal("expected error, got nil") 439 } 440 return 441 } 442 443 if err != nil { 444 t.Fatalf("unexpected error: %v", err) 445 } 446 447 if uri != tt.wantURI { 448 t.Errorf("uri = %q, want %q", uri, tt.wantURI) 449 } 450 451 if cid != tt.wantCID { 452 t.Errorf("cid = %q, want %q", cid, tt.wantCID) 453 } 454 }) 455 } 456} 457 458// TestClient_DeleteRecord tests the DeleteRecord method with a mock server. 459func TestClient_DeleteRecord(t *testing.T) { 460 tests := []struct { 461 name string 462 collection string 463 rkey string 464 serverStatus int 465 wantErr bool 466 }{ 467 { 468 name: "successful deletion", 469 collection: "social.coves.vote", 470 rkey: "3kjzl5kcb2s2v", 471 serverStatus: http.StatusOK, 472 wantErr: false, 473 }, 474 { 475 name: "not found error", 476 collection: "social.coves.vote", 477 rkey: "nonexistent", 478 serverStatus: http.StatusNotFound, 479 wantErr: true, 480 }, 481 { 482 name: "server error", 483 collection: "social.coves.vote", 484 rkey: "test", 485 serverStatus: http.StatusInternalServerError, 486 wantErr: true, 487 }, 488 } 489 490 for _, tt := range tests { 491 t.Run(tt.name, func(t *testing.T) { 492 // Create mock server 493 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 494 // Verify method 495 if r.Method != http.MethodPost { 496 t.Errorf("expected POST request, got %s", r.Method) 497 } 498 499 // Verify path 500 expectedPath := "/xrpc/com.atproto.repo.deleteRecord" 501 if r.URL.Path != expectedPath { 502 t.Errorf("path = %q, want %q", r.URL.Path, expectedPath) 503 } 504 505 // Verify request body 506 var payload map[string]any 507 if err := json.NewDecoder(r.Body).Decode(&payload); err != nil { 508 t.Fatalf("failed to decode request body: %v", err) 509 } 510 511 if payload["collection"] != tt.collection { 512 t.Errorf("collection = %v, want %v", payload["collection"], tt.collection) 513 } 514 if payload["rkey"] != tt.rkey { 515 t.Errorf("rkey = %v, want %v", payload["rkey"], tt.rkey) 516 } 517 518 // Send response 519 w.WriteHeader(tt.serverStatus) 520 if tt.serverStatus != http.StatusOK { 521 w.Header().Set("Content-Type", "application/json") 522 json.NewEncoder(w).Encode(map[string]any{ 523 "error": "Error", 524 "message": "Operation failed", 525 }) 526 } 527 })) 528 defer server.Close() 529 530 // Create client 531 apiClient := atclient.NewAPIClient(server.URL) 532 apiClient.Auth = &bearerAuth{token: "test-token"} 533 534 c := &client{ 535 apiClient: apiClient, 536 did: "did:plc:test", 537 host: server.URL, 538 } 539 540 // Execute DeleteRecord 541 ctx := context.Background() 542 err := c.DeleteRecord(ctx, tt.collection, tt.rkey) 543 544 if tt.wantErr { 545 if err == nil { 546 t.Fatal("expected error, got nil") 547 } 548 return 549 } 550 551 if err != nil { 552 t.Fatalf("unexpected error: %v", err) 553 } 554 }) 555 } 556} 557 558// TestClient_ListRecords tests the ListRecords method with pagination. 559func TestClient_ListRecords(t *testing.T) { 560 tests := []struct { 561 name string 562 collection string 563 limit int 564 cursor string 565 serverResponse map[string]any 566 serverStatus int 567 wantRecords int 568 wantCursor string 569 wantErr bool 570 }{ 571 { 572 name: "successful list with records", 573 collection: "social.coves.vote", 574 limit: 10, 575 cursor: "", 576 serverResponse: map[string]any{ 577 "cursor": "next-cursor-123", 578 "records": []map[string]any{ 579 { 580 "uri": "at://did:plc:test/social.coves.vote/1", 581 "cid": "bafyreiabc123", 582 "value": map[string]any{"$type": "social.coves.vote", "direction": "up"}, 583 }, 584 { 585 "uri": "at://did:plc:test/social.coves.vote/2", 586 "cid": "bafyreiabc456", 587 "value": map[string]any{"$type": "social.coves.vote", "direction": "down"}, 588 }, 589 }, 590 }, 591 serverStatus: http.StatusOK, 592 wantRecords: 2, 593 wantCursor: "next-cursor-123", 594 wantErr: false, 595 }, 596 { 597 name: "empty list", 598 collection: "social.coves.vote", 599 limit: 10, 600 cursor: "", 601 serverResponse: map[string]any{ 602 "cursor": "", 603 "records": []map[string]any{}, 604 }, 605 serverStatus: http.StatusOK, 606 wantRecords: 0, 607 wantCursor: "", 608 wantErr: false, 609 }, 610 { 611 name: "with cursor pagination", 612 collection: "social.coves.vote", 613 limit: 5, 614 cursor: "existing-cursor", 615 serverResponse: map[string]any{ 616 "cursor": "final-cursor", 617 "records": []map[string]any{ 618 { 619 "uri": "at://did:plc:test/social.coves.vote/3", 620 "cid": "bafyreiabc789", 621 "value": map[string]any{"$type": "social.coves.vote", "direction": "up"}, 622 }, 623 }, 624 }, 625 serverStatus: http.StatusOK, 626 wantRecords: 1, 627 wantCursor: "final-cursor", 628 wantErr: false, 629 }, 630 { 631 name: "server error", 632 collection: "social.coves.vote", 633 limit: 10, 634 cursor: "", 635 serverResponse: map[string]any{"error": "Internal error"}, 636 serverStatus: http.StatusInternalServerError, 637 wantErr: true, 638 }, 639 } 640 641 for _, tt := range tests { 642 t.Run(tt.name, func(t *testing.T) { 643 // Create mock server 644 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 645 // Verify method 646 if r.Method != http.MethodGet { 647 t.Errorf("expected GET request, got %s", r.Method) 648 } 649 650 // Verify path 651 expectedPath := "/xrpc/com.atproto.repo.listRecords" 652 if r.URL.Path != expectedPath { 653 t.Errorf("path = %q, want %q", r.URL.Path, expectedPath) 654 } 655 656 // Verify query parameters 657 query := r.URL.Query() 658 if query.Get("collection") != tt.collection { 659 t.Errorf("collection param = %q, want %q", query.Get("collection"), tt.collection) 660 } 661 662 if tt.cursor != "" { 663 if query.Get("cursor") != tt.cursor { 664 t.Errorf("cursor param = %q, want %q", query.Get("cursor"), tt.cursor) 665 } 666 } 667 668 // Send response 669 w.Header().Set("Content-Type", "application/json") 670 w.WriteHeader(tt.serverStatus) 671 json.NewEncoder(w).Encode(tt.serverResponse) 672 })) 673 defer server.Close() 674 675 // Create client 676 apiClient := atclient.NewAPIClient(server.URL) 677 apiClient.Auth = &bearerAuth{token: "test-token"} 678 679 c := &client{ 680 apiClient: apiClient, 681 did: "did:plc:test", 682 host: server.URL, 683 } 684 685 // Execute ListRecords 686 ctx := context.Background() 687 resp, err := c.ListRecords(ctx, tt.collection, tt.limit, tt.cursor) 688 689 if tt.wantErr { 690 if err == nil { 691 t.Fatal("expected error, got nil") 692 } 693 return 694 } 695 696 if err != nil { 697 t.Fatalf("unexpected error: %v", err) 698 } 699 700 if resp == nil { 701 t.Fatal("expected response, got nil") 702 } 703 704 if len(resp.Records) != tt.wantRecords { 705 t.Errorf("records count = %d, want %d", len(resp.Records), tt.wantRecords) 706 } 707 708 if resp.Cursor != tt.wantCursor { 709 t.Errorf("cursor = %q, want %q", resp.Cursor, tt.wantCursor) 710 } 711 712 // Verify record structure if we have records 713 if tt.wantRecords > 0 { 714 for i, record := range resp.Records { 715 if record.URI == "" { 716 t.Errorf("record[%d].URI is empty", i) 717 } 718 if record.CID == "" { 719 t.Errorf("record[%d].CID is empty", i) 720 } 721 if record.Value == nil { 722 t.Errorf("record[%d].Value is nil", i) 723 } 724 } 725 } 726 }) 727 } 728} 729 730// TestClient_GetRecord tests the GetRecord method with a mock server. 731func TestClient_GetRecord(t *testing.T) { 732 tests := []struct { 733 name string 734 collection string 735 rkey string 736 serverResponse map[string]any 737 serverStatus int 738 wantURI string 739 wantCID string 740 wantErr bool 741 }{ 742 { 743 name: "successful get", 744 collection: "social.coves.vote", 745 rkey: "3kjzl5kcb2s2v", 746 serverResponse: map[string]any{ 747 "uri": "at://did:plc:test/social.coves.vote/3kjzl5kcb2s2v", 748 "cid": "bafyreigbtj4x7ip5legnfznufuopl4sg4knzc2cof6duas4b3q2fy6swua", 749 "value": map[string]any{ 750 "$type": "social.coves.vote", 751 "subject": "at://did:plc:abc/social.coves.post/123", 752 "direction": "up", 753 }, 754 }, 755 serverStatus: http.StatusOK, 756 wantURI: "at://did:plc:test/social.coves.vote/3kjzl5kcb2s2v", 757 wantCID: "bafyreigbtj4x7ip5legnfznufuopl4sg4knzc2cof6duas4b3q2fy6swua", 758 wantErr: false, 759 }, 760 { 761 name: "record not found", 762 collection: "social.coves.vote", 763 rkey: "nonexistent", 764 serverResponse: map[string]any{ 765 "error": "RecordNotFound", 766 "message": "Record not found", 767 }, 768 serverStatus: http.StatusNotFound, 769 wantErr: true, 770 }, 771 { 772 name: "server error", 773 collection: "social.coves.vote", 774 rkey: "test", 775 serverResponse: map[string]any{ 776 "error": "Internal error", 777 }, 778 serverStatus: http.StatusInternalServerError, 779 wantErr: true, 780 }, 781 } 782 783 for _, tt := range tests { 784 t.Run(tt.name, func(t *testing.T) { 785 // Create mock server 786 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 787 // Verify method 788 if r.Method != http.MethodGet { 789 t.Errorf("expected GET request, got %s", r.Method) 790 } 791 792 // Verify path 793 expectedPath := "/xrpc/com.atproto.repo.getRecord" 794 if r.URL.Path != expectedPath { 795 t.Errorf("path = %q, want %q", r.URL.Path, expectedPath) 796 } 797 798 // Verify query parameters 799 query := r.URL.Query() 800 if query.Get("collection") != tt.collection { 801 t.Errorf("collection param = %q, want %q", query.Get("collection"), tt.collection) 802 } 803 if query.Get("rkey") != tt.rkey { 804 t.Errorf("rkey param = %q, want %q", query.Get("rkey"), tt.rkey) 805 } 806 807 // Send response 808 w.Header().Set("Content-Type", "application/json") 809 w.WriteHeader(tt.serverStatus) 810 json.NewEncoder(w).Encode(tt.serverResponse) 811 })) 812 defer server.Close() 813 814 // Create client 815 apiClient := atclient.NewAPIClient(server.URL) 816 apiClient.Auth = &bearerAuth{token: "test-token"} 817 818 c := &client{ 819 apiClient: apiClient, 820 did: "did:plc:test", 821 host: server.URL, 822 } 823 824 // Execute GetRecord 825 ctx := context.Background() 826 resp, err := c.GetRecord(ctx, tt.collection, tt.rkey) 827 828 if tt.wantErr { 829 if err == nil { 830 t.Fatal("expected error, got nil") 831 } 832 return 833 } 834 835 if err != nil { 836 t.Fatalf("unexpected error: %v", err) 837 } 838 839 if resp == nil { 840 t.Fatal("expected response, got nil") 841 } 842 843 if resp.URI != tt.wantURI { 844 t.Errorf("URI = %q, want %q", resp.URI, tt.wantURI) 845 } 846 847 if resp.CID != tt.wantCID { 848 t.Errorf("CID = %q, want %q", resp.CID, tt.wantCID) 849 } 850 851 if resp.Value == nil { 852 t.Error("Value is nil") 853 } 854 }) 855 } 856} 857 858// TestTypedErrors_IsAuthError tests the IsAuthError helper function. 859func TestTypedErrors_IsAuthError(t *testing.T) { 860 tests := []struct { 861 name string 862 err error 863 wantAuth bool 864 }{ 865 { 866 name: "ErrUnauthorized is auth error", 867 err: ErrUnauthorized, 868 wantAuth: true, 869 }, 870 { 871 name: "ErrForbidden is auth error", 872 err: ErrForbidden, 873 wantAuth: true, 874 }, 875 { 876 name: "ErrNotFound is not auth error", 877 err: ErrNotFound, 878 wantAuth: false, 879 }, 880 { 881 name: "ErrBadRequest is not auth error", 882 err: ErrBadRequest, 883 wantAuth: false, 884 }, 885 { 886 name: "wrapped ErrUnauthorized is auth error", 887 err: errors.New("outer: " + ErrUnauthorized.Error()), 888 wantAuth: false, // Plain string wrap doesn't work 889 }, 890 { 891 name: "fmt.Errorf wrapped ErrUnauthorized is auth error", 892 err: wrapAPIError(&atclient.APIError{StatusCode: 401, Message: "test"}, "op"), 893 wantAuth: true, 894 }, 895 { 896 name: "fmt.Errorf wrapped ErrForbidden is auth error", 897 err: wrapAPIError(&atclient.APIError{StatusCode: 403, Message: "test"}, "op"), 898 wantAuth: true, 899 }, 900 { 901 name: "nil error", 902 err: nil, 903 wantAuth: false, 904 }, 905 { 906 name: "generic error", 907 err: errors.New("something else"), 908 wantAuth: false, 909 }, 910 } 911 912 for _, tt := range tests { 913 t.Run(tt.name, func(t *testing.T) { 914 got := IsAuthError(tt.err) 915 if got != tt.wantAuth { 916 t.Errorf("IsAuthError() = %v, want %v", got, tt.wantAuth) 917 } 918 }) 919 } 920} 921 922// TestWrapAPIError tests error wrapping for HTTP status codes. 923func TestWrapAPIError(t *testing.T) { 924 tests := []struct { 925 name string 926 err error 927 operation string 928 wantTyped error 929 wantNil bool 930 }{ 931 { 932 name: "nil error returns nil", 933 err: nil, 934 operation: "test", 935 wantNil: true, 936 }, 937 { 938 name: "401 maps to ErrUnauthorized", 939 err: &atclient.APIError{StatusCode: 401, Name: "AuthRequired", Message: "Not logged in"}, 940 operation: "createRecord", 941 wantTyped: ErrUnauthorized, 942 }, 943 { 944 name: "403 maps to ErrForbidden", 945 err: &atclient.APIError{StatusCode: 403, Name: "Forbidden", Message: "Access denied"}, 946 operation: "deleteRecord", 947 wantTyped: ErrForbidden, 948 }, 949 { 950 name: "404 maps to ErrNotFound", 951 err: &atclient.APIError{StatusCode: 404, Name: "NotFound", Message: "Record not found"}, 952 operation: "getRecord", 953 wantTyped: ErrNotFound, 954 }, 955 { 956 name: "400 maps to ErrBadRequest", 957 err: &atclient.APIError{StatusCode: 400, Name: "InvalidRequest", Message: "Bad input"}, 958 operation: "createRecord", 959 wantTyped: ErrBadRequest, 960 }, 961 { 962 name: "409 maps to ErrConflict", 963 err: &atclient.APIError{StatusCode: 409, Name: "InvalidSwap", Message: "Record CID mismatch"}, 964 operation: "putRecord", 965 wantTyped: ErrConflict, 966 }, 967 { 968 name: "500 wraps without typed error", 969 err: &atclient.APIError{StatusCode: 500, Name: "InternalError", Message: "Server error"}, 970 operation: "listRecords", 971 wantTyped: nil, // No typed error for 500 972 }, 973 { 974 name: "non-APIError wraps normally", 975 err: errors.New("network timeout"), 976 operation: "createRecord", 977 wantTyped: nil, 978 }, 979 } 980 981 for _, tt := range tests { 982 t.Run(tt.name, func(t *testing.T) { 983 result := wrapAPIError(tt.err, tt.operation) 984 985 if tt.wantNil { 986 if result != nil { 987 t.Errorf("expected nil, got %v", result) 988 } 989 return 990 } 991 992 if result == nil { 993 t.Fatal("expected error, got nil") 994 } 995 996 if tt.wantTyped != nil { 997 if !errors.Is(result, tt.wantTyped) { 998 t.Errorf("expected errors.Is(%v, %v) to be true", result, tt.wantTyped) 999 } 1000 } 1001 1002 // Verify operation is included in error message 1003 if !strings.Contains(result.Error(), tt.operation) { 1004 t.Errorf("error message %q should contain operation %q", result.Error(), tt.operation) 1005 } 1006 }) 1007 } 1008} 1009 1010// TestClient_TypedErrors_CreateRecord tests that CreateRecord returns typed errors. 1011func TestClient_TypedErrors_CreateRecord(t *testing.T) { 1012 tests := []struct { 1013 name string 1014 serverStatus int 1015 wantErr error 1016 }{ 1017 { 1018 name: "401 returns ErrUnauthorized", 1019 serverStatus: http.StatusUnauthorized, 1020 wantErr: ErrUnauthorized, 1021 }, 1022 { 1023 name: "403 returns ErrForbidden", 1024 serverStatus: http.StatusForbidden, 1025 wantErr: ErrForbidden, 1026 }, 1027 { 1028 name: "400 returns ErrBadRequest", 1029 serverStatus: http.StatusBadRequest, 1030 wantErr: ErrBadRequest, 1031 }, 1032 } 1033 1034 for _, tt := range tests { 1035 t.Run(tt.name, func(t *testing.T) { 1036 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1037 w.Header().Set("Content-Type", "application/json") 1038 w.WriteHeader(tt.serverStatus) 1039 json.NewEncoder(w).Encode(map[string]any{ 1040 "error": "TestError", 1041 "message": "Test error message", 1042 }) 1043 })) 1044 defer server.Close() 1045 1046 apiClient := atclient.NewAPIClient(server.URL) 1047 apiClient.Auth = &bearerAuth{token: "test-token"} 1048 1049 c := &client{ 1050 apiClient: apiClient, 1051 did: "did:plc:test", 1052 host: server.URL, 1053 } 1054 1055 ctx := context.Background() 1056 _, _, err := c.CreateRecord(ctx, "test.collection", "rkey", map[string]any{}) 1057 1058 if err == nil { 1059 t.Fatal("expected error, got nil") 1060 } 1061 1062 if !errors.Is(err, tt.wantErr) { 1063 t.Errorf("expected errors.Is(%v, %v) to be true", err, tt.wantErr) 1064 } 1065 }) 1066 } 1067} 1068 1069// TestClient_TypedErrors_DeleteRecord tests that DeleteRecord returns typed errors. 1070func TestClient_TypedErrors_DeleteRecord(t *testing.T) { 1071 tests := []struct { 1072 name string 1073 serverStatus int 1074 wantErr error 1075 }{ 1076 { 1077 name: "401 returns ErrUnauthorized", 1078 serverStatus: http.StatusUnauthorized, 1079 wantErr: ErrUnauthorized, 1080 }, 1081 { 1082 name: "403 returns ErrForbidden", 1083 serverStatus: http.StatusForbidden, 1084 wantErr: ErrForbidden, 1085 }, 1086 { 1087 name: "404 returns ErrNotFound", 1088 serverStatus: http.StatusNotFound, 1089 wantErr: ErrNotFound, 1090 }, 1091 } 1092 1093 for _, tt := range tests { 1094 t.Run(tt.name, func(t *testing.T) { 1095 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1096 w.Header().Set("Content-Type", "application/json") 1097 w.WriteHeader(tt.serverStatus) 1098 json.NewEncoder(w).Encode(map[string]any{ 1099 "error": "TestError", 1100 "message": "Test error message", 1101 }) 1102 })) 1103 defer server.Close() 1104 1105 apiClient := atclient.NewAPIClient(server.URL) 1106 apiClient.Auth = &bearerAuth{token: "test-token"} 1107 1108 c := &client{ 1109 apiClient: apiClient, 1110 did: "did:plc:test", 1111 host: server.URL, 1112 } 1113 1114 ctx := context.Background() 1115 err := c.DeleteRecord(ctx, "test.collection", "rkey") 1116 1117 if err == nil { 1118 t.Fatal("expected error, got nil") 1119 } 1120 1121 if !errors.Is(err, tt.wantErr) { 1122 t.Errorf("expected errors.Is(%v, %v) to be true", err, tt.wantErr) 1123 } 1124 }) 1125 } 1126} 1127 1128// TestClient_TypedErrors_ListRecords tests that ListRecords returns typed errors. 1129func TestClient_TypedErrors_ListRecords(t *testing.T) { 1130 tests := []struct { 1131 name string 1132 serverStatus int 1133 wantErr error 1134 }{ 1135 { 1136 name: "401 returns ErrUnauthorized", 1137 serverStatus: http.StatusUnauthorized, 1138 wantErr: ErrUnauthorized, 1139 }, 1140 { 1141 name: "403 returns ErrForbidden", 1142 serverStatus: http.StatusForbidden, 1143 wantErr: ErrForbidden, 1144 }, 1145 } 1146 1147 for _, tt := range tests { 1148 t.Run(tt.name, func(t *testing.T) { 1149 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1150 w.Header().Set("Content-Type", "application/json") 1151 w.WriteHeader(tt.serverStatus) 1152 json.NewEncoder(w).Encode(map[string]any{ 1153 "error": "TestError", 1154 "message": "Test error message", 1155 }) 1156 })) 1157 defer server.Close() 1158 1159 apiClient := atclient.NewAPIClient(server.URL) 1160 apiClient.Auth = &bearerAuth{token: "test-token"} 1161 1162 c := &client{ 1163 apiClient: apiClient, 1164 did: "did:plc:test", 1165 host: server.URL, 1166 } 1167 1168 ctx := context.Background() 1169 _, err := c.ListRecords(ctx, "test.collection", 10, "") 1170 1171 if err == nil { 1172 t.Fatal("expected error, got nil") 1173 } 1174 1175 if !errors.Is(err, tt.wantErr) { 1176 t.Errorf("expected errors.Is(%v, %v) to be true", err, tt.wantErr) 1177 } 1178 }) 1179 } 1180} 1181 1182// TestClient_PutRecord tests the PutRecord method with a mock server. 1183func TestClient_PutRecord(t *testing.T) { 1184 tests := []struct { 1185 name string 1186 collection string 1187 rkey string 1188 record map[string]any 1189 swapRecord string 1190 serverResponse map[string]any 1191 serverStatus int 1192 wantURI string 1193 wantCID string 1194 wantErr bool 1195 }{ 1196 { 1197 name: "successful put with swapRecord", 1198 collection: "social.coves.comment", 1199 rkey: "3kjzl5kcb2s2v", 1200 record: map[string]any{ 1201 "$type": "social.coves.comment", 1202 "content": "Updated comment content", 1203 }, 1204 swapRecord: "bafyreigbtj4x7ip5legnfznufuopl4sg4knzc2cof6duas4b3q2fy6swua", 1205 serverResponse: map[string]any{ 1206 "uri": "at://did:plc:test/social.coves.comment/3kjzl5kcb2s2v", 1207 "cid": "bafyreihd4q3yqcfvnv5zlp6n4fqzh6z4p4m3mwc7vvr6k2j6y6v2a3b4c5", 1208 }, 1209 serverStatus: http.StatusOK, 1210 wantURI: "at://did:plc:test/social.coves.comment/3kjzl5kcb2s2v", 1211 wantCID: "bafyreihd4q3yqcfvnv5zlp6n4fqzh6z4p4m3mwc7vvr6k2j6y6v2a3b4c5", 1212 wantErr: false, 1213 }, 1214 { 1215 name: "successful put without swapRecord", 1216 collection: "social.coves.comment", 1217 rkey: "3kjzl5kcb2s2v", 1218 record: map[string]any{ 1219 "$type": "social.coves.comment", 1220 "content": "Updated comment", 1221 }, 1222 swapRecord: "", 1223 serverResponse: map[string]any{ 1224 "uri": "at://did:plc:test/social.coves.comment/3kjzl5kcb2s2v", 1225 "cid": "bafyreihd4q3yqcfvnv5zlp6n4fqzh6z4p4m3mwc7vvr6k2j6y6v2a3b4c5", 1226 }, 1227 serverStatus: http.StatusOK, 1228 wantURI: "at://did:plc:test/social.coves.comment/3kjzl5kcb2s2v", 1229 wantCID: "bafyreihd4q3yqcfvnv5zlp6n4fqzh6z4p4m3mwc7vvr6k2j6y6v2a3b4c5", 1230 wantErr: false, 1231 }, 1232 { 1233 name: "conflict error (409)", 1234 collection: "social.coves.comment", 1235 rkey: "test", 1236 record: map[string]any{"$type": "social.coves.comment"}, 1237 swapRecord: "bafyreigbtj4x7ip5legnfznufuopl4sg4knzc2cof6duas4b3q2fy6swua", 1238 serverResponse: map[string]any{ 1239 "error": "InvalidSwap", 1240 "message": "Record CID does not match", 1241 }, 1242 serverStatus: http.StatusConflict, 1243 wantErr: true, 1244 }, 1245 { 1246 name: "server error", 1247 collection: "social.coves.comment", 1248 rkey: "test", 1249 record: map[string]any{"$type": "social.coves.comment"}, 1250 swapRecord: "", 1251 serverResponse: map[string]any{ 1252 "error": "InvalidRequest", 1253 "message": "Invalid record", 1254 }, 1255 serverStatus: http.StatusBadRequest, 1256 wantErr: true, 1257 }, 1258 } 1259 1260 for _, tt := range tests { 1261 t.Run(tt.name, func(t *testing.T) { 1262 // Create mock server 1263 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1264 // Verify method 1265 if r.Method != http.MethodPost { 1266 t.Errorf("expected POST request, got %s", r.Method) 1267 } 1268 1269 // Verify path 1270 expectedPath := "/xrpc/com.atproto.repo.putRecord" 1271 if r.URL.Path != expectedPath { 1272 t.Errorf("path = %q, want %q", r.URL.Path, expectedPath) 1273 } 1274 1275 // Verify request body 1276 var payload map[string]any 1277 if err := json.NewDecoder(r.Body).Decode(&payload); err != nil { 1278 t.Fatalf("failed to decode request body: %v", err) 1279 } 1280 1281 // Check required fields 1282 if payload["collection"] != tt.collection { 1283 t.Errorf("collection = %v, want %v", payload["collection"], tt.collection) 1284 } 1285 if payload["rkey"] != tt.rkey { 1286 t.Errorf("rkey = %v, want %v", payload["rkey"], tt.rkey) 1287 } 1288 1289 // Check swapRecord inclusion 1290 if tt.swapRecord != "" { 1291 if payload["swapRecord"] != tt.swapRecord { 1292 t.Errorf("swapRecord = %v, want %v", payload["swapRecord"], tt.swapRecord) 1293 } 1294 } else { 1295 if _, exists := payload["swapRecord"]; exists { 1296 t.Error("swapRecord should not be included when empty") 1297 } 1298 } 1299 1300 // Send response 1301 w.Header().Set("Content-Type", "application/json") 1302 w.WriteHeader(tt.serverStatus) 1303 json.NewEncoder(w).Encode(tt.serverResponse) 1304 })) 1305 defer server.Close() 1306 1307 // Create client 1308 apiClient := atclient.NewAPIClient(server.URL) 1309 apiClient.Auth = &bearerAuth{token: "test-token"} 1310 1311 c := &client{ 1312 apiClient: apiClient, 1313 did: "did:plc:test", 1314 host: server.URL, 1315 } 1316 1317 // Execute PutRecord 1318 ctx := context.Background() 1319 uri, cid, err := c.PutRecord(ctx, tt.collection, tt.rkey, tt.record, tt.swapRecord) 1320 1321 if tt.wantErr { 1322 if err == nil { 1323 t.Fatal("expected error, got nil") 1324 } 1325 return 1326 } 1327 1328 if err != nil { 1329 t.Fatalf("unexpected error: %v", err) 1330 } 1331 1332 if uri != tt.wantURI { 1333 t.Errorf("uri = %q, want %q", uri, tt.wantURI) 1334 } 1335 1336 if cid != tt.wantCID { 1337 t.Errorf("cid = %q, want %q", cid, tt.wantCID) 1338 } 1339 }) 1340 } 1341} 1342 1343// TestClient_TypedErrors_PutRecord tests that PutRecord returns typed errors. 1344func TestClient_TypedErrors_PutRecord(t *testing.T) { 1345 tests := []struct { 1346 name string 1347 serverStatus int 1348 wantErr error 1349 }{ 1350 { 1351 name: "401 returns ErrUnauthorized", 1352 serverStatus: http.StatusUnauthorized, 1353 wantErr: ErrUnauthorized, 1354 }, 1355 { 1356 name: "403 returns ErrForbidden", 1357 serverStatus: http.StatusForbidden, 1358 wantErr: ErrForbidden, 1359 }, 1360 { 1361 name: "409 returns ErrConflict", 1362 serverStatus: http.StatusConflict, 1363 wantErr: ErrConflict, 1364 }, 1365 { 1366 name: "400 returns ErrBadRequest", 1367 serverStatus: http.StatusBadRequest, 1368 wantErr: ErrBadRequest, 1369 }, 1370 } 1371 1372 for _, tt := range tests { 1373 t.Run(tt.name, func(t *testing.T) { 1374 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1375 w.Header().Set("Content-Type", "application/json") 1376 w.WriteHeader(tt.serverStatus) 1377 json.NewEncoder(w).Encode(map[string]any{ 1378 "error": "TestError", 1379 "message": "Test error message", 1380 }) 1381 })) 1382 defer server.Close() 1383 1384 apiClient := atclient.NewAPIClient(server.URL) 1385 apiClient.Auth = &bearerAuth{token: "test-token"} 1386 1387 c := &client{ 1388 apiClient: apiClient, 1389 did: "did:plc:test", 1390 host: server.URL, 1391 } 1392 1393 ctx := context.Background() 1394 _, _, err := c.PutRecord(ctx, "test.collection", "rkey", map[string]any{}, "") 1395 1396 if err == nil { 1397 t.Fatal("expected error, got nil") 1398 } 1399 1400 if !errors.Is(err, tt.wantErr) { 1401 t.Errorf("expected errors.Is(%v, %v) to be true", err, tt.wantErr) 1402 } 1403 }) 1404 } 1405}