diff options
| -rw-r--r-- | internal/crawler/crawler.go | 3 | ||||
| -rw-r--r-- | internal/crawler/security_test.go | 34 |
2 files changed, 36 insertions, 1 deletions
diff --git a/internal/crawler/crawler.go b/internal/crawler/crawler.go index 4f5de98..e664e06 100644 --- a/internal/crawler/crawler.go +++ b/internal/crawler/crawler.go @@ -15,6 +15,7 @@ import ( ) const MAX_CRAWLERS = 5 +const MAX_FEED_SIZE = 10 * 1024 * 1024 // 10MB func Crawl() { crawlJobs := make(chan *feed.Feed, 100) @@ -88,7 +89,7 @@ func GetFeedContent(feedURL string) string { return "" } - bodyBytes, err := io.ReadAll(resp.Body) + bodyBytes, err := io.ReadAll(io.LimitReader(resp.Body, MAX_FEED_SIZE)) if err != nil { return "" } diff --git a/internal/crawler/security_test.go b/internal/crawler/security_test.go new file mode 100644 index 0000000..198f7ee --- /dev/null +++ b/internal/crawler/security_test.go @@ -0,0 +1,34 @@ +package crawler + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + + "adammathes.com/neko/internal/safehttp" +) + +func init() { + safehttp.AllowLocal = true +} + +func TestGetFeedContentLimit(t *testing.T) { + // 10MB limit expected + limit := 10 * 1024 * 1024 + // 11MB payload + size := limit + 1024*1024 + largeBody := strings.Repeat("a", size) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte(largeBody)) + })) + defer ts.Close() + + content := GetFeedContent(ts.URL) + + if len(content) != limit { + t.Errorf("Expected content length %d, got %d", limit, len(content)) + } +} |
