diff options
Diffstat (limited to 'api')
| -rw-r--r-- | api/api.go | 51 | ||||
| -rw-r--r-- | api/api_test.go | 113 |
2 files changed, 92 insertions, 72 deletions
@@ -7,23 +7,36 @@ import ( "strconv" "strings" - "adammathes.com/neko/crawler" - "adammathes.com/neko/exporter" + "adammathes.com/neko/config" + "adammathes.com/neko/internal/crawler" + "adammathes.com/neko/internal/exporter" "adammathes.com/neko/models/feed" "adammathes.com/neko/models/item" ) -// NewRouter returns a configured mux with all API routes. -func NewRouter() *http.ServeMux { - mux := http.NewServeMux() - mux.HandleFunc("/stream", HandleStream) - mux.HandleFunc("/item/", HandleItem) - mux.HandleFunc("/feed", HandleFeed) - mux.HandleFunc("/feed/", HandleFeed) - mux.HandleFunc("/tag", HandleCategory) - mux.HandleFunc("/export/", HandleExport) - mux.HandleFunc("/crawl", HandleCrawl) - return mux +type Server struct { + Config *config.Settings + *http.ServeMux +} + +// NewServer returns a configured server with all API routes. +func NewServer(cfg *config.Settings) *Server { + s := &Server{ + Config: cfg, + ServeMux: http.NewServeMux(), + } + s.routes() + return s +} + +func (s *Server) routes() { + s.HandleFunc("/stream", s.HandleStream) + s.HandleFunc("/item/", s.HandleItem) + s.HandleFunc("/feed", s.HandleFeed) + s.HandleFunc("/feed/", s.HandleFeed) + s.HandleFunc("/tag", s.HandleCategory) + s.HandleFunc("/export/", s.HandleExport) + s.HandleFunc("/crawl", s.HandleCrawl) } func jsonError(w http.ResponseWriter, msg string, code int) { @@ -37,7 +50,7 @@ func jsonResponse(w http.ResponseWriter, data interface{}) { json.NewEncoder(w).Encode(data) } -func HandleStream(w http.ResponseWriter, r *http.Request) { +func (s *Server) HandleStream(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { jsonError(w, "method not allowed", http.StatusMethodNotAllowed) return @@ -72,7 +85,7 @@ func HandleStream(w http.ResponseWriter, r *http.Request) { jsonResponse(w, items) } -func HandleItem(w http.ResponseWriter, r *http.Request) { +func (s *Server) HandleItem(w http.ResponseWriter, r *http.Request) { idStr := strings.TrimPrefix(r.URL.Path, "/item/") id, _ := strconv.ParseInt(idStr, 10, 64) @@ -115,7 +128,7 @@ func HandleItem(w http.ResponseWriter, r *http.Request) { } } -func HandleFeed(w http.ResponseWriter, r *http.Request) { +func (s *Server) HandleFeed(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: feeds, err := feed.All() @@ -180,7 +193,7 @@ func HandleFeed(w http.ResponseWriter, r *http.Request) { } } -func HandleCategory(w http.ResponseWriter, r *http.Request) { +func (s *Server) HandleCategory(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { jsonError(w, "method not allowed", http.StatusMethodNotAllowed) return @@ -194,7 +207,7 @@ func HandleCategory(w http.ResponseWriter, r *http.Request) { jsonResponse(w, categories) } -func HandleExport(w http.ResponseWriter, r *http.Request) { +func (s *Server) HandleExport(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { jsonError(w, "method not allowed", http.StatusMethodNotAllowed) return @@ -208,7 +221,7 @@ func HandleExport(w http.ResponseWriter, r *http.Request) { w.Write([]byte(exporter.ExportFeeds(format))) } -func HandleCrawl(w http.ResponseWriter, r *http.Request) { +func (s *Server) HandleCrawl(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { jsonError(w, "method not allowed", http.StatusMethodNotAllowed) return diff --git a/api/api_test.go b/api/api_test.go index 2adc357..15679f7 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -46,14 +46,18 @@ func seedData(t *testing.T) { i.Create() } +func newTestServer() *Server { + return NewServer(&config.Config) +} + func TestStream(t *testing.T) { setupTestDB(t) seedData(t) - router := NewRouter() + server := newTestServer() req := httptest.NewRequest("GET", "/stream", nil) rr := httptest.NewRecorder() - router.ServeHTTP(rr, req) + server.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Errorf("expected 200, got %d", rr.Code) @@ -68,14 +72,14 @@ func TestStream(t *testing.T) { func TestFeedCRUD(t *testing.T) { setupTestDB(t) - router := NewRouter() + server := newTestServer() // Create f := feed.Feed{Url: "http://example.com", Title: "New Feed"} b, _ := json.Marshal(f) req := httptest.NewRequest("POST", "/feed", bytes.NewBuffer(b)) rr := httptest.NewRecorder() - router.ServeHTTP(rr, req) + server.ServeHTTP(rr, req) if rr.Code != http.StatusCreated { t.Errorf("expected 201, got %d", rr.Code) @@ -84,7 +88,7 @@ func TestFeedCRUD(t *testing.T) { // List req = httptest.NewRequest("GET", "/feed", nil) rr = httptest.NewRecorder() - router.ServeHTTP(rr, req) + server.ServeHTTP(rr, req) var feeds []feed.Feed json.NewDecoder(rr.Body).Decode(&feeds) @@ -99,7 +103,7 @@ func TestFeedCRUD(t *testing.T) { b, _ = json.Marshal(feeds[0]) req = httptest.NewRequest("PUT", "/feed", bytes.NewBuffer(b)) rr = httptest.NewRecorder() - router.ServeHTTP(rr, req) + server.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Errorf("expected 200, got %d", rr.Code) @@ -108,7 +112,7 @@ func TestFeedCRUD(t *testing.T) { // Delete req = httptest.NewRequest("DELETE", "/feed/"+strconv.FormatInt(feedID, 10), nil) rr = httptest.NewRecorder() - router.ServeHTTP(rr, req) + server.ServeHTTP(rr, req) if rr.Code != http.StatusNoContent { t.Errorf("expected 204, got %d", rr.Code) @@ -118,7 +122,7 @@ func TestFeedCRUD(t *testing.T) { func TestItemUpdate(t *testing.T) { setupTestDB(t) seedData(t) - router := NewRouter() + server := newTestServer() // Get an item first to know its ID var id int64 @@ -131,7 +135,7 @@ func TestItemUpdate(t *testing.T) { b, _ := json.Marshal(i) req := httptest.NewRequest("PUT", "/item/"+strconv.FormatInt(id, 10), bytes.NewBuffer(b)) rr := httptest.NewRecorder() - router.ServeHTTP(rr, req) + server.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Errorf("expected 200, got %d", rr.Code) @@ -141,11 +145,11 @@ func TestItemUpdate(t *testing.T) { func TestGetCategories(t *testing.T) { setupTestDB(t) seedData(t) - router := NewRouter() + server := newTestServer() req := httptest.NewRequest("GET", "/tag", nil) rr := httptest.NewRecorder() - router.ServeHTTP(rr, req) + server.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Errorf("expected 200, got %d", rr.Code) @@ -153,20 +157,21 @@ func TestGetCategories(t *testing.T) { var cats []feed.Category json.NewDecoder(rr.Body).Decode(&cats) - if len(cats) != 1 { // Corrected 'categories' to 'cats' for syntactic correctness - t.Errorf("Expected 1 category, got %d", len(cats)) // Corrected 'categories' to 'cats' + if len(cats) != 1 { + t.Errorf("Expected 1 category, got %d", len(cats)) } } func TestHandleExport(t *testing.T) { setupTestDB(t) seedData(t) + server := newTestServer() formats := []string{"text", "json", "opml", "html"} for _, fmt := range formats { req := httptest.NewRequest("GET", "/export/"+fmt, nil) rr := httptest.NewRecorder() - HandleExport(rr, req) + server.HandleExport(rr, req) if rr.Code != http.StatusOK { t.Errorf("Expected 200 for format %s, got %d", fmt, rr.Code) @@ -175,18 +180,19 @@ func TestHandleExport(t *testing.T) { req := httptest.NewRequest("GET", "/export/unknown", nil) rr := httptest.NewRecorder() - HandleExport(rr, req) - if rr.Code != http.StatusOK { // This should probably be http.StatusBadRequest or similar for unknown format + server.HandleExport(rr, req) + if rr.Code != http.StatusOK { t.Errorf("Expected 200 for unknown format, got %d", rr.Code) } } func TestHandleCrawl(t *testing.T) { setupTestDB(t) + server := newTestServer() req := httptest.NewRequest("POST", "/crawl", nil) rr := httptest.NewRecorder() - HandleCrawl(rr, req) + server.HandleCrawl(rr, req) if rr.Code != http.StatusOK { t.Errorf("Expected 200, got %d", rr.Code) @@ -194,14 +200,14 @@ func TestHandleCrawl(t *testing.T) { if !strings.Contains(rr.Body.String(), "crawl started") { t.Error("Expected crawl started message in response") } - // Wait for background goroutine to at least start/finish before DB is closed by cleanup time.Sleep(100 * time.Millisecond) } func TestJsonError(t *testing.T) { + server := newTestServer() req := httptest.NewRequest("PUT", "/item/notanumber", nil) rr := httptest.NewRecorder() - HandleItem(rr, req) + server.HandleItem(rr, req) if rr.Code != http.StatusBadRequest { t.Errorf("Expected 400, got %d", rr.Code) @@ -216,7 +222,7 @@ func TestJsonError(t *testing.T) { func TestHandleStreamFilters(t *testing.T) { setupTestDB(t) seedData(t) - router := NewRouter() + server := newTestServer() testCases := []struct { url string @@ -225,14 +231,14 @@ func TestHandleStreamFilters(t *testing.T) { {"/stream?tag=tech", 1}, {"/stream?tag=missing", 0}, {"/stream?feed_url=http://example.com", 1}, - {"/stream?starred=1", 0}, // none starred in seed + {"/stream?starred=1", 0}, {"/stream?q=Test", 1}, } for _, tc := range testCases { req := httptest.NewRequest("GET", tc.url, nil) rr := httptest.NewRecorder() - router.ServeHTTP(rr, req) + server.ServeHTTP(rr, req) var items []item.Item json.NewDecoder(rr.Body).Decode(&items) @@ -244,29 +250,26 @@ func TestHandleStreamFilters(t *testing.T) { func TestHandleFeedErrors(t *testing.T) { setupTestDB(t) - router := NewRouter() + server := newTestServer() - // Post missing URL b, _ := json.Marshal(feed.Feed{Title: "No URL"}) req := httptest.NewRequest("POST", "/feed", bytes.NewBuffer(b)) rr := httptest.NewRecorder() - router.ServeHTTP(rr, req) + server.ServeHTTP(rr, req) if rr.Code != http.StatusBadRequest { t.Errorf("Expected 400 for missing URL, got %d", rr.Code) } - // Invalid JSON req = httptest.NewRequest("POST", "/feed", strings.NewReader("not json")) rr = httptest.NewRecorder() - router.ServeHTTP(rr, req) + server.ServeHTTP(rr, req) if rr.Code != http.StatusBadRequest { t.Errorf("Expected 400 for invalid JSON, got %d", rr.Code) } - // Method Not Allowed req = httptest.NewRequest("PATCH", "/feed", nil) rr = httptest.NewRecorder() - router.ServeHTTP(rr, req) + server.ServeHTTP(rr, req) if rr.Code != http.StatusMethodNotAllowed { t.Errorf("Expected 405, got %d", rr.Code) } @@ -275,25 +278,22 @@ func TestHandleFeedErrors(t *testing.T) { func TestHandleItemEdgeCases(t *testing.T) { setupTestDB(t) seedData(t) - router := NewRouter() + server := newTestServer() - // Item not found req := httptest.NewRequest("GET", "/item/999", nil) rr := httptest.NewRecorder() - router.ServeHTTP(rr, req) + server.ServeHTTP(rr, req) if rr.Code != http.StatusNotFound { t.Errorf("Expected 404, got %d", rr.Code) } - // Method not allowed req = httptest.NewRequest("DELETE", "/item/1", nil) rr = httptest.NewRecorder() - router.ServeHTTP(rr, req) + server.ServeHTTP(rr, req) if rr.Code != http.StatusMethodNotAllowed { t.Errorf("Expected 405, got %d", rr.Code) } - // GET/POST for content extraction (mocked content extraction is tested in models/item) var id int64 err := models.DB.QueryRow("SELECT id FROM item LIMIT 1").Scan(&id) if err != nil { @@ -302,7 +302,7 @@ func TestHandleItemEdgeCases(t *testing.T) { req = httptest.NewRequest("GET", "/item/"+strconv.FormatInt(id, 10), nil) rr = httptest.NewRecorder() - router.ServeHTTP(rr, req) + server.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Errorf("Expected 200, got %d", rr.Code) } @@ -310,11 +310,11 @@ func TestHandleItemEdgeCases(t *testing.T) { func TestHandleFeedDeleteNoId(t *testing.T) { setupTestDB(t) - router := NewRouter() + server := newTestServer() req := httptest.NewRequest("DELETE", "/feed/", nil) rr := httptest.NewRecorder() - router.ServeHTTP(rr, req) + server.ServeHTTP(rr, req) if rr.Code != http.StatusBadRequest { t.Errorf("Expected 400, got %d", rr.Code) } @@ -322,7 +322,7 @@ func TestHandleFeedDeleteNoId(t *testing.T) { func TestMethodNotAllowed(t *testing.T) { setupTestDB(t) - router := NewRouter() + server := newTestServer() testCases := []struct { method string @@ -336,7 +336,7 @@ func TestMethodNotAllowed(t *testing.T) { for _, tc := range testCases { req := httptest.NewRequest(tc.method, tc.url, nil) rr := httptest.NewRecorder() - router.ServeHTTP(rr, req) + server.ServeHTTP(rr, req) if rr.Code != http.StatusMethodNotAllowed { t.Errorf("Expected 405 for %s %s, got %d", tc.method, tc.url, rr.Code) } @@ -345,9 +345,10 @@ func TestMethodNotAllowed(t *testing.T) { func TestExportBadRequest(t *testing.T) { setupTestDB(t) + server := newTestServer() req := httptest.NewRequest("GET", "/export/", nil) rr := httptest.NewRecorder() - HandleExport(rr, req) + server.HandleExport(rr, req) if rr.Code != http.StatusBadRequest { t.Errorf("Expected 400 for empty format, got %d", rr.Code) } @@ -355,9 +356,10 @@ func TestExportBadRequest(t *testing.T) { func TestHandleFeedPutInvalidJson(t *testing.T) { setupTestDB(t) + server := newTestServer() req := httptest.NewRequest("PUT", "/feed", strings.NewReader("not json")) rr := httptest.NewRecorder() - HandleFeed(rr, req) + server.HandleFeed(rr, req) if rr.Code != http.StatusBadRequest { t.Errorf("Expected 400 for invalid JSON in PUT, got %d", rr.Code) } @@ -365,10 +367,11 @@ func TestHandleFeedPutInvalidJson(t *testing.T) { func TestHandleFeedPutMissingId(t *testing.T) { setupTestDB(t) + server := newTestServer() b, _ := json.Marshal(feed.Feed{Title: "No ID"}) req := httptest.NewRequest("PUT", "/feed", bytes.NewBuffer(b)) rr := httptest.NewRecorder() - HandleFeed(rr, req) + server.HandleFeed(rr, req) if rr.Code != http.StatusBadRequest { t.Errorf("Expected 400 for missing ID in PUT, got %d", rr.Code) } @@ -377,22 +380,24 @@ func TestHandleFeedPutMissingId(t *testing.T) { func TestHandleItemIdMismatch(t *testing.T) { setupTestDB(t) seedData(t) - b, _ := json.Marshal(item.Item{Id: 999}) // mismatch with path 1 + server := newTestServer() + b, _ := json.Marshal(item.Item{Id: 999}) req := httptest.NewRequest("PUT", "/item/1", bytes.NewBuffer(b)) rr := httptest.NewRecorder() - HandleItem(rr, req) + server.HandleItem(rr, req) if rr.Code != http.StatusBadRequest { t.Errorf("Expected 400 for ID mismatch, got %d", rr.Code) } } + func TestHandleCategoryError(t *testing.T) { setupTestDB(t) - // Close DB to force error + server := newTestServer() models.DB.Close() req := httptest.NewRequest("GET", "/tag", nil) rr := httptest.NewRecorder() - HandleCategory(rr, req) + server.HandleCategory(rr, req) if rr.Code != http.StatusInternalServerError { t.Errorf("Expected 500, got %d", rr.Code) @@ -402,15 +407,15 @@ func TestHandleCategoryError(t *testing.T) { func TestHandleItemAlreadyHasContent(t *testing.T) { setupTestDB(t) seedData(t) + server := newTestServer() var id int64 models.DB.QueryRow("SELECT id FROM item LIMIT 1").Scan(&id) - // Pre-set content models.DB.Exec("UPDATE item SET full_content = 'existing' WHERE id = ?", id) req := httptest.NewRequest("GET", "/item/"+strconv.FormatInt(id, 10), nil) rr := httptest.NewRecorder() - HandleItem(rr, req) + server.HandleItem(rr, req) if rr.Code != http.StatusOK { t.Errorf("Expected 200, got %d", rr.Code) @@ -418,9 +423,10 @@ func TestHandleItemAlreadyHasContent(t *testing.T) { } func TestHandleCrawlMethodNotAllowed(t *testing.T) { + server := newTestServer() req := httptest.NewRequest("GET", "/crawl", nil) rr := httptest.NewRecorder() - HandleCrawl(rr, req) + server.HandleCrawl(rr, req) if rr.Code != http.StatusMethodNotAllowed { t.Errorf("Expected 405, got %d", rr.Code) } @@ -429,11 +435,11 @@ func TestHandleCrawlMethodNotAllowed(t *testing.T) { func TestHandleStreamComplexFilters(t *testing.T) { setupTestDB(t) seedData(t) + server := newTestServer() - // Test max_id, feed_id combo req := httptest.NewRequest("GET", "/stream?max_id=999&feed_id=1", nil) rr := httptest.NewRecorder() - HandleStream(rr, req) + server.HandleStream(rr, req) if rr.Code != http.StatusOK { t.Errorf("Expected 200, got %d", rr.Code) } @@ -441,12 +447,13 @@ func TestHandleStreamComplexFilters(t *testing.T) { func TestHandleCategorySuccess(t *testing.T) { setupTestDB(t) + server := newTestServer() f := &feed.Feed{Url: "http://example.com/cat", Category: "News"} f.Create() req := httptest.NewRequest("GET", "/api/categories", nil) rr := httptest.NewRecorder() - HandleCategory(rr, req) + server.HandleCategory(rr, req) if rr.Code != http.StatusOK { t.Errorf("Expected 200, got %d", rr.Code) } |
