aboutsummaryrefslogtreecommitdiffstats
path: root/web
diff options
context:
space:
mode:
Diffstat (limited to 'web')
-rw-r--r--web/web.go93
-rw-r--r--web/web_test.go52
2 files changed, 136 insertions, 9 deletions
diff --git a/web/web.go b/web/web.go
index 10e9b2f..77f449f 100644
--- a/web/web.go
+++ b/web/web.go
@@ -10,12 +10,23 @@ import (
"strings"
"time"
+ "compress/gzip"
+ "io"
+ "sync"
+
"adammathes.com/neko/api"
"adammathes.com/neko/config"
rice "github.com/GeertJohan/go.rice"
"golang.org/x/crypto/bcrypt"
)
+var gzPool = sync.Pool{
+ New: func() interface{} {
+ gz, _ := gzip.NewWriterLevel(io.Discard, gzip.BestSpeed)
+ return gz
+ },
+}
+
func indexHandler(w http.ResponseWriter, r *http.Request) {
serveBoxedFile(w, r, "ui.html")
}
@@ -183,14 +194,14 @@ func apiAuthStatusHandler(w http.ResponseWriter, r *http.Request) {
func Serve() {
box := rice.MustFindBox("../static")
staticFileServer := http.StripPrefix("/static/", http.FileServer(box.HTTPBox()))
- http.Handle("/static/", staticFileServer)
+ http.Handle("/static/", GzipMiddleware(staticFileServer))
// New Frontend
- http.Handle("/v2/", http.StripPrefix("/v2/", http.HandlerFunc(ServeFrontend)))
+ http.Handle("/v2/", GzipMiddleware(http.StripPrefix("/v2/", http.HandlerFunc(ServeFrontend))))
// New REST API
apiRouter := api.NewRouter()
- http.Handle("/api/", http.StripPrefix("/api", AuthWrapHandler(apiRouter)))
+ http.Handle("/api/", GzipMiddleware(http.StripPrefix("/api", AuthWrapHandler(apiRouter))))
// Legacy routes for backward compatibility
http.HandleFunc("/stream/", AuthWrap(api.HandleStream))
@@ -208,11 +219,85 @@ func Serve() {
http.HandleFunc("/api/logout", apiLogoutHandler)
http.HandleFunc("/api/auth", apiAuthStatusHandler)
- http.HandleFunc("/", AuthWrap(indexHandler))
+ http.Handle("/", GzipMiddleware(AuthWrap(http.HandlerFunc(indexHandler))))
log.Fatal(http.ListenAndServe(":"+strconv.Itoa(config.Config.Port), nil))
}
+type gzipWriter struct {
+ http.ResponseWriter
+ gz *gzip.Writer
+}
+
+func (w *gzipWriter) Write(b []byte) (int, error) {
+ if w.Header().Get("Content-Type") == "" {
+ w.Header().Set("Content-Type", http.DetectContentType(b))
+ }
+ contentType := w.Header().Get("Content-Type")
+ if w.gz == nil && isCompressible(contentType) && w.Header().Get("Content-Encoding") == "" {
+ w.Header().Set("Content-Encoding", "gzip")
+ w.Header().Del("Content-Length")
+ gz := gzPool.Get().(*gzip.Writer)
+ gz.Reset(w.ResponseWriter)
+ w.gz = gz
+ }
+ if w.gz != nil {
+ return w.gz.Write(b)
+ }
+ return w.ResponseWriter.Write(b)
+}
+
+func (w *gzipWriter) WriteHeader(status int) {
+ if status != http.StatusOK && status != http.StatusCreated && status != http.StatusAccepted {
+ w.ResponseWriter.WriteHeader(status)
+ return
+ }
+ contentType := w.Header().Get("Content-Type")
+ if isCompressible(contentType) && w.Header().Get("Content-Encoding") == "" {
+ w.Header().Set("Content-Encoding", "gzip")
+ w.Header().Del("Content-Length")
+ gz := gzPool.Get().(*gzip.Writer)
+ gz.Reset(w.ResponseWriter)
+ w.gz = gz
+ }
+ w.ResponseWriter.WriteHeader(status)
+}
+
+func (w *gzipWriter) Flush() {
+ if w.gz != nil {
+ w.gz.Flush()
+ }
+ if f, ok := w.ResponseWriter.(http.Flusher); ok {
+ f.Flush()
+ }
+}
+
+func isCompressible(contentType string) bool {
+ ct := strings.ToLower(contentType)
+ return strings.Contains(ct, "text/") ||
+ strings.Contains(ct, "javascript") ||
+ strings.Contains(ct, "json") ||
+ strings.Contains(ct, "xml") ||
+ strings.Contains(ct, "rss") ||
+ strings.Contains(ct, "xhtml")
+}
+
+func GzipMiddleware(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
+ next.ServeHTTP(w, r)
+ return
+ }
+
+ gzw := &gzipWriter{ResponseWriter: w}
+ next.ServeHTTP(gzw, r)
+ if gzw.gz != nil {
+ gzw.gz.Close()
+ gzPool.Put(gzw.gz)
+ }
+ })
+}
+
func apiLogoutHandler(w http.ResponseWriter, r *http.Request) {
c := http.Cookie{Name: AuthCookie, Value: "", Path: "/", MaxAge: -1, HttpOnly: false}
http.SetCookie(w, &c)
diff --git a/web/web_test.go b/web/web_test.go
index 9030947..039ad89 100644
--- a/web/web_test.go
+++ b/web/web_test.go
@@ -422,13 +422,55 @@ func TestServeFrontend(t *testing.T) {
// We expect 200 if built, or maybe panic if box not found (rice.MustFindBox)
// But rice usually works in dev mode by looking at disk.
if rr.Code != http.StatusOK {
- // If 404/500, it might be that dist is missing, but for this specific verification
+ // If 404/500, it might be that dist is missing, but for this specific verification
// where we know we built it, we expect 200.
// However, protecting against CI failures where build might not happen:
t.Logf("Got code %d for frontend request", rr.Code)
}
- // Check for unauthenticated access (no cookie needed)
- if rr.Code == http.StatusTemporaryRedirect {
- t.Error("Frontend should not redirect to login")
- }
+ // Check for unauthenticated access (no cookie needed)
+ if rr.Code == http.StatusTemporaryRedirect {
+ t.Error("Frontend should not redirect to login")
+ }
+}
+
+func TestGzipCompression(t *testing.T) {
+ handler := GzipMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "text/plain")
+ w.Write([]byte("this is a test string that should be compressed when gzip is enabled and the client supports it"))
+ }))
+
+ // Case 1: Client supports gzip
+ req := httptest.NewRequest("GET", "/", nil)
+ req.Header.Set("Accept-Encoding", "gzip")
+ rr := httptest.NewRecorder()
+ handler.ServeHTTP(rr, req)
+
+ if rr.Header().Get("Content-Encoding") != "gzip" {
+ t.Errorf("Expected Content-Encoding: gzip, got %q", rr.Header().Get("Content-Encoding"))
+ }
+
+ // Case 2: Client does NOT support gzip
+ req = httptest.NewRequest("GET", "/", nil)
+ rr = httptest.NewRecorder()
+ handler.ServeHTTP(rr, req)
+
+ if rr.Header().Get("Content-Encoding") == "gzip" {
+ t.Error("Expected no Content-Encoding: gzip for client without support")
+ }
+
+ // Case 3: 304 Not Modified (Should NOT be gzipped)
+ handler304 := GzipMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusNotModified)
+ }))
+ req = httptest.NewRequest("GET", "/", nil)
+ req.Header.Set("Accept-Encoding", "gzip")
+ rr = httptest.NewRecorder()
+ handler304.ServeHTTP(rr, req)
+
+ if rr.Code != http.StatusNotModified {
+ t.Errorf("Expected 304, got %d", rr.Code)
+ }
+ if rr.Header().Get("Content-Encoding") == "gzip" {
+ t.Error("Expected no Content-Encoding for 304 response")
+ }
}