From d642d4f3f6ff30d17ae608f0f62e64235a925832 Mon Sep 17 00:00:00 2001 From: Quinton Miller Date: Tue, 6 Sep 2022 02:52:08 +0800 Subject: [PATCH] Hoist complex element expressions outside container literals (#12366) --- spec/compiler/normalize/array_literal_spec.cr | 45 ++++ spec/compiler/normalize/hash_literal_spec.cr | 47 +++- spec/spec_helper.cr | 16 ++ .../crystal/semantic/literal_expander.cr | 218 +++++++++++------- src/compiler/crystal/semantic/main_visitor.cr | 38 +-- 5 files changed, 247 insertions(+), 117 deletions(-) diff --git a/spec/compiler/normalize/array_literal_spec.cr b/spec/compiler/normalize/array_literal_spec.cr index b466785b6a3e..ab247427e144 100644 --- a/spec/compiler/normalize/array_literal_spec.cr +++ b/spec/compiler/normalize/array_literal_spec.cr @@ -56,4 +56,49 @@ describe "Normalize: array literal" do __temp_1 CR end + + it "hoists complex element expressions" do + assert_expand "[[1]]", <<-CR + __temp_1 = [1] + __temp_2 = ::Array(typeof(__temp_1)).unsafe_build(1) + __temp_3 = __temp_2.to_unsafe + __temp_3[0] = __temp_1 + __temp_2 + CR + end + + it "hoists complex element expressions, with splat" do + assert_expand "[*[1]]", <<-CR + __temp_1 = [1] + __temp_2 = ::Array(typeof(::Enumerable.element_type(__temp_1))).new(0) + __temp_2.concat(__temp_1) + __temp_2 + CR + end + + it "hoists complex element expressions, array-like" do + assert_expand_named "Foo{[1], *[2]}", <<-CR + __temp_1 = [1] + __temp_2 = [2] + __temp_3 = Foo.new + __temp_3 << __temp_1 + __temp_2.each do |__temp_4| + __temp_3 << __temp_4 + end + __temp_3 + CR + end + + it "hoists complex element expressions, array-like generic" do + assert_expand_named "Foo{[1], *[2]}", <<-CR, generic: "Foo" + __temp_1 = [1] + __temp_2 = [2] + __temp_3 = Foo(typeof(__temp_1, ::Enumerable.element_type(__temp_2))).new + __temp_3 << __temp_1 + __temp_2.each do |__temp_4| + __temp_3 << __temp_4 + end + __temp_3 + CR + end end diff --git a/spec/compiler/normalize/hash_literal_spec.cr b/spec/compiler/normalize/hash_literal_spec.cr index 571e0f528e29..67f95424368a 100644 --- a/spec/compiler/normalize/hash_literal_spec.cr +++ b/spec/compiler/normalize/hash_literal_spec.cr @@ -6,10 +6,53 @@ describe "Normalize: hash literal" do end it "normalizes non-empty with of" do - assert_expand "{1 => 2, 3 => 4} of Int => Float", "__temp_1 = ::Hash(Int, Float).new\n__temp_1[1] = 2\n__temp_1[3] = 4\n__temp_1" + assert_expand "{1 => 2, 3 => 4} of Int => Float", <<-CR + __temp_1 = ::Hash(Int, Float).new + __temp_1[1] = 2 + __temp_1[3] = 4 + __temp_1 + CR end it "normalizes non-empty without of" do - assert_expand "{1 => 2, 3 => 4}", "__temp_1 = ::Hash(typeof(1, 3), typeof(2, 4)).new\n__temp_1[1] = 2\n__temp_1[3] = 4\n__temp_1" + assert_expand "{1 => 2, 3 => 4}", <<-CR + __temp_1 = ::Hash(typeof(1, 3), typeof(2, 4)).new + __temp_1[1] = 2 + __temp_1[3] = 4 + __temp_1 + CR + end + + it "hoists complex element expressions" do + assert_expand "{[1] => 2, 3 => [4]}", <<-CR + __temp_1 = [1] + __temp_2 = [4] + __temp_3 = ::Hash(typeof(__temp_1, 3), typeof(2, __temp_2)).new + __temp_3[__temp_1] = 2 + __temp_3[3] = __temp_2 + __temp_3 + CR + end + + it "hoists complex element expressions, hash-like" do + assert_expand_named "Foo{[1] => 2, 3 => [4]}", <<-CR + __temp_1 = [1] + __temp_2 = [4] + __temp_3 = Foo.new + __temp_3[__temp_1] = 2 + __temp_3[3] = __temp_2 + __temp_3 + CR + end + + it "hoists complex element expressions, hash-like generic" do + assert_expand_named "Foo{[1] => 2, 3 => [4]}", <<-CR, generic: "Foo" + __temp_1 = [1] + __temp_2 = [4] + __temp_3 = Foo(typeof(__temp_1, 3), typeof(2, __temp_2)).new + __temp_3[__temp_1] = 2 + __temp_3[3] = __temp_2 + __temp_3 + CR end end diff --git a/spec/spec_helper.cr b/spec/spec_helper.cr index 026fab0a9ab7..9bf1b0ca05bb 100644 --- a/spec/spec_helper.cr +++ b/spec/spec_helper.cr @@ -106,6 +106,22 @@ def assert_expand_third(from : String, to, *, flags = nil, file = __FILE__, line assert_expand node, to, flags: flags, file: file, line: line end +def assert_expand_named(from : String, to, *, generic = nil, flags = nil, file = __FILE__, line = __LINE__) + program = new_program + program.flags.concat(flags.split) if flags + from_nodes = Parser.parse(from) + generic_type = generic.path if generic + case from_nodes + when ArrayLiteral + to_nodes = LiteralExpander.new(program).expand_named(from_nodes, generic_type) + when HashLiteral + to_nodes = LiteralExpander.new(program).expand_named(from_nodes, generic_type) + else + fail "Expected: ArrayLiteral | HashLiteral, got: #{from_nodes.class}", file: file, line: line + end + to_nodes.to_s.strip.should eq(to.strip), file: file, line: line +end + def assert_error(str, message = nil, *, inject_primitives = false, flags = nil, file = __FILE__, line = __LINE__) expect_raises TypeException, message, file, line do semantic str, inject_primitives: inject_primitives, flags: flags diff --git a/src/compiler/crystal/semantic/literal_expander.cr b/src/compiler/crystal/semantic/literal_expander.cr index 70aa3593225a..74c59209838b 100644 --- a/src/compiler/crystal/semantic/literal_expander.cr +++ b/src/compiler/crystal/semantic/literal_expander.cr @@ -33,17 +33,20 @@ module Crystal # # To: # - # ary = ::Array(typeof(1, ::Enumerable.element_type(exp2), ::Enumerable.element_type(exp3), 4)).new(2) + # temp1 = exp2 + # temp2 = exp3 + # ary = ::Array(typeof(1, ::Enumerable.element_type(temp1), ::Enumerable.element_type(temp2), 4)).new(2) # ary << 1 - # ary.concat(exp2) - # ary.concat(exp3) + # ary.concat(temp1) + # ary.concat(temp2) # ary << 4 # ary def expand(node : ArrayLiteral) + elem_temp_vars, elem_temp_var_count = complex_elem_temp_vars(node.elements) if node_of = node.of type_var = node_of else - type_var = typeof_exp(node) + type_var = typeof_exp(node, elem_temp_vars) end capacity = node.elements.count { |elem| !elem.is_a?(Splat) } @@ -55,18 +58,25 @@ module Crystal ary_instance = Call.new(generic, "new", NumberLiteral.new(capacity).at(node)).at(node) - exps = Array(ASTNode).new(node.elements.size + 2) + exps = Array(ASTNode).new(node.elements.size + elem_temp_var_count + 2) + elem_temp_vars.try &.each_with_index do |elem_temp_var, i| + next unless elem_temp_var + elem_exp = node.elements[i] + elem_exp = elem_exp.exp if elem_exp.is_a?(Splat) + exps << Assign.new(elem_temp_var, elem_exp.clone).at(elem_temp_var) + end exps << Assign.new(ary_var.clone, ary_instance).at(node) - node.elements.each do |elem| + node.elements.each_with_index do |elem, i| + temp_var = elem_temp_vars.try &.[i] if elem.is_a?(Splat) - exps << Call.new(ary_var.clone, "concat", elem.exp.clone).at(node) + exps << Call.new(ary_var.clone, "concat", (temp_var || elem.exp).clone).at(node) else - exps << Call.new(ary_var.clone, "<<", elem.clone).at(node) + exps << Call.new(ary_var.clone, "<<", (temp_var || elem).clone).at(node) end end - exps << ary_var.clone + exps << ary_var Expressions.new(exps).at(node) elsif capacity.zero? @@ -79,12 +89,18 @@ module Crystal buffer = Call.new(ary_var, "to_unsafe").at(node) buffer_var = new_temp_var.at(node) - exps = Array(ASTNode).new(node.elements.size + 3) + exps = Array(ASTNode).new(node.elements.size + elem_temp_var_count + 3) + elem_temp_vars.try &.each_with_index do |elem_temp_var, i| + next unless elem_temp_var + elem_exp = node.elements[i] + exps << Assign.new(elem_temp_var, elem_exp.clone).at(elem_temp_var) + end exps << Assign.new(ary_var.clone, ary_instance).at(node) exps << Assign.new(buffer_var, buffer).at(node) node.elements.each_with_index do |elem, i| - exps << Call.new(buffer_var.clone, "[]=", NumberLiteral.new(i).at(node), elem.clone).at(node) + temp_var = elem_temp_vars.try &.[i] + exps << Call.new(buffer_var.clone, "[]=", NumberLiteral.new(i).at(node), (temp_var || elem).clone).at(node) end exps << ary_var.clone @@ -93,12 +109,34 @@ module Crystal end end - def typeof_exp(node : ArrayLiteral) - type_exps = node.elements.map do |elem| + def complex_elem_temp_vars(elems : Array, &) + temp_vars = nil + count = 0 + + elems.each_with_index do |elem, i| + elem = yield elem + elem = elem.exp if elem.is_a?(Splat) + next if elem.is_a?(Var) || elem.is_a?(InstanceVar) || elem.is_a?(ClassVar) || elem.simple_literal? + + temp_vars ||= Array(Var?).new(elems.size, nil) + temp_vars[i] = new_temp_var.at(elem) + count += 1 + end + + {temp_vars, count} + end + + def complex_elem_temp_vars(elems : Array(ASTNode)) + complex_elem_temp_vars(elems, &.itself) + end + + def typeof_exp(node : ArrayLiteral, temp_vars : Array(Var?)? = nil) + type_exps = node.elements.map_with_index do |elem, i| + temp_var = temp_vars.try &.[i] if elem.is_a?(Splat) - Call.new(Path.global("Enumerable").at(node), "element_type", elem.exp.clone).at(node) + Call.new(Path.global("Enumerable").at(node), "element_type", (temp_var || elem.exp).clone).at(node) else - elem.clone + (temp_var || elem).clone end end @@ -132,79 +170,55 @@ module Crystal # ary << 4 # ary # - # If `T` is an uninstantiated generic type, its type argument is injected by - # `MainVisitor` with a `typeof`. - def expand_named(node : ArrayLiteral) - temp_var = new_temp_var - - constructor = Call.new(node.name, "new").at(node) + # If `T` is an uninstantiated generic type, injects a `typeof` with the + # element types. + def expand_named(node : ArrayLiteral, generic_type : ASTNode?) + elem_temp_vars, elem_temp_var_count = complex_elem_temp_vars(node.elements) + if generic_type + type_of = typeof_exp(node, elem_temp_vars) + node_name = Generic.new(generic_type, type_of).at(node.location) + else + node_name = node.name + end + constructor = Call.new(node_name, "new").at(node) if node.elements.empty? return constructor end - exps = Array(ASTNode).new(node.elements.size + 2) - exps << Assign.new(temp_var.clone, constructor).at(node) - node.elements.each do |elem| + ary_var = new_temp_var.at(node) + + exps = Array(ASTNode).new(node.elements.size + elem_temp_var_count + 2) + elem_temp_vars.try &.each_with_index do |elem_temp_var, i| + next unless elem_temp_var + elem_exp = node.elements[i] + elem_exp = elem_exp.exp if elem_exp.is_a?(Splat) + exps << Assign.new(elem_temp_var, elem_exp.clone).at(elem_temp_var) + end + exps << Assign.new(ary_var.clone, constructor).at(node) + + node.elements.each_with_index do |elem, i| + temp_var = elem_temp_vars.try &.[i] if elem.is_a?(Splat) yield_var = new_temp_var - each_body = Call.new(temp_var.clone, "<<", yield_var.clone).at(node) + each_body = Call.new(ary_var.clone, "<<", yield_var.clone).at(node) each_block = Block.new(args: [yield_var], body: each_body).at(node) - exps << Call.new(elem.exp.clone, "each", block: each_block).at(node) + exps << Call.new((temp_var || elem.exp).clone, "each", block: each_block).at(node) else - exps << Call.new(temp_var.clone, "<<", elem.clone).at(node) + exps << Call.new(ary_var.clone, "<<", (temp_var || elem).clone).at(node) end end - exps << temp_var.clone + + exps << ary_var Expressions.new(exps).at(node) end - # Converts a hash literal into creating a Hash and assigning keys and values: - # - # From: - # - # {} of K => V - # - # To: - # - # Hash(K, V).new + # Converts a hash literal into creating a Hash and assigning keys and values. # - # From: - # - # {a => b, c => d} - # - # To: - # - # hash = ::Hash(typeof(a, c), typeof(b, d)).new - # hash[a] = b - # hash[c] = d - # hash + # Equivalent to a hash-like literal using `::Hash`. def expand(node : HashLiteral) - if of = node.of - type_vars = [of.key, of.value] of ASTNode - else - typeof_key = TypeOf.new(node.entries.map { |x| x.key.clone.as(ASTNode) }).at(node) - typeof_value = TypeOf.new(node.entries.map { |x| x.value.clone.as(ASTNode) }).at(node) - type_vars = [typeof_key, typeof_value] of ASTNode - end - - generic = Generic.new(Path.global("Hash"), type_vars).at(node) - constructor = Call.new(generic, "new").at(node) - - if node.entries.empty? - constructor - else - temp_var = new_temp_var - - exps = Array(ASTNode).new(node.entries.size + 2) - exps << Assign.new(temp_var.clone, constructor).at(node) - node.entries.each do |entry| - exps << Call.new(temp_var.clone, "[]=", entry.key.clone, entry.value.clone).at(node) - end - exps << temp_var.clone - Expressions.new(exps).at(node) - end + expand_named(node, Path.global("Hash")) end # Converts a hash-like literal into creating a Hash and assigning keys and values: @@ -219,6 +233,14 @@ module Crystal # # From: # + # {} of K => V + # + # To: + # + # ::Hash(K, V).new + # + # From: + # # T{a => b, c => d} # # To: @@ -228,24 +250,54 @@ module Crystal # hash[c] = d # hash # - # If `T` is an uninstantiated generic type, its type arguments are injected - # by `MainVisitor` with `typeof`s. - def expand_named(node : HashLiteral) - constructor = Call.new(node.name, "new").at(node) + # Or if `T` is an uninstantiated generic type: + # + # hash = T(typeof(a, c), typeof(b, d)).new + # hash[a] = b + # hash[c] = d + # hash + def expand_named(node : HashLiteral, generic_type : ASTNode?) + key_temp_vars, key_temp_var_count = complex_elem_temp_vars(node.entries, &.key) + value_temp_vars, value_temp_var_count = complex_elem_temp_vars(node.entries, &.value) - if node.entries.empty? - return constructor + if of = node.of + # `generic_type` is nil here + type_vars = [of.key, of.value] of ASTNode + generic = Generic.new(Path.global("Hash"), type_vars).at(node) + elsif generic_type + # `node.entries` is non-empty here + typeof_key = TypeOf.new(node.entries.map_with_index { |x, i| (key_temp_vars.try(&.[i]) || x.key).clone.as(ASTNode) }).at(node) + typeof_value = TypeOf.new(node.entries.map_with_index { |x, i| (value_temp_vars.try(&.[i]) || x.value).clone.as(ASTNode) }).at(node) + generic = Generic.new(generic_type, [typeof_key, typeof_value] of ASTNode).at(node) + else + generic = node.name end - temp_var = new_temp_var + constructor = Call.new(generic, "new").at(node) + return constructor if node.entries.empty? + + hash_var = new_temp_var + + exps = Array(ASTNode).new(node.entries.size + key_temp_var_count + value_temp_var_count + 2) + key_temp_vars.try &.each_with_index do |key_temp_var, i| + next unless key_temp_var + key_exp = node.entries[i].key + exps << Assign.new(key_temp_var, key_exp.clone).at(key_temp_var) + end + value_temp_vars.try &.each_with_index do |value_temp_var, i| + next unless value_temp_var + value_exp = node.entries[i].value + exps << Assign.new(value_temp_var, value_exp.clone).at(value_temp_var) + end + exps << Assign.new(hash_var.clone, constructor).at(node) - exps = Array(ASTNode).new(node.entries.size + 2) - exps << Assign.new(temp_var.clone, constructor).at(node) - node.entries.each do |entry| - exps << Call.new(temp_var.clone, "[]=", entry.key.clone, entry.value.clone).at(node) + node.entries.each_with_index do |entry, i| + key_exp = key_temp_vars.try(&.[i]) || entry.key + value_exp = value_temp_vars.try(&.[i]) || entry.value + exps << Call.new(hash_var.clone, "[]=", key_exp.clone, value_exp.clone).at(node) end - exps << temp_var.clone + exps << hash_var Expressions.new(exps).at(node) end diff --git a/src/compiler/crystal/semantic/main_visitor.cr b/src/compiler/crystal/semantic/main_visitor.cr index 8758fb48d549..0e6826e72be7 100644 --- a/src/compiler/crystal/semantic/main_visitor.cr +++ b/src/compiler/crystal/semantic/main_visitor.cr @@ -2919,20 +2919,8 @@ module Crystal if name = node.name name.accept self type = name.type.instance_type - - case type - when GenericClassType - generic_type = TypeNode.new(type).at(node.location) - type_of = @program.literal_expander.typeof_exp(node) - generic = Generic.new(generic_type, type_of).at(node.location) - node.name = generic - when GenericClassInstanceType - # Nothing - else - node.name = TypeNode.new(name.type).at(node.location) - end - - expand_named(node) + generic_type = TypeNode.new(type).at(node.location) if type.is_a?(GenericClassType) + expand_named(node, generic_type) else expand(node) end @@ -2942,22 +2930,8 @@ module Crystal if name = node.name name.accept self type = name.type.instance_type - - case type - when GenericClassType - generic_type = TypeNode.new(type).at(node.location) - type_of_keys = TypeOf.new(node.entries.map { |x| x.key.as(ASTNode) }).at(node.location) - type_of_values = TypeOf.new(node.entries.map { |x| x.value.as(ASTNode) }).at(node.location) - generic = Generic.new(generic_type, [type_of_keys, type_of_values] of ASTNode).at(node.location) - - node.name = generic - when GenericClassInstanceType - # Nothing - else - node.name = TypeNode.new(name.type).at(node.location) - end - - expand_named(node) + generic_type = TypeNode.new(type).at(node.location) if type.is_a?(GenericClassType) + expand_named(node, generic_type) else expand(node) end @@ -3022,8 +2996,8 @@ module Crystal expand(node) { @program.literal_expander.expand node } end - def expand_named(node) - expand(node) { @program.literal_expander.expand_named node } + def expand_named(node, generic_type) + expand(node) { @program.literal_expander.expand_named node, generic_type } end def expand(node)