diff --git a/spec/compiler/semantic/restrictions_spec.cr b/spec/compiler/semantic/restrictions_spec.cr index 139f881cc447..4a3bd01b3694 100644 --- a/spec/compiler/semantic/restrictions_spec.cr +++ b/spec/compiler/semantic/restrictions_spec.cr @@ -548,6 +548,46 @@ describe "Restrictions" do end end + describe "Path vs NumberLiteral" do + it "inserts constant before number literal of same value with generic arguments" do + assert_type(<<-CR) { bool } + X = 1 + + class Foo(N) + end + + def foo(a : Foo(1)) + 'a' + end + + def foo(a : Foo(X)) + true + end + + foo(Foo(1).new) + CR + end + + it "inserts number literal before constant of same value with generic arguments" do + assert_type(<<-CR) { bool } + X = 1 + + class Foo(N) + end + + def foo(a : Foo(X)) + 'a' + end + + def foo(a : Foo(1)) + true + end + + foo(Foo(1).new) + CR + end + end + describe "free variables" do it "inserts path before free variable with same name" do assert_type(<<-CR) { tuple_of([char, bool]) } diff --git a/src/compiler/crystal/semantic/restrictions.cr b/src/compiler/crystal/semantic/restrictions.cr index 5e157e3ef200..fca3e70a5f7b 100644 --- a/src/compiler/crystal/semantic/restrictions.cr +++ b/src/compiler/crystal/semantic/restrictions.cr @@ -605,11 +605,57 @@ module Crystal false end + def restriction_of?(other : NumberLiteral, owner, self_free_vars = nil, other_free_vars = nil) + # this happens when `self` and `other` are generic arguments: + # + # ``` + # X = 1 + # + # def foo(param : StaticArray(Int32, X)) + # end + # + # def foo(param : StaticArray(Int32, 1)) + # end + # ``` + case self_type = owner.lookup_path(self) + when Const + self_type.value == other + when NumberLiteral + self_type == other + else + false + end + end + def restriction_of?(other, owner, self_free_vars = nil, other_free_vars = nil) false end end + class NumberLiteral + def restriction_of?(other : Path, owner, self_free_vars = nil, other_free_vars = nil) + # this happens when `self` and `other` are generic arguments: + # + # ``` + # X = 1 + # + # def foo(param : StaticArray(Int32, 1)) + # end + # + # def foo(param : StaticArray(Int32, X)) + # end + # ``` + case other_type = owner.lookup_path(other) + when Const + other_type.value == self + when NumberLiteral + other_type == self + else + false + end + end + end + class Union def restriction_of?(other, owner, self_free_vars = nil, other_free_vars = nil) # For a union to be considered before another restriction,