Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Require type var splats in included/inherited generic to match splats in definition #10240

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 56 additions & 1 deletion spec/compiler/semantic/generic_class_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,67 @@ describe "Semantic: generic class" do
class Foo(T)
end

class Bar < Foo(A, B)
class Bar < Foo(Int32, String)
end
),
"wrong number of type vars for Foo(T) (given 2, expected 1)"
end

it "errors if inheriting from variadic generic and incorrect number of type vars" do
assert_error %(
class Foo(*T, U, V)
end

class Bar < Foo(Int32)
end
),
"wrong number of type vars for Foo(*T, U, V) (given 1, expected 2+)"
end

it "errors if splatting type var into generic without splats" do
assert_error %(
class Foo(T)
end

class Bar(*T) < Foo(*T)
end
),
"cannot splat *T into Foo(T)"
end

it "errors if splatting type var into non-splat parameter, before splat in definition" do
assert_error %(
class Foo(T, *U)
end

class Bar(*T) < Foo(*T, Int32)
end
),
"cannot splat *T into non-splat type parameter T of Foo(T, *U)"
end

it "errors if splatting type var into non-splat parameter, after splat in definition" do
assert_error %(
class Foo(*T, U)
end

class Bar(*T) < Foo(Int32, *T)
end
),
"cannot splat *T into non-splat type parameter U of Foo(*T, U)"
end

it "errors if splatting type var into non-splat parameter, more args" do
assert_error %(
class Foo(T, *U, V, W)
end

class Bar(T, U, *V, W) < Foo(V, U, T, W, *V, T)
end
),
"cannot splat *V into non-splat type parameter V of Foo(T, *U, V, W)"
end

it "inherits from generic with instantiation" do
assert_type(%(
class Foo(T)
Expand Down
62 changes: 61 additions & 1 deletion spec/compiler/semantic/module_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ describe "Semantic: module" do
"Foo is not a generic type"
end

it "includes module but wrong number of arguments" do
it "errors if including generic and incorrect number of type vars" do
assert_error "
module Foo(T, U)
end
Expand All @@ -101,6 +101,66 @@ describe "Semantic: module" do
"wrong number of type vars for Foo(T, U) (given 1, expected 2)"
end

it "errors if including variadic generic and incorrect number of type vars" do
assert_error %(
module Foo(*T, U, V)
end

class Bar
include Foo(Int32)
end
),
"wrong number of type vars for Foo(*T, U, V) (given 1, expected 2+)"
end

it "errors if splatting type var into generic without splats" do
assert_error %(
module Foo(T)
end

class Bar(*T)
include Foo(*T)
end
),
"cannot splat *T into Foo(T)"
end

it "errors if splatting type var into non-splat parameter, before splat in definition" do
assert_error %(
module Foo(T, *U)
end

class Bar(*T)
include Foo(*T, Int32)
end
),
"cannot splat *T into non-splat type parameter T of Foo(T, *U)"
end

it "errors if splatting type var into non-splat parameter, after splat in definition" do
assert_error %(
module Foo(*T, U)
end

class Bar(*T)
include Foo(Int32, *T)
end
),
"cannot splat *T into non-splat type parameter U of Foo(*T, U)"
end

it "errors if splatting type var into non-splat parameter, more args" do
assert_error %(
module Foo(T, *U, V, W)
end

class Bar(T, U, *V, W)
include Foo(V, U, T, W, *V, T)
end
),
"cannot splat *V into non-splat type parameter V of Foo(T, *U, V, W)"
end

it "errors if including generic module and not specifying type vars" do
assert_error "
module Foo(T)
Expand Down
54 changes: 37 additions & 17 deletions src/compiler/crystal/semantic/type_lookup.cr
Original file line number Diff line number Diff line change
Expand Up @@ -200,23 +200,8 @@ class Crystal::Type
node.raise "instantiating #{node}", inner: ex if @raise
end
when GenericType
if instance_type.splat_index
if node.named_args
node.raise "can only use named arguments with NamedTuple"
end

min_needed = instance_type.type_vars.size - 1
if node.type_vars.size < min_needed
node.wrong_number_of "type vars", instance_type, node.type_vars.size, "#{min_needed}+"
end
else
if node.named_args
node.raise "can only use named arguments with NamedTuple"
end

if instance_type.type_vars.size != node.type_vars.size
node.wrong_number_of "type vars", instance_type, node.type_vars.size, instance_type.type_vars.size
end
if node.named_args
node.raise "can only use named arguments with NamedTuple"
end
else
node.raise "#{instance_type} is not a generic type, it's a #{instance_type.type_desc}"
Expand Down Expand Up @@ -279,6 +264,41 @@ class Crystal::Type
type_vars << type.virtual_type
end

splat_index = instance_type.splat_index
var_count = instance_type.type_vars.size

case {type_vars.any?(TypeSplat), splat_index}
in {true, Int32}
# node is `(A, B, C, *D, E)`, definition is `(T, *U, V, W)`; *D would splat into V
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain this a bit more? Aren't all types on the left hand side need to be know to know how to splat them? I don't understand why D would splat only into V, which is not a splat.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The arguments before and after the splat in the instantiation are always matched first, so A, B, and C would match T and *U, followed by E with W, and finally *D with *U and V, which would be disallowed regardless of what types the TypeParameters and TypeSplats have. Thus the definition must not have more than 3 non-splat parameters before *U or more than 1 non-splat parameter after *U.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the PR description it says:

class Foo(T, *U, V); end
class Bar(*A) < Foo(*A, Int32, Int32); end        # error, *A splats into T

But if I do:

Bar(Int32, Int32, Int32).new

then we get:

Foo(Int32, Int32, Int32, Int32, Int32)

and for:

Foo(T, *U, V)

we get:

  • T: Int32
  • U: Int32, Int32, Int32
  • V: Int32

So that shouldn't be an error?

That said, I'm sure I'm misunderstanding something.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess what I'm saying is... don't we need to expand/collect all the types on node and then match that against the definition?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In light of #3649 (comment), maybe this restriction isn't really necessary...

0.upto(splat_index - 1) do |index|
if (splat = type_vars[index]).is_a?(TypeSplat)
type_var = instance_type.type_vars[index]
node.raise "cannot splat #{splat} into non-splat type parameter #{type_var} of #{instance_type}"
end
end
(splat_index + 1).upto(var_count - 1) do |index|
if (splat = type_vars[index - var_count]).is_a?(TypeSplat)
type_var = instance_type.type_vars[index]
node.raise "cannot splat #{splat} into non-splat type parameter #{type_var} of #{instance_type}"
end
end
in {true, Nil}
# node is `(A, B, *C)`, definition is `(T)`; instantiation would always fail
splat = type_vars.find(&.is_a?(TypeSplat)).not_nil!
node.raise "cannot splat #{splat} into #{instance_type}"
in {false, Int32}
# node is `(A)`, definition is `(T, U, *V)`; instantiation would always fail
min_needed = var_count - 1
if type_vars.size < min_needed
node.wrong_number_of "type vars", instance_type, type_vars.size, "#{min_needed}+"
end
in {false, Nil}
# neither side contains splats; type var counts must be equal
if type_vars.size != var_count
node.wrong_number_of "type vars", instance_type, type_vars.size, var_count
end
end

begin
if instance_type.is_a?(GenericUnionType) && type_vars.any? &.is_a?(TypeSplat)
# In the case of `Union(*T)`, we don't need to instantiate the union right
Expand Down