aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--internal/crawler/crawler.go3
-rw-r--r--internal/crawler/security_test.go34
2 files changed, 36 insertions, 1 deletions
diff --git a/internal/crawler/crawler.go b/internal/crawler/crawler.go
index 4f5de98..e664e06 100644
--- a/internal/crawler/crawler.go
+++ b/internal/crawler/crawler.go
@@ -15,6 +15,7 @@ import (
)
const MAX_CRAWLERS = 5
+const MAX_FEED_SIZE = 10 * 1024 * 1024 // 10MB
func Crawl() {
crawlJobs := make(chan *feed.Feed, 100)
@@ -88,7 +89,7 @@ func GetFeedContent(feedURL string) string {
return ""
}
- bodyBytes, err := io.ReadAll(resp.Body)
+ bodyBytes, err := io.ReadAll(io.LimitReader(resp.Body, MAX_FEED_SIZE))
if err != nil {
return ""
}
diff --git a/internal/crawler/security_test.go b/internal/crawler/security_test.go
new file mode 100644
index 0000000..198f7ee
--- /dev/null
+++ b/internal/crawler/security_test.go
@@ -0,0 +1,34 @@
+package crawler
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+
+ "adammathes.com/neko/internal/safehttp"
+)
+
+func init() {
+ safehttp.AllowLocal = true
+}
+
+func TestGetFeedContentLimit(t *testing.T) {
+ // 10MB limit expected
+ limit := 10 * 1024 * 1024
+ // 11MB payload
+ size := limit + 1024*1024
+ largeBody := strings.Repeat("a", size)
+
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(200)
+ w.Write([]byte(largeBody))
+ }))
+ defer ts.Close()
+
+ content := GetFeedContent(ts.URL)
+
+ if len(content) != limit {
+ t.Errorf("Expected content length %d, got %d", limit, len(content))
+ }
+}