diff --git a/spec/compiler/semantic/generic_class_spec.cr b/spec/compiler/semantic/generic_class_spec.cr index b85e0cf76b98..af647a32f2bc 100644 --- a/spec/compiler/semantic/generic_class_spec.cr +++ b/spec/compiler/semantic/generic_class_spec.cr @@ -17,12 +17,67 @@ describe "Semantic: generic class" do class Foo(T) end - class Bar < Foo(A, B) + class Bar < Foo(Int32, String) end ), "wrong number of type vars for Foo(T) (given 2, expected 1)" end + it "errors if inheriting from variadic generic and incorrect number of type vars" do + assert_error %( + class Foo(*T, U, V) + end + + class Bar < Foo(Int32) + end + ), + "wrong number of type vars for Foo(*T, U, V) (given 1, expected 2+)" + end + + it "errors if splatting type var into generic without splats" do + assert_error %( + class Foo(T) + end + + class Bar(*T) < Foo(*T) + end + ), + "cannot splat *T into Foo(T)" + end + + it "errors if splatting type var into non-splat parameter, before splat in definition" do + assert_error %( + class Foo(T, *U) + end + + class Bar(*T) < Foo(*T, Int32) + end + ), + "cannot splat *T into non-splat type parameter T of Foo(T, *U)" + end + + it "errors if splatting type var into non-splat parameter, after splat in definition" do + assert_error %( + class Foo(*T, U) + end + + class Bar(*T) < Foo(Int32, *T) + end + ), + "cannot splat *T into non-splat type parameter U of Foo(*T, U)" + end + + it "errors if splatting type var into non-splat parameter, more args" do + assert_error %( + class Foo(T, *U, V, W) + end + + class Bar(T, U, *V, W) < Foo(V, U, T, W, *V, T) + end + ), + "cannot splat *V into non-splat type parameter V of Foo(T, *U, V, W)" + end + it "inherits from generic with instantiation" do assert_type(%( class Foo(T) diff --git a/spec/compiler/semantic/module_spec.cr b/spec/compiler/semantic/module_spec.cr index c85a0b777be9..4dff1813433f 100644 --- a/spec/compiler/semantic/module_spec.cr +++ b/spec/compiler/semantic/module_spec.cr @@ -89,7 +89,7 @@ describe "Semantic: module" do "Foo is not a generic type" end - it "includes module but wrong number of arguments" do + it "errors if including generic and incorrect number of type vars" do assert_error " module Foo(T, U) end @@ -101,6 +101,66 @@ describe "Semantic: module" do "wrong number of type vars for Foo(T, U) (given 1, expected 2)" end + it "errors if including variadic generic and incorrect number of type vars" do + assert_error %( + module Foo(*T, U, V) + end + + class Bar + include Foo(Int32) + end + ), + "wrong number of type vars for Foo(*T, U, V) (given 1, expected 2+)" + end + + it "errors if splatting type var into generic without splats" do + assert_error %( + module Foo(T) + end + + class Bar(*T) + include Foo(*T) + end + ), + "cannot splat *T into Foo(T)" + end + + it "errors if splatting type var into non-splat parameter, before splat in definition" do + assert_error %( + module Foo(T, *U) + end + + class Bar(*T) + include Foo(*T, Int32) + end + ), + "cannot splat *T into non-splat type parameter T of Foo(T, *U)" + end + + it "errors if splatting type var into non-splat parameter, after splat in definition" do + assert_error %( + module Foo(*T, U) + end + + class Bar(*T) + include Foo(Int32, *T) + end + ), + "cannot splat *T into non-splat type parameter U of Foo(*T, U)" + end + + it "errors if splatting type var into non-splat parameter, more args" do + assert_error %( + module Foo(T, *U, V, W) + end + + class Bar(T, U, *V, W) + include Foo(V, U, T, W, *V, T) + end + ), + "cannot splat *V into non-splat type parameter V of Foo(T, *U, V, W)" + end + it "errors if including generic module and not specifying type vars" do assert_error " module Foo(T) diff --git a/src/compiler/crystal/semantic/type_lookup.cr b/src/compiler/crystal/semantic/type_lookup.cr index a3020ff397e8..3de29e827f3b 100644 --- a/src/compiler/crystal/semantic/type_lookup.cr +++ b/src/compiler/crystal/semantic/type_lookup.cr @@ -200,23 +200,8 @@ class Crystal::Type node.raise "instantiating #{node}", inner: ex if @raise end when GenericType - if instance_type.splat_index - if node.named_args - node.raise "can only use named arguments with NamedTuple" - end - - min_needed = instance_type.type_vars.size - 1 - if node.type_vars.size < min_needed - node.wrong_number_of "type vars", instance_type, node.type_vars.size, "#{min_needed}+" - end - else - if node.named_args - node.raise "can only use named arguments with NamedTuple" - end - - if instance_type.type_vars.size != node.type_vars.size - node.wrong_number_of "type vars", instance_type, node.type_vars.size, instance_type.type_vars.size - end + if node.named_args + node.raise "can only use named arguments with NamedTuple" end else node.raise "#{instance_type} is not a generic type, it's a #{instance_type.type_desc}" @@ -279,6 +264,41 @@ class Crystal::Type type_vars << type.virtual_type end + splat_index = instance_type.splat_index + var_count = instance_type.type_vars.size + + case {type_vars.any?(TypeSplat), splat_index} + in {true, Int32} + # node is `(A, B, C, *D, E)`, definition is `(T, *U, V, W)`; *D would splat into V + 0.upto(splat_index - 1) do |index| + if (splat = type_vars[index]).is_a?(TypeSplat) + type_var = instance_type.type_vars[index] + node.raise "cannot splat #{splat} into non-splat type parameter #{type_var} of #{instance_type}" + end + end + (splat_index + 1).upto(var_count - 1) do |index| + if (splat = type_vars[index - var_count]).is_a?(TypeSplat) + type_var = instance_type.type_vars[index] + node.raise "cannot splat #{splat} into non-splat type parameter #{type_var} of #{instance_type}" + end + end + in {true, Nil} + # node is `(A, B, *C)`, definition is `(T)`; instantiation would always fail + splat = type_vars.find(&.is_a?(TypeSplat)).not_nil! + node.raise "cannot splat #{splat} into #{instance_type}" + in {false, Int32} + # node is `(A)`, definition is `(T, U, *V)`; instantiation would always fail + min_needed = var_count - 1 + if type_vars.size < min_needed + node.wrong_number_of "type vars", instance_type, type_vars.size, "#{min_needed}+" + end + in {false, Nil} + # neither side contains splats; type var counts must be equal + if type_vars.size != var_count + node.wrong_number_of "type vars", instance_type, type_vars.size, var_count + end + end + begin if instance_type.is_a?(GenericUnionType) && type_vars.any? &.is_a?(TypeSplat) # In the case of `Union(*T)`, we don't need to instantiate the union right