package crawler
import (
"net/http"
"net/http/httptest"
"testing"
"adammathes.com/neko/config"
"adammathes.com/neko/models"
"adammathes.com/neko/models/feed"
)
func setupTestDB(t *testing.T) {
t.Helper()
config.Config.DBFile = ":memory:"
models.InitDB()
t.Cleanup(func() {
if models.DB != nil {
models.DB.Close()
}
})
}
func TestGetFeedContentSuccess(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ua := r.Header.Get("User-Agent")
if ua == "" {
t.Error("Request should include User-Agent")
}
w.WriteHeader(200)
w.Write([]byte("Test"))
}))
defer ts.Close()
content := GetFeedContent(ts.URL)
if content == "" {
t.Error("GetFeedContent should return content for valid URL")
}
if content != "Test" {
t.Errorf("Unexpected content: %q", content)
}
}
func TestGetFeedContentBadURL(t *testing.T) {
content := GetFeedContent("http://invalid.invalid.invalid:99999/feed")
if content != "" {
t.Errorf("GetFeedContent should return empty string for bad URL, got %q", content)
}
}
func TestGetFeedContent404(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(404)
}))
defer ts.Close()
content := GetFeedContent(ts.URL)
if content != "" {
t.Errorf("GetFeedContent should return empty for 404, got %q", content)
}
}
func TestGetFeedContent500(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(500)
}))
defer ts.Close()
content := GetFeedContent(ts.URL)
if content != "" {
t.Errorf("GetFeedContent should return empty for 500, got %q", content)
}
}
func TestGetFeedContentUserAgent(t *testing.T) {
var receivedUA string
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedUA = r.Header.Get("User-Agent")
w.WriteHeader(200)
w.Write([]byte("ok"))
}))
defer ts.Close()
GetFeedContent(ts.URL)
expected := "neko RSS Crawler +https://github.com/adammathes/neko"
if receivedUA != expected {
t.Errorf("Expected UA %q, got %q", expected, receivedUA)
}
}
func TestCrawlFeedWithTestServer(t *testing.T) {
setupTestDB(t)
rssContent := `
Test Feed
https://example.com
-
Article 1
https://example.com/article1
First article
Mon, 01 Jan 2024 00:00:00 GMT
-
Article 2
https://example.com/article2
Second article
Tue, 02 Jan 2024 00:00:00 GMT
`
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/rss+xml")
w.WriteHeader(200)
w.Write([]byte(rssContent))
}))
defer ts.Close()
// Create a feed pointing to the test server
f := &feed.Feed{Url: ts.URL, Title: "Test"}
f.Create()
ch := make(chan string, 1)
CrawlFeed(f, ch)
result := <-ch
if result == "" {
t.Error("CrawlFeed should send a result")
}
// Verify items were created
var count int
models.DB.QueryRow("SELECT COUNT(*) FROM item").Scan(&count)
if count != 2 {
t.Errorf("Expected 2 items, got %d", count)
}
}
func TestCrawlFeedBadContent(t *testing.T) {
setupTestDB(t)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.Write([]byte("not xml at all"))
}))
defer ts.Close()
f := &feed.Feed{Url: ts.URL, Title: "Bad"}
f.Create()
ch := make(chan string, 1)
CrawlFeed(f, ch)
result := <-ch
if result == "" {
t.Error("CrawlFeed should send a result even on failure")
}
}
func TestCrawlWorker(t *testing.T) {
setupTestDB(t)
rssContent := `
Worker Feed
https://example.com
-
Worker Article
https://example.com/worker-article
An article
`
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.Write([]byte(rssContent))
}))
defer ts.Close()
f := &feed.Feed{Url: ts.URL, Title: "Worker Test"}
f.Create()
feeds := make(chan *feed.Feed, 1)
results := make(chan string, 1)
feeds <- f
close(feeds)
CrawlWorker(feeds, results)
result := <-results
if result == "" {
t.Error("CrawlWorker should produce a result")
}
}
func TestCrawl(t *testing.T) {
setupTestDB(t)
rssContent := `
Crawl Feed
https://example.com
-
Crawl Article
https://example.com/crawl-article
Article for crawl test
`
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.Write([]byte(rssContent))
}))
defer ts.Close()
f := &feed.Feed{Url: ts.URL, Title: "Full Crawl"}
f.Create()
// Should not panic
Crawl()
var count int
models.DB.QueryRow("SELECT COUNT(*) FROM item").Scan(&count)
if count != 1 {
t.Errorf("Expected 1 item after crawl, got %d", count)
}
}