diff options
Diffstat (limited to 'web')
| -rw-r--r-- | web/web.go | 93 | ||||
| -rw-r--r-- | web/web_test.go | 52 |
2 files changed, 136 insertions, 9 deletions
@@ -10,12 +10,23 @@ import ( "strings" "time" + "compress/gzip" + "io" + "sync" + "adammathes.com/neko/api" "adammathes.com/neko/config" rice "github.com/GeertJohan/go.rice" "golang.org/x/crypto/bcrypt" ) +var gzPool = sync.Pool{ + New: func() interface{} { + gz, _ := gzip.NewWriterLevel(io.Discard, gzip.BestSpeed) + return gz + }, +} + func indexHandler(w http.ResponseWriter, r *http.Request) { serveBoxedFile(w, r, "ui.html") } @@ -183,14 +194,14 @@ func apiAuthStatusHandler(w http.ResponseWriter, r *http.Request) { func Serve() { box := rice.MustFindBox("../static") staticFileServer := http.StripPrefix("/static/", http.FileServer(box.HTTPBox())) - http.Handle("/static/", staticFileServer) + http.Handle("/static/", GzipMiddleware(staticFileServer)) // New Frontend - http.Handle("/v2/", http.StripPrefix("/v2/", http.HandlerFunc(ServeFrontend))) + http.Handle("/v2/", GzipMiddleware(http.StripPrefix("/v2/", http.HandlerFunc(ServeFrontend)))) // New REST API apiRouter := api.NewRouter() - http.Handle("/api/", http.StripPrefix("/api", AuthWrapHandler(apiRouter))) + http.Handle("/api/", GzipMiddleware(http.StripPrefix("/api", AuthWrapHandler(apiRouter)))) // Legacy routes for backward compatibility http.HandleFunc("/stream/", AuthWrap(api.HandleStream)) @@ -208,11 +219,85 @@ func Serve() { http.HandleFunc("/api/logout", apiLogoutHandler) http.HandleFunc("/api/auth", apiAuthStatusHandler) - http.HandleFunc("/", AuthWrap(indexHandler)) + http.Handle("/", GzipMiddleware(AuthWrap(http.HandlerFunc(indexHandler)))) log.Fatal(http.ListenAndServe(":"+strconv.Itoa(config.Config.Port), nil)) } +type gzipWriter struct { + http.ResponseWriter + gz *gzip.Writer +} + +func (w *gzipWriter) Write(b []byte) (int, error) { + if w.Header().Get("Content-Type") == "" { + w.Header().Set("Content-Type", http.DetectContentType(b)) + } + contentType := w.Header().Get("Content-Type") + if w.gz == nil && isCompressible(contentType) && w.Header().Get("Content-Encoding") == "" { + w.Header().Set("Content-Encoding", "gzip") + w.Header().Del("Content-Length") + gz := gzPool.Get().(*gzip.Writer) + gz.Reset(w.ResponseWriter) + w.gz = gz + } + if w.gz != nil { + return w.gz.Write(b) + } + return w.ResponseWriter.Write(b) +} + +func (w *gzipWriter) WriteHeader(status int) { + if status != http.StatusOK && status != http.StatusCreated && status != http.StatusAccepted { + w.ResponseWriter.WriteHeader(status) + return + } + contentType := w.Header().Get("Content-Type") + if isCompressible(contentType) && w.Header().Get("Content-Encoding") == "" { + w.Header().Set("Content-Encoding", "gzip") + w.Header().Del("Content-Length") + gz := gzPool.Get().(*gzip.Writer) + gz.Reset(w.ResponseWriter) + w.gz = gz + } + w.ResponseWriter.WriteHeader(status) +} + +func (w *gzipWriter) Flush() { + if w.gz != nil { + w.gz.Flush() + } + if f, ok := w.ResponseWriter.(http.Flusher); ok { + f.Flush() + } +} + +func isCompressible(contentType string) bool { + ct := strings.ToLower(contentType) + return strings.Contains(ct, "text/") || + strings.Contains(ct, "javascript") || + strings.Contains(ct, "json") || + strings.Contains(ct, "xml") || + strings.Contains(ct, "rss") || + strings.Contains(ct, "xhtml") +} + +func GzipMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") { + next.ServeHTTP(w, r) + return + } + + gzw := &gzipWriter{ResponseWriter: w} + next.ServeHTTP(gzw, r) + if gzw.gz != nil { + gzw.gz.Close() + gzPool.Put(gzw.gz) + } + }) +} + func apiLogoutHandler(w http.ResponseWriter, r *http.Request) { c := http.Cookie{Name: AuthCookie, Value: "", Path: "/", MaxAge: -1, HttpOnly: false} http.SetCookie(w, &c) diff --git a/web/web_test.go b/web/web_test.go index 9030947..039ad89 100644 --- a/web/web_test.go +++ b/web/web_test.go @@ -422,13 +422,55 @@ func TestServeFrontend(t *testing.T) { // We expect 200 if built, or maybe panic if box not found (rice.MustFindBox) // But rice usually works in dev mode by looking at disk. if rr.Code != http.StatusOK { - // If 404/500, it might be that dist is missing, but for this specific verification + // If 404/500, it might be that dist is missing, but for this specific verification // where we know we built it, we expect 200. // However, protecting against CI failures where build might not happen: t.Logf("Got code %d for frontend request", rr.Code) } - // Check for unauthenticated access (no cookie needed) - if rr.Code == http.StatusTemporaryRedirect { - t.Error("Frontend should not redirect to login") - } + // Check for unauthenticated access (no cookie needed) + if rr.Code == http.StatusTemporaryRedirect { + t.Error("Frontend should not redirect to login") + } +} + +func TestGzipCompression(t *testing.T) { + handler := GzipMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain") + w.Write([]byte("this is a test string that should be compressed when gzip is enabled and the client supports it")) + })) + + // Case 1: Client supports gzip + req := httptest.NewRequest("GET", "/", nil) + req.Header.Set("Accept-Encoding", "gzip") + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Header().Get("Content-Encoding") != "gzip" { + t.Errorf("Expected Content-Encoding: gzip, got %q", rr.Header().Get("Content-Encoding")) + } + + // Case 2: Client does NOT support gzip + req = httptest.NewRequest("GET", "/", nil) + rr = httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Header().Get("Content-Encoding") == "gzip" { + t.Error("Expected no Content-Encoding: gzip for client without support") + } + + // Case 3: 304 Not Modified (Should NOT be gzipped) + handler304 := GzipMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotModified) + })) + req = httptest.NewRequest("GET", "/", nil) + req.Header.Set("Accept-Encoding", "gzip") + rr = httptest.NewRecorder() + handler304.ServeHTTP(rr, req) + + if rr.Code != http.StatusNotModified { + t.Errorf("Expected 304, got %d", rr.Code) + } + if rr.Header().Get("Content-Encoding") == "gzip" { + t.Error("Expected no Content-Encoding for 304 response") + } } |
