diff options
Diffstat (limited to 'web')
| -rw-r--r-- | web/web.go | 48 | ||||
| -rw-r--r-- | web/web_test.go | 44 |
2 files changed, 87 insertions, 5 deletions
@@ -1,7 +1,9 @@ package web import ( + "crypto/rand" "encoding/base64" + "encoding/hex" "fmt" "io/fs" "io/ioutil" @@ -114,7 +116,7 @@ func loginHandler(w http.ResponseWriter, r *http.Request) { password := r.FormValue("password") if password == config.Config.DigestPassword { v, _ := bcrypt.GenerateFromPassword([]byte(password), 0) - c := http.Cookie{Name: AuthCookie, Value: string(v), Path: "/", MaxAge: SecondsInAYear, HttpOnly: false} + c := http.Cookie{Name: AuthCookie, Value: string(v), Path: "/", MaxAge: SecondsInAYear, HttpOnly: true} http.SetCookie(w, &c) http.Redirect(w, r, "/", 307) } else { @@ -126,7 +128,7 @@ func loginHandler(w http.ResponseWriter, r *http.Request) { } func logoutHandler(w http.ResponseWriter, r *http.Request) { - c := http.Cookie{Name: AuthCookie, MaxAge: 0, Path: "/", HttpOnly: false} + c := http.Cookie{Name: AuthCookie, MaxAge: 0, Path: "/", HttpOnly: true} http.SetCookie(w, &c) fmt.Fprintf(w, "you are logged out") } @@ -195,7 +197,7 @@ func apiLoginHandler(w http.ResponseWriter, r *http.Request) { if password == config.Config.DigestPassword { v, _ := bcrypt.GenerateFromPassword([]byte(password), 0) - c := http.Cookie{Name: AuthCookie, Value: string(v), Path: "/", MaxAge: SecondsInAYear, HttpOnly: false} + c := http.Cookie{Name: AuthCookie, Value: string(v), Path: "/", MaxAge: SecondsInAYear, HttpOnly: true} http.SetCookie(w, &c) w.Header().Set("Content-Type", "application/json") fmt.Fprintf(w, `{"status":"ok"}`) @@ -258,7 +260,7 @@ func NewRouter(cfg *config.Settings) http.Handler { mux.Handle("/", GzipMiddleware(AuthWrap(http.HandlerFunc(indexHandler)))) - return mux + return CSRFMiddleware(mux) } func Serve(cfg *config.Settings) { @@ -341,8 +343,44 @@ func GzipMiddleware(next http.Handler) http.Handler { } func apiLogoutHandler(w http.ResponseWriter, r *http.Request) { - c := http.Cookie{Name: AuthCookie, Value: "", Path: "/", MaxAge: -1, HttpOnly: false} + c := http.Cookie{Name: AuthCookie, Value: "", Path: "/", MaxAge: -1, HttpOnly: true} http.SetCookie(w, &c) w.Header().Set("Content-Type", "application/json") fmt.Fprintf(w, `{"status":"ok"}`) } + +func generateRandomToken(n int) string { + b := make([]byte, n) + if _, err := rand.Read(b); err != nil { + return "" + } + return hex.EncodeToString(b) +} + +func CSRFMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + cookie, err := r.Cookie("csrf_token") + var token string + if err != nil { + token = generateRandomToken(16) + http.SetCookie(w, &http.Cookie{ + Name: "csrf_token", + Value: token, + Path: "/", + HttpOnly: false, // accessible by JS + SameSite: http.SameSiteLaxMode, + }) + } else { + token = cookie.Value + } + + if r.Method == http.MethodPost || r.Method == http.MethodPut || r.Method == http.MethodDelete { + headerToken := r.Header.Get("X-CSRF-Token") + if headerToken == "" || headerToken != token { + http.Error(w, "CSRF token mismatch", http.StatusForbidden) + return + } + } + next.ServeHTTP(w, r) + }) +} diff --git a/web/web_test.go b/web/web_test.go index aca3aed..89ca998 100644 --- a/web/web_test.go +++ b/web/web_test.go @@ -730,3 +730,47 @@ func TestGzipMiddlewareNonCompressible(t *testing.T) { t.Error("Expected no gzip for image/png") } } + +func TestCSRFMiddleware(t *testing.T) { + handler := CSRFMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // Case 1: GET should succeed and set a cookie + req := httptest.NewRequest("GET", "/", nil) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + if rr.Code != http.StatusOK { + t.Errorf("Expected 200 for GET, got %d", rr.Code) + } + cookies := rr.Result().Cookies() + var csrfCookie *http.Cookie + for _, c := range cookies { + if c.Name == "csrf_token" { + csrfCookie = c + break + } + } + if csrfCookie == nil { + t.Fatal("Expected csrf_token cookie to be set on first GET") + } + + // Case 2: POST without token should fail + req = httptest.NewRequest("POST", "/", nil) + req.AddCookie(csrfCookie) + rr = httptest.NewRecorder() + handler.ServeHTTP(rr, req) + if rr.Code != http.StatusForbidden { + t.Errorf("Expected 403 for POST without token, got %d", rr.Code) + } + + // Case 3: POST with valid token should succeed + req = httptest.NewRequest("POST", "/", nil) + req.AddCookie(csrfCookie) + req.Header.Set("X-CSRF-Token", csrfCookie.Value) + rr = httptest.NewRecorder() + handler.ServeHTTP(rr, req) + if rr.Code != http.StatusOK { + t.Errorf("Expected 200 for POST with valid token, got %d", rr.Code) + } +} |
