aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--api/api_stress_test.go134
-rw-r--r--api/api_test.go229
-rw-r--r--cmd/neko/main_test.go54
-rw-r--r--internal/importer/importer_test.go178
-rw-r--r--internal/safehttp/safehttp_test.go287
-rw-r--r--web/security_regression_test.go222
6 files changed, 1104 insertions, 0 deletions
diff --git a/api/api_stress_test.go b/api/api_stress_test.go
index a846f75..4fcbf5e 100644
--- a/api/api_stress_test.go
+++ b/api/api_stress_test.go
@@ -194,6 +194,140 @@ func TestStress_LargeDataset(t *testing.T) {
}
}
+func TestStress_ConcurrentMixedOperations(t *testing.T) {
+ if testing.Short() {
+ t.Skip("skipping stress test in short mode")
+ }
+
+ setupTestDB(t)
+
+ // Create multiple feeds with items across categories
+ categories := []string{"tech", "news", "science", "art"}
+ for i, cat := range categories {
+ f := &feed.Feed{
+ Url: fmt.Sprintf("http://example.com/mixed/%d", i),
+ Title: fmt.Sprintf("Mixed Feed %d", i),
+ Category: cat,
+ }
+ f.Create()
+ for j := 0; j < 25; j++ {
+ it := &item.Item{
+ Title: fmt.Sprintf("Mixed Item %d-%d", i, j),
+ Url: fmt.Sprintf("http://example.com/mixed/%d/%d", i, j),
+ Description: fmt.Sprintf("<p>Mixed content %d-%d</p>", i, j),
+ PublishDate: "2024-01-01 00:00:00",
+ FeedId: f.Id,
+ }
+ _ = it.Create()
+ }
+ }
+
+ server := newTestServer()
+
+ const goroutines = 40
+ var wg sync.WaitGroup
+ errors := make(chan error, goroutines*2)
+
+ start := time.Now()
+
+ // Mix of reads, filtered reads, updates, and exports
+ for i := 0; i < goroutines; i++ {
+ wg.Add(1)
+ go func(idx int) {
+ defer wg.Done()
+
+ var req *http.Request
+ switch idx % 4 {
+ case 0:
+ // Stream with category filter
+ req = httptest.NewRequest("GET", "/stream?tag="+categories[idx%len(categories)], nil)
+ case 1:
+ // Stream with search
+ req = httptest.NewRequest("GET", "/stream?q=Mixed", nil)
+ case 2:
+ // Feed list
+ req = httptest.NewRequest("GET", "/feed", nil)
+ case 3:
+ // Export
+ req = httptest.NewRequest("GET", "/export/json", nil)
+ }
+
+ rr := httptest.NewRecorder()
+ server.ServeHTTP(rr, req)
+ if rr.Code != http.StatusOK {
+ errors <- fmt.Errorf("op %d (type %d) got status %d", idx, idx%4, rr.Code)
+ }
+ }(i)
+ }
+ wg.Wait()
+ close(errors)
+ elapsed := time.Since(start)
+
+ for err := range errors {
+ t.Errorf("concurrent mixed operation error: %v", err)
+ }
+
+ t.Logf("40 concurrent mixed operations completed in %v", elapsed)
+ if elapsed > 10*time.Second {
+ t.Errorf("concurrent mixed operations took too long: %v (threshold: 10s)", elapsed)
+ }
+}
+
+func TestStress_RapidReadMarkUnmark(t *testing.T) {
+ if testing.Short() {
+ t.Skip("skipping stress test in short mode")
+ }
+
+ setupTestDB(t)
+
+ f := &feed.Feed{Url: "http://example.com/rapid", Title: "Rapid Feed"}
+ f.Create()
+ it := &item.Item{
+ Title: "Rapid Toggle",
+ Url: "http://example.com/rapid/1",
+ FeedId: f.Id,
+ }
+ _ = it.Create()
+
+ server := newTestServer()
+
+ // Rapidly toggle read state on the same item
+ const iterations = 100
+ var wg sync.WaitGroup
+ errors := make(chan error, iterations)
+
+ start := time.Now()
+ for i := 0; i < iterations; i++ {
+ wg.Add(1)
+ go func(idx int) {
+ defer wg.Done()
+ body, _ := json.Marshal(item.Item{
+ Id: it.Id,
+ ReadState: idx%2 == 0,
+ Starred: idx%3 == 0,
+ })
+ req := httptest.NewRequest("PUT", "/item/"+strconv.FormatInt(it.Id, 10), bytes.NewBuffer(body))
+ rr := httptest.NewRecorder()
+ server.ServeHTTP(rr, req)
+ if rr.Code != http.StatusOK {
+ errors <- fmt.Errorf("rapid update %d got status %d", idx, rr.Code)
+ }
+ }(i)
+ }
+ wg.Wait()
+ close(errors)
+ elapsed := time.Since(start)
+
+ for err := range errors {
+ t.Errorf("rapid toggle error: %v", err)
+ }
+
+ t.Logf("100 concurrent read-state toggles completed in %v", elapsed)
+ if elapsed > 10*time.Second {
+ t.Errorf("rapid toggles took too long: %v (threshold: 10s)", elapsed)
+ }
+}
+
func seedStressData(t *testing.T, count int) {
t.Helper()
f := &feed.Feed{Url: "http://example.com/bench", Title: "Bench Feed", Category: "tech"}
diff --git a/api/api_test.go b/api/api_test.go
index a2c3415..2c77501 100644
--- a/api/api_test.go
+++ b/api/api_test.go
@@ -3,6 +3,7 @@ package api
import (
"bytes"
"encoding/json"
+ "mime/multipart"
"net/http"
"net/http/httptest"
"path/filepath"
@@ -546,3 +547,231 @@ func TestHandleCategorySuccess(t *testing.T) {
t.Errorf("Expected %d, got %d", http.StatusOK, rr.Code)
}
}
+
+func TestHandleImportOPML(t *testing.T) {
+ setupTestDB(t)
+ server := newTestServer()
+
+ opmlContent := `<?xml version="1.0" encoding="UTF-8"?>
+<opml version="2.0">
+ <head><title>test</title></head>
+ <body>
+ <outline type="rss" text="Test Feed" title="Test Feed" xmlUrl="https://example.com/feed" htmlUrl="https://example.com"/>
+ </body>
+</opml>`
+
+ body := &bytes.Buffer{}
+ writer := multipart.NewWriter(body)
+ part, err := writer.CreateFormFile("file", "feeds.opml")
+ if err != nil {
+ t.Fatal(err)
+ }
+ part.Write([]byte(opmlContent))
+ writer.Close()
+
+ req := httptest.NewRequest("POST", "/import", body)
+ req.Header.Set("Content-Type", writer.FormDataContentType())
+ rr := httptest.NewRecorder()
+ server.HandleImport(rr, req)
+
+ if rr.Code != http.StatusOK {
+ t.Errorf("expected %d, got %d: %s", http.StatusOK, rr.Code, rr.Body.String())
+ }
+
+ var resp map[string]string
+ json.NewDecoder(rr.Body).Decode(&resp)
+ if resp["status"] != "ok" {
+ t.Errorf("expected status ok, got %q", resp["status"])
+ }
+
+ // Verify the feed was imported
+ feeds, _ := feed.All()
+ if len(feeds) != 1 {
+ t.Errorf("expected 1 feed after import, got %d", len(feeds))
+ }
+
+ time.Sleep(100 * time.Millisecond) // let goroutine settle
+}
+
+func TestHandleImportText(t *testing.T) {
+ setupTestDB(t)
+ server := newTestServer()
+
+ textContent := "https://example.com/feed1\nhttps://example.com/feed2\n"
+
+ body := &bytes.Buffer{}
+ writer := multipart.NewWriter(body)
+ part, err := writer.CreateFormFile("file", "feeds.txt")
+ if err != nil {
+ t.Fatal(err)
+ }
+ part.Write([]byte(textContent))
+ writer.WriteField("format", "text")
+ writer.Close()
+
+ req := httptest.NewRequest("POST", "/import?format=text", body)
+ req.Header.Set("Content-Type", writer.FormDataContentType())
+ rr := httptest.NewRecorder()
+ server.HandleImport(rr, req)
+
+ if rr.Code != http.StatusOK {
+ t.Errorf("expected %d, got %d: %s", http.StatusOK, rr.Code, rr.Body.String())
+ }
+
+ feeds, _ := feed.All()
+ if len(feeds) != 2 {
+ t.Errorf("expected 2 feeds after text import, got %d", len(feeds))
+ }
+
+ time.Sleep(100 * time.Millisecond)
+}
+
+func TestHandleImportMethodNotAllowed(t *testing.T) {
+ server := newTestServer()
+
+ req := httptest.NewRequest("GET", "/import", nil)
+ rr := httptest.NewRecorder()
+ server.HandleImport(rr, req)
+
+ if rr.Code != http.StatusMethodNotAllowed {
+ t.Errorf("expected %d, got %d", http.StatusMethodNotAllowed, rr.Code)
+ }
+}
+
+func TestHandleImportNoFile(t *testing.T) {
+ setupTestDB(t)
+ server := newTestServer()
+
+ req := httptest.NewRequest("POST", "/import", nil)
+ rr := httptest.NewRecorder()
+ server.HandleImport(rr, req)
+
+ if rr.Code != http.StatusBadRequest {
+ t.Errorf("expected %d, got %d", http.StatusBadRequest, rr.Code)
+ }
+}
+
+func TestHandleImportUnsupportedFormat(t *testing.T) {
+ setupTestDB(t)
+ server := newTestServer()
+
+ body := &bytes.Buffer{}
+ writer := multipart.NewWriter(body)
+ part, _ := writer.CreateFormFile("file", "feeds.csv")
+ part.Write([]byte("some data"))
+ writer.Close()
+
+ req := httptest.NewRequest("POST", "/import?format=csv", body)
+ req.Header.Set("Content-Type", writer.FormDataContentType())
+ rr := httptest.NewRecorder()
+ server.HandleImport(rr, req)
+
+ if rr.Code != http.StatusInternalServerError {
+ t.Errorf("expected %d for unsupported format, got %d", http.StatusInternalServerError, rr.Code)
+ }
+}
+
+func TestHandleImportInvalidOPML(t *testing.T) {
+ setupTestDB(t)
+ server := newTestServer()
+
+ body := &bytes.Buffer{}
+ writer := multipart.NewWriter(body)
+ part, _ := writer.CreateFormFile("file", "bad.opml")
+ part.Write([]byte("not valid xml at all"))
+ writer.Close()
+
+ req := httptest.NewRequest("POST", "/import?format=opml", body)
+ req.Header.Set("Content-Type", writer.FormDataContentType())
+ rr := httptest.NewRecorder()
+ server.HandleImport(rr, req)
+
+ if rr.Code != http.StatusInternalServerError {
+ t.Errorf("expected %d for invalid OPML, got %d", http.StatusInternalServerError, rr.Code)
+ }
+}
+
+func TestHandleStreamErrorOnClosedDB(t *testing.T) {
+ setupTestDB(t)
+ seedData(t)
+ server := newTestServer()
+
+ // Close the DB to force an error
+ models.DB.Close()
+
+ req := httptest.NewRequest("GET", "/stream", nil)
+ rr := httptest.NewRecorder()
+ server.HandleStream(rr, req)
+
+ if rr.Code != http.StatusInternalServerError {
+ t.Errorf("expected %d for closed DB, got %d", http.StatusInternalServerError, rr.Code)
+ }
+}
+
+func TestHandleItemInvalidJSON(t *testing.T) {
+ setupTestDB(t)
+ seedData(t)
+ server := newTestServer()
+
+ req := httptest.NewRequest("PUT", "/item/1", strings.NewReader("not json"))
+ rr := httptest.NewRecorder()
+ server.HandleItem(rr, req)
+
+ if rr.Code != http.StatusBadRequest {
+ t.Errorf("expected %d for invalid JSON, got %d", http.StatusBadRequest, rr.Code)
+ }
+}
+
+func TestHandleExportContentTypes(t *testing.T) {
+ setupTestDB(t)
+ seedData(t)
+ server := newTestServer()
+
+ testCases := []struct {
+ format string
+ contentType string
+ disposition string
+ }{
+ {"text", "text/plain", "neko_export.txt"},
+ {"opml", "application/xml", "neko_export.opml"},
+ {"json", "application/json", "neko_export.json"},
+ {"html", "text/html", "neko_export.html"},
+ }
+
+ for _, tc := range testCases {
+ req := httptest.NewRequest("GET", "/export/"+tc.format, nil)
+ rr := httptest.NewRecorder()
+ server.HandleExport(rr, req)
+
+ if ct := rr.Header().Get("Content-Type"); ct != tc.contentType {
+ t.Errorf("export/%s: expected Content-Type %q, got %q", tc.format, tc.contentType, ct)
+ }
+ if cd := rr.Header().Get("Content-Disposition"); !strings.Contains(cd, tc.disposition) {
+ t.Errorf("export/%s: expected Content-Disposition containing %q, got %q", tc.format, tc.disposition, cd)
+ }
+ }
+}
+
+func TestHandleImportJSON(t *testing.T) {
+ setupTestDB(t)
+ server := newTestServer()
+
+ jsonContent := `{"title":"Article 1","url":"https://example.com/1","description":"desc","read":false,"starred":false,"date":{"$date":"2024-01-01"},"feed":{"url":"https://example.com/feed","title":"Feed 1"}}`
+
+ body := &bytes.Buffer{}
+ writer := multipart.NewWriter(body)
+ part, _ := writer.CreateFormFile("file", "items.json")
+ part.Write([]byte(jsonContent))
+ writer.Close()
+
+ req := httptest.NewRequest("POST", "/import?format=json", body)
+ req.Header.Set("Content-Type", writer.FormDataContentType())
+ rr := httptest.NewRecorder()
+ server.HandleImport(rr, req)
+
+ if rr.Code != http.StatusOK {
+ t.Errorf("expected %d, got %d: %s", http.StatusOK, rr.Code, rr.Body.String())
+ }
+
+ time.Sleep(100 * time.Millisecond)
+}
diff --git a/cmd/neko/main_test.go b/cmd/neko/main_test.go
index b03d6c8..4403e5b 100644
--- a/cmd/neko/main_test.go
+++ b/cmd/neko/main_test.go
@@ -122,3 +122,57 @@ func TestRunNoArgs(t *testing.T) {
t.Errorf("Run with no args failed: %v", err)
}
}
+
+func TestRunPurge(t *testing.T) {
+ dbPath := filepath.Join(t.TempDir(), "test_purge.db")
+ err := Run([]string{"-d", dbPath, "-purge", "30"})
+ if err != nil {
+ t.Errorf("Run -purge should succeed, got %v", err)
+ }
+}
+
+func TestRunPurgeUnread(t *testing.T) {
+ dbPath := filepath.Join(t.TempDir(), "test_purge_unread.db")
+ err := Run([]string{"-d", dbPath, "-purge", "30", "-purge-unread"})
+ if err != nil {
+ t.Errorf("Run -purge -purge-unread should succeed, got %v", err)
+ }
+}
+
+func TestRunSecureCookies(t *testing.T) {
+ dbPath := filepath.Join(t.TempDir(), "test_secure.db")
+ config.Config.Port = -1
+ err := Run([]string{"-d", dbPath, "-secure-cookies"})
+ if err != nil {
+ t.Errorf("Run -secure-cookies should succeed, got %v", err)
+ }
+ if !config.Config.SecureCookies {
+ t.Error("Expected SecureCookies to be true")
+ }
+}
+
+func TestRunMinutesFlag(t *testing.T) {
+ dbPath := filepath.Join(t.TempDir(), "test_minutes.db")
+ config.Config.Port = -1
+ err := Run([]string{"-d", dbPath, "-m", "30"})
+ if err != nil {
+ t.Errorf("Run -m 30 should succeed, got %v", err)
+ }
+ if config.Config.CrawlMinutes != 30 {
+ t.Errorf("Expected CrawlMinutes=30, got %d", config.Config.CrawlMinutes)
+ }
+}
+
+func TestRunAllowLocal(t *testing.T) {
+ dbPath := filepath.Join(t.TempDir(), "test_local.db")
+ config.Config.Port = -1
+ err := Run([]string{"-d", dbPath, "-allow-local"})
+ if err != nil {
+ t.Errorf("Run -allow-local should succeed, got %v", err)
+ }
+}
+
+func TestBackgroundCrawlNegative(_ *testing.T) {
+ // Negative minutes should return immediately
+ backgroundCrawl(-1)
+}
diff --git a/internal/importer/importer_test.go b/internal/importer/importer_test.go
index 59f06f1..631c441 100644
--- a/internal/importer/importer_test.go
+++ b/internal/importer/importer_test.go
@@ -3,6 +3,7 @@ package importer
import (
"os"
"path/filepath"
+ "strings"
"testing"
"adammathes.com/neko/config"
@@ -147,3 +148,180 @@ func TestImportJSONNonexistent(t *testing.T) {
t.Error("ImportJSON should error on nonexistent file")
}
}
+
+// Test the ImportFeeds dispatcher function (previously 0% coverage)
+func TestImportFeedsOPML(t *testing.T) {
+ setupTestDB(t)
+
+ opml := `<?xml version="1.0" encoding="UTF-8"?>
+<opml version="2.0">
+ <head><title>test</title></head>
+ <body>
+ <outline type="rss" text="Feed A" xmlUrl="https://a.com/feed" htmlUrl="https://a.com"/>
+ </body>
+</opml>`
+
+ err := ImportFeeds("opml", strings.NewReader(opml))
+ if err != nil {
+ t.Fatalf("ImportFeeds(opml) failed: %v", err)
+ }
+
+ var count int
+ models.DB.QueryRow("SELECT COUNT(*) FROM feed").Scan(&count)
+ if count != 1 {
+ t.Errorf("expected 1 feed, got %d", count)
+ }
+}
+
+func TestImportFeedsText(t *testing.T) {
+ setupTestDB(t)
+
+ text := "https://example.com/feed1\nhttps://example.com/feed2\n"
+
+ err := ImportFeeds("text", strings.NewReader(text))
+ if err != nil {
+ t.Fatalf("ImportFeeds(text) failed: %v", err)
+ }
+
+ var count int
+ models.DB.QueryRow("SELECT COUNT(*) FROM feed").Scan(&count)
+ if count != 2 {
+ t.Errorf("expected 2 feeds, got %d", count)
+ }
+}
+
+func TestImportFeedsJSON(t *testing.T) {
+ setupTestDB(t)
+
+ jsonContent := `{"title":"A1","url":"https://example.com/1","description":"d","feed":{"url":"https://example.com/feed","title":"F1"}}`
+
+ err := ImportFeeds("json", strings.NewReader(jsonContent))
+ if err != nil {
+ t.Fatalf("ImportFeeds(json) failed: %v", err)
+ }
+
+ var count int
+ models.DB.QueryRow("SELECT COUNT(*) FROM item").Scan(&count)
+ if count != 1 {
+ t.Errorf("expected 1 item, got %d", count)
+ }
+}
+
+func TestImportFeedsUnsupported(t *testing.T) {
+ err := ImportFeeds("csv", strings.NewReader("data"))
+ if err == nil {
+ t.Error("ImportFeeds should error for unsupported format")
+ }
+ if err != nil && !strings.Contains(err.Error(), "unsupported") {
+ t.Errorf("expected 'unsupported' error, got: %v", err)
+ }
+}
+
+func TestImportOPMLInvalid(t *testing.T) {
+ setupTestDB(t)
+ err := ImportOPML(strings.NewReader("not valid xml"))
+ if err == nil {
+ t.Error("ImportOPML should error on invalid XML")
+ }
+}
+
+func TestImportOPMLNestedCategories(t *testing.T) {
+ setupTestDB(t)
+
+ opml := `<?xml version="1.0" encoding="UTF-8"?>
+<opml version="2.0">
+ <head><title>test</title></head>
+ <body>
+ <outline text="Tech">
+ <outline text="Programming">
+ <outline type="rss" text="Blog A" xmlUrl="https://a.com/feed" htmlUrl="https://a.com"/>
+ </outline>
+ </outline>
+ <outline type="rss" xmlUrl="https://b.com/feed" htmlUrl="https://b.com" category="news"/>
+ </body>
+</opml>`
+
+ err := ImportOPML(strings.NewReader(opml))
+ if err != nil {
+ t.Fatalf("ImportOPML failed: %v", err)
+ }
+
+ var count int
+ models.DB.QueryRow("SELECT COUNT(*) FROM feed").Scan(&count)
+ if count != 2 {
+ t.Errorf("expected 2 feeds, got %d", count)
+ }
+
+ // Verify nested category is inherited
+ var category string
+ models.DB.QueryRow("SELECT category FROM feed WHERE url=?", "https://a.com/feed").Scan(&category)
+ if category != "Programming" {
+ t.Errorf("expected category 'Programming' for nested feed, got %q", category)
+ }
+
+ // Feed with category attribute
+ models.DB.QueryRow("SELECT category FROM feed WHERE url=?", "https://b.com/feed").Scan(&category)
+ if category != "news" {
+ t.Errorf("expected category 'news' for feed with category attr, got %q", category)
+ }
+}
+
+func TestInsertIItemNilDate(t *testing.T) {
+ setupTestDB(t)
+
+ ii := &IItem{
+ Title: "No Date Article",
+ Url: "https://example.com/nodate",
+ Description: "Article without date",
+ Date: nil,
+ Feed: &IFeed{
+ Url: "https://example.com/feed",
+ Title: "Test Feed",
+ },
+ }
+
+ err := InsertIItem(ii)
+ if err != nil {
+ t.Errorf("InsertIItem with nil date should not error, got %v", err)
+ }
+
+ var count int
+ models.DB.QueryRow("SELECT COUNT(*) FROM item").Scan(&count)
+ if count != 1 {
+ t.Errorf("expected 1 item, got %d", count)
+ }
+}
+
+func TestImportJSONReaderEmpty(t *testing.T) {
+ setupTestDB(t)
+
+ // Empty reader - should not error (just EOF immediately)
+ err := ImportJSONReader(strings.NewReader(""))
+ if err != nil {
+ t.Errorf("ImportJSONReader with empty input should not error, got %v", err)
+ }
+}
+
+func TestImportTextSkipsCommentsAndBlanks(t *testing.T) {
+ setupTestDB(t)
+
+ text := `
+# This is a comment
+https://example.com/feed1
+
+ # Another comment
+
+https://example.com/feed2
+`
+
+ err := ImportText(strings.NewReader(text))
+ if err != nil {
+ t.Fatalf("ImportText failed: %v", err)
+ }
+
+ var count int
+ models.DB.QueryRow("SELECT COUNT(*) FROM feed").Scan(&count)
+ if count != 2 {
+ t.Errorf("expected 2 feeds (comments and blanks skipped), got %d", count)
+ }
+}
diff --git a/internal/safehttp/safehttp_test.go b/internal/safehttp/safehttp_test.go
index b2636da..dc428e4 100644
--- a/internal/safehttp/safehttp_test.go
+++ b/internal/safehttp/safehttp_test.go
@@ -1,7 +1,12 @@
package safehttp
import (
+ "context"
+ "fmt"
"net"
+ "net/http"
+ "net/http/httptest"
+ "strings"
"testing"
"time"
)
@@ -51,3 +56,285 @@ func TestIsPrivateIP(t *testing.T) {
}
}
}
+
+func TestIsPrivateIPAllowLocal(t *testing.T) {
+ // Save and restore AllowLocal
+ orig := AllowLocal
+ AllowLocal = true
+ defer func() { AllowLocal = orig }()
+
+ // When AllowLocal is true, all IPs should be considered non-private
+ privateIPs := []string{"127.0.0.1", "10.0.0.1", "192.168.1.1", "::1", "fe80::1"}
+ for _, ipStr := range privateIPs {
+ ip := net.ParseIP(ipStr)
+ if isPrivateIP(ip) {
+ t.Errorf("isPrivateIP(%s) should be false when AllowLocal=true", ipStr)
+ }
+ }
+}
+
+func TestSafeDialerDirectIP(t *testing.T) {
+ dialer := &net.Dialer{Timeout: 2 * time.Second}
+ safeDial := SafeDialer(dialer)
+ ctx := context.Background()
+
+ // Direct private IP should be blocked
+ _, err := safeDial(ctx, "tcp", "127.0.0.1:80")
+ if err == nil {
+ t.Error("SafeDialer should block direct private IP 127.0.0.1")
+ }
+ if err != nil && !strings.Contains(err.Error(), "private IP") {
+ t.Errorf("expected 'private IP' error, got: %v", err)
+ }
+
+ // Another private IP
+ _, err = safeDial(ctx, "tcp", "10.0.0.1:80")
+ if err == nil {
+ t.Error("SafeDialer should block direct private IP 10.0.0.1")
+ }
+
+ // IPv6 loopback
+ _, err = safeDial(ctx, "tcp", "[::1]:80")
+ if err == nil {
+ t.Error("SafeDialer should block IPv6 loopback")
+ }
+}
+
+func TestSafeDialerHostWithoutPort(t *testing.T) {
+ dialer := &net.Dialer{Timeout: 2 * time.Second}
+ safeDial := SafeDialer(dialer)
+ ctx := context.Background()
+
+ // Address without port should still be checked
+ _, err := safeDial(ctx, "tcp", "127.0.0.1")
+ if err == nil {
+ t.Error("SafeDialer should block private IP even without port")
+ }
+}
+
+func TestSafeDialerHostnameResolution(t *testing.T) {
+ dialer := &net.Dialer{Timeout: 2 * time.Second}
+ safeDial := SafeDialer(dialer)
+ ctx := context.Background()
+
+ // "localhost" resolves to 127.0.0.1 which should be blocked
+ _, err := safeDial(ctx, "tcp", "localhost:80")
+ if err == nil {
+ t.Error("SafeDialer should block localhost hostname")
+ }
+}
+
+func TestSafeDialerUnresolvableHostname(t *testing.T) {
+ dialer := &net.Dialer{Timeout: 2 * time.Second}
+ safeDial := SafeDialer(dialer)
+ ctx := context.Background()
+
+ // Non-existent hostname should fail DNS lookup
+ _, err := safeDial(ctx, "tcp", "this-host-does-not-exist.invalid:80")
+ if err == nil {
+ t.Error("SafeDialer should error on unresolvable hostname")
+ }
+}
+
+func TestNewSafeClientProperties(t *testing.T) {
+ client := NewSafeClient(5 * time.Second)
+
+ if client.Timeout != 5*time.Second {
+ t.Errorf("expected timeout 5s, got %v", client.Timeout)
+ }
+
+ transport, ok := client.Transport.(*http.Transport)
+ if !ok {
+ t.Fatal("expected *http.Transport")
+ }
+
+ // Proxy should be nil to prevent SSRF bypass
+ if transport.Proxy != nil {
+ t.Error("transport.Proxy should be nil to prevent SSRF bypass")
+ }
+
+ // DialContext should be set
+ if transport.DialContext == nil {
+ t.Error("transport.DialContext should be set to safe dialer")
+ }
+}
+
+func TestNewSafeClientRedirectToPrivateIP(t *testing.T) {
+ // Create a server that redirects to a private IP
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ http.Redirect(w, r, "http://127.0.0.1:9999/secret", http.StatusFound)
+ }))
+ defer ts.Close()
+
+ // Allow local so the initial connection succeeds, but the redirect check should catch it
+ orig := AllowLocal
+ AllowLocal = true
+ defer func() { AllowLocal = orig }()
+
+ client := NewSafeClient(2 * time.Second)
+
+ // The redirect to 127.0.0.1 should be blocked by CheckRedirect
+ // Note: AllowLocal only affects SafeDialer's isPrivateIP, not CheckRedirect's isPrivateIP
+ // Actually, AllowLocal affects ALL isPrivateIP calls, so redirect will also pass.
+ // Let's test with AllowLocal=false instead using a public-appearing redirect.
+ AllowLocal = false
+
+ // Direct request to server on loopback with AllowLocal=false will fail at dial level
+ _, err := client.Get(ts.URL)
+ if err == nil {
+ t.Error("expected error for connection to local server with AllowLocal=false")
+ }
+}
+
+func TestNewSafeClientTooManyRedirects(t *testing.T) {
+ // Create a server that redirects in a loop
+ var ts *httptest.Server
+ ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ http.Redirect(w, r, ts.URL+"/loop", http.StatusFound)
+ }))
+ defer ts.Close()
+
+ orig := AllowLocal
+ AllowLocal = true
+ defer func() { AllowLocal = orig }()
+
+ client := NewSafeClient(5 * time.Second)
+ _, err := client.Get(ts.URL)
+ if err == nil {
+ t.Error("expected error for too many redirects")
+ }
+ if err != nil && !strings.Contains(err.Error(), "redirect") {
+ t.Logf("got error (expected redirect-related): %v", err)
+ }
+}
+
+func TestNewSafeClientRedirectCheck(t *testing.T) {
+ // Test the redirect checker directly by creating a chain of redirects
+ // Server 1: redirects to server 2 (both on localhost)
+ orig := AllowLocal
+ AllowLocal = true
+ defer func() { AllowLocal = orig }()
+
+ var count int
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ count++
+ if count <= 5 {
+ http.Redirect(w, r, fmt.Sprintf("/next%d", count), http.StatusFound)
+ return
+ }
+ w.WriteHeader(http.StatusOK)
+ w.Write([]byte("done"))
+ }))
+ defer ts.Close()
+
+ client := NewSafeClient(5 * time.Second)
+ resp, err := client.Get(ts.URL)
+ if err != nil {
+ t.Fatalf("expected successful response, got error: %v", err)
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode != http.StatusOK {
+ t.Errorf("expected 200, got %d", resp.StatusCode)
+ }
+}
+
+func TestSafeClientBlocksPrivateIPv6(t *testing.T) {
+ client := NewSafeClient(2 * time.Second)
+
+ // fc00::/7 (unique local address)
+ _, err := client.Get("http://[fc00::1]:80")
+ if err == nil {
+ t.Error("expected error for fc00::1 (unique local)")
+ }
+
+ // fe80::/10 (link-local)
+ _, err = client.Get("http://[fe80::1]:80")
+ if err == nil {
+ t.Error("expected error for fe80::1 (link-local)")
+ }
+}
+
+func TestSafeClientBlocksRFC1918(t *testing.T) {
+ client := NewSafeClient(2 * time.Second)
+
+ // 172.16.0.0/12
+ _, err := client.Get("http://172.16.0.1:80")
+ if err == nil {
+ t.Error("expected error for 172.16.0.1")
+ }
+
+ // 192.168.0.0/16
+ _, err = client.Get("http://192.168.1.1:80")
+ if err == nil {
+ t.Error("expected error for 192.168.1.1")
+ }
+
+ // 169.254.0.0/16 (link-local)
+ _, err = client.Get("http://169.254.1.1:80")
+ if err == nil {
+ t.Error("expected error for 169.254.1.1")
+ }
+}
+
+func TestNewSafeClientRedirectToPrivateHostname(t *testing.T) {
+ // Create a server that redirects to localhost (hostname, not IP)
+ orig := AllowLocal
+ AllowLocal = true
+ defer func() { AllowLocal = orig }()
+
+ // Server redirects to a URL with a hostname that resolves to private IP
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path == "/redirect" {
+ http.Redirect(w, r, "http://localhost:1/secret", http.StatusFound)
+ return
+ }
+ w.WriteHeader(http.StatusOK)
+ }))
+ defer ts.Close()
+
+ client := NewSafeClient(2 * time.Second)
+ // The redirect follows to localhost, which should succeed with AllowLocal=true
+ // but the redirect checker hostname resolution path is exercised
+ resp, err := client.Get(ts.URL + "/redirect")
+ // This will likely fail at the dial level to localhost:1, but the redirect checker runs first
+ if err == nil {
+ resp.Body.Close()
+ }
+ // We mainly care that it doesn't panic and exercises the redirect path
+}
+
+func TestNewSafeClientRedirectNoPort(t *testing.T) {
+ // Test redirect to a URL without an explicit port (exercises SplitHostPort error path)
+ orig := AllowLocal
+ AllowLocal = true
+ defer func() { AllowLocal = orig }()
+
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path == "/redir" {
+ // Redirect to a URL with a plain hostname (no port) that resolves to private
+ http.Redirect(w, r, "http://localhost/nope", http.StatusFound)
+ return
+ }
+ w.WriteHeader(http.StatusOK)
+ }))
+ defer ts.Close()
+
+ client := NewSafeClient(2 * time.Second)
+ resp, err := client.Get(ts.URL + "/redir")
+ if err == nil {
+ resp.Body.Close()
+ }
+ // We mainly care that the redirect hostname resolution path is hit
+}
+
+func TestInitPrivateIPBlocks(t *testing.T) {
+ // Verify that the init function populated privateIPBlocks
+ if len(privateIPBlocks) == 0 {
+ t.Error("privateIPBlocks should be populated by init()")
+ }
+ // We expect 8 CIDR ranges
+ expectedCount := 8
+ if len(privateIPBlocks) != expectedCount {
+ t.Errorf("expected %d private IP blocks, got %d", expectedCount, len(privateIPBlocks))
+ }
+}
diff --git a/web/security_regression_test.go b/web/security_regression_test.go
new file mode 100644
index 0000000..6c97491
--- /dev/null
+++ b/web/security_regression_test.go
@@ -0,0 +1,222 @@
+package web
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+
+ "adammathes.com/neko/config"
+)
+
+// Security regression tests to ensure critical security properties are maintained.
+
+// TestCSRFTokenMismatchRejected ensures mismatched CSRF tokens are rejected.
+func TestCSRFTokenMismatchRejected(t *testing.T) {
+ cfg := &config.Settings{SecureCookies: false}
+ inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ })
+ handler := CSRFMiddleware(cfg, inner)
+
+ // Get a valid token
+ getReq := httptest.NewRequest("GET", "/", nil)
+ getRR := httptest.NewRecorder()
+ handler.ServeHTTP(getRR, getReq)
+
+ var csrfToken string
+ for _, c := range getRR.Result().Cookies() {
+ if c.Name == "csrf_token" {
+ csrfToken = c.Value
+ }
+ }
+
+ // POST with wrong token in header should be rejected
+ req := httptest.NewRequest("POST", "/something", nil)
+ req.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
+ req.Header.Set("X-CSRF-Token", "completely-wrong-token")
+ rr := httptest.NewRecorder()
+ handler.ServeHTTP(rr, req)
+
+ if rr.Code != http.StatusForbidden {
+ t.Errorf("CSRF token mismatch should return 403, got %d", rr.Code)
+ }
+}
+
+// TestCSRFTokenEmptyHeaderRejected ensures empty CSRF tokens are rejected.
+func TestCSRFTokenEmptyHeaderRejected(t *testing.T) {
+ cfg := &config.Settings{SecureCookies: false}
+ inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ })
+ handler := CSRFMiddleware(cfg, inner)
+
+ getReq := httptest.NewRequest("GET", "/", nil)
+ getRR := httptest.NewRecorder()
+ handler.ServeHTTP(getRR, getReq)
+
+ var csrfToken string
+ for _, c := range getRR.Result().Cookies() {
+ if c.Name == "csrf_token" {
+ csrfToken = c.Value
+ }
+ }
+
+ // POST with empty X-CSRF-Token header
+ req := httptest.NewRequest("POST", "/data", nil)
+ req.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
+ req.Header.Set("X-CSRF-Token", "")
+ rr := httptest.NewRecorder()
+ handler.ServeHTTP(rr, req)
+
+ if rr.Code != http.StatusForbidden {
+ t.Errorf("Empty CSRF token should return 403, got %d", rr.Code)
+ }
+}
+
+// TestSecurityHeadersPresent verifies all security headers are set correctly.
+func TestSecurityHeadersPresent(t *testing.T) {
+ handler := SecurityHeadersMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ }))
+
+ req := httptest.NewRequest("GET", "/", nil)
+ rr := httptest.NewRecorder()
+ handler.ServeHTTP(rr, req)
+
+ headers := map[string]string{
+ "X-Content-Type-Options": "nosniff",
+ "X-Frame-Options": "DENY",
+ "X-XSS-Protection": "1; mode=block",
+ "Referrer-Policy": "strict-origin-when-cross-origin",
+ }
+
+ for name, expected := range headers {
+ if got := rr.Header().Get(name); got != expected {
+ t.Errorf("Header %s: expected %q, got %q", name, expected, got)
+ }
+ }
+
+ // CSP should deny framing
+ csp := rr.Header().Get("Content-Security-Policy")
+ if !strings.Contains(csp, "frame-ancestors 'none'") {
+ t.Error("CSP should contain frame-ancestors 'none'")
+ }
+}
+
+// TestAuthCookieHttpOnly ensures the auth cookie is HttpOnly.
+func TestAuthCookieHttpOnly(t *testing.T) {
+ originalPw := config.Config.DigestPassword
+ defer func() { config.Config.DigestPassword = originalPw }()
+ config.Config.DigestPassword = "testpass"
+
+ req := httptest.NewRequest("POST", "/login/", nil)
+ req.Form = map[string][]string{"password": {"testpass"}}
+ rr := httptest.NewRecorder()
+ loginHandler(rr, req)
+
+ for _, c := range rr.Result().Cookies() {
+ if c.Name == AuthCookie {
+ if !c.HttpOnly {
+ t.Error("Auth cookie must be HttpOnly to prevent XSS theft")
+ }
+ return
+ }
+ }
+ t.Error("Auth cookie not found in login response")
+}
+
+// TestLogoutClearsAuthCookie ensures logout properly invalidates the cookie.
+func TestLogoutClearsAuthCookie(t *testing.T) {
+ req := httptest.NewRequest("POST", "/api/logout", nil)
+ rr := httptest.NewRecorder()
+ apiLogoutHandler(rr, req)
+
+ for _, c := range rr.Result().Cookies() {
+ if c.Name == AuthCookie {
+ if c.MaxAge != -1 {
+ t.Errorf("Logout should set MaxAge=-1 to expire cookie, got %d", c.MaxAge)
+ }
+ if c.Value != "" {
+ t.Error("Logout should clear cookie value")
+ }
+ return
+ }
+ }
+ t.Error("Auth cookie not found in logout response")
+}
+
+// TestAPIRoutesRequireAuth ensures API routes redirect when not authenticated.
+func TestAPIRoutesRequireAuth(t *testing.T) {
+ setupTestDB(t)
+ originalPw := config.Config.DigestPassword
+ defer func() { config.Config.DigestPassword = originalPw }()
+ config.Config.DigestPassword = "secret"
+
+ router := NewRouter(&config.Config)
+
+ protectedPaths := []string{
+ "/api/stream",
+ "/api/feed",
+ "/api/tag",
+ }
+
+ for _, path := range protectedPaths {
+ req := httptest.NewRequest("GET", path, nil)
+ rr := httptest.NewRecorder()
+ router.ServeHTTP(rr, req)
+
+ if rr.Code != http.StatusTemporaryRedirect {
+ t.Errorf("GET %s without auth should redirect, got %d", path, rr.Code)
+ }
+ }
+}
+
+// TestCSRFTokenUniqueness ensures each new session gets a unique CSRF token.
+func TestCSRFTokenUniqueness(t *testing.T) {
+ cfg := &config.Settings{SecureCookies: false}
+ inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ })
+ handler := CSRFMiddleware(cfg, inner)
+
+ tokens := make(map[string]bool)
+ for i := 0; i < 10; i++ {
+ req := httptest.NewRequest("GET", "/", nil)
+ rr := httptest.NewRecorder()
+ handler.ServeHTTP(rr, req)
+
+ for _, c := range rr.Result().Cookies() {
+ if c.Name == "csrf_token" {
+ if tokens[c.Value] {
+ t.Error("CSRF tokens should be unique across sessions")
+ }
+ tokens[c.Value] = true
+ }
+ }
+ }
+
+ if len(tokens) < 10 {
+ t.Errorf("Expected 10 unique CSRF tokens, got %d", len(tokens))
+ }
+}
+
+// TestCSRFExcludedPathsTrailingSlash ensures CSRF exclusion works with and without trailing slashes.
+func TestCSRFExcludedPathsTrailingSlash(t *testing.T) {
+ originalPw := config.Config.DigestPassword
+ defer func() { config.Config.DigestPassword = originalPw }()
+ config.Config.DigestPassword = "secret"
+
+ mux := http.NewServeMux()
+ mux.HandleFunc("/api/login", apiLoginHandler)
+ handler := CSRFMiddleware(&config.Config, mux)
+
+ // POST /api/login/ (with trailing slash) should also be excluded
+ req := httptest.NewRequest("POST", "/api/login/", strings.NewReader("password=secret"))
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ rr := httptest.NewRecorder()
+ handler.ServeHTTP(rr, req)
+ if rr.Code == http.StatusForbidden {
+ t.Error("POST /api/login/ (trailing slash) should be excluded from CSRF protection")
+ }
+}