diff options
| author | Adam Mathes <adam@adammathes.com> | 2026-02-18 08:06:42 -0800 |
|---|---|---|
| committer | Adam Mathes <adam@adammathes.com> | 2026-02-18 08:06:42 -0800 |
| commit | 64830820d23ba58b63509cb721e7551dd30c4997 (patch) | |
| tree | 40ed2d5e2ede7a761939d97cf1ae6d004ccc5e16 /internal/safehttp/safehttp_test.go | |
| parent | 20337a80775d81a69d8019430bb1f3b0d450e259 (diff) | |
| download | neko-64830820d23ba58b63509cb721e7551dd30c4997.tar.gz neko-64830820d23ba58b63509cb721e7551dd30c4997.tar.bz2 neko-64830820d23ba58b63509cb721e7551dd30c4997.zip | |
fix: implement HTTP/2 fallback for SafeClient on protocol errors
Diffstat (limited to 'internal/safehttp/safehttp_test.go')
| -rw-r--r-- | internal/safehttp/safehttp_test.go | 25 |
1 files changed, 23 insertions, 2 deletions
diff --git a/internal/safehttp/safehttp_test.go b/internal/safehttp/safehttp_test.go index dc428e4..19f9f51 100644 --- a/internal/safehttp/safehttp_test.go +++ b/internal/safehttp/safehttp_test.go @@ -143,11 +143,13 @@ func TestNewSafeClientProperties(t *testing.T) { t.Errorf("expected timeout 5s, got %v", client.Timeout) } - transport, ok := client.Transport.(*http.Transport) + h2Transport, ok := client.Transport.(*H2FallbackTransport) if !ok { - t.Fatal("expected *http.Transport") + t.Fatal("expected *H2FallbackTransport") } + transport := h2Transport.Transport + // Proxy should be nil to prevent SSRF bypass if transport.Proxy != nil { t.Error("transport.Proxy should be nil to prevent SSRF bypass") @@ -159,6 +161,25 @@ func TestNewSafeClientProperties(t *testing.T) { } } +func TestIsHTTP2Error(t *testing.T) { + tests := []struct { + err error + expected bool + }{ + {fmt.Errorf("http2: stream error"), true}, + {fmt.Errorf("random error"), false}, + {fmt.Errorf("PROTOCOL_ERROR"), true}, + {fmt.Errorf("GOAWAY"), true}, + {nil, false}, + } + + for _, tc := range tests { + if res := isHTTP2Error(tc.err); res != tc.expected { + t.Errorf("isHTTP2Error(%v) = %v, want %v", tc.err, res, tc.expected) + } + } +} + func TestNewSafeClientRedirectToPrivateIP(t *testing.T) { // Create a server that redirects to a private IP ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
