Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Let Array#sort only use <=>, and let <=> return nil for partial comparability #6611

Merged
merged 5 commits into from
Mar 15, 2019
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions spec/compiler/macro/macro_methods_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,10 @@ module Crystal
assert_macro "", "{{1 <=> -1}}", [] of ASTNode, "1"
end

it "executes <=> (returns nil)" do
assert_macro "", "{{0.0/0.0 <=> -1}}", [] of ASTNode, "nil"
end

it "executes +" do
assert_macro "", "{{1 + 2}}", [] of ASTNode, "3"
end
Expand Down
75 changes: 75 additions & 0 deletions spec/std/array_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,19 @@ private class BadSortingClass
end
end

private class Spaceship
getter value : Float64

def initialize(@value : Float64, @return_nil = false)
end

def <=>(other : Spaceship)
return nil if @return_nil

value <=> other.value
end
end

describe "Array" do
describe "new" do
it "creates with default value" do
Expand Down Expand Up @@ -1059,6 +1072,37 @@ describe "Array" do
[1, 2, 3].sort { 1 }
Array.new(10) { BadSortingClass.new }.sort
end

it "can sort just by using <=> (#6608)" do
spaceships = [
Spaceship.new(2),
Spaceship.new(0),
Spaceship.new(1),
Spaceship.new(3),
]

sorted = spaceships.sort
4.times do |i|
sorted[i].value.should eq(i)
end
end

it "raises if <=> returns nil" do
spaceships = [
Spaceship.new(2, return_nil: true),
Spaceship.new(0, return_nil: true),
]

expect_raises(ArgumentError) do
spaceships.sort
end
end

it "raises if sort block returns nil" do
expect_raises(ArgumentError) do
[1, 2].sort { nil }
end
end
end

describe "sort!" do
Expand All @@ -1079,6 +1123,37 @@ describe "Array" do
b = a.sort { -1 }
a.should eq(b)
end

it "can sort! just by using <=> (#6608)" do
spaceships = [
Spaceship.new(2),
Spaceship.new(0),
Spaceship.new(1),
Spaceship.new(3),
]

spaceships.sort!
4.times do |i|
spaceships[i].value.should eq(i)
end
end

it "raises if <=> returns nil" do
spaceships = [
Spaceship.new(2, return_nil: true),
Spaceship.new(0, return_nil: true),
]

expect_raises(ArgumentError) do
spaceships.sort!
end
end

it "raises if sort! block returns nil" do
expect_raises(ArgumentError) do
[1, 2].sort! { nil }
end
end
end

describe "sort_by" do
Expand Down
27 changes: 26 additions & 1 deletion spec/std/comparable_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ require "spec"
private class ComparableTestClass
include Comparable(Int)

def initialize(@value : Int32)
def initialize(@value : Int32, @return_nil = false)
end

def <=>(other : Int)
return nil if @return_nil

@value <=> other
end
end
Expand All @@ -15,7 +17,30 @@ describe Comparable do
it "can compare against Int (#2461)" do
obj = ComparableTestClass.new(4)
(obj == 3).should be_false
(obj == 4).should be_true

(obj < 3).should be_false
(obj < 4).should be_false

(obj > 3).should be_true
(obj > 4).should be_false

(obj <= 3).should be_false
(obj <= 4).should be_true
(obj <= 5).should be_true

(obj >= 3).should be_true
(obj >= 4).should be_true
(obj >= 5).should be_false
end

it "checks for nil" do
obj = ComparableTestClass.new(4, return_nil: true)

(obj < 1).should be_false
(obj <= 1).should be_false
(obj == 1).should be_false
(obj >= 1).should be_false
(obj > 1).should be_false
end
end
50 changes: 50 additions & 0 deletions spec/std/float_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -291,4 +291,54 @@ describe "Float" do
Float64::EPSILON.unsafe_as(UInt64).should eq 0x3cb0000000000000_u64
Float64::MIN_POSITIVE.unsafe_as(UInt64).should eq 0x0010000000000000_u64
end

it "returns nil in <=> for NaN values (Float32)" do
nan = Float32::NAN

(1_f32 <=> nan).should be_nil
(1_f64 <=> nan).should be_nil

(1_u8 <=> nan).should be_nil
(1_u16 <=> nan).should be_nil
(1_u32 <=> nan).should be_nil
(1_u64 <=> nan).should be_nil
(1_i8 <=> nan).should be_nil
(1_i16 <=> nan).should be_nil
(1_i32 <=> nan).should be_nil
(1_i64 <=> nan).should be_nil

(nan <=> 1_u8).should be_nil
(nan <=> 1_u16).should be_nil
(nan <=> 1_u32).should be_nil
(nan <=> 1_u64).should be_nil
(nan <=> 1_i8).should be_nil
(nan <=> 1_i16).should be_nil
(nan <=> 1_i32).should be_nil
(nan <=> 1_i64).should be_nil
end

it "returns nil in <=> for NaN values (Float64)" do
nan = Float64::NAN

(1_f32 <=> nan).should be_nil
(1_f64 <=> nan).should be_nil

(1_u8 <=> nan).should be_nil
(1_u16 <=> nan).should be_nil
(1_u32 <=> nan).should be_nil
(1_u64 <=> nan).should be_nil
(1_i8 <=> nan).should be_nil
(1_i16 <=> nan).should be_nil
(1_i32 <=> nan).should be_nil
(1_i64 <=> nan).should be_nil

(nan <=> 1_u8).should be_nil
(nan <=> 1_u16).should be_nil
(nan <=> 1_u32).should be_nil
(nan <=> 1_u64).should be_nil
(nan <=> 1_i8).should be_nil
(nan <=> 1_i16).should be_nil
(nan <=> 1_i32).should be_nil
(nan <=> 1_i64).should be_nil
end
end
68 changes: 44 additions & 24 deletions src/array.cr
Original file line number Diff line number Diff line change
Expand Up @@ -1631,7 +1631,11 @@ class Array(T)
# b # => [3, 2, 1]
# a # => [3, 1, 2]
# ```
def sort(&block : T, T -> Int32) : Array(T)
def sort(&block : T, T -> U) : Array(T) forall U
{% unless U <= Int32? %}
{% raise "expected block to return Int32 or Nil, not #{U}" %}
{% end %}

dup.sort! &block
end

Expand Down Expand Up @@ -1661,7 +1665,11 @@ class Array(T)
# a.sort! { |a, b| b <=> a }
# a # => [3, 2, 1]
# ```
def sort!(&block : T, T -> Int32) : Array(T)
def sort!(&block : T, T -> U) : Array(T) forall U
{% unless U <= Int32? %}
{% raise "expected block to return Int32 or Nil, not #{U}" %}
{% end %}

Array.intro_sort!(@buffer, @size, block)
self
end
Expand Down Expand Up @@ -1939,14 +1947,14 @@ class Array(T)
v, c = a[p], p
while c < (n - 1) / 2
c = 2 * (c + 1)
c -= 1 if a[c] < a[c - 1]
break unless v <= a[c]
c -= 1 if cmp(a[c], a[c - 1]) < 0
break unless cmp(v, a[c]) <= 0
a[p] = a[c]
p = c
end
if n & 1 == 0 && c == n / 2 - 1
c = 2 * c + 1
if v < a[c]
if cmp(v, a[c]) < 0
a[p] = a[c]
p = c
end
Expand All @@ -1956,17 +1964,17 @@ class Array(T)

protected def self.center_median!(a, n)
b, c = a + n / 2, a + n - 1
if a.value <= b.value
if b.value <= c.value
if cmp(a.value, b.value) <= 0
if cmp(b.value, c.value) <= 0
return
elsif a.value <= c.value
elsif cmp(a.value, c.value) <= 0
b.value, c.value = c.value, b.value
else
a.value, b.value, c.value = c.value, a.value, b.value
end
elsif a.value <= c.value
elsif cmp(a.value, c.value) <= 0
a.value, b.value = b.value, a.value
elsif b.value <= c.value
elsif cmp(b.value, c.value) <= 0
a.value, b.value, c.value = b.value, c.value, a.value
else
a.value, c.value = c.value, a.value
Expand All @@ -1976,11 +1984,11 @@ class Array(T)
protected def self.partition_for_quick_sort!(a, n)
v, l, r = a[n / 2], a + 1, a + n - 1
loop do
while l.value < v
while cmp(l.value, v) < 0
l += 1
end
r -= 1
while v < r.value
while cmp(v, r.value) < 0
r -= 1
end
return l unless l < r
Expand All @@ -1994,7 +2002,7 @@ class Array(T)
l = a + i
v = l.value
p = l - 1
while l > a && v < p.value
while l > a && cmp(v, p.value) < 0
l.value = p.value
l, p = p, p - 1
end
Expand Down Expand Up @@ -2037,14 +2045,14 @@ class Array(T)
v, c = a[p], p
while c < (n - 1) / 2
c = 2 * (c + 1)
c -= 1 if comp.call(a[c], a[c - 1]) < 0
break unless comp.call(v, a[c]) <= 0
c -= 1 if cmp(a[c], a[c - 1], comp) < 0
break unless cmp(v, a[c], comp) <= 0
a[p] = a[c]
p = c
end
if n & 1 == 0 && c == n / 2 - 1
c = 2 * c + 1
if comp.call(v, a[c]) < 0
if cmp(v, a[c], comp) < 0
a[p] = a[c]
p = c
end
Expand All @@ -2054,17 +2062,17 @@ class Array(T)

protected def self.center_median!(a, n, comp)
b, c = a + n / 2, a + n - 1
if comp.call(a.value, b.value) <= 0
if comp.call(b.value, c.value) <= 0
if cmp(a.value, b.value, comp) <= 0
if cmp(b.value, c.value, comp) <= 0
return
elsif comp.call(a.value, c.value) <= 0
elsif cmp(a.value, c.value, comp) <= 0
b.value, c.value = c.value, b.value
else
a.value, b.value, c.value = c.value, a.value, b.value
end
elsif comp.call(a.value, c.value) <= 0
elsif cmp(a.value, c.value, comp) <= 0
a.value, b.value = b.value, a.value
elsif comp.call(b.value, c.value) <= 0
elsif cmp(b.value, c.value, comp) <= 0
a.value, b.value, c.value = b.value, c.value, a.value
else
a.value, c.value = c.value, a.value
Expand All @@ -2074,11 +2082,11 @@ class Array(T)
protected def self.partition_for_quick_sort!(a, n, comp)
v, l, r = a[n / 2], a + 1, a + n - 1
loop do
while l < a + n && comp.call(l.value, v) < 0
while l < a + n && cmp(l.value, v, comp) < 0
l += 1
end
r -= 1
while r >= a && comp.call(v, r.value) < 0
while r >= a && cmp(v, r.value, comp) < 0
r -= 1
end
return l unless l < r
Expand All @@ -2092,14 +2100,26 @@ class Array(T)
l = a + i
v = l.value
p = l - 1
while l > a && comp.call(v, p.value) < 0
while l > a && cmp(v, p.value, comp) < 0
l.value = p.value
l, p = p, p - 1
end
l.value = v
end
end

protected def self.cmp(v1, v2)
v = v1 <=> v2
raise ArgumentError.new("Comparison of #{v1} and #{v2} failed") if v.nil?
v
end

protected def self.cmp(v1, v2, block)
v = block.call(v1, v2)
raise ArgumentError.new("Comparison of #{v1} and #{v2} failed") if v.nil?
v
end

protected def to_lookup_hash
to_lookup_hash { |elem| elem }
end
Expand Down
Loading