A community based topic aggregation platform built on atproto
1package oauth 2 3import ( 4 "fmt" 5 "net" 6 "net/http" 7 "time" 8) 9 10// ssrfSafeTransport wraps http.Transport to prevent SSRF attacks 11type ssrfSafeTransport struct { 12 base *http.Transport 13 allowPrivate bool // For dev/testing only 14} 15 16// isPrivateIP checks if an IP is in a private/reserved range 17func isPrivateIP(ip net.IP) bool { 18 if ip == nil { 19 return false 20 } 21 22 // Check for loopback 23 if ip.IsLoopback() { 24 return true 25 } 26 27 // Check for link-local 28 if ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { 29 return true 30 } 31 32 // Check for private ranges 33 privateRanges := []string{ 34 "10.0.0.0/8", 35 "172.16.0.0/12", 36 "192.168.0.0/16", 37 "169.254.0.0/16", 38 "::1/128", 39 "fc00::/7", 40 "fe80::/10", 41 } 42 43 for _, cidr := range privateRanges { 44 _, network, err := net.ParseCIDR(cidr) 45 if err == nil && network.Contains(ip) { 46 return true 47 } 48 } 49 50 return false 51} 52 53func (t *ssrfSafeTransport) RoundTrip(req *http.Request) (*http.Response, error) { 54 host := req.URL.Hostname() 55 56 // Resolve hostname to IP 57 ips, err := net.LookupIP(host) 58 if err != nil { 59 return nil, fmt.Errorf("failed to resolve host: %w", err) 60 } 61 62 // Check all resolved IPs 63 if !t.allowPrivate { 64 for _, ip := range ips { 65 if isPrivateIP(ip) { 66 return nil, fmt.Errorf("SSRF blocked: %s resolves to private IP %s", host, ip) 67 } 68 } 69 } 70 71 return t.base.RoundTrip(req) 72} 73 74// NewSSRFSafeHTTPClient creates an HTTP client with SSRF protections 75func NewSSRFSafeHTTPClient(allowPrivate bool) *http.Client { 76 transport := &ssrfSafeTransport{ 77 base: &http.Transport{ 78 DialContext: (&net.Dialer{ 79 Timeout: 10 * time.Second, 80 KeepAlive: 30 * time.Second, 81 }).DialContext, 82 MaxIdleConns: 100, 83 IdleConnTimeout: 90 * time.Second, 84 TLSHandshakeTimeout: 10 * time.Second, 85 }, 86 allowPrivate: allowPrivate, 87 } 88 89 return &http.Client{ 90 Timeout: 15 * time.Second, 91 Transport: transport, 92 CheckRedirect: func(req *http.Request, via []*http.Request) error { 93 if len(via) >= 5 { 94 return fmt.Errorf("too many redirects") 95 } 96 return nil 97 }, 98 } 99}