From b53026c058d1415857d611ccd7fd4aaf8c31b1f1 Mon Sep 17 00:00:00 2001 From: Venkatesh-Prasad Ranganath Date: Thu, 26 Dec 2024 14:13:15 -0600 Subject: [PATCH] Fixes #15269 Using the first type of a union type as the type of the result of `Enumerable#sum/product()` call can cause runtime failures, e.g. `[1, 10000000000_u64].sum/product` will result in an ``OverflowError``. A safer alternative is to flag/disallow the use of union types with `Enumerable#sum/product()` and recommend the use of `Enumerable#sum/product(initial)` with an initial value of the expected type of the sum/product call. --- spec/std/enumerable_spec.cr | 26 ++++++++++++++++++++++++++ src/enumerable.cr | 18 ++++++++++++++---- 2 files changed, 40 insertions(+), 4 deletions(-) diff --git a/spec/std/enumerable_spec.cr b/spec/std/enumerable_spec.cr index 084fe80dcf96..31984e2ff7a5 100644 --- a/spec/std/enumerable_spec.cr +++ b/spec/std/enumerable_spec.cr @@ -1,4 +1,5 @@ require "spec" +require "../spec_helper" require "spec/helpers/iterate" module SomeInterface; end @@ -1364,6 +1365,18 @@ describe "Enumerable" do it { [1, 2, 3].sum(4.5).should eq(10.5) } it { (1..3).sum { |x| x * 2 }.should eq(12) } it { (1..3).sum(1.5) { |x| x * 2 }.should eq(13.5) } + it { [1, 3_u64].sum(0_i32).should eq(4_u32) } + it { [1, 3].sum(0_u64).should eq(4_u64) } + it { [1, 10000000000_u64].sum(0_u64).should eq(10000000001) } + it "raises if union types are summed", tags: %w[slow] do + exc = assert_error <<-CRYSTAL, + require "prelude" + [1, 10000000000_u64].sum + CRYSTAL + "Enumerable#sum/product() does support Union types. Instead, " + + "use Enumerable#sum/product(initial) with an initial value of " + + "the expected type of the sum/product call." + end it "uses additive_identity from type" do typeof([1, 2, 3].sum).should eq(Int32) @@ -1405,6 +1418,19 @@ describe "Enumerable" do typeof([1.5, 2.5, 3.5].product).should eq(Float64) typeof([1, 2, 3].product(&.to_f)).should eq(Float64) end + + it { [1, 3_u64].product(3_i32).should eq(9_u32) } + it { [1, 3].product(3_u64).should eq(9_u64) } + it { [1, 10000000000_u64].product(3_u64).should eq(30000000000_u64) } + it "raises if union types are multiplied", tags: %w[slow] do + exc = assert_error <<-CRYSTAL, + require "prelude" + [1, 10000000000_u64].product + CRYSTAL + "Enumerable#sum/product() does support Union types. Instead, " + + "use Enumerable#sum/product(initial) with an initial value of " + + "the expected type of the sum/product call." + end end describe "first" do diff --git a/src/enumerable.cr b/src/enumerable.cr index 0993f38bbc4d..e7b2b917a8f3 100644 --- a/src/enumerable.cr +++ b/src/enumerable.cr @@ -1808,7 +1808,10 @@ module Enumerable(T) # Expects all types returned from the block to respond to `#+` method. # # This method calls `.additive_identity` on the yielded type to determine the - # type of the sum value. + # type of the sum value. Hence, it can fail to compile if + # `.additive_identity` fails to determine a safe type, e.g., in case of + # union types. In such cases, use `sum(initial)` with an initial value of + # the expected type of the sum value. # # If the collection is empty, returns `additive_identity`. # @@ -1886,8 +1889,11 @@ module Enumerable(T) # # Expects all types returned from the block to respond to `#*` method. # - # This method calls `.multiplicative_identity` on the element type to determine the - # type of the sum value. + # This method calls `.multiplicative_identity` on the element type to + # determine the type of the product value. Hence, it can fail to compile if + # `.multiplicative_identity` fails to determine a safe type, e.g., in case + # of union types. In such cases, use `product(initial)` with an initial + # value of the expected type of the product value. # # If the collection is empty, returns `multiplicative_identity`. # @@ -2292,7 +2298,11 @@ module Enumerable(T) # if the type is a union. def self.first {% if X.union? %} - {{X.union_types.first}} + {{ + raise("Enumerable#sum/product() does support Union types. " + + "Instead, use Enumerable#sum/product(initial) with an " + + "initial value of the expected type of the sum/product call.") + }} {% else %} X {% end %}