From 64830820d23ba58b63509cb721e7551dd30c4997 Mon Sep 17 00:00:00 2001 From: Adam Mathes Date: Wed, 18 Feb 2026 08:06:42 -0800 Subject: fix: implement HTTP/2 fallback for SafeClient on protocol errors --- internal/safehttp/safehttp.go | 34 +++++++++++++++++++++++++++++++++- internal/safehttp/safehttp_test.go | 25 +++++++++++++++++++++++-- 2 files changed, 56 insertions(+), 3 deletions(-) (limited to 'internal/safehttp') diff --git a/internal/safehttp/safehttp.go b/internal/safehttp/safehttp.go index f2c316b..1072130 100644 --- a/internal/safehttp/safehttp.go +++ b/internal/safehttp/safehttp.go @@ -2,9 +2,11 @@ package safehttp import ( "context" + "crypto/tls" "fmt" "net" "net/http" + "strings" "time" ) @@ -84,7 +86,7 @@ func NewSafeClient(timeout time.Duration) *http.Client { return &http.Client{ Timeout: timeout, - Transport: transport, + Transport: &H2FallbackTransport{Transport: transport}, CheckRedirect: func(req *http.Request, via []*http.Request) error { if len(via) >= 10 { return fmt.Errorf("too many redirects") @@ -115,3 +117,33 @@ func NewSafeClient(timeout time.Duration) *http.Client { }, } } + +// H2FallbackTransport wraps an *http.Transport and retries failed requests with HTTP/1.1 +// if an HTTP/2 protocol error is detected. This is useful for crawling external feeds +// where some servers may have buggy HTTP/2 implementations. +type H2FallbackTransport struct { + Transport *http.Transport +} + +func (t *H2FallbackTransport) RoundTrip(req *http.Request) (*http.Response, error) { + resp, err := t.Transport.RoundTrip(req) + if err != nil && isHTTP2Error(err) && (req.Method == "GET" || req.Method == "HEAD" || req.Body == nil) { + // Clone the transport and disable HTTP/2 for the retry + h1Transport := t.Transport.Clone() + h1Transport.TLSNextProto = make(map[string]func(string, *tls.Conn) http.RoundTripper) + h1Transport.ForceAttemptHTTP2 = false + return h1Transport.RoundTrip(req) + } + return resp, err +} + +func isHTTP2Error(err error) bool { + if err == nil { + return false + } + msg := strings.ToLower(err.Error()) + return strings.Contains(msg, "http2") || + strings.Contains(msg, "stream error") || + strings.Contains(msg, "goaway") || + strings.Contains(msg, "protocol") +} diff --git a/internal/safehttp/safehttp_test.go b/internal/safehttp/safehttp_test.go index dc428e4..19f9f51 100644 --- a/internal/safehttp/safehttp_test.go +++ b/internal/safehttp/safehttp_test.go @@ -143,11 +143,13 @@ func TestNewSafeClientProperties(t *testing.T) { t.Errorf("expected timeout 5s, got %v", client.Timeout) } - transport, ok := client.Transport.(*http.Transport) + h2Transport, ok := client.Transport.(*H2FallbackTransport) if !ok { - t.Fatal("expected *http.Transport") + t.Fatal("expected *H2FallbackTransport") } + transport := h2Transport.Transport + // Proxy should be nil to prevent SSRF bypass if transport.Proxy != nil { t.Error("transport.Proxy should be nil to prevent SSRF bypass") @@ -159,6 +161,25 @@ func TestNewSafeClientProperties(t *testing.T) { } } +func TestIsHTTP2Error(t *testing.T) { + tests := []struct { + err error + expected bool + }{ + {fmt.Errorf("http2: stream error"), true}, + {fmt.Errorf("random error"), false}, + {fmt.Errorf("PROTOCOL_ERROR"), true}, + {fmt.Errorf("GOAWAY"), true}, + {nil, false}, + } + + for _, tc := range tests { + if res := isHTTP2Error(tc.err); res != tc.expected { + t.Errorf("isHTTP2Error(%v) = %v, want %v", tc.err, res, tc.expected) + } + } +} + func TestNewSafeClientRedirectToPrivateIP(t *testing.T) { // Create a server that redirects to a private IP ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { -- cgit v1.2.3