Skip to content

Commit

Permalink
Let Array#sort only use <=>, and let <=> return nil for partial…
Browse files Browse the repository at this point in the history
… comparability.

- Float <=> Float now will return `nil` for NaN
- Removed PartialComparable
  • Loading branch information
asterite committed Aug 26, 2018
1 parent 720928f commit 1fe21d4
Show file tree
Hide file tree
Showing 12 changed files with 262 additions and 110 deletions.
4 changes: 4 additions & 0 deletions spec/compiler/macro/macro_methods_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,10 @@ module Crystal
assert_macro "", "{{1 <=> -1}}", [] of ASTNode, "1"
end

it "executes <=> (returns nil)" do
assert_macro "", "{{0.0/0.0 <=> -1}}", [] of ASTNode, "nil"
end

it "executes +" do
assert_macro "", "{{1 + 2}}", [] of ASTNode, "3"
end
Expand Down
75 changes: 75 additions & 0 deletions spec/std/array_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,19 @@ private class BadSortingClass
end
end

private class Spaceship
getter value : Float64

def initialize(@value : Float64, @return_nil = false)
end

def <=>(other : Spaceship)
return nil if @return_nil

value <=> other.value
end
end

describe "Array" do
describe "new" do
it "creates with default value" do
Expand Down Expand Up @@ -1017,6 +1030,37 @@ describe "Array" do
[1, 2, 3].sort { 1 }
Array.new(10) { BadSortingClass.new }.sort
end

it "can sort just by using <=> (#6608)" do
spaceships = [
Spaceship.new(2),
Spaceship.new(0),
Spaceship.new(1),
Spaceship.new(3),
]

sorted = spaceships.sort
4.times do |i|
sorted[i].value.should eq(i)
end
end

it "raises if <=> returns nil" do
spaceships = [
Spaceship.new(2, return_nil: true),
Spaceship.new(0, return_nil: true),
]

expect_raises(ArgumentError) do
spaceships.sort
end
end

it "raises if sort block returns nil" do
expect_raises(ArgumentError) do
[1, 2].sort { nil }
end
end
end

describe "sort!" do
Expand All @@ -1037,6 +1081,37 @@ describe "Array" do
b = a.sort { -1 }
a.should eq(b)
end

it "can sort! just by using <=> (#6608)" do
spaceships = [
Spaceship.new(2),
Spaceship.new(0),
Spaceship.new(1),
Spaceship.new(3),
]

spaceships.sort!
4.times do |i|
spaceships[i].value.should eq(i)
end
end

it "raises if <=> returns nil" do
spaceships = [
Spaceship.new(2, return_nil: true),
Spaceship.new(0, return_nil: true),
]

expect_raises(ArgumentError) do
spaceships.sort!
end
end

it "raises if sort! block returns nil" do
expect_raises(ArgumentError) do
[1, 2].sort! { nil }
end
end
end

describe "sort_by" do
Expand Down
27 changes: 26 additions & 1 deletion spec/std/comparable_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ require "spec"
private class ComparableTestClass
include Comparable(Int)

def initialize(@value : Int32)
def initialize(@value : Int32, @return_nil = false)
end

def <=>(other : Int)
return nil if @return_nil

@value <=> other
end
end
Expand All @@ -15,7 +17,30 @@ describe Comparable do
it "can compare against Int (#2461)" do
obj = ComparableTestClass.new(4)
(obj == 3).should be_false
(obj == 4).should be_true

(obj < 3).should be_false
(obj < 4).should be_false

(obj > 3).should be_true
(obj > 4).should be_false

(obj <= 3).should be_false
(obj <= 4).should be_true
(obj <= 5).should be_true

(obj >= 3).should be_true
(obj >= 4).should be_true
(obj >= 5).should be_false
end

it "checks for nil" do
obj = ComparableTestClass.new(4, return_nil: true)

(obj < 1).should be_false
(obj <= 1).should be_false
(obj == 1).should be_false
(obj >= 1).should be_false
(obj > 1).should be_false
end
end
50 changes: 50 additions & 0 deletions spec/std/float_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -271,4 +271,54 @@ describe "Float" do
Float64::EPSILON.unsafe_as(UInt64).should eq 0x3cb0000000000000_u64
Float64::MIN_POSITIVE.unsafe_as(UInt64).should eq 0x0010000000000000_u64
end

it "returns nil in <=> for NaN values (Float32)" do
nan = Float32::NAN

(1_f32 <=> nan).should be_nil
(1_f64 <=> nan).should be_nil

(1_u8 <=> nan).should be_nil
(1_u16 <=> nan).should be_nil
(1_u32 <=> nan).should be_nil
(1_u64 <=> nan).should be_nil
(1_i8 <=> nan).should be_nil
(1_i16 <=> nan).should be_nil
(1_i32 <=> nan).should be_nil
(1_i64 <=> nan).should be_nil

(nan <=> 1_u8).should be_nil
(nan <=> 1_u16).should be_nil
(nan <=> 1_u32).should be_nil
(nan <=> 1_u64).should be_nil
(nan <=> 1_i8).should be_nil
(nan <=> 1_i16).should be_nil
(nan <=> 1_i32).should be_nil
(nan <=> 1_i64).should be_nil
end

it "returns nil in <=> for NaN values (Float64)" do
nan = Float64::NAN

(1_f32 <=> nan).should be_nil
(1_f64 <=> nan).should be_nil

(1_u8 <=> nan).should be_nil
(1_u16 <=> nan).should be_nil
(1_u32 <=> nan).should be_nil
(1_u64 <=> nan).should be_nil
(1_i8 <=> nan).should be_nil
(1_i16 <=> nan).should be_nil
(1_i32 <=> nan).should be_nil
(1_i64 <=> nan).should be_nil

(nan <=> 1_u8).should be_nil
(nan <=> 1_u16).should be_nil
(nan <=> 1_u32).should be_nil
(nan <=> 1_u64).should be_nil
(nan <=> 1_i8).should be_nil
(nan <=> 1_i16).should be_nil
(nan <=> 1_i32).should be_nil
(nan <=> 1_i64).should be_nil
end
end
76 changes: 52 additions & 24 deletions src/array.cr
Original file line number Diff line number Diff line change
Expand Up @@ -1561,7 +1561,15 @@ class Array(T)
# b # => [3, 2, 1]
# a # => [3, 1, 2]
# ```
def sort(&block : T, T -> Int32) : Array(T)
def sort(&block : T, T -> U) : Array(T) forall U
# TODO: use a better way to check U < Int32?
{% begin %}
{% block_type = U.union? ? U.union_types.first { |t| t != Nil } : U %}
{% if block_type != Int32 && block_type != Nil %}
{% raise "expected block to return Int32 or Nil, not #{U}" %}
{% end %}
{% end %}

dup.sort! &block
end

Expand Down Expand Up @@ -1591,7 +1599,15 @@ class Array(T)
# a.sort! { |a, b| b <=> a }
# a # => [3, 2, 1]
# ```
def sort!(&block : T, T -> Int32) : Array(T)
def sort!(&block : T, T -> U) : Array(T) forall U
# TODO: use a better way to check U < Int32?
{% begin %}
{% block_type = U.union? ? U.union_types.first { |t| t != Nil } : U %}
{% if block_type != Int32 && block_type != Nil %}
{% raise "expected block to return Int32 or Nil, not #{U}" %}
{% end %}
{% end %}

Array.intro_sort!(@buffer, @size, block)
self
end
Expand Down Expand Up @@ -1869,14 +1885,14 @@ class Array(T)
v, c = a[p], p
while c < (n - 1) / 2
c = 2 * (c + 1)
c -= 1 if a[c] < a[c - 1]
break unless v <= a[c]
c -= 1 if cmp(a[c], a[c - 1]) < 0
break unless cmp(v, a[c]) <= 0
a[p] = a[c]
p = c
end
if n & 1 == 0 && c == n / 2 - 1
c = 2 * c + 1
if v < a[c]
if cmp(v, a[c]) < 0
a[p] = a[c]
p = c
end
Expand All @@ -1886,17 +1902,17 @@ class Array(T)

protected def self.center_median!(a, n)
b, c = a + n / 2, a + n - 1
if a.value <= b.value
if b.value <= c.value
if cmp(a.value, b.value) <= 0
if cmp(b.value, c.value) <= 0
return
elsif a.value <= c.value
elsif cmp(a.value, c.value) <= 0
b.value, c.value = c.value, b.value
else
a.value, b.value, c.value = c.value, a.value, b.value
end
elsif a.value <= c.value
elsif cmp(a.value, c.value) <= 0
a.value, b.value = b.value, a.value
elsif b.value <= c.value
elsif cmp(b.value, c.value) <= 0
a.value, b.value, c.value = b.value, c.value, a.value
else
a.value, c.value = c.value, a.value
Expand All @@ -1906,11 +1922,11 @@ class Array(T)
protected def self.partition_for_quick_sort!(a, n)
v, l, r = a[n / 2], a + 1, a + n - 1
loop do
while l.value < v
while cmp(l.value, v) < 0
l += 1
end
r -= 1
while v < r.value
while cmp(v, r.value) < 0
r -= 1
end
return l unless l < r
Expand All @@ -1924,7 +1940,7 @@ class Array(T)
l = a + i
v = l.value
p = l - 1
while l > a && v < p.value
while l > a && cmp(v, p.value) < 0
l.value = p.value
l, p = p, p - 1
end
Expand Down Expand Up @@ -1967,14 +1983,14 @@ class Array(T)
v, c = a[p], p
while c < (n - 1) / 2
c = 2 * (c + 1)
c -= 1 if comp.call(a[c], a[c - 1]) < 0
break unless comp.call(v, a[c]) <= 0
c -= 1 if cmp(a[c], a[c - 1], comp) < 0
break unless cmp(v, a[c], comp) <= 0
a[p] = a[c]
p = c
end
if n & 1 == 0 && c == n / 2 - 1
c = 2 * c + 1
if comp.call(v, a[c]) < 0
if cmp(v, a[c], comp) < 0
a[p] = a[c]
p = c
end
Expand All @@ -1984,17 +2000,17 @@ class Array(T)

protected def self.center_median!(a, n, comp)
b, c = a + n / 2, a + n - 1
if comp.call(a.value, b.value) <= 0
if comp.call(b.value, c.value) <= 0
if cmp(a.value, b.value, comp) <= 0
if cmp(b.value, c.value, comp) <= 0
return
elsif comp.call(a.value, c.value) <= 0
elsif cmp(a.value, c.value, comp) <= 0
b.value, c.value = c.value, b.value
else
a.value, b.value, c.value = c.value, a.value, b.value
end
elsif comp.call(a.value, c.value) <= 0
elsif cmp(a.value, c.value, comp) <= 0
a.value, b.value = b.value, a.value
elsif comp.call(b.value, c.value) <= 0
elsif cmp(b.value, c.value, comp) <= 0
a.value, b.value, c.value = b.value, c.value, a.value
else
a.value, c.value = c.value, a.value
Expand All @@ -2004,11 +2020,11 @@ class Array(T)
protected def self.partition_for_quick_sort!(a, n, comp)
v, l, r = a[n / 2], a + 1, a + n - 1
loop do
while l < a + n && comp.call(l.value, v) < 0
while l < a + n && cmp(l.value, v, comp) < 0
l += 1
end
r -= 1
while r >= a && comp.call(v, r.value) < 0
while r >= a && cmp(v, r.value, comp) < 0
r -= 1
end
return l unless l < r
Expand All @@ -2022,14 +2038,26 @@ class Array(T)
l = a + i
v = l.value
p = l - 1
while l > a && comp.call(v, p.value) < 0
while l > a && cmp(v, p.value, comp) < 0
l.value = p.value
l, p = p, p - 1
end
l.value = v
end
end

protected def self.cmp(v1, v2)
v = v1 <=> v2
raise ArgumentError.new("comparison of #{v1} and #{v2} failed") if v.nil?
v
end

protected def self.cmp(v1, v2, block)
v = block.call(v1, v2)
raise ArgumentError.new("comparison of #{v1} and #{v2} failed") if v.nil?
v
end

protected def to_lookup_hash
to_lookup_hash { |elem| elem }
end
Expand Down
Loading

0 comments on commit 1fe21d4

Please sign in to comment.