From 38d479d3bf08be4f74c8b896e757e69a270c8e03 Mon Sep 17 00:00:00 2001 From: RandomHashTags Date: Tue, 15 Oct 2024 00:17:53 -0500 Subject: [PATCH] improved what node is being warned for `unsafeInterpolation` compiler warning --- Sources/HTMLKitMacros/HTMLElement.swift | 68 ++++++++++++++----------- Tests/HTMLKitTests/HTMLKitTests.swift | 12 ++++- 2 files changed, 48 insertions(+), 32 deletions(-) diff --git a/Sources/HTMLKitMacros/HTMLElement.swift b/Sources/HTMLKitMacros/HTMLElement.swift index f6576ee..895aabc 100644 --- a/Sources/HTMLKitMacros/HTMLElement.swift +++ b/Sources/HTMLKitMacros/HTMLElement.swift @@ -88,11 +88,11 @@ private extension HTMLElement { break case "event": key = "on" + key_element.memberAccess!.declName.baseName.text - if var (literalValue, returnType):(String, LiteralReturnType) = parse_literal_value(context: context, elementType: elementType, key: key, argument: function.arguments.last!) { - if returnType == .string { - literalValue.escapeHTML(escapeAttributes: true) + if var result:(String, LiteralReturnType) = parse_literal_value(context: context, elementType: elementType, key: key, argument: function.arguments.last!) { + if result.1 == .string { + result.0.escapeHTML(escapeAttributes: true) } - value = literalValue + value = result.0 } else { unallowed_expression(context: context, node: function.arguments.last!) return [] @@ -164,13 +164,13 @@ private extension HTMLElement { static func parse_attribute(context: some MacroExpansionContext, elementType: HTMLElementType, key: String, argument: LabeledExprSyntax) -> String? { let expression:ExprSyntax = argument.expression - if var (string, returnType):(String, LiteralReturnType) = parse_literal_value(context: context, elementType: elementType, key: key, argument: argument) { - switch returnType { - case .boolean: return string.elementsEqual("true") ? "" : nil + if var result:(String, LiteralReturnType) = parse_literal_value(context: context, elementType: elementType, key: key, argument: argument) { + switch result.1 { + case .boolean: return result.0.elementsEqual("true") ? "" : nil case .string: - string.escapeHTML(escapeAttributes: true) - return string - case .interpolation: return string + result.0.escapeHTML(escapeAttributes: true) + return result.0 + case .interpolation: return result.0 } } func member(_ value: String) -> String { @@ -215,22 +215,13 @@ private extension HTMLElement { if function.calledExpression.as(DeclReferenceExprSyntax.self)?.baseName.text == "StaticString" { return (function.arguments.first!.expression.stringLiteral!.string, .string) } - return ("\\(\(function))", .interpolation) + return ("\(function)", .interpolation) } } if let member:MemberAccessExprSyntax = expression.memberAccess { let decl:String = member.declName.baseName.text if let _:ExprSyntax = member.base { - /*if let integer:String = base.integerLiteral?.literal.text { - switch decl { - case "description": - return (integer, .integer) - default: - return (integer, .interpolation) - } - } else {*/ - return ("\\(\(member))", .interpolation) - //} + return ("\(member)", .interpolation) } else { return (HTMLElementAttribute.Extra.htmlValue(enumName: enumName(elementType: elementType, key: key), for: decl), .string) } @@ -268,18 +259,18 @@ private extension HTMLElement { let interpolation:[ExpressionSegmentSyntax] = expression.stringLiteral?.segments.compactMap({ $0.as(ExpressionSegmentSyntax.self) }) ?? [] var remaining_interpolation:Int = interpolation.count for expr in interpolation { - string = flatten_interpolation(remaining_interpolation: &remaining_interpolation, expr: expr) + string = flatten_interpolation(context: context, remaining_interpolation: &remaining_interpolation, expr: expr) } if returnType == .interpolation || remaining_interpolation > 0 { if !string.contains("\\(") { string = "\\(" + string + ")" + warn_interpolation(context: context, node: expression) } returnType = .interpolation - context.diagnose(Diagnostic(node: expression, message: DiagnosticMsg(id: "unsafeInterpolation", message: "Interpolation may introduce raw HTML.", severity: .warning))) } return (string, returnType) } - static func flatten_interpolation(remaining_interpolation: inout Int, expr: ExpressionSegmentSyntax) -> String { + static func flatten_interpolation(context: some MacroExpansionContext, remaining_interpolation: inout Int, expr: ExpressionSegmentSyntax) -> String { let expression:ExprSyntax = expr.expressions.first!.expression var string:String = "\(expr)" if let stringLiteral:StringLiteralExprSyntax = expression.stringLiteral { @@ -288,24 +279,39 @@ private extension HTMLElement { remaining_interpolation = 0 string = segments.map({ $0.as(StringSegmentSyntax.self)!.content.text }).joined() } else { - var values:[String] = [] + string = "" for segment in segments { if let literal:String = segment.as(StringSegmentSyntax.self)?.content.text { - values.append(literal) + string += literal } else if let interpolation:ExpressionSegmentSyntax = segment.as(ExpressionSegmentSyntax.self) { - values.append(flatten_interpolation(remaining_interpolation: &remaining_interpolation, expr: interpolation)) + let flattened:String = flatten_interpolation(context: context, remaining_interpolation: &remaining_interpolation, expr: interpolation) + if "\(interpolation)" == flattened { + //string += "\\(\"\(flattened)\".escapingHTML(escapeAttributes: true))" + string += "\(flattened)" + warn_interpolation(context: context, node: interpolation) + } else { + string += flattened + } } else { - values.append("\(segment)") + //string += "\\(\"\(segment)\".escapingHTML(escapeAttributes: true))" + warn_interpolation(context: context, node: segment) + string += "\(segment)" } } - string = values.joined() } } else if let fix:String = expression.integerLiteral?.literal.text ?? expression.floatLiteral?.literal.text { - remaining_interpolation -= string.ranges(of: "\(expr)").count - string.replace("\(expr)", with: fix) + let target:String = "\(expr)" + remaining_interpolation -= string.ranges(of: target).count + string.replace(target, with: fix) + } else { + //string = "\\(\"\(string)\".escapingHTML(escapeAttributes: true))" + warn_interpolation(context: context, node: expr) } return string } + static func warn_interpolation(context: some MacroExpansionContext, node: some SyntaxProtocol) { + context.diagnose(Diagnostic(node: node, message: DiagnosticMsg(id: "unsafeInterpolation", message: "Interpolation may introduce raw HTML.", severity: .warning))) + } } enum LiteralReturnType { diff --git a/Tests/HTMLKitTests/HTMLKitTests.swift b/Tests/HTMLKitTests/HTMLKitTests.swift index 9d1d9af..ffdc375 100644 --- a/Tests/HTMLKitTests/HTMLKitTests.swift +++ b/Tests/HTMLKitTests/HTMLKitTests.swift @@ -12,7 +12,7 @@ struct HTMLKitTests { @Test func escape_html() { let unescaped:String = "Test" let escaped:String = "<!DOCTYPE html><html>Test</html>" - let expected_result:String = "

\(escaped)

" + var expected_result:String = "

\(escaped)

" var string:String = #p("Test") #expect(string == expected_result) @@ -25,6 +25,16 @@ struct HTMLKitTests { string = #p("\(unescaped.escapingHTML(escapeAttributes: false))") #expect(string == expected_result) + + expected_result = "
<p></p>
" + string = #div(attributes: [.title(StaticString("

"))], StaticString("

")).description + #expect(string == expected_result) + + string = #div(attributes: [.title("

")], StaticString("

")).description + #expect(string == expected_result) + + string = #div(attributes: [.title("

")], "

") + #expect(string == expected_result) } }