diff --git a/spec/compiler/codegen/macro_spec.cr b/spec/compiler/codegen/macro_spec.cr index 6807486a951b..d5c940764302 100644 --- a/spec/compiler/codegen/macro_spec.cr +++ b/spec/compiler/codegen/macro_spec.cr @@ -1272,28 +1272,6 @@ describe "Code gen: macro" do )).to_i.should eq(1) end - it "solves macro expression arguments before macro expansion (type)" do - run(%( - macro name(x) - {{x.name.stringify}} - end - - name({{String}}) - )).to_string.should eq("String") - end - - it "solves macro expression arguments before macro expansion (constant)" do - run(%( - CONST = 1 - - macro id(x) - {{x}} - end - - id({{CONST}}) - )).to_i.should eq(1) - end - it "can use macro inside array literal" do run(%( require "prelude" diff --git a/spec/compiler/semantic/macro_spec.cr b/spec/compiler/semantic/macro_spec.cr index ec282fe924eb..17cd6d03c11c 100644 --- a/spec/compiler/semantic/macro_spec.cr +++ b/spec/compiler/semantic/macro_spec.cr @@ -841,6 +841,64 @@ describe "Semantic: macro" do "missing argument: z" end + it "solves macro expression arguments before macro expansion (type)" do + assert_type(%( + macro foo(x) + {% if x.is_a?(TypeNode) && x.name == "String" %} + 1 + {% else %} + 'a' + {% end %} + end + + foo({{ String }}) + )) { int32 } + end + + it "solves macro expression arguments before macro expansion (constant)" do + assert_type(%( + macro foo(x) + {% if x.is_a?(NumberLiteral) && x == 1 %} + 1 + {% else %} + 'a' + {% end %} + end + + CONST = 1 + foo({{ CONST }}) + )) { int32 } + end + + it "solves named macro expression arguments before macro expansion (type) (#2423)" do + assert_type(%( + macro foo(x) + {% if x.is_a?(TypeNode) && x.name == "String" %} + 1 + {% else %} + 'a' + {% end %} + end + + foo(x: {{ String }}) + )) { int32 } + end + + it "solves named macro expression arguments before macro expansion (constant) (#2423)" do + assert_type(%( + macro foo(x) + {% if x.is_a?(NumberLiteral) && x == 1 %} + 1 + {% else %} + 'a' + {% end %} + end + + CONST = 1 + foo(x: {{ CONST }}) + )) { int32 } + end + it "finds generic type argument of included module" do assert_type(%( module Bar(T) diff --git a/src/compiler/crystal/semantic/semantic_visitor.cr b/src/compiler/crystal/semantic/semantic_visitor.cr index a6c5156a5de2..0f19aaaa2266 100644 --- a/src/compiler/crystal/semantic/semantic_visitor.cr +++ b/src/compiler/crystal/semantic/semantic_visitor.cr @@ -308,14 +308,14 @@ abstract class Crystal::SemanticVisitor < Crystal::Visitor expansion_scope = (macro_scope || @scope || current_type) - args = expand_macro_arguments(node, expansion_scope) + args, named_args = expand_macro_arguments(node, expansion_scope) @exp_nest -= 1 generated_nodes = expand_macro(the_macro, node, visibility: node.visibility, accept: accept) do - old_args = node.args - node.args = args + old_args, old_named_args = node.args, node.named_args + node.args, node.named_args = args, named_args expanded_macro, macro_expansion_pragmas = @program.expand_macro the_macro, node, expansion_scope, expansion_scope, @untyped_def - node.args = old_args + node.args, node.named_args = old_args, old_named_args {expanded_macro, macro_expansion_pragmas} end @exp_nest += 1 @@ -377,33 +377,44 @@ abstract class Crystal::SemanticVisitor < Crystal::Visitor end end - def expand_macro_arguments(node, expansion_scope) + def expand_macro_arguments(call, expansion_scope) # If any argument is a MacroExpression, solve it first and # replace Path with Const/TypeNode if it denotes such thing - args = node.args - if args.any? &.is_a?(MacroExpression) + args = call.args + named_args = call.named_args + + if args.any?(MacroExpression) || named_args.try &.any? &.value.is_a?(MacroExpression) @exp_nest -= 1 args = args.map do |arg| - if arg.is_a?(MacroExpression) - arg.accept self - expanded = arg.expanded.not_nil! - if expanded.is_a?(Path) - expanded_type = expansion_scope.lookup_path(expanded) - case expanded_type - when Const - expanded = expanded_type.value - when Type - expanded = TypeNode.new(expanded_type) - end - end - expanded - else - arg - end + expand_macro_argument(arg, expansion_scope) + end + named_args = named_args.try &.map do |named_arg| + value = expand_macro_argument(named_arg.value, expansion_scope) + NamedArgument.new(named_arg.name, value) end @exp_nest += 1 end - args + + {args, named_args} + end + + def expand_macro_argument(node, expansion_scope) + if node.is_a?(MacroExpression) + node.accept self + expanded = node.expanded.not_nil! + if expanded.is_a?(Path) + expanded_type = expansion_scope.lookup_path(expanded) + case expanded_type + when Const + expanded = expanded_type.value + when Type + expanded = TypeNode.new(expanded_type) + end + end + expanded + else + node + end end def expand_inline_macro(node, mode = nil, accept = true)