diff options
| author | Adam Mathes <adam@adammathes.com> | 2026-02-14 09:17:56 -0800 |
|---|---|---|
| committer | Adam Mathes <adam@adammathes.com> | 2026-02-14 09:17:56 -0800 |
| commit | cac85dc06b519d9bd6db4d017d501dffbbd8bac4 (patch) | |
| tree | dc8024e501c0fbda6b9d28622ff2553475044487 /internal | |
| parent | ca1418fc0135d52a009ab218d6e24187fb355a3c (diff) | |
| download | neko-cac85dc06b519d9bd6db4d017d501dffbbd8bac4.tar.gz neko-cac85dc06b519d9bd6db4d017d501dffbbd8bac4.tar.bz2 neko-cac85dc06b519d9bd6db4d017d501dffbbd8bac4.zip | |
security: mitigate SSRF in image proxy and feed fetcher (fixing NK-0ca7nq)
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/crawler/crawler.go | 11 | ||||
| -rw-r--r-- | internal/safehttp/safehttp.go | 110 | ||||
| -rw-r--r-- | internal/safehttp/safehttp_test.go | 53 |
3 files changed, 166 insertions, 8 deletions
diff --git a/internal/crawler/crawler.go b/internal/crawler/crawler.go index 10253d8..fce2769 100644 --- a/internal/crawler/crawler.go +++ b/internal/crawler/crawler.go @@ -6,6 +6,7 @@ import ( "net/http" "time" + "adammathes.com/neko/internal/safehttp" "adammathes.com/neko/internal/vlog" "adammathes.com/neko/models/feed" "adammathes.com/neko/models/item" @@ -58,10 +59,7 @@ func GetFeedContent(feedURL string) string { // n := time.Duration(rand.Int63n(3)) // time.Sleep(n * time.Second) - c := &http.Client{ - // give up after 5 seconds - Timeout: 5 * time.Second, - } + c := safehttp.NewSafeClient(5 * time.Second) request, err := http.NewRequest("GET", feedURL, nil) if err != nil { @@ -100,10 +98,7 @@ func GetFeedContent(feedURL string) string { TODO: sanitize input on crawl */ func CrawlFeed(f *feed.Feed, ch chan<- string) { - c := &http.Client{ - // give up after 5 seconds - Timeout: 5 * time.Second, - } + c := safehttp.NewSafeClient(5 * time.Second) fp := gofeed.NewParser() fp.Client = c diff --git a/internal/safehttp/safehttp.go b/internal/safehttp/safehttp.go new file mode 100644 index 0000000..cfc70f1 --- /dev/null +++ b/internal/safehttp/safehttp.go @@ -0,0 +1,110 @@ +package safehttp + +import ( + "context" + "fmt" + "net" + "net/http" + "time" +) + +var privateIPBlocks []*net.IPNet + +func init() { + for _, cidr := range []string{ + "127.0.0.0/8", // IPv4 loopback + "10.0.0.0/8", // RFC1918 + "172.16.0.0/12", // RFC1918 + "192.168.0.0/16", // RFC1918 + "169.254.0.0/16", // IPv4 link-local + "::1/128", // IPv6 loopback + "fe80::/10", // IPv6 link-local + "fc00::/7", // IPv6 unique local addr + } { + _, block, _ := net.ParseCIDR(cidr) + privateIPBlocks = append(privateIPBlocks, block) + } +} + +func isPrivateIP(ip net.IP) bool { + if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { + return true + } + for _, block := range privateIPBlocks { + if block.Contains(ip) { + return true + } + } + return false +} + +func SafeDialer(dialer *net.Dialer) func(context.Context, string, string) (net.Conn, error) { + return func(ctx context.Context, network, address string) (net.Conn, error) { + host, _, err := net.SplitHostPort(address) + if err != nil { + host = address + } + + if ip := net.ParseIP(host); ip != nil { + if isPrivateIP(ip) { + return nil, fmt.Errorf("connection to private IP %s is not allowed", ip) + } + } else { + ips, err := net.DefaultResolver.LookupIP(ctx, "ip", host) + if err != nil { + return nil, err + } + + for _, ip := range ips { + if isPrivateIP(ip) { + return nil, fmt.Errorf("connection to private IP %s is not allowed", ip) + } + } + } + + return dialer.DialContext(ctx, network, address) + } +} + +func NewSafeClient(timeout time.Duration) *http.Client { + dialer := &net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + } + + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.DialContext = SafeDialer(dialer) + + return &http.Client{ + Timeout: timeout, + Transport: transport, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + if len(via) >= 10 { + return fmt.Errorf("too many redirects") + } + + host, _, err := net.SplitHostPort(req.URL.Host) + if err != nil { + host = req.URL.Host + } + + if ip := net.ParseIP(host); ip != nil { + if isPrivateIP(ip) { + return fmt.Errorf("redirect to private IP %s is not allowed", ip) + } + } else { + ips, err := net.DefaultResolver.LookupIP(req.Context(), "ip", host) + if err != nil { + return err + } + + for _, ip := range ips { + if isPrivateIP(ip) { + return fmt.Errorf("redirect to private IP %s is not allowed", ip) + } + } + } + return nil + }, + } +} diff --git a/internal/safehttp/safehttp_test.go b/internal/safehttp/safehttp_test.go new file mode 100644 index 0000000..b2636da --- /dev/null +++ b/internal/safehttp/safehttp_test.go @@ -0,0 +1,53 @@ +package safehttp + +import ( + "net" + "testing" + "time" +) + +func TestSafeClient(t *testing.T) { + client := NewSafeClient(2 * time.Second) + + // Localhost should fail + t.Log("Testing localhost...") + _, err := client.Get("http://127.0.0.1:8080") + if err == nil { + t.Error("Expected error for localhost request, got nil") + } else { + t.Logf("Got expected error: %v", err) + } + + // Private IP should fail + t.Log("Testing private IP...") + _, err = client.Get("http://10.0.0.1") + if err == nil { + t.Error("Expected error for private IP request, got nil") + } else { + t.Logf("Got expected error: %v", err) + } +} + +func TestIsPrivateIP(t *testing.T) { + tests := []struct { + ip string + expected bool + }{ + {"127.0.0.1", true}, + {"10.0.0.1", true}, + {"172.16.0.1", true}, + {"192.168.1.1", true}, + {"169.254.1.1", true}, + {"8.8.8.8", false}, + {"1.1.1.1", false}, + {"::1", true}, + {"fe80::1", true}, + {"fc00::1", true}, + } + + for _, tc := range tests { + if res := isPrivateIP(net.ParseIP(tc.ip)); res != tc.expected { + t.Errorf("isPrivateIP(%s) = %v, want %v", tc.ip, res, tc.expected) + } + } +} |
