From d34deb5b9b876a4d5c5457739b08a8d0ba94622d Mon Sep 17 00:00:00 2001 From: Ary Borenszweig Date: Sun, 26 Aug 2018 11:15:57 -0300 Subject: [PATCH 1/5] Let Array#sort only use `<=>`, and let `<=>` return `nil` for partial comparability. - Float <=> Float now will return `nil` for NaN - Removed PartialComparable --- spec/compiler/macro/macro_methods_spec.cr | 4 ++ spec/std/array_spec.cr | 75 ++++++++++++++++++++++ spec/std/comparable_spec.cr | 27 +++++++- spec/std/float_spec.cr | 50 +++++++++++++++ src/array.cr | 76 ++++++++++++++++------- src/comparable.cr | 61 +++++++++++++----- src/compiler/crystal/macros/methods.cr | 6 +- src/compiler/crystal/syntax/location.cr | 4 +- src/docs_main.cr | 1 - src/int.cr | 6 ++ src/number.cr | 4 ++ 11 files changed, 267 insertions(+), 47 deletions(-) diff --git a/spec/compiler/macro/macro_methods_spec.cr b/spec/compiler/macro/macro_methods_spec.cr index 888e348d1f15..ac693fa57183 100644 --- a/spec/compiler/macro/macro_methods_spec.cr +++ b/spec/compiler/macro/macro_methods_spec.cr @@ -169,6 +169,10 @@ module Crystal assert_macro "", "{{1 <=> -1}}", [] of ASTNode, "1" end + it "executes <=> (returns nil)" do + assert_macro "", "{{0.0/0.0 <=> -1}}", [] of ASTNode, "nil" + end + it "executes +" do assert_macro "", "{{1 + 2}}", [] of ASTNode, "3" end diff --git a/spec/std/array_spec.cr b/spec/std/array_spec.cr index 6575ddafb7fb..6c0a9b462491 100644 --- a/spec/std/array_spec.cr +++ b/spec/std/array_spec.cr @@ -10,6 +10,19 @@ private class BadSortingClass end end +private class Spaceship + getter value : Float64 + + def initialize(@value : Float64, @return_nil = false) + end + + def <=>(other : Spaceship) + return nil if @return_nil + + value <=> other.value + end +end + describe "Array" do describe "new" do it "creates with default value" do @@ -1059,6 +1072,37 @@ describe "Array" do [1, 2, 3].sort { 1 } Array.new(10) { BadSortingClass.new }.sort end + + it "can sort just by using <=> (#6608)" do + spaceships = [ + Spaceship.new(2), + Spaceship.new(0), + Spaceship.new(1), + Spaceship.new(3), + ] + + sorted = spaceships.sort + 4.times do |i| + sorted[i].value.should eq(i) + end + end + + it "raises if <=> returns nil" do + spaceships = [ + Spaceship.new(2, return_nil: true), + Spaceship.new(0, return_nil: true), + ] + + expect_raises(ArgumentError) do + spaceships.sort + end + end + + it "raises if sort block returns nil" do + expect_raises(ArgumentError) do + [1, 2].sort { nil } + end + end end describe "sort!" do @@ -1079,6 +1123,37 @@ describe "Array" do b = a.sort { -1 } a.should eq(b) end + + it "can sort! just by using <=> (#6608)" do + spaceships = [ + Spaceship.new(2), + Spaceship.new(0), + Spaceship.new(1), + Spaceship.new(3), + ] + + spaceships.sort! + 4.times do |i| + spaceships[i].value.should eq(i) + end + end + + it "raises if <=> returns nil" do + spaceships = [ + Spaceship.new(2, return_nil: true), + Spaceship.new(0, return_nil: true), + ] + + expect_raises(ArgumentError) do + spaceships.sort! + end + end + + it "raises if sort! block returns nil" do + expect_raises(ArgumentError) do + [1, 2].sort! { nil } + end + end end describe "sort_by" do diff --git a/spec/std/comparable_spec.cr b/spec/std/comparable_spec.cr index fcaa34133ba4..12756170cec1 100644 --- a/spec/std/comparable_spec.cr +++ b/spec/std/comparable_spec.cr @@ -3,10 +3,12 @@ require "spec" private class ComparableTestClass include Comparable(Int) - def initialize(@value : Int32) + def initialize(@value : Int32, @return_nil = false) end def <=>(other : Int) + return nil if @return_nil + @value <=> other end end @@ -15,7 +17,30 @@ describe Comparable do it "can compare against Int (#2461)" do obj = ComparableTestClass.new(4) (obj == 3).should be_false + (obj == 4).should be_true + (obj < 3).should be_false + (obj < 4).should be_false + (obj > 3).should be_true + (obj > 4).should be_false + + (obj <= 3).should be_false + (obj <= 4).should be_true + (obj <= 5).should be_true + + (obj >= 3).should be_true + (obj >= 4).should be_true + (obj >= 5).should be_false + end + + it "checks for nil" do + obj = ComparableTestClass.new(4, return_nil: true) + + (obj < 1).should be_false + (obj <= 1).should be_false + (obj == 1).should be_false + (obj >= 1).should be_false + (obj > 1).should be_false end end diff --git a/spec/std/float_spec.cr b/spec/std/float_spec.cr index c2be91711435..eb0824958443 100644 --- a/spec/std/float_spec.cr +++ b/spec/std/float_spec.cr @@ -291,4 +291,54 @@ describe "Float" do Float64::EPSILON.unsafe_as(UInt64).should eq 0x3cb0000000000000_u64 Float64::MIN_POSITIVE.unsafe_as(UInt64).should eq 0x0010000000000000_u64 end + + it "returns nil in <=> for NaN values (Float32)" do + nan = Float32::NAN + + (1_f32 <=> nan).should be_nil + (1_f64 <=> nan).should be_nil + + (1_u8 <=> nan).should be_nil + (1_u16 <=> nan).should be_nil + (1_u32 <=> nan).should be_nil + (1_u64 <=> nan).should be_nil + (1_i8 <=> nan).should be_nil + (1_i16 <=> nan).should be_nil + (1_i32 <=> nan).should be_nil + (1_i64 <=> nan).should be_nil + + (nan <=> 1_u8).should be_nil + (nan <=> 1_u16).should be_nil + (nan <=> 1_u32).should be_nil + (nan <=> 1_u64).should be_nil + (nan <=> 1_i8).should be_nil + (nan <=> 1_i16).should be_nil + (nan <=> 1_i32).should be_nil + (nan <=> 1_i64).should be_nil + end + + it "returns nil in <=> for NaN values (Float64)" do + nan = Float64::NAN + + (1_f32 <=> nan).should be_nil + (1_f64 <=> nan).should be_nil + + (1_u8 <=> nan).should be_nil + (1_u16 <=> nan).should be_nil + (1_u32 <=> nan).should be_nil + (1_u64 <=> nan).should be_nil + (1_i8 <=> nan).should be_nil + (1_i16 <=> nan).should be_nil + (1_i32 <=> nan).should be_nil + (1_i64 <=> nan).should be_nil + + (nan <=> 1_u8).should be_nil + (nan <=> 1_u16).should be_nil + (nan <=> 1_u32).should be_nil + (nan <=> 1_u64).should be_nil + (nan <=> 1_i8).should be_nil + (nan <=> 1_i16).should be_nil + (nan <=> 1_i32).should be_nil + (nan <=> 1_i64).should be_nil + end end diff --git a/src/array.cr b/src/array.cr index 41bea57a9f79..d7cc41f971a4 100644 --- a/src/array.cr +++ b/src/array.cr @@ -1631,7 +1631,15 @@ class Array(T) # b # => [3, 2, 1] # a # => [3, 1, 2] # ``` - def sort(&block : T, T -> Int32) : Array(T) + def sort(&block : T, T -> U) : Array(T) forall U + # TODO: use a better way to check U < Int32? + {% begin %} + {% block_type = U.union? ? U.union_types.first { |t| t != Nil } : U %} + {% if block_type != Int32 && block_type != Nil %} + {% raise "expected block to return Int32 or Nil, not #{U}" %} + {% end %} + {% end %} + dup.sort! &block end @@ -1661,7 +1669,15 @@ class Array(T) # a.sort! { |a, b| b <=> a } # a # => [3, 2, 1] # ``` - def sort!(&block : T, T -> Int32) : Array(T) + def sort!(&block : T, T -> U) : Array(T) forall U + # TODO: use a better way to check U < Int32? + {% begin %} + {% block_type = U.union? ? U.union_types.first { |t| t != Nil } : U %} + {% if block_type != Int32 && block_type != Nil %} + {% raise "expected block to return Int32 or Nil, not #{U}" %} + {% end %} + {% end %} + Array.intro_sort!(@buffer, @size, block) self end @@ -1939,14 +1955,14 @@ class Array(T) v, c = a[p], p while c < (n - 1) / 2 c = 2 * (c + 1) - c -= 1 if a[c] < a[c - 1] - break unless v <= a[c] + c -= 1 if cmp(a[c], a[c - 1]) < 0 + break unless cmp(v, a[c]) <= 0 a[p] = a[c] p = c end if n & 1 == 0 && c == n / 2 - 1 c = 2 * c + 1 - if v < a[c] + if cmp(v, a[c]) < 0 a[p] = a[c] p = c end @@ -1956,17 +1972,17 @@ class Array(T) protected def self.center_median!(a, n) b, c = a + n / 2, a + n - 1 - if a.value <= b.value - if b.value <= c.value + if cmp(a.value, b.value) <= 0 + if cmp(b.value, c.value) <= 0 return - elsif a.value <= c.value + elsif cmp(a.value, c.value) <= 0 b.value, c.value = c.value, b.value else a.value, b.value, c.value = c.value, a.value, b.value end - elsif a.value <= c.value + elsif cmp(a.value, c.value) <= 0 a.value, b.value = b.value, a.value - elsif b.value <= c.value + elsif cmp(b.value, c.value) <= 0 a.value, b.value, c.value = b.value, c.value, a.value else a.value, c.value = c.value, a.value @@ -1976,11 +1992,11 @@ class Array(T) protected def self.partition_for_quick_sort!(a, n) v, l, r = a[n / 2], a + 1, a + n - 1 loop do - while l.value < v + while cmp(l.value, v) < 0 l += 1 end r -= 1 - while v < r.value + while cmp(v, r.value) < 0 r -= 1 end return l unless l < r @@ -1994,7 +2010,7 @@ class Array(T) l = a + i v = l.value p = l - 1 - while l > a && v < p.value + while l > a && cmp(v, p.value) < 0 l.value = p.value l, p = p, p - 1 end @@ -2037,14 +2053,14 @@ class Array(T) v, c = a[p], p while c < (n - 1) / 2 c = 2 * (c + 1) - c -= 1 if comp.call(a[c], a[c - 1]) < 0 - break unless comp.call(v, a[c]) <= 0 + c -= 1 if cmp(a[c], a[c - 1], comp) < 0 + break unless cmp(v, a[c], comp) <= 0 a[p] = a[c] p = c end if n & 1 == 0 && c == n / 2 - 1 c = 2 * c + 1 - if comp.call(v, a[c]) < 0 + if cmp(v, a[c], comp) < 0 a[p] = a[c] p = c end @@ -2054,17 +2070,17 @@ class Array(T) protected def self.center_median!(a, n, comp) b, c = a + n / 2, a + n - 1 - if comp.call(a.value, b.value) <= 0 - if comp.call(b.value, c.value) <= 0 + if cmp(a.value, b.value, comp) <= 0 + if cmp(b.value, c.value, comp) <= 0 return - elsif comp.call(a.value, c.value) <= 0 + elsif cmp(a.value, c.value, comp) <= 0 b.value, c.value = c.value, b.value else a.value, b.value, c.value = c.value, a.value, b.value end - elsif comp.call(a.value, c.value) <= 0 + elsif cmp(a.value, c.value, comp) <= 0 a.value, b.value = b.value, a.value - elsif comp.call(b.value, c.value) <= 0 + elsif cmp(b.value, c.value, comp) <= 0 a.value, b.value, c.value = b.value, c.value, a.value else a.value, c.value = c.value, a.value @@ -2074,11 +2090,11 @@ class Array(T) protected def self.partition_for_quick_sort!(a, n, comp) v, l, r = a[n / 2], a + 1, a + n - 1 loop do - while l < a + n && comp.call(l.value, v) < 0 + while l < a + n && cmp(l.value, v, comp) < 0 l += 1 end r -= 1 - while r >= a && comp.call(v, r.value) < 0 + while r >= a && cmp(v, r.value, comp) < 0 r -= 1 end return l unless l < r @@ -2092,7 +2108,7 @@ class Array(T) l = a + i v = l.value p = l - 1 - while l > a && comp.call(v, p.value) < 0 + while l > a && cmp(v, p.value, comp) < 0 l.value = p.value l, p = p, p - 1 end @@ -2100,6 +2116,18 @@ class Array(T) end end + protected def self.cmp(v1, v2) + v = v1 <=> v2 + raise ArgumentError.new("Comparison of #{v1} and #{v2} failed") if v.nil? + v + end + + protected def self.cmp(v1, v2, block) + v = block.call(v1, v2) + raise ArgumentError.new("Comparison of #{v1} and #{v2} failed") if v.nil? + v + end + protected def to_lookup_hash to_lookup_hash { |elem| elem } end diff --git a/src/comparable.cr b/src/comparable.cr index ff6dbb014dd8..999d053ca868 100644 --- a/src/comparable.cr +++ b/src/comparable.cr @@ -1,22 +1,38 @@ # The `Comparable` mixin is used by classes whose objects may be ordered. # # Including types must provide an `<=>` method, which compares the receiver against -# another object, returning a negative number, `0`, or a positive number depending -# on whether the receiver is less than, equal to, or greater than the other object. +# another object, returning: +# - a value less than zero if `self` is less than the other object +# - a value greater than zero if `self` is greater than the other object +# - zero if `self` is equal to the other object +# - `nil` if `self` and the other object are not comparable # -# `Comparable` uses `<=>` to implement the conventional comparison operators (`<`, `<=`, `==`, `>=`, and `>`). +# `Comparable` uses `<=>` to implement the conventional comparison operators +# (`<`, `<=`, `==`, `>=`, and `>`). +# +# Note that returning `nil` is only useful when defining a partial comparable +# relationship. One such example is float values: they are generally comparable, +# except for `NaN`. If none of the values of a type are comparable between each +# other, `Comparable` shouldn't be included. +# +# NOTE: when `nil` is returned from `<=>`, `Array#sort` and related sorting +# methods will perform slightly slower. module Comparable(T) - # Compares this object to *other* based on the receiver's `<=>` method, returning `true` if it returns a negative number. + # Compares this object to *other* based on the receiver’s `<=>` method, + # returning `true` if it returns a value less than zero. def <(other : T) - (self <=> other) < 0 + _compare_with other, &.<(0) end - # Compares this object to *other* based on the receiver's `<=>` method, returning `true` if it returns a negative number or `0`. + # Compares this object to *other* based on the receiver’s `<=>` method, + # returning `true` if it returns a value equal or less then zero. def <=(other : T) - (self <=> other) <= 0 + _compare_with other, &.<=(0) end - # Compares this object to *other* based on the receiver's `<=>` method, returning `true` if it returns `0`. + # Compares this object to *other* based on the receiver’s `<=>` method, + # returning `true` if it returns `0`. + # # Also returns `true` if this and *other* are the same object. def ==(other : T) if self.is_a?(Reference) @@ -27,23 +43,25 @@ module Comparable(T) return true if other.is_a?(Nil) && self.same?(other) end - (self <=> other) == 0 + _compare_with other, &.==(0) end - # Compares this object to *other* based on the receiver's `<=>` method, returning `true` if it returns a positive number. + # Compares this object to *other* based on the receiver’s `<=>` method, + # returning `true` if it returns a value greater then zero. def >(other : T) - (self <=> other) > 0 + _compare_with other, &.>(0) end - # Compares this object to *other* based on the receiver's `<=>` method, returning `true` if it returns a positive number or `0`. + # Compares this object to *other* based on the receiver’s `<=>` method, + # returning `true` if it returns a value equal or greater than zero. def >=(other : T) - (self <=> other) >= 0 + _compare_with other, &.>=(0) end - # The comparison operator. - # - # Returns `-1`, `0` or `1` depending on whether `self` is less than *other*, equals *other* - # or is greater than *other*. + # The comparison operator. Returns `0` if the two objects are equal, + # a negative number if this object is considered less than *other*, + # a positive number if this object is considered greter than *other*, + # or `nil` if the two objects are not comparable. # # Subclasses define this method to provide class-specific ordering. # @@ -57,4 +75,13 @@ module Comparable(T) # [3, 1, 2].sort { |x, y| x <=> y } # => [1, 2, 3] # ``` abstract def <=>(other : T) + + private def _compare_with(other : T) + cmp = self <=> other + if cmp + yield cmp + else + false + end + end end diff --git a/src/compiler/crystal/macros/methods.cr b/src/compiler/crystal/macros/methods.cr index 72aa0445fd72..9b6487f15986 100644 --- a/src/compiler/crystal/macros/methods.cr +++ b/src/compiler/crystal/macros/methods.cr @@ -422,7 +422,11 @@ module Crystal when "<=" bool_bin_op(method, args) { |me, other| me <= other } when "<=>" - num_bin_op(method, args) { |me, other| me <=> other } + num_bin_op(method, args) do |me, other| + v = me <=> other + return NilLiteral.new if v.nil? + v + end when "+" if args.empty? self diff --git a/src/compiler/crystal/syntax/location.cr b/src/compiler/crystal/syntax/location.cr index ac95ad491346..a9dbdbe569e1 100644 --- a/src/compiler/crystal/syntax/location.cr +++ b/src/compiler/crystal/syntax/location.cr @@ -1,8 +1,6 @@ -require "../../../partial_comparable" - # A location of an `ASTnode`, including its filename, line number and column number. class Crystal::Location - include PartialComparable(self) + include Comparable(self) getter line_number getter column_number diff --git a/src/docs_main.cr b/src/docs_main.cr index e535d07b79de..d6940c7b4984 100644 --- a/src/docs_main.cr +++ b/src/docs_main.cr @@ -47,7 +47,6 @@ require "./gzip" require "./ini" require "./levenshtein" require "./option_parser" -require "./partial_comparable" require "./random/**" require "./readline" require "./semantic_version" diff --git a/src/int.cr b/src/int.cr index 95aafbb268d3..835fc8e8adfe 100644 --- a/src/int.cr +++ b/src/int.cr @@ -237,6 +237,12 @@ struct Int end end + def <=>(other : Int) + # Override Number#<=> because when comparing + # Int vs Int there's no way we can return `nil` + self > other ? 1 : (self < other ? -1 : 0) + end + def abs self >= 0 ? self : -self end diff --git a/src/number.cr b/src/number.cr index 18b3cab5d6ed..d748309c56bb 100644 --- a/src/number.cr +++ b/src/number.cr @@ -166,6 +166,10 @@ struct Number # Returns `-1`, `0` or `1` depending on whether `self` is less than *other*, equals *other* # or is greater than *other*. def <=>(other) + # NaN can't be compared to other numbers + return nil if self.is_a?(Float) && self.nan? + return nil if other.is_a?(Float) && other.nan? + self > other ? 1 : (self < other ? -1 : 0) end From f01a0e9fe8520d3733a01ae521284f4447637536 Mon Sep 17 00:00:00 2001 From: Ary Borenszweig Date: Mon, 11 Feb 2019 17:09:02 -0300 Subject: [PATCH 2/5] Comparable: inline `_compare_with` private method --- src/comparable.cr | 24 ++++++++++-------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/src/comparable.cr b/src/comparable.cr index 999d053ca868..9ae887b2fa6c 100644 --- a/src/comparable.cr +++ b/src/comparable.cr @@ -21,13 +21,15 @@ module Comparable(T) # Compares this object to *other* based on the receiver’s `<=>` method, # returning `true` if it returns a value less than zero. def <(other : T) - _compare_with other, &.<(0) + cmp = self <=> other + cmp ? cmp < 0 : false end # Compares this object to *other* based on the receiver’s `<=>` method, # returning `true` if it returns a value equal or less then zero. def <=(other : T) - _compare_with other, &.<=(0) + cmp = self <=> other + cmp ? cmp <= 0 : false end # Compares this object to *other* based on the receiver’s `<=>` method, @@ -43,19 +45,22 @@ module Comparable(T) return true if other.is_a?(Nil) && self.same?(other) end - _compare_with other, &.==(0) + cmp = self <=> other + cmp ? cmp == 0 : false end # Compares this object to *other* based on the receiver’s `<=>` method, # returning `true` if it returns a value greater then zero. def >(other : T) - _compare_with other, &.>(0) + cmp = self <=> other + cmp ? cmp > 0 : false end # Compares this object to *other* based on the receiver’s `<=>` method, # returning `true` if it returns a value equal or greater than zero. def >=(other : T) - _compare_with other, &.>=(0) + cmp = self <=> other + cmp ? cmp >= 0 : false end # The comparison operator. Returns `0` if the two objects are equal, @@ -75,13 +80,4 @@ module Comparable(T) # [3, 1, 2].sort { |x, y| x <=> y } # => [1, 2, 3] # ``` abstract def <=>(other : T) - - private def _compare_with(other : T) - cmp = self <=> other - if cmp - yield cmp - else - false - end - end end From debd4f97852df18bbbfa395c383d6a7fc001f82d Mon Sep 17 00:00:00 2001 From: Ary Borenszweig Date: Wed, 13 Mar 2019 09:41:13 -0300 Subject: [PATCH 3/5] Simpler macro code --- src/array.cr | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/src/array.cr b/src/array.cr index d7cc41f971a4..a7c7e8d2470d 100644 --- a/src/array.cr +++ b/src/array.cr @@ -1632,12 +1632,8 @@ class Array(T) # a # => [3, 1, 2] # ``` def sort(&block : T, T -> U) : Array(T) forall U - # TODO: use a better way to check U < Int32? - {% begin %} - {% block_type = U.union? ? U.union_types.first { |t| t != Nil } : U %} - {% if block_type != Int32 && block_type != Nil %} - {% raise "expected block to return Int32 or Nil, not #{U}" %} - {% end %} + {% unless U <= Int32? %} + {% raise "expected block to return Int32 or Nil, not #{U}" %} {% end %} dup.sort! &block @@ -1670,12 +1666,8 @@ class Array(T) # a # => [3, 2, 1] # ``` def sort!(&block : T, T -> U) : Array(T) forall U - # TODO: use a better way to check U < Int32? - {% begin %} - {% block_type = U.union? ? U.union_types.first { |t| t != Nil } : U %} - {% if block_type != Int32 && block_type != Nil %} - {% raise "expected block to return Int32 or Nil, not #{U}" %} - {% end %} + {% unless U <= Int32? %} + {% raise "expected block to return Int32 or Nil, not #{U}" %} {% end %} Array.intro_sort!(@buffer, @size, block) From dc0ba8141f53214cd98837a7cd9796a223f6d5af Mon Sep 17 00:00:00 2001 From: Ary Borenszweig Date: Wed, 13 Mar 2019 09:52:23 -0300 Subject: [PATCH 4/5] Use `0` instead of zero --- src/comparable.cr | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/comparable.cr b/src/comparable.cr index 9ae887b2fa6c..65ab61dead31 100644 --- a/src/comparable.cr +++ b/src/comparable.cr @@ -2,9 +2,9 @@ # # Including types must provide an `<=>` method, which compares the receiver against # another object, returning: -# - a value less than zero if `self` is less than the other object -# - a value greater than zero if `self` is greater than the other object -# - zero if `self` is equal to the other object +# - a negative number if `self` is less than the other object +# - a positive number if `self` is greater than the other object +# - `0` if `self` is equal to the other object # - `nil` if `self` and the other object are not comparable # # `Comparable` uses `<=>` to implement the conventional comparison operators @@ -19,14 +19,14 @@ # methods will perform slightly slower. module Comparable(T) # Compares this object to *other* based on the receiver’s `<=>` method, - # returning `true` if it returns a value less than zero. + # returning `true` if it returns a negative number. def <(other : T) cmp = self <=> other cmp ? cmp < 0 : false end # Compares this object to *other* based on the receiver’s `<=>` method, - # returning `true` if it returns a value equal or less then zero. + # returning `true` if it returns a value equal or less then `0`. def <=(other : T) cmp = self <=> other cmp ? cmp <= 0 : false @@ -50,14 +50,14 @@ module Comparable(T) end # Compares this object to *other* based on the receiver’s `<=>` method, - # returning `true` if it returns a value greater then zero. + # returning `true` if it returns a value greater then `0`. def >(other : T) cmp = self <=> other cmp ? cmp > 0 : false end # Compares this object to *other* based on the receiver’s `<=>` method, - # returning `true` if it returns a value equal or greater than zero. + # returning `true` if it returns a value equal or greater than `0`. def >=(other : T) cmp = self <=> other cmp ? cmp >= 0 : false From 0140f77d00250ba8083c8f6683f16947655dfc52 Mon Sep 17 00:00:00 2001 From: Ary Borenszweig Date: Wed, 13 Mar 2019 20:29:52 -0300 Subject: [PATCH 5/5] Small doc and code improvements --- src/comparable.cr | 5 +++-- src/compiler/crystal/macros/methods.cr | 4 +--- src/int.cr | 2 +- src/number.cr | 9 ++++++--- 4 files changed, 11 insertions(+), 9 deletions(-) diff --git a/src/comparable.cr b/src/comparable.cr index 65ab61dead31..ba7157df20c8 100644 --- a/src/comparable.cr +++ b/src/comparable.cr @@ -8,14 +8,15 @@ # - `nil` if `self` and the other object are not comparable # # `Comparable` uses `<=>` to implement the conventional comparison operators -# (`<`, `<=`, `==`, `>=`, and `>`). +# (`<`, `<=`, `==`, `>=`, and `>`). All of these return `false` when `<=>` +# returns `nil`. # # Note that returning `nil` is only useful when defining a partial comparable # relationship. One such example is float values: they are generally comparable, # except for `NaN`. If none of the values of a type are comparable between each # other, `Comparable` shouldn't be included. # -# NOTE: when `nil` is returned from `<=>`, `Array#sort` and related sorting +# NOTE: When `nil` is returned from `<=>`, `Array#sort` and related sorting # methods will perform slightly slower. module Comparable(T) # Compares this object to *other* based on the receiver’s `<=>` method, diff --git a/src/compiler/crystal/macros/methods.cr b/src/compiler/crystal/macros/methods.cr index 9b6487f15986..ce7c4b9873c9 100644 --- a/src/compiler/crystal/macros/methods.cr +++ b/src/compiler/crystal/macros/methods.cr @@ -423,9 +423,7 @@ module Crystal bool_bin_op(method, args) { |me, other| me <= other } when "<=>" num_bin_op(method, args) do |me, other| - v = me <=> other - return NilLiteral.new if v.nil? - v + (me <=> other) || (return NilLiteral.new) end when "+" if args.empty? diff --git a/src/int.cr b/src/int.cr index 835fc8e8adfe..d6389c8bdf15 100644 --- a/src/int.cr +++ b/src/int.cr @@ -237,7 +237,7 @@ struct Int end end - def <=>(other : Int) + def <=>(other : Int) : Int32 # Override Number#<=> because when comparing # Int vs Int there's no way we can return `nil` self > other ? 1 : (self < other ? -1 : 0) diff --git a/src/number.cr b/src/number.cr index d748309c56bb..72e264a5114a 100644 --- a/src/number.cr +++ b/src/number.cr @@ -163,9 +163,12 @@ struct Number # The comparison operator. # - # Returns `-1`, `0` or `1` depending on whether `self` is less than *other*, equals *other* - # or is greater than *other*. - def <=>(other) + # Returns: + # - `-1` if `self` is less than *other* + # - `0` if `self` is equal to *other* + # - `-1` if `self` is greater than *other* + # - `nil` if self is `NaN` or *other* is `NaN`, because `NaN` values are not comparable + def <=>(other) : Int32? # NaN can't be compared to other numbers return nil if self.is_a?(Float) && self.nan? return nil if other.is_a?(Float) && other.nan?