Skip to content

Commit

Permalink
Fix logic for subclass restricted against uninstantiated nested gener…
Browse files Browse the repository at this point in the history
…ic superclass (crystal-lang#10522)
  • Loading branch information
HertzDevil authored Mar 30, 2021
1 parent a0ce5f3 commit e7b46c4
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 7 deletions.
72 changes: 72 additions & 0 deletions spec/compiler/codegen/is_a_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,78 @@ describe "Codegen: is_a?" do
)).to_i.should eq(2)
end

it "does is_a?(generic type) for nested generic inheritance (1) (#9660)" do
run(%(
class Cxx
end
class Foo(T)
end
class Bar(T) < Foo(T)
end
class Baz < Bar(Cxx)
end
Baz.new.is_a?(Foo)
), inject_primitives: false).to_b.should be_true
end

it "does is_a?(generic type) for nested generic inheritance (2)" do
run(%(
class Cxx
end
class Foo(T)
end
class Bar(T) < Foo(T)
end
class Baz(T) < Bar(T)
end
Baz(Cxx).new.is_a?(Foo)
), inject_primitives: false).to_b.should be_true
end

it "does is_a?(generic type) for nested generic inheritance, through upcast (1)" do
run(%(
class Cxx
end
class Foo(T)
end
class Bar(T) < Foo(T)
end
class Baz < Bar(Cxx)
end
Baz.new.as(Foo(Cxx)).is_a?(Bar)
), inject_primitives: false).to_b.should be_true
end

it "does is_a?(generic type) for nested generic inheritance, through upcast (2)" do
run(%(
class Cxx
end
class Foo(T)
end
class Bar(T) < Foo(T)
end
class Baz(T) < Bar(T)
end
Baz(Cxx).new.as(Foo(Cxx)).is_a?(Bar)
), inject_primitives: false).to_b.should be_true
end

it "doesn't consider generic type to be a generic type of a recursive alias (#3524)" do
run(%(
class Gen(T)
Expand Down
48 changes: 48 additions & 0 deletions spec/compiler/semantic/restrictions_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,54 @@ describe "Restrictions" do

mod.t("Axx+").restrict(mod.t("Mxx"), MatchContext.new(mod, mod)).should eq(mod.union_of(mod.t("Bxx+"), mod.t("Cxx+")))
end

it "restricts class against uninstantiated generic base class through multiple inheritance (1) (#9660)" do
mod = Program.new
mod.semantic parse("
class Axx(T); end
class Bxx(T) < Axx(T); end
class Cxx < Bxx(Int32); end
")

result = mod.t("Cxx").restrict(mod.t("Axx"), MatchContext.new(mod, mod))
result.should eq(mod.t("Cxx"))
end

it "restricts class against uninstantiated generic base class through multiple inheritance (2) (#9660)" do
mod = Program.new
mod.semantic parse("
class Axx(T); end
class Bxx(T) < Axx(T); end
class Cxx(T) < Bxx(T); end
")

result = mod.generic_class("Cxx", mod.int32).restrict(mod.t("Axx"), MatchContext.new(mod, mod))
result.should eq(mod.generic_class("Cxx", mod.int32))
end

it "restricts virtual generic class against uninstantiated generic subclass (1)" do
mod = Program.new
mod.semantic parse("
class Axx(T); end
class Bxx(T) < Axx(T); end
class Cxx < Bxx(Int32); end
")

result = mod.generic_class("Axx", mod.int32).virtual_type.restrict(mod.generic_class("Bxx", mod.int32), MatchContext.new(mod, mod))
result.should eq(mod.generic_class("Bxx", mod.int32).virtual_type)
end

it "restricts virtual generic class against uninstantiated generic subclass (2)" do
mod = Program.new
mod.semantic parse("
class Axx(T); end
class Bxx(T) < Axx(T); end
class Cxx(T) < Bxx(T); end
")

result = mod.generic_class("Axx", mod.int32).virtual_type.restrict(mod.generic_class("Bxx", mod.int32), MatchContext.new(mod, mod))
result.should eq(mod.generic_class("Bxx", mod.int32).virtual_type)
end
end

describe "restriction_of?" do
Expand Down
35 changes: 28 additions & 7 deletions src/compiler/crystal/semantic/restrictions.cr
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,19 @@ module Crystal
implements?(other.base_type) ? self : nil
end

def restrict(other : GenericClassType, context)
parents.try &.each do |parent|
if parent.module?
return self if parent.restriction_of?(other, context.instantiated_type, context)
else
restricted = parent.restrict other, context
return self if restricted
end
end

nil
end

def restrict(other : Union, context)
# Match all concrete types first
free_var_count = other.types.count do |other_type|
Expand Down Expand Up @@ -728,7 +741,18 @@ module Crystal
end

def restrict(other : GenericType, context)
generic_type == other ? self : super
return self if generic_type == other

parents.try &.each do |parent|
if parent.module?
return self if parent.restriction_of?(other, context.instantiated_type, context)
else
restricted = parent.restrict other, context
return self if restricted
end
end

nil
end

def restrict(other : Generic, context)
Expand Down Expand Up @@ -1016,14 +1040,11 @@ module Crystal
elsif base_type.is_a?(GenericInstanceType) && other.is_a?(GenericType)
# Consider the case of Foo(Int32) vs. Bar(T), with Bar(T) < Foo(T):
# we want to return Bar(Int32), so we search in Bar's generic instantiations
other.each_instantiated_type do |instance|
types = other.instantiated_types.compact_map do |instance|
next if instance.unbound? || instance.abstract?

if instance.implements?(base_type)
return instance
end
instance.virtual_type if instance.implements?(base_type)
end
nil
program.type_merge_union_of types
else
nil
end
Expand Down

0 comments on commit e7b46c4

Please sign in to comment.