From c9131b0e89db2989dd5e2e6894d6bec6c0075cb6 Mon Sep 17 00:00:00 2001
From: dzaikos <you@example.com>
Date: Sun, 24 Jun 2018 13:16:07 -0400
Subject: [PATCH] Improve sanitizer to remove style tag contents.

See #157.

Refactored how blacklisted tags are handled so they're easier manage in the future.
---
 reader/sanitizer/sanitizer.go      | 29 +++++++++++++++++++++--------
 reader/sanitizer/sanitizer_test.go | 10 ++++++++++
 2 files changed, 31 insertions(+), 8 deletions(-)

diff --git a/reader/sanitizer/sanitizer.go b/reader/sanitizer/sanitizer.go
index d7a4626a..2a0a2af3 100644
--- a/reader/sanitizer/sanitizer.go
+++ b/reader/sanitizer/sanitizer.go
@@ -25,7 +25,7 @@ func Sanitize(baseURL, input string) string {
 	tokenizer := html.NewTokenizer(bytes.NewBufferString(input))
 	var buffer bytes.Buffer
 	var tagStack []string
-	scriptTagDepth := 0
+	blacklistedTagDepth := 0
 
 	for {
 		if tokenizer.Next() == html.ErrorToken {
@@ -40,7 +40,7 @@ func Sanitize(baseURL, input string) string {
 		token := tokenizer.Token()
 		switch token.Type {
 		case html.TextToken:
-			if scriptTagDepth > 0 {
+			if blacklistedTagDepth > 0 {
 				continue
 			}
 
@@ -60,15 +60,15 @@ func Sanitize(baseURL, input string) string {
 
 					tagStack = append(tagStack, tagName)
 				}
-			} else if isScriptTag(tagName) {
-				scriptTagDepth++
+			} else if isBlacklistedTag(tagName) {
+				blacklistedTagDepth++
 			}
 		case html.EndTagToken:
 			tagName := token.DataAtom.String()
 			if isValidTag(tagName) && inList(tagName, tagStack) {
 				buffer.WriteString(fmt.Sprintf("</%s>", tagName))
-			} else if isScriptTag(tagName) {
-				scriptTagDepth--
+			} else if isBlacklistedTag(tagName) {
+				blacklistedTagDepth--
 			}
 		case html.SelfClosingTagToken:
 			tagName := token.DataAtom.String()
@@ -394,6 +394,19 @@ func rewriteIframeURL(link string) string {
 	return link
 }
 
-func isScriptTag(tagName string) bool {
-	return tagName == "script" || tagName == "noscript"
+// Blacklisted tags remove the tag and all descendants.
+func isBlacklistedTag(tagName string) bool {
+	blacklist := []string{
+		"noscript",
+		"script",
+		"style",
+	}
+
+	for _, element := range blacklist {
+		if element == tagName {
+			return true
+		}
+	}
+
+	return false
 }
diff --git a/reader/sanitizer/sanitizer_test.go b/reader/sanitizer/sanitizer_test.go
index fa7dd6d9..374c107c 100644
--- a/reader/sanitizer/sanitizer_test.go
+++ b/reader/sanitizer/sanitizer_test.go
@@ -232,3 +232,13 @@ func TestReplaceScript(t *testing.T) {
 		t.Errorf(`Wrong output: "%s" != "%s"`, expected, output)
 	}
 }
+
+func TestReplaceStyle(t *testing.T) {
+	input := `<p>Before paragraph.</p><style>body { background-color: #ff0000; }</style><p>After paragraph.</p>`
+	expected := `<p>Before paragraph.</p><p>After paragraph.</p>`
+	output := Sanitize("http://example.org/", input)
+
+	if expected != output {
+		t.Errorf(`Wrong output: "%s" != "%s"`, expected, output)
+	}
+}