diff --git a/spec/compiler/semantic/while_spec.cr b/spec/compiler/semantic/while_spec.cr index e73dc792aa0b..e551047a9f4d 100644 --- a/spec/compiler/semantic/while_spec.cr +++ b/spec/compiler/semantic/while_spec.cr @@ -169,4 +169,32 @@ describe "Semantic: while" do a )) { nilable int32 } end + + it "rebinds condition variable after while body (#6158)" do + assert_type(%( + class Foo + @parent : self? + + def parent + @parent + end + end + + class Bar + def initialize(@parent : Foo) + end + + def parent + @parent + end + end + + a = Foo.new + b = Bar.new(a) + while b = b.parent + break if 1 == 1 + end + b + )) { nilable types["Foo"] } + end end diff --git a/src/compiler/crystal/semantic/main_visitor.cr b/src/compiler/crystal/semantic/main_visitor.cr index 8b02f747333c..225f5088c34d 100644 --- a/src/compiler/crystal/semantic/main_visitor.cr +++ b/src/compiler/crystal/semantic/main_visitor.cr @@ -2081,6 +2081,22 @@ module Crystal node.body.accept self end + # After while's body, bind variables *before* the condition to the + # ones after the body, because the loop will repeat. + # + # For example: + # + # x = exp + # # x starts with the type of exp + # while x = x.method + # # but after the loop, the x above (in x.method) + # # should now also get the type of x.method, recursively + # end + before_cond_vars.each do |name, before_cond_var| + var = @vars[name]? + before_cond_var.bind_to(var) if var && !var.same?(before_cond_var) + end + cond = node.cond.single_expression endless_while = cond.true_literal?