A community based topic aggregation platform built on atproto
1package oauth
2
3import (
4 "net"
5 "net/http"
6 "testing"
7)
8
9func TestIsPrivateIP(t *testing.T) {
10 tests := []struct {
11 name string
12 ip string
13 expected bool
14 }{
15 // Loopback addresses
16 {"IPv4 loopback", "127.0.0.1", true},
17 {"IPv6 loopback", "::1", true},
18
19 // Private IPv4 ranges
20 {"Private 10.x.x.x", "10.0.0.1", true},
21 {"Private 10.x.x.x edge", "10.255.255.255", true},
22 {"Private 172.16.x.x", "172.16.0.1", true},
23 {"Private 172.31.x.x edge", "172.31.255.255", true},
24 {"Private 192.168.x.x", "192.168.1.1", true},
25 {"Private 192.168.x.x edge", "192.168.255.255", true},
26
27 // Link-local addresses
28 {"Link-local IPv4", "169.254.1.1", true},
29 {"Link-local IPv6", "fe80::1", true},
30
31 // IPv6 private ranges
32 {"IPv6 unique local fc00", "fc00::1", true},
33 {"IPv6 unique local fd00", "fd00::1", true},
34
35 // Public addresses
36 {"Public IP 1.1.1.1", "1.1.1.1", false},
37 {"Public IP 8.8.8.8", "8.8.8.8", false},
38 {"Public IP 172.15.0.1", "172.15.0.1", false}, // Just before 172.16/12
39 {"Public IP 172.32.0.1", "172.32.0.1", false}, // Just after 172.31/12
40 {"Public IP 11.0.0.1", "11.0.0.1", false}, // Just after 10/8
41 {"Public IPv6", "2001:4860:4860::8888", false}, // Google DNS
42 }
43
44 for _, tt := range tests {
45 t.Run(tt.name, func(t *testing.T) {
46 ip := net.ParseIP(tt.ip)
47 if ip == nil {
48 t.Fatalf("Failed to parse IP: %s", tt.ip)
49 }
50
51 result := isPrivateIP(ip)
52 if result != tt.expected {
53 t.Errorf("isPrivateIP(%s) = %v, expected %v", tt.ip, result, tt.expected)
54 }
55 })
56 }
57}
58
59func TestIsPrivateIP_NilIP(t *testing.T) {
60 result := isPrivateIP(nil)
61 if result != false {
62 t.Errorf("isPrivateIP(nil) = %v, expected false", result)
63 }
64}
65
66func TestNewSSRFSafeHTTPClient(t *testing.T) {
67 tests := []struct {
68 name string
69 allowPrivate bool
70 }{
71 {"Production client (no private IPs)", false},
72 {"Development client (allow private IPs)", true},
73 }
74
75 for _, tt := range tests {
76 t.Run(tt.name, func(t *testing.T) {
77 client := NewSSRFSafeHTTPClient(tt.allowPrivate)
78
79 if client == nil {
80 t.Fatal("NewSSRFSafeHTTPClient returned nil")
81 }
82
83 if client.Timeout == 0 {
84 t.Error("Expected timeout to be set")
85 }
86
87 if client.Transport == nil {
88 t.Error("Expected transport to be set")
89 }
90
91 transport, ok := client.Transport.(*ssrfSafeTransport)
92 if !ok {
93 t.Error("Expected ssrfSafeTransport")
94 }
95
96 if transport.allowPrivate != tt.allowPrivate {
97 t.Errorf("Expected allowPrivate=%v, got %v", tt.allowPrivate, transport.allowPrivate)
98 }
99 })
100 }
101}
102
103func TestSSRFSafeHTTPClient_RedirectLimit(t *testing.T) {
104 client := NewSSRFSafeHTTPClient(false)
105
106 // Simulate checking redirect limit
107 if client.CheckRedirect == nil {
108 t.Fatal("Expected CheckRedirect to be set")
109 }
110
111 // Test redirect limit (5 redirects)
112 var via []*http.Request
113 for i := 0; i < 5; i++ {
114 req := &http.Request{}
115 via = append(via, req)
116 }
117
118 err := client.CheckRedirect(nil, via)
119 if err == nil {
120 t.Error("Expected error for too many redirects")
121 }
122 if err.Error() != "too many redirects" {
123 t.Errorf("Expected 'too many redirects' error, got: %v", err)
124 }
125
126 // Test within limit (4 redirects)
127 via = via[:4]
128 err = client.CheckRedirect(nil, via)
129 if err != nil {
130 t.Errorf("Expected no error for 4 redirects, got: %v", err)
131 }
132}