diff --git a/spec/compiler/interpreter/classes_spec.cr b/spec/compiler/interpreter/classes_spec.cr index 334df66a7368..ea54f0b98fb2 100644 --- a/spec/compiler/interpreter/classes_spec.cr +++ b/spec/compiler/interpreter/classes_spec.cr @@ -159,4 +159,24 @@ describe Crystal::Repl::Interpreter do expression.value.foo CODE end + + it "downcasts virtual type to its only type (#12351)" do + interpret(<<-CODE).should eq(1) + abstract class A + end + + class B < A + def x + 1 + end + end + + def foo(b : B) + b = 1 + end + + b = B.new.as(A) + foo(b) + CODE + end end diff --git a/spec/compiler/interpreter/multidispatch_spec.cr b/spec/compiler/interpreter/multidispatch_spec.cr index ecce46c1df9a..ebd16a44f720 100644 --- a/spec/compiler/interpreter/multidispatch_spec.cr +++ b/spec/compiler/interpreter/multidispatch_spec.cr @@ -422,5 +422,18 @@ describe Crystal::Repl::Interpreter do a_value - b_value CODE end + + it "casts multidispatch argument to the def's arg type" do + interpret(<<-CODE) + def foo(a : String) forall T + end + + def foo(a) + a + end + + foo("b" || nil) + CODE + end end end diff --git a/src/compiler/crystal/interpreter/compiler.cr b/src/compiler/crystal/interpreter/compiler.cr index 2e4014dc1412..ec9cd2f11561 100644 --- a/src/compiler/crystal/interpreter/compiler.cr +++ b/src/compiler/crystal/interpreter/compiler.cr @@ -2244,7 +2244,7 @@ class Crystal::Repl::Compiler < Crystal::Visitor target_def_arg = target_def_args[i] target_def_var_type = target_def.vars.not_nil![target_def_arg.name].type - compile_call_arg(arg, arg_type, target_def_var_type) + compile_call_arg(arg, arg_type, target_def_arg.type, target_def_var_type) i += 1 end @@ -2291,7 +2291,7 @@ class Crystal::Repl::Compiler < Crystal::Visitor end end - private def compile_call_arg(arg, arg_type, target_def_var_type) + private def compile_call_arg(arg, arg_type, target_def_arg_type, target_def_var_type) # Check autocasting from symbol to enum if arg.is_a?(SymbolLiteral) && target_def_var_type.is_a?(EnumType) symbol_name = arg.value.underscore @@ -2318,7 +2318,11 @@ class Crystal::Repl::Compiler < Crystal::Visitor request_value(arg) - # We need to cast the argument to the target_def variable + # We first cast the argument to the def's arg type, + # which is the external methods' type. + downcast arg, arg_type, target_def_arg_type + + # Then we need to cast the argument to the target_def variable # corresponding to the argument. If for example we have this: # # ``` @@ -2331,7 +2335,7 @@ class Crystal::Repl::Compiler < Crystal::Visitor # # Then the actual type of `x` inside `foo` is (Int32 | Nil), # and we must cast `1` to it. - upcast arg, arg_type, target_def_var_type + upcast arg, target_def_arg_type, target_def_var_type end private def compile_pointerof_node(obj : Var, owner : Type) : Nil diff --git a/src/compiler/crystal/interpreter/multidispatch.cr b/src/compiler/crystal/interpreter/multidispatch.cr index 83c5ca304151..6decee13cde2 100644 --- a/src/compiler/crystal/interpreter/multidispatch.cr +++ b/src/compiler/crystal/interpreter/multidispatch.cr @@ -142,7 +142,7 @@ module Crystal::Repl::Multidispatch blocks = [] of Block - target_defs.each do |target_def| + target_defs.each_with_index do |target_def, target_def_index| i = 0 condition = nil @@ -168,8 +168,26 @@ module Crystal::Repl::Multidispatch call_args = [] of ASTNode i = 0 - node.args.each do - call_args << Var.new("arg#{i}") + node.args.each_with_index do |arg, arg_index| + var = Var.new("arg#{i}") + + # If the argument was autocasted it will always match in a multidispatch + if autocast_types.try &.[arg_index]? + call_args << var + next + end + + # Make sure to cast the argument to the target def arg's type + # in the last case, where the argument's type is not restricted by an if is_a? + if target_def_index == target_defs.size - 1 + target_def_arg = target_def.args[i] + + cast = Cast.new(var, TypeNode.new(target_def_arg.type)) + call_args << cast + else + call_args << var + end + i += 1 end