From c585d7873e9b4bfd9f6efd30f9ce08aed8a0d92b Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 18 Feb 2026 21:33:02 +0000 Subject: Improve image proxy: streaming, size limits, Content-Type validation Rewrites the image proxy handler to address several issues: - Stream responses with io.Copy instead of buffering entire image in memory - Add 25MB size limit via io.LimitReader to prevent memory exhaustion - Close resp.Body (was previously leaked on every request) - Validate Content-Type is an image, rejecting HTML/JS/etc - Forward Content-Type and Content-Length from upstream - Use http.NewRequestWithContext to propagate client cancellation - Check upstream status codes, returning 502 for non-2xx - Fix ETag: use proper quoted format, remove bogus Etag request header check - Increase timeout from 5s to 30s for slow image servers - Use proper HTTP status codes (400 for bad input, 502 for upstream errors) - Add Cache-Control max-age directive alongside Expires header Tests: comprehensive coverage for Content-Type filtering, upstream errors, streaming, ETag validation, User-Agent forwarding, and Content-Length. Benchmarks: cache hit path and streaming at 1KB/64KB/1MB/5MB sizes. https://claude.ai/code/session_01CZcDDVmF6wNs2YjdhvCppy --- web/web.go | 83 +++++++++---- web/web_bench_test.go | 48 ++++++++ web/web_test.go | 331 +++++++++++++++++++++++++++++++++++++++++++------- 3 files changed, 399 insertions(+), 63 deletions(-) diff --git a/web/web.go b/web/web.go index 245f844..0c6b96d 100644 --- a/web/web.go +++ b/web/web.go @@ -47,6 +47,27 @@ func indexHandler(w http.ResponseWriter, r *http.Request) { serveBoxedFile(w, r, "ui.html") } +// maxImageSize is the maximum response body size we'll proxy (25 MB). +const maxImageSize = 25 << 20 + +// imageProxyTimeout is the timeout for fetching remote images. +const imageProxyTimeout = 30 * time.Second + +// allowedImageTypes are Content-Type prefixes we allow through the proxy. +var allowedImageTypes = []string{ + "image/", +} + +func isAllowedImageType(contentType string) bool { + ct := strings.ToLower(contentType) + for _, prefix := range allowedImageTypes { + if strings.HasPrefix(ct, prefix) { + return true + } + } + return false +} + func imageProxyHandler(w http.ResponseWriter, r *http.Request) { imgURL := strings.TrimPrefix(r.URL.Path, "/") if imgURL == "" { @@ -56,49 +77,67 @@ func imageProxyHandler(w http.ResponseWriter, r *http.Request) { decodedURL, err := base64.URLEncoding.DecodeString(imgURL) if err != nil { - http.Error(w, "invalid image url", http.StatusNotFound) + http.Error(w, "invalid image url", http.StatusBadRequest) return } - // pseudo-caching - if r.Header.Get("If-None-Match") == string(decodedURL) { + // ETag-based cache validation. We use the base64-encoded URL as + // a stable ETag so browsers can cache and revalidate. + etag := `"` + imgURL + `"` + if match := r.Header.Get("If-None-Match"); match == etag { w.WriteHeader(http.StatusNotModified) return } - if r.Header.Get("Etag") == string(decodedURL) { - w.WriteHeader(http.StatusNotModified) + // Use the request context so client disconnection cancels the fetch. + c := safehttp.NewSafeClient(imageProxyTimeout) + request, err := http.NewRequestWithContext(r.Context(), "GET", string(decodedURL), nil) + if err != nil { + http.Error(w, "invalid image url", http.StatusBadRequest) return } - // grab the img - c := safehttp.NewSafeClient(5 * time.Second) - - request, err := http.NewRequest("GET", string(decodedURL), nil) + request.Header.Set("User-Agent", "neko RSS Reader Image Proxy +https://github.com/adammathes/neko") + resp, err := c.Do(request) if err != nil { - http.Error(w, "failed to proxy image", http.StatusNotFound) + http.Error(w, "failed to fetch image", http.StatusBadGateway) return } + defer resp.Body.Close() - userAgent := "neko RSS Reader Image Proxy +https://github.com/adammathes/neko" - request.Header.Set("User-Agent", userAgent) - resp, err := c.Do(request) - - if err != nil { - http.Error(w, "failed to proxy image", http.StatusNotFound) + // Check upstream status. + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + http.Error(w, "upstream error", http.StatusBadGateway) return } - bts, err := io.ReadAll(resp.Body) - if err != nil { - http.Error(w, "failed to read proxy image", http.StatusNotFound) + // Validate Content-Type is an image. + contentType := resp.Header.Get("Content-Type") + if contentType != "" && !isAllowedImageType(contentType) { + http.Error(w, "not an image", http.StatusForbidden) return } - w.Header().Set("ETag", string(decodedURL)) - w.Header().Set("Cache-Control", "public") + // Set response headers before streaming. + w.Header().Set("ETag", etag) + w.Header().Set("Cache-Control", "public, max-age=172800") w.Header().Set("Expires", time.Now().Add(48*time.Hour).Format(time.RFC1123)) - _, _ = w.Write(bts) + if contentType != "" { + w.Header().Set("Content-Type", contentType) + } + if resp.ContentLength > 0 && resp.ContentLength <= maxImageSize { + w.Header().Set("Content-Length", strconv.FormatInt(resp.ContentLength, 10)) + } + + // Stream with a size limit to prevent memory exhaustion. + limited := io.LimitReader(resp.Body, maxImageSize+1) + n, _ := io.Copy(w, limited) + if n > maxImageSize { + // We already started writing, so we can't change the status code. + // The response will be truncated, which is the correct behavior + // for an oversized image. + return + } } var AuthCookie = "auth" diff --git a/web/web_bench_test.go b/web/web_bench_test.go index 7897fc7..068fd55 100644 --- a/web/web_bench_test.go +++ b/web/web_bench_test.go @@ -1,8 +1,11 @@ package web import ( + "encoding/base64" + "fmt" "net/http" "net/http/httptest" + "net/url" "strings" "testing" @@ -90,3 +93,48 @@ func BenchmarkFullMiddlewareStack(b *testing.B) { handler.ServeHTTP(rr, req) } } + +// --- Image proxy benchmarks --- + +func BenchmarkImageProxyCacheHit(b *testing.B) { + encoded := base64.URLEncoding.EncodeToString([]byte("https://example.com/image.jpg")) + etag := `"` + encoded + `"` + + b.ResetTimer() + for i := 0; i < b.N; i++ { + req := httptest.NewRequest("GET", "/"+encoded, nil) + req.Header.Set("If-None-Match", etag) + rr := httptest.NewRecorder() + imageProxyHandler(rr, req) + } +} + +func benchmarkImageProxySize(b *testing.B, size int) { + data := make([]byte, size) + for i := range data { + data[i] = byte(i % 256) + } + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "image/jpeg") + w.Header().Set("Content-Length", fmt.Sprintf("%d", size)) + w.Write(data) + })) + defer ts.Close() + + encoded := base64.URLEncoding.EncodeToString([]byte(ts.URL + "/img.jpg")) + + b.ResetTimer() + b.SetBytes(int64(size)) + for i := 0; i < b.N; i++ { + req := httptest.NewRequest("GET", "/"+encoded, nil) + req.URL = &url.URL{Path: encoded} + rr := httptest.NewRecorder() + imageProxyHandler(rr, req) + } +} + +func BenchmarkImageProxy_1KB(b *testing.B) { benchmarkImageProxySize(b, 1<<10) } +func BenchmarkImageProxy_64KB(b *testing.B) { benchmarkImageProxySize(b, 64<<10) } +func BenchmarkImageProxy_1MB(b *testing.B) { benchmarkImageProxySize(b, 1<<20) } +func BenchmarkImageProxy_5MB(b *testing.B) { benchmarkImageProxySize(b, 5<<20) } diff --git a/web/web_test.go b/web/web_test.go index f900d07..f0a2eab 100644 --- a/web/web_test.go +++ b/web/web_test.go @@ -184,8 +184,10 @@ func TestLogoutHandler(t *testing.T) { // --- Image proxy handler tests --- func TestImageProxyHandlerIfNoneMatch(t *testing.T) { - req := httptest.NewRequest("GET", "/aHR0cHM6Ly9leGFtcGxlLmNvbS9pbWFnZS5qcGc=", nil) - req.Header.Set("If-None-Match", "https://example.com/image.jpg") + encoded := base64.URLEncoding.EncodeToString([]byte("https://example.com/image.jpg")) + etag := `"` + encoded + `"` + req := httptest.NewRequest("GET", "/"+encoded, nil) + req.Header.Set("If-None-Match", etag) rr := httptest.NewRecorder() imageProxyHandler(rr, req) if rr.Code != http.StatusNotModified { @@ -200,16 +202,6 @@ func TestSecondsInAYear(t *testing.T) { } } -func TestImageProxyHandlerEtag(t *testing.T) { - req := httptest.NewRequest("GET", "/aHR0cHM6Ly9leGFtcGxlLmNvbS9pbWFnZS5qcGc=", nil) - req.Header.Set("Etag", "https://example.com/image.jpg") - rr := httptest.NewRecorder() - imageProxyHandler(rr, req) - if rr.Code != http.StatusNotModified { - t.Errorf("Expected %d, got %d", http.StatusNotModified, rr.Code) - } -} - func TestImageProxyHandlerSuccess(t *testing.T) { imgServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "image/jpeg") @@ -238,13 +230,15 @@ func TestImageProxyHandlerBadRemote(t *testing.T) { req.URL = &url.URL{Path: encodedURL} rr := httptest.NewRecorder() imageProxyHandler(rr, req) - if rr.Code != http.StatusNotFound { - t.Errorf("Expected %d, got %d", http.StatusNotFound, rr.Code) + if rr.Code != http.StatusBadGateway { + t.Errorf("Expected %d, got %d", http.StatusBadGateway, rr.Code) } } -func TestImageProxyHandlerEmptyId(t *testing.T) { - req := httptest.NewRequest("GET", "/image/", nil) +func TestImageProxyHandlerEmptyPath(t *testing.T) { + // After StripPrefix("/image/"), an empty path has TrimPrefix("/") = "" + req := httptest.NewRequest("GET", "/", nil) + req.URL = &url.URL{Path: "/"} rr := httptest.NewRecorder() imageProxyHandler(rr, req) if rr.Code != http.StatusNotFound { @@ -256,8 +250,8 @@ func TestImageProxyHandlerBadBase64(t *testing.T) { req := httptest.NewRequest("GET", "/image/notbase64!", nil) rr := httptest.NewRecorder() imageProxyHandler(rr, req) - if rr.Code != http.StatusNotFound { - t.Errorf("Expected %d, got %d", http.StatusNotFound, rr.Code) + if rr.Code != http.StatusBadRequest { + t.Errorf("Expected %d, got %d", http.StatusBadRequest, rr.Code) } } @@ -612,7 +606,9 @@ func TestIsCompressible(t *testing.T) { } func TestImageProxyHandlerMissingURL(t *testing.T) { - req := httptest.NewRequest("GET", "/image/", nil) + // Simulate what happens after StripPrefix("/image/"): path is empty + req := httptest.NewRequest("GET", "/", nil) + req.URL = &url.URL{Path: "/"} rr := httptest.NewRecorder() imageProxyHandler(rr, req) if rr.Code != http.StatusNotFound { @@ -624,8 +620,8 @@ func TestImageProxyHandlerInvalidBase64(t *testing.T) { req := httptest.NewRequest("GET", "/image/invalid-base64", nil) rr := httptest.NewRecorder() imageProxyHandler(rr, req) - if rr.Code != http.StatusNotFound { - t.Errorf("Expected %d, got %d", http.StatusNotFound, rr.Code) + if rr.Code != http.StatusBadRequest { + t.Errorf("Expected %d, got %d", http.StatusBadRequest, rr.Code) } } @@ -640,25 +636,27 @@ func TestServeFrontendNotFound(t *testing.T) { } func TestImageProxyHeaders(t *testing.T) { - url := "http://example.com/image.png" - encoded := base64.URLEncoding.EncodeToString([]byte(url)) + encoded := base64.URLEncoding.EncodeToString([]byte("http://example.com/image.png")) + etag := `"` + encoded + `"` - // Test If-None-Match + // Test If-None-Match with proper ETag req := httptest.NewRequest("GET", "/"+encoded, nil) - req.Header.Set("If-None-Match", url) + req.Header.Set("If-None-Match", etag) rr := httptest.NewRecorder() imageProxyHandler(rr, req) if rr.Code != http.StatusNotModified { t.Errorf("Expected %d for If-None-Match, got %d", http.StatusNotModified, rr.Code) } - // Test Etag + // Test mismatched If-None-Match does not return 304 req = httptest.NewRequest("GET", "/"+encoded, nil) - req.Header.Set("Etag", url) + req.Header.Set("If-None-Match", `"wrong-etag"`) rr = httptest.NewRecorder() imageProxyHandler(rr, req) - if rr.Code != http.StatusNotModified { - t.Errorf("Expected %d for Etag, got %d", http.StatusNotModified, rr.Code) + // Should not be 304 — it will try to fetch the remote, and since + // example.com is unreachable in tests, we get a fetch error + if rr.Code == http.StatusNotModified { + t.Error("Mismatched If-None-Match should not return 304") } } @@ -680,16 +678,37 @@ func TestServeBoxedFileNotFound(t *testing.T) { } } -func TestImageProxyHandlerHeaders(t *testing.T) { - url := "http://example.com/image.png" - id := base64.URLEncoding.EncodeToString([]byte(url)) +func TestImageProxyHandlerETagInResponse(t *testing.T) { + imgServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "image/png") + w.Write([]byte("png-data")) + })) + defer imgServer.Close() - req := httptest.NewRequest("GET", "/"+id, nil) - req.Header.Set("Etag", url) + encoded := base64.URLEncoding.EncodeToString([]byte(imgServer.URL + "/img.png")) + req := httptest.NewRequest("GET", "/"+encoded, nil) + req.URL = &url.URL{Path: encoded} rr := httptest.NewRecorder() imageProxyHandler(rr, req) - if rr.Code != http.StatusNotModified { - t.Errorf("Expected %d for matching Etag, got %d", http.StatusNotModified, rr.Code) + + if rr.Code != http.StatusOK { + t.Fatalf("Expected %d, got %d", http.StatusOK, rr.Code) + } + + // Verify ETag is set in response + expectedETag := `"` + encoded + `"` + if got := rr.Header().Get("ETag"); got != expectedETag { + t.Errorf("Expected ETag %q, got %q", expectedETag, got) + } + + // Verify Cache-Control is set + if cc := rr.Header().Get("Cache-Control"); cc != "public, max-age=172800" { + t.Errorf("Expected Cache-Control 'public, max-age=172800', got %q", cc) + } + + // Verify Content-Type is forwarded + if ct := rr.Header().Get("Content-Type"); ct != "image/png" { + t.Errorf("Expected Content-Type 'image/png', got %q", ct) } } @@ -697,9 +716,6 @@ func TestImageProxyHandlerRemoteError(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Length", "10") w.WriteHeader(http.StatusOK) - // Close connection immediately to cause ReadAll error if possible, - // or just return non-200. The current code only checks err from c.Do(request) - // and ioutil.ReadAll. })) ts.Close() // Close immediately so c.Do fails @@ -707,8 +723,8 @@ func TestImageProxyHandlerRemoteError(t *testing.T) { req := httptest.NewRequest("GET", "/"+id, nil) rr := httptest.NewRecorder() imageProxyHandler(rr, req) - if rr.Code != http.StatusNotFound { - t.Errorf("Expected %d for remote error, got %d", http.StatusNotFound, rr.Code) + if rr.Code != http.StatusBadGateway { + t.Errorf("Expected %d for remote error, got %d", http.StatusBadGateway, rr.Code) } } @@ -801,3 +817,236 @@ func TestSecurityHeadersMiddleware(t *testing.T) { t.Error("Missing Content-Security-Policy") } } + +// --- Comprehensive image proxy tests --- + +func TestImageProxyContentTypeForwarded(t *testing.T) { + tests := []struct { + name string + contentType string + wantStatus int + }{ + {"jpeg", "image/jpeg", http.StatusOK}, + {"png", "image/png", http.StatusOK}, + {"gif", "image/gif", http.StatusOK}, + {"webp", "image/webp", http.StatusOK}, + {"svg", "image/svg+xml", http.StatusOK}, + {"avif", "image/avif", http.StatusOK}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", tc.contentType) + w.Write([]byte("imgdata")) + })) + defer ts.Close() + + encoded := base64.URLEncoding.EncodeToString([]byte(ts.URL + "/img")) + req := httptest.NewRequest("GET", "/"+encoded, nil) + req.URL = &url.URL{Path: encoded} + rr := httptest.NewRecorder() + imageProxyHandler(rr, req) + + if rr.Code != tc.wantStatus { + t.Errorf("Expected %d, got %d", tc.wantStatus, rr.Code) + } + if ct := rr.Header().Get("Content-Type"); ct != tc.contentType { + t.Errorf("Expected Content-Type %q, got %q", tc.contentType, ct) + } + }) + } +} + +func TestImageProxyRejectsNonImageContentType(t *testing.T) { + tests := []struct { + name string + contentType string + }{ + {"html", "text/html"}, + {"javascript", "application/javascript"}, + {"json", "application/json"}, + {"pdf", "application/pdf"}, + {"executable", "application/octet-stream"}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", tc.contentType) + w.Write([]byte("not an image")) + })) + defer ts.Close() + + encoded := base64.URLEncoding.EncodeToString([]byte(ts.URL + "/bad")) + req := httptest.NewRequest("GET", "/"+encoded, nil) + req.URL = &url.URL{Path: encoded} + rr := httptest.NewRecorder() + imageProxyHandler(rr, req) + + if rr.Code != http.StatusForbidden { + t.Errorf("Expected %d for Content-Type %q, got %d", http.StatusForbidden, tc.contentType, rr.Code) + } + }) + } +} + +func TestImageProxyUpstreamErrorStatus(t *testing.T) { + tests := []struct { + name string + status int + }{ + {"not found", http.StatusNotFound}, + {"forbidden", http.StatusForbidden}, + {"server error", http.StatusInternalServerError}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(tc.status) + })) + defer ts.Close() + + encoded := base64.URLEncoding.EncodeToString([]byte(ts.URL + "/err")) + req := httptest.NewRequest("GET", "/"+encoded, nil) + req.URL = &url.URL{Path: encoded} + rr := httptest.NewRecorder() + imageProxyHandler(rr, req) + + if rr.Code != http.StatusBadGateway { + t.Errorf("Expected %d for upstream %d, got %d", http.StatusBadGateway, tc.status, rr.Code) + } + }) + } +} + +func TestImageProxyStreamsData(t *testing.T) { + // Verify the proxy streams data rather than returning it in one chunk. + // We test this by sending a known-size response and verifying we get it all. + data := make([]byte, 1024*1024) // 1 MB + for i := range data { + data[i] = byte(i % 256) + } + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "image/jpeg") + w.Header().Set("Content-Length", "1048576") + w.Write(data) + })) + defer ts.Close() + + encoded := base64.URLEncoding.EncodeToString([]byte(ts.URL + "/large.jpg")) + req := httptest.NewRequest("GET", "/"+encoded, nil) + req.URL = &url.URL{Path: encoded} + rr := httptest.NewRecorder() + imageProxyHandler(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("Expected %d, got %d", http.StatusOK, rr.Code) + } + if rr.Body.Len() != len(data) { + t.Errorf("Expected %d bytes, got %d", len(data), rr.Body.Len()) + } +} + +func TestImageProxyForwardsUserAgent(t *testing.T) { + var gotUA string + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotUA = r.Header.Get("User-Agent") + w.Header().Set("Content-Type", "image/jpeg") + w.Write([]byte("img")) + })) + defer ts.Close() + + encoded := base64.URLEncoding.EncodeToString([]byte(ts.URL + "/ua.jpg")) + req := httptest.NewRequest("GET", "/"+encoded, nil) + req.URL = &url.URL{Path: encoded} + rr := httptest.NewRecorder() + imageProxyHandler(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("Expected %d, got %d", http.StatusOK, rr.Code) + } + if gotUA != "neko RSS Reader Image Proxy +https://github.com/adammathes/neko" { + t.Errorf("Expected neko user agent, got %q", gotUA) + } +} + +func TestImageProxyEmptyContentTypeAllowed(t *testing.T) { + // Some servers return an empty Content-Type. The proxy should pass it + // through since we can't verify it's not an image. + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "") + w.Write([]byte("mystery-data")) + })) + defer ts.Close() + + encoded := base64.URLEncoding.EncodeToString([]byte(ts.URL + "/noct")) + req := httptest.NewRequest("GET", "/"+encoded, nil) + req.URL = &url.URL{Path: encoded} + rr := httptest.NewRecorder() + imageProxyHandler(rr, req) + + if rr.Code != http.StatusOK { + t.Errorf("Expected %d for empty Content-Type, got %d", http.StatusOK, rr.Code) + } +} + +func TestImageProxyContentLengthForwarded(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "image/png") + w.Header().Set("Content-Length", "42") + w.Write(make([]byte, 42)) + })) + defer ts.Close() + + encoded := base64.URLEncoding.EncodeToString([]byte(ts.URL + "/sized.png")) + req := httptest.NewRequest("GET", "/"+encoded, nil) + req.URL = &url.URL{Path: encoded} + rr := httptest.NewRecorder() + imageProxyHandler(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("Expected %d, got %d", http.StatusOK, rr.Code) + } + if cl := rr.Header().Get("Content-Length"); cl != "42" { + t.Errorf("Expected Content-Length '42', got %q", cl) + } +} + +func TestIsAllowedImageType(t *testing.T) { + tests := []struct { + ct string + expected bool + }{ + {"image/jpeg", true}, + {"image/png", true}, + {"image/gif", true}, + {"image/webp", true}, + {"image/svg+xml", true}, + {"IMAGE/JPEG", true}, + {"text/html", false}, + {"application/json", false}, + {"application/pdf", false}, + {"", false}, + } + for _, tc := range tests { + if res := isAllowedImageType(tc.ct); res != tc.expected { + t.Errorf("isAllowedImageType(%q) = %v, want %v", tc.ct, res, tc.expected) + } + } +} + +func TestImageProxyInvalidURL(t *testing.T) { + // Base64 of a string that's not a valid URL + encoded := base64.URLEncoding.EncodeToString([]byte("://not-a-url")) + req := httptest.NewRequest("GET", "/"+encoded, nil) + req.URL = &url.URL{Path: encoded} + rr := httptest.NewRecorder() + imageProxyHandler(rr, req) + + if rr.Code != http.StatusBadRequest { + t.Errorf("Expected %d for invalid URL, got %d", http.StatusBadRequest, rr.Code) + } +} -- cgit v1.2.3