From 64830820d23ba58b63509cb721e7551dd30c4997 Mon Sep 17 00:00:00 2001 From: Adam Mathes Date: Wed, 18 Feb 2026 08:06:42 -0800 Subject: fix: implement HTTP/2 fallback for SafeClient on protocol errors --- internal/safehttp/safehttp.go | 34 +++++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) (limited to 'internal/safehttp/safehttp.go') diff --git a/internal/safehttp/safehttp.go b/internal/safehttp/safehttp.go index f2c316b..1072130 100644 --- a/internal/safehttp/safehttp.go +++ b/internal/safehttp/safehttp.go @@ -2,9 +2,11 @@ package safehttp import ( "context" + "crypto/tls" "fmt" "net" "net/http" + "strings" "time" ) @@ -84,7 +86,7 @@ func NewSafeClient(timeout time.Duration) *http.Client { return &http.Client{ Timeout: timeout, - Transport: transport, + Transport: &H2FallbackTransport{Transport: transport}, CheckRedirect: func(req *http.Request, via []*http.Request) error { if len(via) >= 10 { return fmt.Errorf("too many redirects") @@ -115,3 +117,33 @@ func NewSafeClient(timeout time.Duration) *http.Client { }, } } + +// H2FallbackTransport wraps an *http.Transport and retries failed requests with HTTP/1.1 +// if an HTTP/2 protocol error is detected. This is useful for crawling external feeds +// where some servers may have buggy HTTP/2 implementations. +type H2FallbackTransport struct { + Transport *http.Transport +} + +func (t *H2FallbackTransport) RoundTrip(req *http.Request) (*http.Response, error) { + resp, err := t.Transport.RoundTrip(req) + if err != nil && isHTTP2Error(err) && (req.Method == "GET" || req.Method == "HEAD" || req.Body == nil) { + // Clone the transport and disable HTTP/2 for the retry + h1Transport := t.Transport.Clone() + h1Transport.TLSNextProto = make(map[string]func(string, *tls.Conn) http.RoundTripper) + h1Transport.ForceAttemptHTTP2 = false + return h1Transport.RoundTrip(req) + } + return resp, err +} + +func isHTTP2Error(err error) bool { + if err == nil { + return false + } + msg := strings.ToLower(err.Error()) + return strings.Contains(msg, "http2") || + strings.Contains(msg, "stream error") || + strings.Contains(msg, "goaway") || + strings.Contains(msg, "protocol") +} -- cgit v1.2.3