A community based topic aggregation platform built on atproto
1package pds
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "net/http"
8 "net/http/httptest"
9 "strings"
10 "testing"
11
12 "github.com/bluesky-social/indigo/atproto/atclient"
13 "github.com/bluesky-social/indigo/atproto/auth/oauth"
14 "github.com/bluesky-social/indigo/atproto/syntax"
15)
16
17// This test suite provides comprehensive unit tests for the PDS client package.
18//
19// Coverage:
20// - All Client interface methods: 100%
21// - bearerAuth implementation: 100%
22// - Factory function input validation: 100%
23// - NewFromAccessToken: 100%
24//
25// Not covered (requires integration tests with real infrastructure):
26// - NewFromPasswordAuth success path (requires live PDS server)
27// - NewFromOAuthSession success path (requires OAuth infrastructure)
28//
29// The untested code paths involve external dependencies (PDS authentication,
30// OAuth session resumption) which are appropriately tested in E2E/integration tests.
31
32// TestClientImplementsInterface verifies that client implements the Client interface.
33func TestClientImplementsInterface(t *testing.T) {
34 var _ Client = (*client)(nil)
35}
36
37// TestBearerAuth_DoWithAuth verifies that bearerAuth correctly adds Authorization header.
38func TestBearerAuth_DoWithAuth(t *testing.T) {
39 tests := []struct {
40 name string
41 token string
42 }{
43 {
44 name: "standard token",
45 token: "test-access-token-12345",
46 },
47 {
48 name: "token with special characters",
49 token: "token.with.dots_and-dashes",
50 },
51 }
52
53 for _, tt := range tests {
54 t.Run(tt.name, func(t *testing.T) {
55 // Create a test server that captures the Authorization header
56 var capturedHeader string
57 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
58 capturedHeader = r.Header.Get("Authorization")
59 w.WriteHeader(http.StatusOK)
60 }))
61 defer server.Close()
62
63 // Create bearerAuth instance
64 auth := &bearerAuth{token: tt.token}
65
66 // Create request
67 req, err := http.NewRequest(http.MethodGet, server.URL, nil)
68 if err != nil {
69 t.Fatalf("failed to create request: %v", err)
70 }
71
72 // Execute with auth
73 client := &http.Client{}
74 nsid := syntax.NSID("com.atproto.test")
75 _, err = auth.DoWithAuth(client, req, nsid)
76 if err != nil {
77 t.Fatalf("DoWithAuth failed: %v", err)
78 }
79
80 // Verify Authorization header
81 expectedHeader := "Bearer " + tt.token
82 if capturedHeader != expectedHeader {
83 t.Errorf("Authorization header = %q, want %q", capturedHeader, expectedHeader)
84 }
85 })
86 }
87}
88
89// TestBearerAuth_ImplementsAuthMethod verifies bearerAuth implements atclient.AuthMethod.
90func TestBearerAuth_ImplementsAuthMethod(t *testing.T) {
91 var _ atclient.AuthMethod = (*bearerAuth)(nil)
92}
93
94// TestNewFromAccessToken validates factory function input validation.
95func TestNewFromAccessToken(t *testing.T) {
96 tests := []struct {
97 name string
98 host string
99 did string
100 accessToken string
101 wantErr bool
102 errContains string
103 }{
104 {
105 name: "valid inputs",
106 host: "https://pds.example.com",
107 did: "did:plc:12345",
108 accessToken: "test-token",
109 wantErr: false,
110 },
111 {
112 name: "empty host",
113 host: "",
114 did: "did:plc:12345",
115 accessToken: "test-token",
116 wantErr: true,
117 errContains: "host is required",
118 },
119 {
120 name: "empty did",
121 host: "https://pds.example.com",
122 did: "",
123 accessToken: "test-token",
124 wantErr: true,
125 errContains: "did is required",
126 },
127 {
128 name: "empty access token",
129 host: "https://pds.example.com",
130 did: "did:plc:12345",
131 accessToken: "",
132 wantErr: true,
133 errContains: "accessToken is required",
134 },
135 {
136 name: "all empty",
137 host: "",
138 did: "",
139 accessToken: "",
140 wantErr: true,
141 errContains: "host is required",
142 },
143 }
144
145 for _, tt := range tests {
146 t.Run(tt.name, func(t *testing.T) {
147 client, err := NewFromAccessToken(tt.host, tt.did, tt.accessToken)
148
149 if tt.wantErr {
150 if err == nil {
151 t.Fatal("expected error, got nil")
152 }
153 if !strings.Contains(err.Error(), tt.errContains) {
154 t.Errorf("error = %q, want contains %q", err.Error(), tt.errContains)
155 }
156 return
157 }
158
159 if err != nil {
160 t.Fatalf("unexpected error: %v", err)
161 }
162
163 if client == nil {
164 t.Fatal("expected client, got nil")
165 }
166
167 // Verify DID and HostURL methods
168 if client.DID() != tt.did {
169 t.Errorf("DID() = %q, want %q", client.DID(), tt.did)
170 }
171 if client.HostURL() != tt.host {
172 t.Errorf("HostURL() = %q, want %q", client.HostURL(), tt.host)
173 }
174 })
175 }
176}
177
178// TestNewFromPasswordAuth validates factory function input validation.
179func TestNewFromPasswordAuth(t *testing.T) {
180 tests := []struct {
181 name string
182 host string
183 handle string
184 password string
185 wantErr bool
186 errContains string
187 }{
188 {
189 name: "empty host",
190 host: "",
191 handle: "user.bsky.social",
192 password: "password",
193 wantErr: true,
194 errContains: "host is required",
195 },
196 {
197 name: "empty handle",
198 host: "https://pds.example.com",
199 handle: "",
200 password: "password",
201 wantErr: true,
202 errContains: "handle is required",
203 },
204 {
205 name: "empty password",
206 host: "https://pds.example.com",
207 handle: "user.bsky.social",
208 password: "",
209 wantErr: true,
210 errContains: "password is required",
211 },
212 {
213 name: "all empty",
214 host: "",
215 handle: "",
216 password: "",
217 wantErr: true,
218 errContains: "host is required",
219 },
220 }
221
222 for _, tt := range tests {
223 t.Run(tt.name, func(t *testing.T) {
224 ctx := context.Background()
225 _, err := NewFromPasswordAuth(ctx, tt.host, tt.handle, tt.password)
226
227 if tt.wantErr {
228 if err == nil {
229 t.Fatal("expected error, got nil")
230 }
231 if !strings.Contains(err.Error(), tt.errContains) {
232 t.Errorf("error = %q, want contains %q", err.Error(), tt.errContains)
233 }
234 return
235 }
236
237 // Note: We don't test success case here because it requires a real PDS
238 // Those are covered in integration tests
239 })
240 }
241}
242
243// TestNewFromOAuthSession validates factory function input validation.
244func TestNewFromOAuthSession(t *testing.T) {
245 ctx := context.Background()
246
247 tests := []struct {
248 name string
249 oauthClient *oauth.ClientApp
250 sessionData *oauth.ClientSessionData
251 wantErr bool
252 errContains string
253 }{
254 {
255 name: "nil oauth client",
256 oauthClient: nil,
257 sessionData: &oauth.ClientSessionData{},
258 wantErr: true,
259 errContains: "oauthClient is required",
260 },
261 {
262 name: "nil session data",
263 oauthClient: &oauth.ClientApp{},
264 sessionData: nil,
265 wantErr: true,
266 errContains: "sessionData is required",
267 },
268 {
269 name: "both nil",
270 oauthClient: nil,
271 sessionData: nil,
272 wantErr: true,
273 errContains: "oauthClient is required",
274 },
275 }
276
277 for _, tt := range tests {
278 t.Run(tt.name, func(t *testing.T) {
279 _, err := NewFromOAuthSession(ctx, tt.oauthClient, tt.sessionData)
280
281 if tt.wantErr {
282 if err == nil {
283 t.Fatal("expected error, got nil")
284 }
285 if !strings.Contains(err.Error(), tt.errContains) {
286 t.Errorf("error = %q, want contains %q", err.Error(), tt.errContains)
287 }
288 return
289 }
290
291 // Note: Success case requires proper OAuth setup, tested in integration tests
292 })
293 }
294}
295
296// TestClient_DIDAndHostURL verifies DID() and HostURL() return correct values.
297func TestClient_DIDAndHostURL(t *testing.T) {
298 expectedDID := "did:plc:test123"
299 expectedHost := "https://pds.test.com"
300
301 c := &client{
302 did: expectedDID,
303 host: expectedHost,
304 }
305
306 if got := c.DID(); got != expectedDID {
307 t.Errorf("DID() = %q, want %q", got, expectedDID)
308 }
309
310 if got := c.HostURL(); got != expectedHost {
311 t.Errorf("HostURL() = %q, want %q", got, expectedHost)
312 }
313}
314
315// TestClient_CreateRecord tests the CreateRecord method with a mock server.
316func TestClient_CreateRecord(t *testing.T) {
317 tests := []struct {
318 name string
319 collection string
320 rkey string
321 record map[string]any
322 serverResponse map[string]any
323 serverStatus int
324 wantURI string
325 wantCID string
326 wantErr bool
327 }{
328 {
329 name: "successful creation with rkey",
330 collection: "social.coves.vote",
331 rkey: "3kjzl5kcb2s2v",
332 record: map[string]any{
333 "$type": "social.coves.vote",
334 "subject": "at://did:plc:abc123/social.coves.post/3kjzl5kc",
335 "direction": "up",
336 },
337 serverResponse: map[string]any{
338 "uri": "at://did:plc:test/social.coves.vote/3kjzl5kcb2s2v",
339 "cid": "bafyreigbtj4x7ip5legnfznufuopl4sg4knzc2cof6duas4b3q2fy6swua",
340 },
341 serverStatus: http.StatusOK,
342 wantURI: "at://did:plc:test/social.coves.vote/3kjzl5kcb2s2v",
343 wantCID: "bafyreigbtj4x7ip5legnfznufuopl4sg4knzc2cof6duas4b3q2fy6swua",
344 wantErr: false,
345 },
346 {
347 name: "successful creation without rkey (TID generated)",
348 collection: "social.coves.vote",
349 rkey: "",
350 record: map[string]any{
351 "$type": "social.coves.vote",
352 "subject": "at://did:plc:abc123/social.coves.post/3kjzl5kc",
353 "direction": "down",
354 },
355 serverResponse: map[string]any{
356 "uri": "at://did:plc:test/social.coves.vote/3kjzl5kcc2a1b",
357 "cid": "bafyreihd4q3yqcfvnv5zlp6n4fqzh6z4p4m3mwc7vvr6k2j6y6v2a3b4c5",
358 },
359 serverStatus: http.StatusOK,
360 wantURI: "at://did:plc:test/social.coves.vote/3kjzl5kcc2a1b",
361 wantCID: "bafyreihd4q3yqcfvnv5zlp6n4fqzh6z4p4m3mwc7vvr6k2j6y6v2a3b4c5",
362 wantErr: false,
363 },
364 {
365 name: "server error",
366 collection: "social.coves.vote",
367 rkey: "test",
368 record: map[string]any{"$type": "social.coves.vote"},
369 serverResponse: map[string]any{
370 "error": "InvalidRequest",
371 "message": "Invalid record",
372 },
373 serverStatus: http.StatusBadRequest,
374 wantErr: true,
375 },
376 }
377
378 for _, tt := range tests {
379 t.Run(tt.name, func(t *testing.T) {
380 // Create mock server
381 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
382 // Verify method
383 if r.Method != http.MethodPost {
384 t.Errorf("expected POST request, got %s", r.Method)
385 }
386
387 // Verify path
388 expectedPath := "/xrpc/com.atproto.repo.createRecord"
389 if r.URL.Path != expectedPath {
390 t.Errorf("path = %q, want %q", r.URL.Path, expectedPath)
391 }
392
393 // Verify request body
394 var payload map[string]any
395 if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
396 t.Fatalf("failed to decode request body: %v", err)
397 }
398
399 // Check required fields
400 if payload["collection"] != tt.collection {
401 t.Errorf("collection = %v, want %v", payload["collection"], tt.collection)
402 }
403
404 // Check rkey inclusion
405 if tt.rkey != "" {
406 if payload["rkey"] != tt.rkey {
407 t.Errorf("rkey = %v, want %v", payload["rkey"], tt.rkey)
408 }
409 } else {
410 if _, exists := payload["rkey"]; exists {
411 t.Error("rkey should not be included when empty")
412 }
413 }
414
415 // Send response
416 w.Header().Set("Content-Type", "application/json")
417 w.WriteHeader(tt.serverStatus)
418 json.NewEncoder(w).Encode(tt.serverResponse)
419 }))
420 defer server.Close()
421
422 // Create client
423 apiClient := atclient.NewAPIClient(server.URL)
424 apiClient.Auth = &bearerAuth{token: "test-token"}
425
426 c := &client{
427 apiClient: apiClient,
428 did: "did:plc:test",
429 host: server.URL,
430 }
431
432 // Execute CreateRecord
433 ctx := context.Background()
434 uri, cid, err := c.CreateRecord(ctx, tt.collection, tt.rkey, tt.record)
435
436 if tt.wantErr {
437 if err == nil {
438 t.Fatal("expected error, got nil")
439 }
440 return
441 }
442
443 if err != nil {
444 t.Fatalf("unexpected error: %v", err)
445 }
446
447 if uri != tt.wantURI {
448 t.Errorf("uri = %q, want %q", uri, tt.wantURI)
449 }
450
451 if cid != tt.wantCID {
452 t.Errorf("cid = %q, want %q", cid, tt.wantCID)
453 }
454 })
455 }
456}
457
458// TestClient_DeleteRecord tests the DeleteRecord method with a mock server.
459func TestClient_DeleteRecord(t *testing.T) {
460 tests := []struct {
461 name string
462 collection string
463 rkey string
464 serverStatus int
465 wantErr bool
466 }{
467 {
468 name: "successful deletion",
469 collection: "social.coves.vote",
470 rkey: "3kjzl5kcb2s2v",
471 serverStatus: http.StatusOK,
472 wantErr: false,
473 },
474 {
475 name: "not found error",
476 collection: "social.coves.vote",
477 rkey: "nonexistent",
478 serverStatus: http.StatusNotFound,
479 wantErr: true,
480 },
481 {
482 name: "server error",
483 collection: "social.coves.vote",
484 rkey: "test",
485 serverStatus: http.StatusInternalServerError,
486 wantErr: true,
487 },
488 }
489
490 for _, tt := range tests {
491 t.Run(tt.name, func(t *testing.T) {
492 // Create mock server
493 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
494 // Verify method
495 if r.Method != http.MethodPost {
496 t.Errorf("expected POST request, got %s", r.Method)
497 }
498
499 // Verify path
500 expectedPath := "/xrpc/com.atproto.repo.deleteRecord"
501 if r.URL.Path != expectedPath {
502 t.Errorf("path = %q, want %q", r.URL.Path, expectedPath)
503 }
504
505 // Verify request body
506 var payload map[string]any
507 if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
508 t.Fatalf("failed to decode request body: %v", err)
509 }
510
511 if payload["collection"] != tt.collection {
512 t.Errorf("collection = %v, want %v", payload["collection"], tt.collection)
513 }
514 if payload["rkey"] != tt.rkey {
515 t.Errorf("rkey = %v, want %v", payload["rkey"], tt.rkey)
516 }
517
518 // Send response
519 w.WriteHeader(tt.serverStatus)
520 if tt.serverStatus != http.StatusOK {
521 w.Header().Set("Content-Type", "application/json")
522 json.NewEncoder(w).Encode(map[string]any{
523 "error": "Error",
524 "message": "Operation failed",
525 })
526 }
527 }))
528 defer server.Close()
529
530 // Create client
531 apiClient := atclient.NewAPIClient(server.URL)
532 apiClient.Auth = &bearerAuth{token: "test-token"}
533
534 c := &client{
535 apiClient: apiClient,
536 did: "did:plc:test",
537 host: server.URL,
538 }
539
540 // Execute DeleteRecord
541 ctx := context.Background()
542 err := c.DeleteRecord(ctx, tt.collection, tt.rkey)
543
544 if tt.wantErr {
545 if err == nil {
546 t.Fatal("expected error, got nil")
547 }
548 return
549 }
550
551 if err != nil {
552 t.Fatalf("unexpected error: %v", err)
553 }
554 })
555 }
556}
557
558// TestClient_ListRecords tests the ListRecords method with pagination.
559func TestClient_ListRecords(t *testing.T) {
560 tests := []struct {
561 name string
562 collection string
563 limit int
564 cursor string
565 serverResponse map[string]any
566 serverStatus int
567 wantRecords int
568 wantCursor string
569 wantErr bool
570 }{
571 {
572 name: "successful list with records",
573 collection: "social.coves.vote",
574 limit: 10,
575 cursor: "",
576 serverResponse: map[string]any{
577 "cursor": "next-cursor-123",
578 "records": []map[string]any{
579 {
580 "uri": "at://did:plc:test/social.coves.vote/1",
581 "cid": "bafyreiabc123",
582 "value": map[string]any{"$type": "social.coves.vote", "direction": "up"},
583 },
584 {
585 "uri": "at://did:plc:test/social.coves.vote/2",
586 "cid": "bafyreiabc456",
587 "value": map[string]any{"$type": "social.coves.vote", "direction": "down"},
588 },
589 },
590 },
591 serverStatus: http.StatusOK,
592 wantRecords: 2,
593 wantCursor: "next-cursor-123",
594 wantErr: false,
595 },
596 {
597 name: "empty list",
598 collection: "social.coves.vote",
599 limit: 10,
600 cursor: "",
601 serverResponse: map[string]any{
602 "cursor": "",
603 "records": []map[string]any{},
604 },
605 serverStatus: http.StatusOK,
606 wantRecords: 0,
607 wantCursor: "",
608 wantErr: false,
609 },
610 {
611 name: "with cursor pagination",
612 collection: "social.coves.vote",
613 limit: 5,
614 cursor: "existing-cursor",
615 serverResponse: map[string]any{
616 "cursor": "final-cursor",
617 "records": []map[string]any{
618 {
619 "uri": "at://did:plc:test/social.coves.vote/3",
620 "cid": "bafyreiabc789",
621 "value": map[string]any{"$type": "social.coves.vote", "direction": "up"},
622 },
623 },
624 },
625 serverStatus: http.StatusOK,
626 wantRecords: 1,
627 wantCursor: "final-cursor",
628 wantErr: false,
629 },
630 {
631 name: "server error",
632 collection: "social.coves.vote",
633 limit: 10,
634 cursor: "",
635 serverResponse: map[string]any{"error": "Internal error"},
636 serverStatus: http.StatusInternalServerError,
637 wantErr: true,
638 },
639 }
640
641 for _, tt := range tests {
642 t.Run(tt.name, func(t *testing.T) {
643 // Create mock server
644 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
645 // Verify method
646 if r.Method != http.MethodGet {
647 t.Errorf("expected GET request, got %s", r.Method)
648 }
649
650 // Verify path
651 expectedPath := "/xrpc/com.atproto.repo.listRecords"
652 if r.URL.Path != expectedPath {
653 t.Errorf("path = %q, want %q", r.URL.Path, expectedPath)
654 }
655
656 // Verify query parameters
657 query := r.URL.Query()
658 if query.Get("collection") != tt.collection {
659 t.Errorf("collection param = %q, want %q", query.Get("collection"), tt.collection)
660 }
661
662 if tt.cursor != "" {
663 if query.Get("cursor") != tt.cursor {
664 t.Errorf("cursor param = %q, want %q", query.Get("cursor"), tt.cursor)
665 }
666 }
667
668 // Send response
669 w.Header().Set("Content-Type", "application/json")
670 w.WriteHeader(tt.serverStatus)
671 json.NewEncoder(w).Encode(tt.serverResponse)
672 }))
673 defer server.Close()
674
675 // Create client
676 apiClient := atclient.NewAPIClient(server.URL)
677 apiClient.Auth = &bearerAuth{token: "test-token"}
678
679 c := &client{
680 apiClient: apiClient,
681 did: "did:plc:test",
682 host: server.URL,
683 }
684
685 // Execute ListRecords
686 ctx := context.Background()
687 resp, err := c.ListRecords(ctx, tt.collection, tt.limit, tt.cursor)
688
689 if tt.wantErr {
690 if err == nil {
691 t.Fatal("expected error, got nil")
692 }
693 return
694 }
695
696 if err != nil {
697 t.Fatalf("unexpected error: %v", err)
698 }
699
700 if resp == nil {
701 t.Fatal("expected response, got nil")
702 }
703
704 if len(resp.Records) != tt.wantRecords {
705 t.Errorf("records count = %d, want %d", len(resp.Records), tt.wantRecords)
706 }
707
708 if resp.Cursor != tt.wantCursor {
709 t.Errorf("cursor = %q, want %q", resp.Cursor, tt.wantCursor)
710 }
711
712 // Verify record structure if we have records
713 if tt.wantRecords > 0 {
714 for i, record := range resp.Records {
715 if record.URI == "" {
716 t.Errorf("record[%d].URI is empty", i)
717 }
718 if record.CID == "" {
719 t.Errorf("record[%d].CID is empty", i)
720 }
721 if record.Value == nil {
722 t.Errorf("record[%d].Value is nil", i)
723 }
724 }
725 }
726 })
727 }
728}
729
730// TestClient_GetRecord tests the GetRecord method with a mock server.
731func TestClient_GetRecord(t *testing.T) {
732 tests := []struct {
733 name string
734 collection string
735 rkey string
736 serverResponse map[string]any
737 serverStatus int
738 wantURI string
739 wantCID string
740 wantErr bool
741 }{
742 {
743 name: "successful get",
744 collection: "social.coves.vote",
745 rkey: "3kjzl5kcb2s2v",
746 serverResponse: map[string]any{
747 "uri": "at://did:plc:test/social.coves.vote/3kjzl5kcb2s2v",
748 "cid": "bafyreigbtj4x7ip5legnfznufuopl4sg4knzc2cof6duas4b3q2fy6swua",
749 "value": map[string]any{
750 "$type": "social.coves.vote",
751 "subject": "at://did:plc:abc/social.coves.post/123",
752 "direction": "up",
753 },
754 },
755 serverStatus: http.StatusOK,
756 wantURI: "at://did:plc:test/social.coves.vote/3kjzl5kcb2s2v",
757 wantCID: "bafyreigbtj4x7ip5legnfznufuopl4sg4knzc2cof6duas4b3q2fy6swua",
758 wantErr: false,
759 },
760 {
761 name: "record not found",
762 collection: "social.coves.vote",
763 rkey: "nonexistent",
764 serverResponse: map[string]any{
765 "error": "RecordNotFound",
766 "message": "Record not found",
767 },
768 serverStatus: http.StatusNotFound,
769 wantErr: true,
770 },
771 {
772 name: "server error",
773 collection: "social.coves.vote",
774 rkey: "test",
775 serverResponse: map[string]any{
776 "error": "Internal error",
777 },
778 serverStatus: http.StatusInternalServerError,
779 wantErr: true,
780 },
781 }
782
783 for _, tt := range tests {
784 t.Run(tt.name, func(t *testing.T) {
785 // Create mock server
786 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
787 // Verify method
788 if r.Method != http.MethodGet {
789 t.Errorf("expected GET request, got %s", r.Method)
790 }
791
792 // Verify path
793 expectedPath := "/xrpc/com.atproto.repo.getRecord"
794 if r.URL.Path != expectedPath {
795 t.Errorf("path = %q, want %q", r.URL.Path, expectedPath)
796 }
797
798 // Verify query parameters
799 query := r.URL.Query()
800 if query.Get("collection") != tt.collection {
801 t.Errorf("collection param = %q, want %q", query.Get("collection"), tt.collection)
802 }
803 if query.Get("rkey") != tt.rkey {
804 t.Errorf("rkey param = %q, want %q", query.Get("rkey"), tt.rkey)
805 }
806
807 // Send response
808 w.Header().Set("Content-Type", "application/json")
809 w.WriteHeader(tt.serverStatus)
810 json.NewEncoder(w).Encode(tt.serverResponse)
811 }))
812 defer server.Close()
813
814 // Create client
815 apiClient := atclient.NewAPIClient(server.URL)
816 apiClient.Auth = &bearerAuth{token: "test-token"}
817
818 c := &client{
819 apiClient: apiClient,
820 did: "did:plc:test",
821 host: server.URL,
822 }
823
824 // Execute GetRecord
825 ctx := context.Background()
826 resp, err := c.GetRecord(ctx, tt.collection, tt.rkey)
827
828 if tt.wantErr {
829 if err == nil {
830 t.Fatal("expected error, got nil")
831 }
832 return
833 }
834
835 if err != nil {
836 t.Fatalf("unexpected error: %v", err)
837 }
838
839 if resp == nil {
840 t.Fatal("expected response, got nil")
841 }
842
843 if resp.URI != tt.wantURI {
844 t.Errorf("URI = %q, want %q", resp.URI, tt.wantURI)
845 }
846
847 if resp.CID != tt.wantCID {
848 t.Errorf("CID = %q, want %q", resp.CID, tt.wantCID)
849 }
850
851 if resp.Value == nil {
852 t.Error("Value is nil")
853 }
854 })
855 }
856}
857
858// TestTypedErrors_IsAuthError tests the IsAuthError helper function.
859func TestTypedErrors_IsAuthError(t *testing.T) {
860 tests := []struct {
861 name string
862 err error
863 wantAuth bool
864 }{
865 {
866 name: "ErrUnauthorized is auth error",
867 err: ErrUnauthorized,
868 wantAuth: true,
869 },
870 {
871 name: "ErrForbidden is auth error",
872 err: ErrForbidden,
873 wantAuth: true,
874 },
875 {
876 name: "ErrNotFound is not auth error",
877 err: ErrNotFound,
878 wantAuth: false,
879 },
880 {
881 name: "ErrBadRequest is not auth error",
882 err: ErrBadRequest,
883 wantAuth: false,
884 },
885 {
886 name: "wrapped ErrUnauthorized is auth error",
887 err: errors.New("outer: " + ErrUnauthorized.Error()),
888 wantAuth: false, // Plain string wrap doesn't work
889 },
890 {
891 name: "fmt.Errorf wrapped ErrUnauthorized is auth error",
892 err: wrapAPIError(&atclient.APIError{StatusCode: 401, Message: "test"}, "op"),
893 wantAuth: true,
894 },
895 {
896 name: "fmt.Errorf wrapped ErrForbidden is auth error",
897 err: wrapAPIError(&atclient.APIError{StatusCode: 403, Message: "test"}, "op"),
898 wantAuth: true,
899 },
900 {
901 name: "nil error",
902 err: nil,
903 wantAuth: false,
904 },
905 {
906 name: "generic error",
907 err: errors.New("something else"),
908 wantAuth: false,
909 },
910 }
911
912 for _, tt := range tests {
913 t.Run(tt.name, func(t *testing.T) {
914 got := IsAuthError(tt.err)
915 if got != tt.wantAuth {
916 t.Errorf("IsAuthError() = %v, want %v", got, tt.wantAuth)
917 }
918 })
919 }
920}
921
922// TestWrapAPIError tests error wrapping for HTTP status codes.
923func TestWrapAPIError(t *testing.T) {
924 tests := []struct {
925 name string
926 err error
927 operation string
928 wantTyped error
929 wantNil bool
930 }{
931 {
932 name: "nil error returns nil",
933 err: nil,
934 operation: "test",
935 wantNil: true,
936 },
937 {
938 name: "401 maps to ErrUnauthorized",
939 err: &atclient.APIError{StatusCode: 401, Name: "AuthRequired", Message: "Not logged in"},
940 operation: "createRecord",
941 wantTyped: ErrUnauthorized,
942 },
943 {
944 name: "403 maps to ErrForbidden",
945 err: &atclient.APIError{StatusCode: 403, Name: "Forbidden", Message: "Access denied"},
946 operation: "deleteRecord",
947 wantTyped: ErrForbidden,
948 },
949 {
950 name: "404 maps to ErrNotFound",
951 err: &atclient.APIError{StatusCode: 404, Name: "NotFound", Message: "Record not found"},
952 operation: "getRecord",
953 wantTyped: ErrNotFound,
954 },
955 {
956 name: "400 maps to ErrBadRequest",
957 err: &atclient.APIError{StatusCode: 400, Name: "InvalidRequest", Message: "Bad input"},
958 operation: "createRecord",
959 wantTyped: ErrBadRequest,
960 },
961 {
962 name: "409 maps to ErrConflict",
963 err: &atclient.APIError{StatusCode: 409, Name: "InvalidSwap", Message: "Record CID mismatch"},
964 operation: "putRecord",
965 wantTyped: ErrConflict,
966 },
967 {
968 name: "500 wraps without typed error",
969 err: &atclient.APIError{StatusCode: 500, Name: "InternalError", Message: "Server error"},
970 operation: "listRecords",
971 wantTyped: nil, // No typed error for 500
972 },
973 {
974 name: "non-APIError wraps normally",
975 err: errors.New("network timeout"),
976 operation: "createRecord",
977 wantTyped: nil,
978 },
979 }
980
981 for _, tt := range tests {
982 t.Run(tt.name, func(t *testing.T) {
983 result := wrapAPIError(tt.err, tt.operation)
984
985 if tt.wantNil {
986 if result != nil {
987 t.Errorf("expected nil, got %v", result)
988 }
989 return
990 }
991
992 if result == nil {
993 t.Fatal("expected error, got nil")
994 }
995
996 if tt.wantTyped != nil {
997 if !errors.Is(result, tt.wantTyped) {
998 t.Errorf("expected errors.Is(%v, %v) to be true", result, tt.wantTyped)
999 }
1000 }
1001
1002 // Verify operation is included in error message
1003 if !strings.Contains(result.Error(), tt.operation) {
1004 t.Errorf("error message %q should contain operation %q", result.Error(), tt.operation)
1005 }
1006 })
1007 }
1008}
1009
1010// TestClient_TypedErrors_CreateRecord tests that CreateRecord returns typed errors.
1011func TestClient_TypedErrors_CreateRecord(t *testing.T) {
1012 tests := []struct {
1013 name string
1014 serverStatus int
1015 wantErr error
1016 }{
1017 {
1018 name: "401 returns ErrUnauthorized",
1019 serverStatus: http.StatusUnauthorized,
1020 wantErr: ErrUnauthorized,
1021 },
1022 {
1023 name: "403 returns ErrForbidden",
1024 serverStatus: http.StatusForbidden,
1025 wantErr: ErrForbidden,
1026 },
1027 {
1028 name: "400 returns ErrBadRequest",
1029 serverStatus: http.StatusBadRequest,
1030 wantErr: ErrBadRequest,
1031 },
1032 }
1033
1034 for _, tt := range tests {
1035 t.Run(tt.name, func(t *testing.T) {
1036 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1037 w.Header().Set("Content-Type", "application/json")
1038 w.WriteHeader(tt.serverStatus)
1039 json.NewEncoder(w).Encode(map[string]any{
1040 "error": "TestError",
1041 "message": "Test error message",
1042 })
1043 }))
1044 defer server.Close()
1045
1046 apiClient := atclient.NewAPIClient(server.URL)
1047 apiClient.Auth = &bearerAuth{token: "test-token"}
1048
1049 c := &client{
1050 apiClient: apiClient,
1051 did: "did:plc:test",
1052 host: server.URL,
1053 }
1054
1055 ctx := context.Background()
1056 _, _, err := c.CreateRecord(ctx, "test.collection", "rkey", map[string]any{})
1057
1058 if err == nil {
1059 t.Fatal("expected error, got nil")
1060 }
1061
1062 if !errors.Is(err, tt.wantErr) {
1063 t.Errorf("expected errors.Is(%v, %v) to be true", err, tt.wantErr)
1064 }
1065 })
1066 }
1067}
1068
1069// TestClient_TypedErrors_DeleteRecord tests that DeleteRecord returns typed errors.
1070func TestClient_TypedErrors_DeleteRecord(t *testing.T) {
1071 tests := []struct {
1072 name string
1073 serverStatus int
1074 wantErr error
1075 }{
1076 {
1077 name: "401 returns ErrUnauthorized",
1078 serverStatus: http.StatusUnauthorized,
1079 wantErr: ErrUnauthorized,
1080 },
1081 {
1082 name: "403 returns ErrForbidden",
1083 serverStatus: http.StatusForbidden,
1084 wantErr: ErrForbidden,
1085 },
1086 {
1087 name: "404 returns ErrNotFound",
1088 serverStatus: http.StatusNotFound,
1089 wantErr: ErrNotFound,
1090 },
1091 }
1092
1093 for _, tt := range tests {
1094 t.Run(tt.name, func(t *testing.T) {
1095 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1096 w.Header().Set("Content-Type", "application/json")
1097 w.WriteHeader(tt.serverStatus)
1098 json.NewEncoder(w).Encode(map[string]any{
1099 "error": "TestError",
1100 "message": "Test error message",
1101 })
1102 }))
1103 defer server.Close()
1104
1105 apiClient := atclient.NewAPIClient(server.URL)
1106 apiClient.Auth = &bearerAuth{token: "test-token"}
1107
1108 c := &client{
1109 apiClient: apiClient,
1110 did: "did:plc:test",
1111 host: server.URL,
1112 }
1113
1114 ctx := context.Background()
1115 err := c.DeleteRecord(ctx, "test.collection", "rkey")
1116
1117 if err == nil {
1118 t.Fatal("expected error, got nil")
1119 }
1120
1121 if !errors.Is(err, tt.wantErr) {
1122 t.Errorf("expected errors.Is(%v, %v) to be true", err, tt.wantErr)
1123 }
1124 })
1125 }
1126}
1127
1128// TestClient_TypedErrors_ListRecords tests that ListRecords returns typed errors.
1129func TestClient_TypedErrors_ListRecords(t *testing.T) {
1130 tests := []struct {
1131 name string
1132 serverStatus int
1133 wantErr error
1134 }{
1135 {
1136 name: "401 returns ErrUnauthorized",
1137 serverStatus: http.StatusUnauthorized,
1138 wantErr: ErrUnauthorized,
1139 },
1140 {
1141 name: "403 returns ErrForbidden",
1142 serverStatus: http.StatusForbidden,
1143 wantErr: ErrForbidden,
1144 },
1145 }
1146
1147 for _, tt := range tests {
1148 t.Run(tt.name, func(t *testing.T) {
1149 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1150 w.Header().Set("Content-Type", "application/json")
1151 w.WriteHeader(tt.serverStatus)
1152 json.NewEncoder(w).Encode(map[string]any{
1153 "error": "TestError",
1154 "message": "Test error message",
1155 })
1156 }))
1157 defer server.Close()
1158
1159 apiClient := atclient.NewAPIClient(server.URL)
1160 apiClient.Auth = &bearerAuth{token: "test-token"}
1161
1162 c := &client{
1163 apiClient: apiClient,
1164 did: "did:plc:test",
1165 host: server.URL,
1166 }
1167
1168 ctx := context.Background()
1169 _, err := c.ListRecords(ctx, "test.collection", 10, "")
1170
1171 if err == nil {
1172 t.Fatal("expected error, got nil")
1173 }
1174
1175 if !errors.Is(err, tt.wantErr) {
1176 t.Errorf("expected errors.Is(%v, %v) to be true", err, tt.wantErr)
1177 }
1178 })
1179 }
1180}
1181
1182// TestClient_PutRecord tests the PutRecord method with a mock server.
1183func TestClient_PutRecord(t *testing.T) {
1184 tests := []struct {
1185 name string
1186 collection string
1187 rkey string
1188 record map[string]any
1189 swapRecord string
1190 serverResponse map[string]any
1191 serverStatus int
1192 wantURI string
1193 wantCID string
1194 wantErr bool
1195 }{
1196 {
1197 name: "successful put with swapRecord",
1198 collection: "social.coves.comment",
1199 rkey: "3kjzl5kcb2s2v",
1200 record: map[string]any{
1201 "$type": "social.coves.comment",
1202 "content": "Updated comment content",
1203 },
1204 swapRecord: "bafyreigbtj4x7ip5legnfznufuopl4sg4knzc2cof6duas4b3q2fy6swua",
1205 serverResponse: map[string]any{
1206 "uri": "at://did:plc:test/social.coves.comment/3kjzl5kcb2s2v",
1207 "cid": "bafyreihd4q3yqcfvnv5zlp6n4fqzh6z4p4m3mwc7vvr6k2j6y6v2a3b4c5",
1208 },
1209 serverStatus: http.StatusOK,
1210 wantURI: "at://did:plc:test/social.coves.comment/3kjzl5kcb2s2v",
1211 wantCID: "bafyreihd4q3yqcfvnv5zlp6n4fqzh6z4p4m3mwc7vvr6k2j6y6v2a3b4c5",
1212 wantErr: false,
1213 },
1214 {
1215 name: "successful put without swapRecord",
1216 collection: "social.coves.comment",
1217 rkey: "3kjzl5kcb2s2v",
1218 record: map[string]any{
1219 "$type": "social.coves.comment",
1220 "content": "Updated comment",
1221 },
1222 swapRecord: "",
1223 serverResponse: map[string]any{
1224 "uri": "at://did:plc:test/social.coves.comment/3kjzl5kcb2s2v",
1225 "cid": "bafyreihd4q3yqcfvnv5zlp6n4fqzh6z4p4m3mwc7vvr6k2j6y6v2a3b4c5",
1226 },
1227 serverStatus: http.StatusOK,
1228 wantURI: "at://did:plc:test/social.coves.comment/3kjzl5kcb2s2v",
1229 wantCID: "bafyreihd4q3yqcfvnv5zlp6n4fqzh6z4p4m3mwc7vvr6k2j6y6v2a3b4c5",
1230 wantErr: false,
1231 },
1232 {
1233 name: "conflict error (409)",
1234 collection: "social.coves.comment",
1235 rkey: "test",
1236 record: map[string]any{"$type": "social.coves.comment"},
1237 swapRecord: "bafyreigbtj4x7ip5legnfznufuopl4sg4knzc2cof6duas4b3q2fy6swua",
1238 serverResponse: map[string]any{
1239 "error": "InvalidSwap",
1240 "message": "Record CID does not match",
1241 },
1242 serverStatus: http.StatusConflict,
1243 wantErr: true,
1244 },
1245 {
1246 name: "server error",
1247 collection: "social.coves.comment",
1248 rkey: "test",
1249 record: map[string]any{"$type": "social.coves.comment"},
1250 swapRecord: "",
1251 serverResponse: map[string]any{
1252 "error": "InvalidRequest",
1253 "message": "Invalid record",
1254 },
1255 serverStatus: http.StatusBadRequest,
1256 wantErr: true,
1257 },
1258 }
1259
1260 for _, tt := range tests {
1261 t.Run(tt.name, func(t *testing.T) {
1262 // Create mock server
1263 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1264 // Verify method
1265 if r.Method != http.MethodPost {
1266 t.Errorf("expected POST request, got %s", r.Method)
1267 }
1268
1269 // Verify path
1270 expectedPath := "/xrpc/com.atproto.repo.putRecord"
1271 if r.URL.Path != expectedPath {
1272 t.Errorf("path = %q, want %q", r.URL.Path, expectedPath)
1273 }
1274
1275 // Verify request body
1276 var payload map[string]any
1277 if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
1278 t.Fatalf("failed to decode request body: %v", err)
1279 }
1280
1281 // Check required fields
1282 if payload["collection"] != tt.collection {
1283 t.Errorf("collection = %v, want %v", payload["collection"], tt.collection)
1284 }
1285 if payload["rkey"] != tt.rkey {
1286 t.Errorf("rkey = %v, want %v", payload["rkey"], tt.rkey)
1287 }
1288
1289 // Check swapRecord inclusion
1290 if tt.swapRecord != "" {
1291 if payload["swapRecord"] != tt.swapRecord {
1292 t.Errorf("swapRecord = %v, want %v", payload["swapRecord"], tt.swapRecord)
1293 }
1294 } else {
1295 if _, exists := payload["swapRecord"]; exists {
1296 t.Error("swapRecord should not be included when empty")
1297 }
1298 }
1299
1300 // Send response
1301 w.Header().Set("Content-Type", "application/json")
1302 w.WriteHeader(tt.serverStatus)
1303 json.NewEncoder(w).Encode(tt.serverResponse)
1304 }))
1305 defer server.Close()
1306
1307 // Create client
1308 apiClient := atclient.NewAPIClient(server.URL)
1309 apiClient.Auth = &bearerAuth{token: "test-token"}
1310
1311 c := &client{
1312 apiClient: apiClient,
1313 did: "did:plc:test",
1314 host: server.URL,
1315 }
1316
1317 // Execute PutRecord
1318 ctx := context.Background()
1319 uri, cid, err := c.PutRecord(ctx, tt.collection, tt.rkey, tt.record, tt.swapRecord)
1320
1321 if tt.wantErr {
1322 if err == nil {
1323 t.Fatal("expected error, got nil")
1324 }
1325 return
1326 }
1327
1328 if err != nil {
1329 t.Fatalf("unexpected error: %v", err)
1330 }
1331
1332 if uri != tt.wantURI {
1333 t.Errorf("uri = %q, want %q", uri, tt.wantURI)
1334 }
1335
1336 if cid != tt.wantCID {
1337 t.Errorf("cid = %q, want %q", cid, tt.wantCID)
1338 }
1339 })
1340 }
1341}
1342
1343// TestClient_TypedErrors_PutRecord tests that PutRecord returns typed errors.
1344func TestClient_TypedErrors_PutRecord(t *testing.T) {
1345 tests := []struct {
1346 name string
1347 serverStatus int
1348 wantErr error
1349 }{
1350 {
1351 name: "401 returns ErrUnauthorized",
1352 serverStatus: http.StatusUnauthorized,
1353 wantErr: ErrUnauthorized,
1354 },
1355 {
1356 name: "403 returns ErrForbidden",
1357 serverStatus: http.StatusForbidden,
1358 wantErr: ErrForbidden,
1359 },
1360 {
1361 name: "409 returns ErrConflict",
1362 serverStatus: http.StatusConflict,
1363 wantErr: ErrConflict,
1364 },
1365 {
1366 name: "400 returns ErrBadRequest",
1367 serverStatus: http.StatusBadRequest,
1368 wantErr: ErrBadRequest,
1369 },
1370 }
1371
1372 for _, tt := range tests {
1373 t.Run(tt.name, func(t *testing.T) {
1374 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1375 w.Header().Set("Content-Type", "application/json")
1376 w.WriteHeader(tt.serverStatus)
1377 json.NewEncoder(w).Encode(map[string]any{
1378 "error": "TestError",
1379 "message": "Test error message",
1380 })
1381 }))
1382 defer server.Close()
1383
1384 apiClient := atclient.NewAPIClient(server.URL)
1385 apiClient.Auth = &bearerAuth{token: "test-token"}
1386
1387 c := &client{
1388 apiClient: apiClient,
1389 did: "did:plc:test",
1390 host: server.URL,
1391 }
1392
1393 ctx := context.Background()
1394 _, _, err := c.PutRecord(ctx, "test.collection", "rkey", map[string]any{}, "")
1395
1396 if err == nil {
1397 t.Fatal("expected error, got nil")
1398 }
1399
1400 if !errors.Is(err, tt.wantErr) {
1401 t.Errorf("expected errors.Is(%v, %v) to be true", err, tt.wantErr)
1402 }
1403 })
1404 }
1405}