From e7b46c40739902db70a2dceebd09caf1a7c70553 Mon Sep 17 00:00:00 2001 From: Quinton Miller Date: Tue, 30 Mar 2021 17:21:14 +0800 Subject: [PATCH] Fix logic for subclass restricted against uninstantiated nested generic superclass (#10522) --- spec/compiler/codegen/is_a_spec.cr | 72 +++++++++++++++++++ spec/compiler/semantic/restrictions_spec.cr | 48 +++++++++++++ src/compiler/crystal/semantic/restrictions.cr | 35 +++++++-- 3 files changed, 148 insertions(+), 7 deletions(-) diff --git a/spec/compiler/codegen/is_a_spec.cr b/spec/compiler/codegen/is_a_spec.cr index 77289e3902e6..46250bc81400 100644 --- a/spec/compiler/codegen/is_a_spec.cr +++ b/spec/compiler/codegen/is_a_spec.cr @@ -687,6 +687,78 @@ describe "Codegen: is_a?" do )).to_i.should eq(2) end + it "does is_a?(generic type) for nested generic inheritance (1) (#9660)" do + run(%( + class Cxx + end + + class Foo(T) + end + + class Bar(T) < Foo(T) + end + + class Baz < Bar(Cxx) + end + + Baz.new.is_a?(Foo) + ), inject_primitives: false).to_b.should be_true + end + + it "does is_a?(generic type) for nested generic inheritance (2)" do + run(%( + class Cxx + end + + class Foo(T) + end + + class Bar(T) < Foo(T) + end + + class Baz(T) < Bar(T) + end + + Baz(Cxx).new.is_a?(Foo) + ), inject_primitives: false).to_b.should be_true + end + + it "does is_a?(generic type) for nested generic inheritance, through upcast (1)" do + run(%( + class Cxx + end + + class Foo(T) + end + + class Bar(T) < Foo(T) + end + + class Baz < Bar(Cxx) + end + + Baz.new.as(Foo(Cxx)).is_a?(Bar) + ), inject_primitives: false).to_b.should be_true + end + + it "does is_a?(generic type) for nested generic inheritance, through upcast (2)" do + run(%( + class Cxx + end + + class Foo(T) + end + + class Bar(T) < Foo(T) + end + + class Baz(T) < Bar(T) + end + + Baz(Cxx).new.as(Foo(Cxx)).is_a?(Bar) + ), inject_primitives: false).to_b.should be_true + end + it "doesn't consider generic type to be a generic type of a recursive alias (#3524)" do run(%( class Gen(T) diff --git a/spec/compiler/semantic/restrictions_spec.cr b/spec/compiler/semantic/restrictions_spec.cr index 08107002e1dd..3d9567972df8 100644 --- a/spec/compiler/semantic/restrictions_spec.cr +++ b/spec/compiler/semantic/restrictions_spec.cr @@ -60,6 +60,54 @@ describe "Restrictions" do mod.t("Axx+").restrict(mod.t("Mxx"), MatchContext.new(mod, mod)).should eq(mod.union_of(mod.t("Bxx+"), mod.t("Cxx+"))) end + + it "restricts class against uninstantiated generic base class through multiple inheritance (1) (#9660)" do + mod = Program.new + mod.semantic parse(" + class Axx(T); end + class Bxx(T) < Axx(T); end + class Cxx < Bxx(Int32); end + ") + + result = mod.t("Cxx").restrict(mod.t("Axx"), MatchContext.new(mod, mod)) + result.should eq(mod.t("Cxx")) + end + + it "restricts class against uninstantiated generic base class through multiple inheritance (2) (#9660)" do + mod = Program.new + mod.semantic parse(" + class Axx(T); end + class Bxx(T) < Axx(T); end + class Cxx(T) < Bxx(T); end + ") + + result = mod.generic_class("Cxx", mod.int32).restrict(mod.t("Axx"), MatchContext.new(mod, mod)) + result.should eq(mod.generic_class("Cxx", mod.int32)) + end + + it "restricts virtual generic class against uninstantiated generic subclass (1)" do + mod = Program.new + mod.semantic parse(" + class Axx(T); end + class Bxx(T) < Axx(T); end + class Cxx < Bxx(Int32); end + ") + + result = mod.generic_class("Axx", mod.int32).virtual_type.restrict(mod.generic_class("Bxx", mod.int32), MatchContext.new(mod, mod)) + result.should eq(mod.generic_class("Bxx", mod.int32).virtual_type) + end + + it "restricts virtual generic class against uninstantiated generic subclass (2)" do + mod = Program.new + mod.semantic parse(" + class Axx(T); end + class Bxx(T) < Axx(T); end + class Cxx(T) < Bxx(T); end + ") + + result = mod.generic_class("Axx", mod.int32).virtual_type.restrict(mod.generic_class("Bxx", mod.int32), MatchContext.new(mod, mod)) + result.should eq(mod.generic_class("Bxx", mod.int32).virtual_type) + end end describe "restriction_of?" do diff --git a/src/compiler/crystal/semantic/restrictions.cr b/src/compiler/crystal/semantic/restrictions.cr index 8fc128422667..74811e7c8356 100644 --- a/src/compiler/crystal/semantic/restrictions.cr +++ b/src/compiler/crystal/semantic/restrictions.cr @@ -460,6 +460,19 @@ module Crystal implements?(other.base_type) ? self : nil end + def restrict(other : GenericClassType, context) + parents.try &.each do |parent| + if parent.module? + return self if parent.restriction_of?(other, context.instantiated_type, context) + else + restricted = parent.restrict other, context + return self if restricted + end + end + + nil + end + def restrict(other : Union, context) # Match all concrete types first free_var_count = other.types.count do |other_type| @@ -728,7 +741,18 @@ module Crystal end def restrict(other : GenericType, context) - generic_type == other ? self : super + return self if generic_type == other + + parents.try &.each do |parent| + if parent.module? + return self if parent.restriction_of?(other, context.instantiated_type, context) + else + restricted = parent.restrict other, context + return self if restricted + end + end + + nil end def restrict(other : Generic, context) @@ -1016,14 +1040,11 @@ module Crystal elsif base_type.is_a?(GenericInstanceType) && other.is_a?(GenericType) # Consider the case of Foo(Int32) vs. Bar(T), with Bar(T) < Foo(T): # we want to return Bar(Int32), so we search in Bar's generic instantiations - other.each_instantiated_type do |instance| + types = other.instantiated_types.compact_map do |instance| next if instance.unbound? || instance.abstract? - - if instance.implements?(base_type) - return instance - end + instance.virtual_type if instance.implements?(base_type) end - nil + program.type_merge_union_of types else nil end