Skip to content

Commit

Permalink
improved what node is being warned for unsafeInterpolation compiler…
Browse files Browse the repository at this point in the history
… warning
  • Loading branch information
RandomHashTags committed Oct 15, 2024
1 parent f22ca51 commit 38d479d
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 32 deletions.
68 changes: 37 additions & 31 deletions Sources/HTMLKitMacros/HTMLElement.swift
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,11 @@ private extension HTMLElement {
break
case "event":
key = "on" + key_element.memberAccess!.declName.baseName.text
if var (literalValue, returnType):(String, LiteralReturnType) = parse_literal_value(context: context, elementType: elementType, key: key, argument: function.arguments.last!) {
if returnType == .string {
literalValue.escapeHTML(escapeAttributes: true)
if var result:(String, LiteralReturnType) = parse_literal_value(context: context, elementType: elementType, key: key, argument: function.arguments.last!) {
if result.1 == .string {
result.0.escapeHTML(escapeAttributes: true)
}
value = literalValue
value = result.0
} else {
unallowed_expression(context: context, node: function.arguments.last!)
return []
Expand Down Expand Up @@ -164,13 +164,13 @@ private extension HTMLElement {

static func parse_attribute(context: some MacroExpansionContext, elementType: HTMLElementType, key: String, argument: LabeledExprSyntax) -> String? {
let expression:ExprSyntax = argument.expression
if var (string, returnType):(String, LiteralReturnType) = parse_literal_value(context: context, elementType: elementType, key: key, argument: argument) {
switch returnType {
case .boolean: return string.elementsEqual("true") ? "" : nil
if var result:(String, LiteralReturnType) = parse_literal_value(context: context, elementType: elementType, key: key, argument: argument) {
switch result.1 {
case .boolean: return result.0.elementsEqual("true") ? "" : nil
case .string:
string.escapeHTML(escapeAttributes: true)
return string
case .interpolation: return string
result.0.escapeHTML(escapeAttributes: true)
return result.0
case .interpolation: return result.0
}
}
func member(_ value: String) -> String {
Expand Down Expand Up @@ -215,22 +215,13 @@ private extension HTMLElement {
if function.calledExpression.as(DeclReferenceExprSyntax.self)?.baseName.text == "StaticString" {
return (function.arguments.first!.expression.stringLiteral!.string, .string)
}
return ("\\(\(function))", .interpolation)
return ("\(function)", .interpolation)
}
}
if let member:MemberAccessExprSyntax = expression.memberAccess {
let decl:String = member.declName.baseName.text
if let _:ExprSyntax = member.base {
/*if let integer:String = base.integerLiteral?.literal.text {
switch decl {
case "description":
return (integer, .integer)
default:
return (integer, .interpolation)
}
} else {*/
return ("\\(\(member))", .interpolation)
//}
return ("\(member)", .interpolation)
} else {
return (HTMLElementAttribute.Extra.htmlValue(enumName: enumName(elementType: elementType, key: key), for: decl), .string)
}
Expand Down Expand Up @@ -268,18 +259,18 @@ private extension HTMLElement {
let interpolation:[ExpressionSegmentSyntax] = expression.stringLiteral?.segments.compactMap({ $0.as(ExpressionSegmentSyntax.self) }) ?? []
var remaining_interpolation:Int = interpolation.count
for expr in interpolation {
string = flatten_interpolation(remaining_interpolation: &remaining_interpolation, expr: expr)
string = flatten_interpolation(context: context, remaining_interpolation: &remaining_interpolation, expr: expr)
}
if returnType == .interpolation || remaining_interpolation > 0 {
if !string.contains("\\(") {
string = "\\(" + string + ")"
warn_interpolation(context: context, node: expression)
}
returnType = .interpolation
context.diagnose(Diagnostic(node: expression, message: DiagnosticMsg(id: "unsafeInterpolation", message: "Interpolation may introduce raw HTML.", severity: .warning)))
}
return (string, returnType)
}
static func flatten_interpolation(remaining_interpolation: inout Int, expr: ExpressionSegmentSyntax) -> String {
static func flatten_interpolation(context: some MacroExpansionContext, remaining_interpolation: inout Int, expr: ExpressionSegmentSyntax) -> String {
let expression:ExprSyntax = expr.expressions.first!.expression
var string:String = "\(expr)"
if let stringLiteral:StringLiteralExprSyntax = expression.stringLiteral {
Expand All @@ -288,24 +279,39 @@ private extension HTMLElement {
remaining_interpolation = 0
string = segments.map({ $0.as(StringSegmentSyntax.self)!.content.text }).joined()
} else {
var values:[String] = []
string = ""
for segment in segments {
if let literal:String = segment.as(StringSegmentSyntax.self)?.content.text {
values.append(literal)
string += literal
} else if let interpolation:ExpressionSegmentSyntax = segment.as(ExpressionSegmentSyntax.self) {
values.append(flatten_interpolation(remaining_interpolation: &remaining_interpolation, expr: interpolation))
let flattened:String = flatten_interpolation(context: context, remaining_interpolation: &remaining_interpolation, expr: interpolation)
if "\(interpolation)" == flattened {
//string += "\\(\"\(flattened)\".escapingHTML(escapeAttributes: true))"
string += "\(flattened)"
warn_interpolation(context: context, node: interpolation)
} else {
string += flattened
}
} else {
values.append("\(segment)")
//string += "\\(\"\(segment)\".escapingHTML(escapeAttributes: true))"
warn_interpolation(context: context, node: segment)
string += "\(segment)"
}
}
string = values.joined()
}
} else if let fix:String = expression.integerLiteral?.literal.text ?? expression.floatLiteral?.literal.text {
remaining_interpolation -= string.ranges(of: "\(expr)").count
string.replace("\(expr)", with: fix)
let target:String = "\(expr)"
remaining_interpolation -= string.ranges(of: target).count
string.replace(target, with: fix)
} else {
//string = "\\(\"\(string)\".escapingHTML(escapeAttributes: true))"
warn_interpolation(context: context, node: expr)
}
return string
}
static func warn_interpolation(context: some MacroExpansionContext, node: some SyntaxProtocol) {
context.diagnose(Diagnostic(node: node, message: DiagnosticMsg(id: "unsafeInterpolation", message: "Interpolation may introduce raw HTML.", severity: .warning)))
}
}

enum LiteralReturnType {
Expand Down
12 changes: 11 additions & 1 deletion Tests/HTMLKitTests/HTMLKitTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ struct HTMLKitTests {
@Test func escape_html() {
let unescaped:String = "<!DOCTYPE html><html>Test</html>"
let escaped:String = "&lt;!DOCTYPE html&gt;&lt;html&gt;Test&lt;/html&gt;"
let expected_result:String = "<p>\(escaped)</p>"
var expected_result:String = "<p>\(escaped)</p>"

var string:String = #p("<!DOCTYPE html><html>Test</html>")
#expect(string == expected_result)
Expand All @@ -25,6 +25,16 @@ struct HTMLKitTests {

string = #p("\(unescaped.escapingHTML(escapeAttributes: false))")
#expect(string == expected_result)

expected_result = "<div title=\"&lt;p&gt;\">&lt;p&gt;&lt;/p&gt;</div>"
string = #div(attributes: [.title(StaticString("<p>"))], StaticString("<p></p>")).description
#expect(string == expected_result)

string = #div(attributes: [.title("<p>")], StaticString("<p></p>")).description
#expect(string == expected_result)

string = #div(attributes: [.title("<p>")], "<p></p>")
#expect(string == expected_result)
}
}

Expand Down

0 comments on commit 38d479d

Please sign in to comment.