A community based topic aggregation platform built on atproto
1package oauth
2
3import (
4 "fmt"
5 "net"
6 "net/http"
7 "time"
8)
9
10// ssrfSafeTransport wraps http.Transport to prevent SSRF attacks
11type ssrfSafeTransport struct {
12 base *http.Transport
13 allowPrivate bool // For dev/testing only
14}
15
16// isPrivateIP checks if an IP is in a private/reserved range
17func isPrivateIP(ip net.IP) bool {
18 if ip == nil {
19 return false
20 }
21
22 // Check for loopback
23 if ip.IsLoopback() {
24 return true
25 }
26
27 // Check for link-local
28 if ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
29 return true
30 }
31
32 // Check for private ranges
33 privateRanges := []string{
34 "10.0.0.0/8",
35 "172.16.0.0/12",
36 "192.168.0.0/16",
37 "169.254.0.0/16",
38 "::1/128",
39 "fc00::/7",
40 "fe80::/10",
41 }
42
43 for _, cidr := range privateRanges {
44 _, network, err := net.ParseCIDR(cidr)
45 if err == nil && network.Contains(ip) {
46 return true
47 }
48 }
49
50 return false
51}
52
53func (t *ssrfSafeTransport) RoundTrip(req *http.Request) (*http.Response, error) {
54 host := req.URL.Hostname()
55
56 // Resolve hostname to IP
57 ips, err := net.LookupIP(host)
58 if err != nil {
59 return nil, fmt.Errorf("failed to resolve host: %w", err)
60 }
61
62 // Check all resolved IPs
63 if !t.allowPrivate {
64 for _, ip := range ips {
65 if isPrivateIP(ip) {
66 return nil, fmt.Errorf("SSRF blocked: %s resolves to private IP %s", host, ip)
67 }
68 }
69 }
70
71 return t.base.RoundTrip(req)
72}
73
74// NewSSRFSafeHTTPClient creates an HTTP client with SSRF protections
75func NewSSRFSafeHTTPClient(allowPrivate bool) *http.Client {
76 transport := &ssrfSafeTransport{
77 base: &http.Transport{
78 DialContext: (&net.Dialer{
79 Timeout: 10 * time.Second,
80 KeepAlive: 30 * time.Second,
81 }).DialContext,
82 MaxIdleConns: 100,
83 IdleConnTimeout: 90 * time.Second,
84 TLSHandshakeTimeout: 10 * time.Second,
85 },
86 allowPrivate: allowPrivate,
87 }
88
89 return &http.Client{
90 Timeout: 15 * time.Second,
91 Transport: transport,
92 CheckRedirect: func(req *http.Request, via []*http.Request) error {
93 if len(via) >= 5 {
94 return fmt.Errorf("too many redirects")
95 }
96 return nil
97 },
98 }
99}