A community based topic aggregation platform built on atproto
at main 3.3 kB view raw
1package oauth 2 3import ( 4 "net" 5 "net/http" 6 "testing" 7) 8 9func TestIsPrivateIP(t *testing.T) { 10 tests := []struct { 11 name string 12 ip string 13 expected bool 14 }{ 15 // Loopback addresses 16 {"IPv4 loopback", "127.0.0.1", true}, 17 {"IPv6 loopback", "::1", true}, 18 19 // Private IPv4 ranges 20 {"Private 10.x.x.x", "10.0.0.1", true}, 21 {"Private 10.x.x.x edge", "10.255.255.255", true}, 22 {"Private 172.16.x.x", "172.16.0.1", true}, 23 {"Private 172.31.x.x edge", "172.31.255.255", true}, 24 {"Private 192.168.x.x", "192.168.1.1", true}, 25 {"Private 192.168.x.x edge", "192.168.255.255", true}, 26 27 // Link-local addresses 28 {"Link-local IPv4", "169.254.1.1", true}, 29 {"Link-local IPv6", "fe80::1", true}, 30 31 // IPv6 private ranges 32 {"IPv6 unique local fc00", "fc00::1", true}, 33 {"IPv6 unique local fd00", "fd00::1", true}, 34 35 // Public addresses 36 {"Public IP 1.1.1.1", "1.1.1.1", false}, 37 {"Public IP 8.8.8.8", "8.8.8.8", false}, 38 {"Public IP 172.15.0.1", "172.15.0.1", false}, // Just before 172.16/12 39 {"Public IP 172.32.0.1", "172.32.0.1", false}, // Just after 172.31/12 40 {"Public IP 11.0.0.1", "11.0.0.1", false}, // Just after 10/8 41 {"Public IPv6", "2001:4860:4860::8888", false}, // Google DNS 42 } 43 44 for _, tt := range tests { 45 t.Run(tt.name, func(t *testing.T) { 46 ip := net.ParseIP(tt.ip) 47 if ip == nil { 48 t.Fatalf("Failed to parse IP: %s", tt.ip) 49 } 50 51 result := isPrivateIP(ip) 52 if result != tt.expected { 53 t.Errorf("isPrivateIP(%s) = %v, expected %v", tt.ip, result, tt.expected) 54 } 55 }) 56 } 57} 58 59func TestIsPrivateIP_NilIP(t *testing.T) { 60 result := isPrivateIP(nil) 61 if result != false { 62 t.Errorf("isPrivateIP(nil) = %v, expected false", result) 63 } 64} 65 66func TestNewSSRFSafeHTTPClient(t *testing.T) { 67 tests := []struct { 68 name string 69 allowPrivate bool 70 }{ 71 {"Production client (no private IPs)", false}, 72 {"Development client (allow private IPs)", true}, 73 } 74 75 for _, tt := range tests { 76 t.Run(tt.name, func(t *testing.T) { 77 client := NewSSRFSafeHTTPClient(tt.allowPrivate) 78 79 if client == nil { 80 t.Fatal("NewSSRFSafeHTTPClient returned nil") 81 } 82 83 if client.Timeout == 0 { 84 t.Error("Expected timeout to be set") 85 } 86 87 if client.Transport == nil { 88 t.Error("Expected transport to be set") 89 } 90 91 transport, ok := client.Transport.(*ssrfSafeTransport) 92 if !ok { 93 t.Error("Expected ssrfSafeTransport") 94 } 95 96 if transport.allowPrivate != tt.allowPrivate { 97 t.Errorf("Expected allowPrivate=%v, got %v", tt.allowPrivate, transport.allowPrivate) 98 } 99 }) 100 } 101} 102 103func TestSSRFSafeHTTPClient_RedirectLimit(t *testing.T) { 104 client := NewSSRFSafeHTTPClient(false) 105 106 // Simulate checking redirect limit 107 if client.CheckRedirect == nil { 108 t.Fatal("Expected CheckRedirect to be set") 109 } 110 111 // Test redirect limit (5 redirects) 112 var via []*http.Request 113 for i := 0; i < 5; i++ { 114 req := &http.Request{} 115 via = append(via, req) 116 } 117 118 err := client.CheckRedirect(nil, via) 119 if err == nil { 120 t.Error("Expected error for too many redirects") 121 } 122 if err.Error() != "too many redirects" { 123 t.Errorf("Expected 'too many redirects' error, got: %v", err) 124 } 125 126 // Test within limit (4 redirects) 127 via = via[:4] 128 err = client.CheckRedirect(nil, via) 129 if err != nil { 130 t.Errorf("Expected no error for 4 redirects, got: %v", err) 131 } 132}