diff --git a/spec/compiler/macro/macro_methods_spec.cr b/spec/compiler/macro/macro_methods_spec.cr index 888e348d1f15..ac693fa57183 100644 --- a/spec/compiler/macro/macro_methods_spec.cr +++ b/spec/compiler/macro/macro_methods_spec.cr @@ -169,6 +169,10 @@ module Crystal assert_macro "", "{{1 <=> -1}}", [] of ASTNode, "1" end + it "executes <=> (returns nil)" do + assert_macro "", "{{0.0/0.0 <=> -1}}", [] of ASTNode, "nil" + end + it "executes +" do assert_macro "", "{{1 + 2}}", [] of ASTNode, "3" end diff --git a/spec/std/array_spec.cr b/spec/std/array_spec.cr index 6575ddafb7fb..6c0a9b462491 100644 --- a/spec/std/array_spec.cr +++ b/spec/std/array_spec.cr @@ -10,6 +10,19 @@ private class BadSortingClass end end +private class Spaceship + getter value : Float64 + + def initialize(@value : Float64, @return_nil = false) + end + + def <=>(other : Spaceship) + return nil if @return_nil + + value <=> other.value + end +end + describe "Array" do describe "new" do it "creates with default value" do @@ -1059,6 +1072,37 @@ describe "Array" do [1, 2, 3].sort { 1 } Array.new(10) { BadSortingClass.new }.sort end + + it "can sort just by using <=> (#6608)" do + spaceships = [ + Spaceship.new(2), + Spaceship.new(0), + Spaceship.new(1), + Spaceship.new(3), + ] + + sorted = spaceships.sort + 4.times do |i| + sorted[i].value.should eq(i) + end + end + + it "raises if <=> returns nil" do + spaceships = [ + Spaceship.new(2, return_nil: true), + Spaceship.new(0, return_nil: true), + ] + + expect_raises(ArgumentError) do + spaceships.sort + end + end + + it "raises if sort block returns nil" do + expect_raises(ArgumentError) do + [1, 2].sort { nil } + end + end end describe "sort!" do @@ -1079,6 +1123,37 @@ describe "Array" do b = a.sort { -1 } a.should eq(b) end + + it "can sort! just by using <=> (#6608)" do + spaceships = [ + Spaceship.new(2), + Spaceship.new(0), + Spaceship.new(1), + Spaceship.new(3), + ] + + spaceships.sort! + 4.times do |i| + spaceships[i].value.should eq(i) + end + end + + it "raises if <=> returns nil" do + spaceships = [ + Spaceship.new(2, return_nil: true), + Spaceship.new(0, return_nil: true), + ] + + expect_raises(ArgumentError) do + spaceships.sort! + end + end + + it "raises if sort! block returns nil" do + expect_raises(ArgumentError) do + [1, 2].sort! { nil } + end + end end describe "sort_by" do diff --git a/spec/std/comparable_spec.cr b/spec/std/comparable_spec.cr index fcaa34133ba4..12756170cec1 100644 --- a/spec/std/comparable_spec.cr +++ b/spec/std/comparable_spec.cr @@ -3,10 +3,12 @@ require "spec" private class ComparableTestClass include Comparable(Int) - def initialize(@value : Int32) + def initialize(@value : Int32, @return_nil = false) end def <=>(other : Int) + return nil if @return_nil + @value <=> other end end @@ -15,7 +17,30 @@ describe Comparable do it "can compare against Int (#2461)" do obj = ComparableTestClass.new(4) (obj == 3).should be_false + (obj == 4).should be_true + (obj < 3).should be_false + (obj < 4).should be_false + (obj > 3).should be_true + (obj > 4).should be_false + + (obj <= 3).should be_false + (obj <= 4).should be_true + (obj <= 5).should be_true + + (obj >= 3).should be_true + (obj >= 4).should be_true + (obj >= 5).should be_false + end + + it "checks for nil" do + obj = ComparableTestClass.new(4, return_nil: true) + + (obj < 1).should be_false + (obj <= 1).should be_false + (obj == 1).should be_false + (obj >= 1).should be_false + (obj > 1).should be_false end end diff --git a/spec/std/float_spec.cr b/spec/std/float_spec.cr index c2be91711435..eb0824958443 100644 --- a/spec/std/float_spec.cr +++ b/spec/std/float_spec.cr @@ -291,4 +291,54 @@ describe "Float" do Float64::EPSILON.unsafe_as(UInt64).should eq 0x3cb0000000000000_u64 Float64::MIN_POSITIVE.unsafe_as(UInt64).should eq 0x0010000000000000_u64 end + + it "returns nil in <=> for NaN values (Float32)" do + nan = Float32::NAN + + (1_f32 <=> nan).should be_nil + (1_f64 <=> nan).should be_nil + + (1_u8 <=> nan).should be_nil + (1_u16 <=> nan).should be_nil + (1_u32 <=> nan).should be_nil + (1_u64 <=> nan).should be_nil + (1_i8 <=> nan).should be_nil + (1_i16 <=> nan).should be_nil + (1_i32 <=> nan).should be_nil + (1_i64 <=> nan).should be_nil + + (nan <=> 1_u8).should be_nil + (nan <=> 1_u16).should be_nil + (nan <=> 1_u32).should be_nil + (nan <=> 1_u64).should be_nil + (nan <=> 1_i8).should be_nil + (nan <=> 1_i16).should be_nil + (nan <=> 1_i32).should be_nil + (nan <=> 1_i64).should be_nil + end + + it "returns nil in <=> for NaN values (Float64)" do + nan = Float64::NAN + + (1_f32 <=> nan).should be_nil + (1_f64 <=> nan).should be_nil + + (1_u8 <=> nan).should be_nil + (1_u16 <=> nan).should be_nil + (1_u32 <=> nan).should be_nil + (1_u64 <=> nan).should be_nil + (1_i8 <=> nan).should be_nil + (1_i16 <=> nan).should be_nil + (1_i32 <=> nan).should be_nil + (1_i64 <=> nan).should be_nil + + (nan <=> 1_u8).should be_nil + (nan <=> 1_u16).should be_nil + (nan <=> 1_u32).should be_nil + (nan <=> 1_u64).should be_nil + (nan <=> 1_i8).should be_nil + (nan <=> 1_i16).should be_nil + (nan <=> 1_i32).should be_nil + (nan <=> 1_i64).should be_nil + end end diff --git a/src/array.cr b/src/array.cr index 41bea57a9f79..a7c7e8d2470d 100644 --- a/src/array.cr +++ b/src/array.cr @@ -1631,7 +1631,11 @@ class Array(T) # b # => [3, 2, 1] # a # => [3, 1, 2] # ``` - def sort(&block : T, T -> Int32) : Array(T) + def sort(&block : T, T -> U) : Array(T) forall U + {% unless U <= Int32? %} + {% raise "expected block to return Int32 or Nil, not #{U}" %} + {% end %} + dup.sort! &block end @@ -1661,7 +1665,11 @@ class Array(T) # a.sort! { |a, b| b <=> a } # a # => [3, 2, 1] # ``` - def sort!(&block : T, T -> Int32) : Array(T) + def sort!(&block : T, T -> U) : Array(T) forall U + {% unless U <= Int32? %} + {% raise "expected block to return Int32 or Nil, not #{U}" %} + {% end %} + Array.intro_sort!(@buffer, @size, block) self end @@ -1939,14 +1947,14 @@ class Array(T) v, c = a[p], p while c < (n - 1) / 2 c = 2 * (c + 1) - c -= 1 if a[c] < a[c - 1] - break unless v <= a[c] + c -= 1 if cmp(a[c], a[c - 1]) < 0 + break unless cmp(v, a[c]) <= 0 a[p] = a[c] p = c end if n & 1 == 0 && c == n / 2 - 1 c = 2 * c + 1 - if v < a[c] + if cmp(v, a[c]) < 0 a[p] = a[c] p = c end @@ -1956,17 +1964,17 @@ class Array(T) protected def self.center_median!(a, n) b, c = a + n / 2, a + n - 1 - if a.value <= b.value - if b.value <= c.value + if cmp(a.value, b.value) <= 0 + if cmp(b.value, c.value) <= 0 return - elsif a.value <= c.value + elsif cmp(a.value, c.value) <= 0 b.value, c.value = c.value, b.value else a.value, b.value, c.value = c.value, a.value, b.value end - elsif a.value <= c.value + elsif cmp(a.value, c.value) <= 0 a.value, b.value = b.value, a.value - elsif b.value <= c.value + elsif cmp(b.value, c.value) <= 0 a.value, b.value, c.value = b.value, c.value, a.value else a.value, c.value = c.value, a.value @@ -1976,11 +1984,11 @@ class Array(T) protected def self.partition_for_quick_sort!(a, n) v, l, r = a[n / 2], a + 1, a + n - 1 loop do - while l.value < v + while cmp(l.value, v) < 0 l += 1 end r -= 1 - while v < r.value + while cmp(v, r.value) < 0 r -= 1 end return l unless l < r @@ -1994,7 +2002,7 @@ class Array(T) l = a + i v = l.value p = l - 1 - while l > a && v < p.value + while l > a && cmp(v, p.value) < 0 l.value = p.value l, p = p, p - 1 end @@ -2037,14 +2045,14 @@ class Array(T) v, c = a[p], p while c < (n - 1) / 2 c = 2 * (c + 1) - c -= 1 if comp.call(a[c], a[c - 1]) < 0 - break unless comp.call(v, a[c]) <= 0 + c -= 1 if cmp(a[c], a[c - 1], comp) < 0 + break unless cmp(v, a[c], comp) <= 0 a[p] = a[c] p = c end if n & 1 == 0 && c == n / 2 - 1 c = 2 * c + 1 - if comp.call(v, a[c]) < 0 + if cmp(v, a[c], comp) < 0 a[p] = a[c] p = c end @@ -2054,17 +2062,17 @@ class Array(T) protected def self.center_median!(a, n, comp) b, c = a + n / 2, a + n - 1 - if comp.call(a.value, b.value) <= 0 - if comp.call(b.value, c.value) <= 0 + if cmp(a.value, b.value, comp) <= 0 + if cmp(b.value, c.value, comp) <= 0 return - elsif comp.call(a.value, c.value) <= 0 + elsif cmp(a.value, c.value, comp) <= 0 b.value, c.value = c.value, b.value else a.value, b.value, c.value = c.value, a.value, b.value end - elsif comp.call(a.value, c.value) <= 0 + elsif cmp(a.value, c.value, comp) <= 0 a.value, b.value = b.value, a.value - elsif comp.call(b.value, c.value) <= 0 + elsif cmp(b.value, c.value, comp) <= 0 a.value, b.value, c.value = b.value, c.value, a.value else a.value, c.value = c.value, a.value @@ -2074,11 +2082,11 @@ class Array(T) protected def self.partition_for_quick_sort!(a, n, comp) v, l, r = a[n / 2], a + 1, a + n - 1 loop do - while l < a + n && comp.call(l.value, v) < 0 + while l < a + n && cmp(l.value, v, comp) < 0 l += 1 end r -= 1 - while r >= a && comp.call(v, r.value) < 0 + while r >= a && cmp(v, r.value, comp) < 0 r -= 1 end return l unless l < r @@ -2092,7 +2100,7 @@ class Array(T) l = a + i v = l.value p = l - 1 - while l > a && comp.call(v, p.value) < 0 + while l > a && cmp(v, p.value, comp) < 0 l.value = p.value l, p = p, p - 1 end @@ -2100,6 +2108,18 @@ class Array(T) end end + protected def self.cmp(v1, v2) + v = v1 <=> v2 + raise ArgumentError.new("Comparison of #{v1} and #{v2} failed") if v.nil? + v + end + + protected def self.cmp(v1, v2, block) + v = block.call(v1, v2) + raise ArgumentError.new("Comparison of #{v1} and #{v2} failed") if v.nil? + v + end + protected def to_lookup_hash to_lookup_hash { |elem| elem } end diff --git a/src/comparable.cr b/src/comparable.cr index ff6dbb014dd8..ba7157df20c8 100644 --- a/src/comparable.cr +++ b/src/comparable.cr @@ -1,22 +1,41 @@ # The `Comparable` mixin is used by classes whose objects may be ordered. # # Including types must provide an `<=>` method, which compares the receiver against -# another object, returning a negative number, `0`, or a positive number depending -# on whether the receiver is less than, equal to, or greater than the other object. +# another object, returning: +# - a negative number if `self` is less than the other object +# - a positive number if `self` is greater than the other object +# - `0` if `self` is equal to the other object +# - `nil` if `self` and the other object are not comparable # -# `Comparable` uses `<=>` to implement the conventional comparison operators (`<`, `<=`, `==`, `>=`, and `>`). +# `Comparable` uses `<=>` to implement the conventional comparison operators +# (`<`, `<=`, `==`, `>=`, and `>`). All of these return `false` when `<=>` +# returns `nil`. +# +# Note that returning `nil` is only useful when defining a partial comparable +# relationship. One such example is float values: they are generally comparable, +# except for `NaN`. If none of the values of a type are comparable between each +# other, `Comparable` shouldn't be included. +# +# NOTE: When `nil` is returned from `<=>`, `Array#sort` and related sorting +# methods will perform slightly slower. module Comparable(T) - # Compares this object to *other* based on the receiver's `<=>` method, returning `true` if it returns a negative number. + # Compares this object to *other* based on the receiver’s `<=>` method, + # returning `true` if it returns a negative number. def <(other : T) - (self <=> other) < 0 + cmp = self <=> other + cmp ? cmp < 0 : false end - # Compares this object to *other* based on the receiver's `<=>` method, returning `true` if it returns a negative number or `0`. + # Compares this object to *other* based on the receiver’s `<=>` method, + # returning `true` if it returns a value equal or less then `0`. def <=(other : T) - (self <=> other) <= 0 + cmp = self <=> other + cmp ? cmp <= 0 : false end - # Compares this object to *other* based on the receiver's `<=>` method, returning `true` if it returns `0`. + # Compares this object to *other* based on the receiver’s `<=>` method, + # returning `true` if it returns `0`. + # # Also returns `true` if this and *other* are the same object. def ==(other : T) if self.is_a?(Reference) @@ -27,23 +46,28 @@ module Comparable(T) return true if other.is_a?(Nil) && self.same?(other) end - (self <=> other) == 0 + cmp = self <=> other + cmp ? cmp == 0 : false end - # Compares this object to *other* based on the receiver's `<=>` method, returning `true` if it returns a positive number. + # Compares this object to *other* based on the receiver’s `<=>` method, + # returning `true` if it returns a value greater then `0`. def >(other : T) - (self <=> other) > 0 + cmp = self <=> other + cmp ? cmp > 0 : false end - # Compares this object to *other* based on the receiver's `<=>` method, returning `true` if it returns a positive number or `0`. + # Compares this object to *other* based on the receiver’s `<=>` method, + # returning `true` if it returns a value equal or greater than `0`. def >=(other : T) - (self <=> other) >= 0 + cmp = self <=> other + cmp ? cmp >= 0 : false end - # The comparison operator. - # - # Returns `-1`, `0` or `1` depending on whether `self` is less than *other*, equals *other* - # or is greater than *other*. + # The comparison operator. Returns `0` if the two objects are equal, + # a negative number if this object is considered less than *other*, + # a positive number if this object is considered greter than *other*, + # or `nil` if the two objects are not comparable. # # Subclasses define this method to provide class-specific ordering. # diff --git a/src/compiler/crystal/macros/methods.cr b/src/compiler/crystal/macros/methods.cr index 72aa0445fd72..ce7c4b9873c9 100644 --- a/src/compiler/crystal/macros/methods.cr +++ b/src/compiler/crystal/macros/methods.cr @@ -422,7 +422,9 @@ module Crystal when "<=" bool_bin_op(method, args) { |me, other| me <= other } when "<=>" - num_bin_op(method, args) { |me, other| me <=> other } + num_bin_op(method, args) do |me, other| + (me <=> other) || (return NilLiteral.new) + end when "+" if args.empty? self diff --git a/src/compiler/crystal/syntax/location.cr b/src/compiler/crystal/syntax/location.cr index ac95ad491346..a9dbdbe569e1 100644 --- a/src/compiler/crystal/syntax/location.cr +++ b/src/compiler/crystal/syntax/location.cr @@ -1,8 +1,6 @@ -require "../../../partial_comparable" - # A location of an `ASTnode`, including its filename, line number and column number. class Crystal::Location - include PartialComparable(self) + include Comparable(self) getter line_number getter column_number diff --git a/src/docs_main.cr b/src/docs_main.cr index e535d07b79de..d6940c7b4984 100644 --- a/src/docs_main.cr +++ b/src/docs_main.cr @@ -47,7 +47,6 @@ require "./gzip" require "./ini" require "./levenshtein" require "./option_parser" -require "./partial_comparable" require "./random/**" require "./readline" require "./semantic_version" diff --git a/src/int.cr b/src/int.cr index 95aafbb268d3..d6389c8bdf15 100644 --- a/src/int.cr +++ b/src/int.cr @@ -237,6 +237,12 @@ struct Int end end + def <=>(other : Int) : Int32 + # Override Number#<=> because when comparing + # Int vs Int there's no way we can return `nil` + self > other ? 1 : (self < other ? -1 : 0) + end + def abs self >= 0 ? self : -self end diff --git a/src/number.cr b/src/number.cr index 18b3cab5d6ed..72e264a5114a 100644 --- a/src/number.cr +++ b/src/number.cr @@ -163,9 +163,16 @@ struct Number # The comparison operator. # - # Returns `-1`, `0` or `1` depending on whether `self` is less than *other*, equals *other* - # or is greater than *other*. - def <=>(other) + # Returns: + # - `-1` if `self` is less than *other* + # - `0` if `self` is equal to *other* + # - `-1` if `self` is greater than *other* + # - `nil` if self is `NaN` or *other* is `NaN`, because `NaN` values are not comparable + def <=>(other) : Int32? + # NaN can't be compared to other numbers + return nil if self.is_a?(Float) && self.nan? + return nil if other.is_a?(Float) && other.nan? + self > other ? 1 : (self < other ? -1 : 0) end