aboutsummaryrefslogtreecommitdiffstats
path: root/internal
diff options
context:
space:
mode:
Diffstat (limited to 'internal')
-rw-r--r--internal/safehttp/safehttp.go34
-rw-r--r--internal/safehttp/safehttp_test.go25
2 files changed, 56 insertions, 3 deletions
diff --git a/internal/safehttp/safehttp.go b/internal/safehttp/safehttp.go
index f2c316b..1072130 100644
--- a/internal/safehttp/safehttp.go
+++ b/internal/safehttp/safehttp.go
@@ -2,9 +2,11 @@ package safehttp
import (
"context"
+ "crypto/tls"
"fmt"
"net"
"net/http"
+ "strings"
"time"
)
@@ -84,7 +86,7 @@ func NewSafeClient(timeout time.Duration) *http.Client {
return &http.Client{
Timeout: timeout,
- Transport: transport,
+ Transport: &H2FallbackTransport{Transport: transport},
CheckRedirect: func(req *http.Request, via []*http.Request) error {
if len(via) >= 10 {
return fmt.Errorf("too many redirects")
@@ -115,3 +117,33 @@ func NewSafeClient(timeout time.Duration) *http.Client {
},
}
}
+
+// H2FallbackTransport wraps an *http.Transport and retries failed requests with HTTP/1.1
+// if an HTTP/2 protocol error is detected. This is useful for crawling external feeds
+// where some servers may have buggy HTTP/2 implementations.
+type H2FallbackTransport struct {
+ Transport *http.Transport
+}
+
+func (t *H2FallbackTransport) RoundTrip(req *http.Request) (*http.Response, error) {
+ resp, err := t.Transport.RoundTrip(req)
+ if err != nil && isHTTP2Error(err) && (req.Method == "GET" || req.Method == "HEAD" || req.Body == nil) {
+ // Clone the transport and disable HTTP/2 for the retry
+ h1Transport := t.Transport.Clone()
+ h1Transport.TLSNextProto = make(map[string]func(string, *tls.Conn) http.RoundTripper)
+ h1Transport.ForceAttemptHTTP2 = false
+ return h1Transport.RoundTrip(req)
+ }
+ return resp, err
+}
+
+func isHTTP2Error(err error) bool {
+ if err == nil {
+ return false
+ }
+ msg := strings.ToLower(err.Error())
+ return strings.Contains(msg, "http2") ||
+ strings.Contains(msg, "stream error") ||
+ strings.Contains(msg, "goaway") ||
+ strings.Contains(msg, "protocol")
+}
diff --git a/internal/safehttp/safehttp_test.go b/internal/safehttp/safehttp_test.go
index dc428e4..19f9f51 100644
--- a/internal/safehttp/safehttp_test.go
+++ b/internal/safehttp/safehttp_test.go
@@ -143,11 +143,13 @@ func TestNewSafeClientProperties(t *testing.T) {
t.Errorf("expected timeout 5s, got %v", client.Timeout)
}
- transport, ok := client.Transport.(*http.Transport)
+ h2Transport, ok := client.Transport.(*H2FallbackTransport)
if !ok {
- t.Fatal("expected *http.Transport")
+ t.Fatal("expected *H2FallbackTransport")
}
+ transport := h2Transport.Transport
+
// Proxy should be nil to prevent SSRF bypass
if transport.Proxy != nil {
t.Error("transport.Proxy should be nil to prevent SSRF bypass")
@@ -159,6 +161,25 @@ func TestNewSafeClientProperties(t *testing.T) {
}
}
+func TestIsHTTP2Error(t *testing.T) {
+ tests := []struct {
+ err error
+ expected bool
+ }{
+ {fmt.Errorf("http2: stream error"), true},
+ {fmt.Errorf("random error"), false},
+ {fmt.Errorf("PROTOCOL_ERROR"), true},
+ {fmt.Errorf("GOAWAY"), true},
+ {nil, false},
+ }
+
+ for _, tc := range tests {
+ if res := isHTTP2Error(tc.err); res != tc.expected {
+ t.Errorf("isHTTP2Error(%v) = %v, want %v", tc.err, res, tc.expected)
+ }
+ }
+}
+
func TestNewSafeClientRedirectToPrivateIP(t *testing.T) {
// Create a server that redirects to a private IP
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {