aboutsummaryrefslogtreecommitdiffstats
path: root/internal
diff options
context:
space:
mode:
authorAdam Mathes <adam@adammathes.com>2026-02-14 09:17:56 -0800
committerAdam Mathes <adam@adammathes.com>2026-02-14 09:17:56 -0800
commitcac85dc06b519d9bd6db4d017d501dffbbd8bac4 (patch)
treedc8024e501c0fbda6b9d28622ff2553475044487 /internal
parentca1418fc0135d52a009ab218d6e24187fb355a3c (diff)
downloadneko-cac85dc06b519d9bd6db4d017d501dffbbd8bac4.tar.gz
neko-cac85dc06b519d9bd6db4d017d501dffbbd8bac4.tar.bz2
neko-cac85dc06b519d9bd6db4d017d501dffbbd8bac4.zip
security: mitigate SSRF in image proxy and feed fetcher (fixing NK-0ca7nq)
Diffstat (limited to 'internal')
-rw-r--r--internal/crawler/crawler.go11
-rw-r--r--internal/safehttp/safehttp.go110
-rw-r--r--internal/safehttp/safehttp_test.go53
3 files changed, 166 insertions, 8 deletions
diff --git a/internal/crawler/crawler.go b/internal/crawler/crawler.go
index 10253d8..fce2769 100644
--- a/internal/crawler/crawler.go
+++ b/internal/crawler/crawler.go
@@ -6,6 +6,7 @@ import (
"net/http"
"time"
+ "adammathes.com/neko/internal/safehttp"
"adammathes.com/neko/internal/vlog"
"adammathes.com/neko/models/feed"
"adammathes.com/neko/models/item"
@@ -58,10 +59,7 @@ func GetFeedContent(feedURL string) string {
// n := time.Duration(rand.Int63n(3))
// time.Sleep(n * time.Second)
- c := &http.Client{
- // give up after 5 seconds
- Timeout: 5 * time.Second,
- }
+ c := safehttp.NewSafeClient(5 * time.Second)
request, err := http.NewRequest("GET", feedURL, nil)
if err != nil {
@@ -100,10 +98,7 @@ func GetFeedContent(feedURL string) string {
TODO: sanitize input on crawl
*/
func CrawlFeed(f *feed.Feed, ch chan<- string) {
- c := &http.Client{
- // give up after 5 seconds
- Timeout: 5 * time.Second,
- }
+ c := safehttp.NewSafeClient(5 * time.Second)
fp := gofeed.NewParser()
fp.Client = c
diff --git a/internal/safehttp/safehttp.go b/internal/safehttp/safehttp.go
new file mode 100644
index 0000000..cfc70f1
--- /dev/null
+++ b/internal/safehttp/safehttp.go
@@ -0,0 +1,110 @@
+package safehttp
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "net/http"
+ "time"
+)
+
+var privateIPBlocks []*net.IPNet
+
+func init() {
+ for _, cidr := range []string{
+ "127.0.0.0/8", // IPv4 loopback
+ "10.0.0.0/8", // RFC1918
+ "172.16.0.0/12", // RFC1918
+ "192.168.0.0/16", // RFC1918
+ "169.254.0.0/16", // IPv4 link-local
+ "::1/128", // IPv6 loopback
+ "fe80::/10", // IPv6 link-local
+ "fc00::/7", // IPv6 unique local addr
+ } {
+ _, block, _ := net.ParseCIDR(cidr)
+ privateIPBlocks = append(privateIPBlocks, block)
+ }
+}
+
+func isPrivateIP(ip net.IP) bool {
+ if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
+ return true
+ }
+ for _, block := range privateIPBlocks {
+ if block.Contains(ip) {
+ return true
+ }
+ }
+ return false
+}
+
+func SafeDialer(dialer *net.Dialer) func(context.Context, string, string) (net.Conn, error) {
+ return func(ctx context.Context, network, address string) (net.Conn, error) {
+ host, _, err := net.SplitHostPort(address)
+ if err != nil {
+ host = address
+ }
+
+ if ip := net.ParseIP(host); ip != nil {
+ if isPrivateIP(ip) {
+ return nil, fmt.Errorf("connection to private IP %s is not allowed", ip)
+ }
+ } else {
+ ips, err := net.DefaultResolver.LookupIP(ctx, "ip", host)
+ if err != nil {
+ return nil, err
+ }
+
+ for _, ip := range ips {
+ if isPrivateIP(ip) {
+ return nil, fmt.Errorf("connection to private IP %s is not allowed", ip)
+ }
+ }
+ }
+
+ return dialer.DialContext(ctx, network, address)
+ }
+}
+
+func NewSafeClient(timeout time.Duration) *http.Client {
+ dialer := &net.Dialer{
+ Timeout: 30 * time.Second,
+ KeepAlive: 30 * time.Second,
+ }
+
+ transport := http.DefaultTransport.(*http.Transport).Clone()
+ transport.DialContext = SafeDialer(dialer)
+
+ return &http.Client{
+ Timeout: timeout,
+ Transport: transport,
+ CheckRedirect: func(req *http.Request, via []*http.Request) error {
+ if len(via) >= 10 {
+ return fmt.Errorf("too many redirects")
+ }
+
+ host, _, err := net.SplitHostPort(req.URL.Host)
+ if err != nil {
+ host = req.URL.Host
+ }
+
+ if ip := net.ParseIP(host); ip != nil {
+ if isPrivateIP(ip) {
+ return fmt.Errorf("redirect to private IP %s is not allowed", ip)
+ }
+ } else {
+ ips, err := net.DefaultResolver.LookupIP(req.Context(), "ip", host)
+ if err != nil {
+ return err
+ }
+
+ for _, ip := range ips {
+ if isPrivateIP(ip) {
+ return fmt.Errorf("redirect to private IP %s is not allowed", ip)
+ }
+ }
+ }
+ return nil
+ },
+ }
+}
diff --git a/internal/safehttp/safehttp_test.go b/internal/safehttp/safehttp_test.go
new file mode 100644
index 0000000..b2636da
--- /dev/null
+++ b/internal/safehttp/safehttp_test.go
@@ -0,0 +1,53 @@
+package safehttp
+
+import (
+ "net"
+ "testing"
+ "time"
+)
+
+func TestSafeClient(t *testing.T) {
+ client := NewSafeClient(2 * time.Second)
+
+ // Localhost should fail
+ t.Log("Testing localhost...")
+ _, err := client.Get("http://127.0.0.1:8080")
+ if err == nil {
+ t.Error("Expected error for localhost request, got nil")
+ } else {
+ t.Logf("Got expected error: %v", err)
+ }
+
+ // Private IP should fail
+ t.Log("Testing private IP...")
+ _, err = client.Get("http://10.0.0.1")
+ if err == nil {
+ t.Error("Expected error for private IP request, got nil")
+ } else {
+ t.Logf("Got expected error: %v", err)
+ }
+}
+
+func TestIsPrivateIP(t *testing.T) {
+ tests := []struct {
+ ip string
+ expected bool
+ }{
+ {"127.0.0.1", true},
+ {"10.0.0.1", true},
+ {"172.16.0.1", true},
+ {"192.168.1.1", true},
+ {"169.254.1.1", true},
+ {"8.8.8.8", false},
+ {"1.1.1.1", false},
+ {"::1", true},
+ {"fe80::1", true},
+ {"fc00::1", true},
+ }
+
+ for _, tc := range tests {
+ if res := isPrivateIP(net.ParseIP(tc.ip)); res != tc.expected {
+ t.Errorf("isPrivateIP(%s) = %v, want %v", tc.ip, res, tc.expected)
+ }
+ }
+}