aboutsummaryrefslogtreecommitdiffstats
path: root/internal/safehttp/safehttp.go
diff options
context:
space:
mode:
authorAdam Mathes <adam@adammathes.com>2026-02-14 09:17:56 -0800
committerAdam Mathes <adam@adammathes.com>2026-02-14 09:17:56 -0800
commitcac85dc06b519d9bd6db4d017d501dffbbd8bac4 (patch)
treedc8024e501c0fbda6b9d28622ff2553475044487 /internal/safehttp/safehttp.go
parentca1418fc0135d52a009ab218d6e24187fb355a3c (diff)
downloadneko-cac85dc06b519d9bd6db4d017d501dffbbd8bac4.tar.gz
neko-cac85dc06b519d9bd6db4d017d501dffbbd8bac4.tar.bz2
neko-cac85dc06b519d9bd6db4d017d501dffbbd8bac4.zip
security: mitigate SSRF in image proxy and feed fetcher (fixing NK-0ca7nq)
Diffstat (limited to 'internal/safehttp/safehttp.go')
-rw-r--r--internal/safehttp/safehttp.go110
1 files changed, 110 insertions, 0 deletions
diff --git a/internal/safehttp/safehttp.go b/internal/safehttp/safehttp.go
new file mode 100644
index 0000000..cfc70f1
--- /dev/null
+++ b/internal/safehttp/safehttp.go
@@ -0,0 +1,110 @@
+package safehttp
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "net/http"
+ "time"
+)
+
+var privateIPBlocks []*net.IPNet
+
+func init() {
+ for _, cidr := range []string{
+ "127.0.0.0/8", // IPv4 loopback
+ "10.0.0.0/8", // RFC1918
+ "172.16.0.0/12", // RFC1918
+ "192.168.0.0/16", // RFC1918
+ "169.254.0.0/16", // IPv4 link-local
+ "::1/128", // IPv6 loopback
+ "fe80::/10", // IPv6 link-local
+ "fc00::/7", // IPv6 unique local addr
+ } {
+ _, block, _ := net.ParseCIDR(cidr)
+ privateIPBlocks = append(privateIPBlocks, block)
+ }
+}
+
+func isPrivateIP(ip net.IP) bool {
+ if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
+ return true
+ }
+ for _, block := range privateIPBlocks {
+ if block.Contains(ip) {
+ return true
+ }
+ }
+ return false
+}
+
+func SafeDialer(dialer *net.Dialer) func(context.Context, string, string) (net.Conn, error) {
+ return func(ctx context.Context, network, address string) (net.Conn, error) {
+ host, _, err := net.SplitHostPort(address)
+ if err != nil {
+ host = address
+ }
+
+ if ip := net.ParseIP(host); ip != nil {
+ if isPrivateIP(ip) {
+ return nil, fmt.Errorf("connection to private IP %s is not allowed", ip)
+ }
+ } else {
+ ips, err := net.DefaultResolver.LookupIP(ctx, "ip", host)
+ if err != nil {
+ return nil, err
+ }
+
+ for _, ip := range ips {
+ if isPrivateIP(ip) {
+ return nil, fmt.Errorf("connection to private IP %s is not allowed", ip)
+ }
+ }
+ }
+
+ return dialer.DialContext(ctx, network, address)
+ }
+}
+
+func NewSafeClient(timeout time.Duration) *http.Client {
+ dialer := &net.Dialer{
+ Timeout: 30 * time.Second,
+ KeepAlive: 30 * time.Second,
+ }
+
+ transport := http.DefaultTransport.(*http.Transport).Clone()
+ transport.DialContext = SafeDialer(dialer)
+
+ return &http.Client{
+ Timeout: timeout,
+ Transport: transport,
+ CheckRedirect: func(req *http.Request, via []*http.Request) error {
+ if len(via) >= 10 {
+ return fmt.Errorf("too many redirects")
+ }
+
+ host, _, err := net.SplitHostPort(req.URL.Host)
+ if err != nil {
+ host = req.URL.Host
+ }
+
+ if ip := net.ParseIP(host); ip != nil {
+ if isPrivateIP(ip) {
+ return fmt.Errorf("redirect to private IP %s is not allowed", ip)
+ }
+ } else {
+ ips, err := net.DefaultResolver.LookupIP(req.Context(), "ip", host)
+ if err != nil {
+ return err
+ }
+
+ for _, ip := range ips {
+ if isPrivateIP(ip) {
+ return fmt.Errorf("redirect to private IP %s is not allowed", ip)
+ }
+ }
+ }
+ return nil
+ },
+ }
+}