Skip to content

Commit

Permalink
Support Atomic(T)#compare_and_set when T is a reference union (#1…
Browse files Browse the repository at this point in the history
  • Loading branch information
HertzDevil authored Jun 24, 2023
1 parent 8689c06 commit d2add86
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 77 deletions.
152 changes: 95 additions & 57 deletions spec/std/atomic_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -15,64 +15,88 @@ enum AtomicEnumFlags
end

describe Atomic do
it "compares and sets with integer" do
atomic = Atomic.new(1)
describe "#compare_and_set" do
it "with integer" do
atomic = Atomic.new(1)

atomic.compare_and_set(2, 3).should eq({1, false})
atomic.get.should eq(1)
atomic.compare_and_set(2, 3).should eq({1, false})
atomic.get.should eq(1)

atomic.compare_and_set(1, 3).should eq({1, true})
atomic.get.should eq(3)
end
atomic.compare_and_set(1, 3).should eq({1, true})
atomic.get.should eq(3)
end

it "compares and set with enum" do
atomic = Atomic(AtomicEnum).new(AtomicEnum::One)
it "with enum" do
atomic = Atomic(AtomicEnum).new(AtomicEnum::One)

atomic.compare_and_set(AtomicEnum::Two, AtomicEnum::Three).should eq({AtomicEnum::One, false})
atomic.get.should eq(AtomicEnum::One)
atomic.compare_and_set(AtomicEnum::Two, AtomicEnum::Three).should eq({AtomicEnum::One, false})
atomic.get.should eq(AtomicEnum::One)

atomic.compare_and_set(AtomicEnum::One, AtomicEnum::Three).should eq({AtomicEnum::One, true})
atomic.get.should eq(AtomicEnum::Three)
end
atomic.compare_and_set(AtomicEnum::One, AtomicEnum::Three).should eq({AtomicEnum::One, true})
atomic.get.should eq(AtomicEnum::Three)
end

it "compares and set with flags enum" do
atomic = Atomic(AtomicEnumFlags).new(AtomicEnumFlags::One)
it "with flags enum" do
atomic = Atomic(AtomicEnumFlags).new(AtomicEnumFlags::One)

atomic.compare_and_set(AtomicEnumFlags::Two, AtomicEnumFlags::Three).should eq({AtomicEnumFlags::One, false})
atomic.get.should eq(AtomicEnumFlags::One)
atomic.compare_and_set(AtomicEnumFlags::Two, AtomicEnumFlags::Three).should eq({AtomicEnumFlags::One, false})
atomic.get.should eq(AtomicEnumFlags::One)

atomic.compare_and_set(AtomicEnumFlags::One, AtomicEnumFlags::Three).should eq({AtomicEnumFlags::One, true})
atomic.get.should eq(AtomicEnumFlags::Three)
end
atomic.compare_and_set(AtomicEnumFlags::One, AtomicEnumFlags::Three).should eq({AtomicEnumFlags::One, true})
atomic.get.should eq(AtomicEnumFlags::Three)
end

it "compares and sets with nilable type" do
atomic = Atomic(String?).new(nil)
string = "hello"
it "with nilable reference" do
atomic = Atomic(String?).new(nil)
string = "hello"

atomic.compare_and_set(string, "foo").should eq({nil, false})
atomic.get.should be_nil
atomic.compare_and_set(string, "foo").should eq({nil, false})
atomic.get.should be_nil

atomic.compare_and_set(nil, string).should eq({nil, true})
atomic.get.should be(string)
atomic.compare_and_set(nil, string).should eq({nil, true})
atomic.get.should be(string)

atomic.compare_and_set(string, nil).should eq({string, true})
atomic.get.should be_nil
end
atomic.compare_and_set(string, nil).should eq({string, true})
atomic.get.should be_nil
end

it "with reference type" do
str1 = "hello"
str2 = "bye"

atomic = Atomic(String).new(str1)

atomic.compare_and_set(str2, "foo").should eq({str1, false})
atomic.get.should be(str1)

it "compares and sets with reference type" do
str1 = "hello"
str2 = "bye"
atomic.compare_and_set(str1, str2).should eq({str1, true})
atomic.get.should be(str2)

atomic = Atomic(String).new(str1)
atomic.compare_and_set(str2, str1).should eq({str2, true})
atomic.get.should be(str1)

atomic.compare_and_set(str2, "foo").should eq({str1, false})
atomic.get.should eq(str1)
atomic.compare_and_set(String.build(&.<< "bye"), str2).should eq({str1, false})
atomic.get.should be(str1)
end

atomic.compare_and_set(str1, str2).should eq({str1, true})
atomic.get.should be(str2)
it "with reference union" do
arr1 = [1]
arr2 = [""]

atomic.compare_and_set(str2, str1).should eq({str2, true})
atomic.get.should be(str1)
atomic = Atomic(Array(Int32) | Array(String)).new(arr1)

atomic.compare_and_set(arr2, ["foo"]).should eq({arr1, false})
atomic.get.should be(arr1)

atomic.compare_and_set(arr1, arr2).should eq({arr1, true})
atomic.get.should be(arr2)

atomic.compare_and_set(arr2, arr1).should eq({arr2, true})
atomic.get.should be(arr1)

atomic.compare_and_set([1], arr2).should eq({arr1, false})
atomic.get.should be(arr1)
end
end

it "#adds" do
Expand Down Expand Up @@ -185,26 +209,40 @@ describe Atomic do
atomic.get.should eq(2)
end

it "#swap" do
atomic = Atomic.new(1)
atomic.swap(2).should eq(1)
atomic.get.should eq(2)
end
describe "#swap" do
it "with integer" do
atomic = Atomic.new(1)
atomic.swap(2).should eq(1)
atomic.get.should eq(2)
end

it "#swap with Reference type" do
atomic = Atomic.new("hello")
atomic.swap("world").should eq("hello")
atomic.get.should eq("world")
end
it "with reference type" do
atomic = Atomic.new("hello")
atomic.swap("world").should eq("hello")
atomic.get.should eq("world")
end

it "#swap with nil" do
atomic = Atomic(String?).new(nil)
it "with nilable reference" do
atomic = Atomic(String?).new(nil)

atomic.swap("not nil").should eq(nil)
atomic.get.should eq("not nil")
atomic.swap("not nil").should eq(nil)
atomic.get.should eq("not nil")

atomic.swap(nil).should eq("not nil")
atomic.get.should eq(nil)
atomic.swap(nil).should eq("not nil")
atomic.get.should eq(nil)
end

it "with reference union" do
arr1 = [1]
arr2 = [""]
atomic = Atomic(Array(Int32) | Array(String)).new(arr1)

atomic.swap(arr2).should be(arr1)
atomic.get.should be(arr2)

atomic.swap(arr1).should be(arr2)
atomic.get.should be(arr1)
end
end
end

Expand Down
61 changes: 41 additions & 20 deletions src/atomic.cr
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,19 @@ require "llvm/enums/atomic"

# A value that may be updated atomically.
#
# Only primitive integer types, reference types or nilable reference types
# can be used with an Atomic type.
# If `T` is a non-union primitive integer type or enum type, all operations are
# supported. If `T` is a reference type, or a union type containing only
# reference types or `Nil`, then only `#compare_and_set`, `#swap`, `#set`,
# `#lazy_set`, `#get`, and `#lazy_get` are available.
struct Atomic(T)
# Creates an Atomic with the given initial value.
def initialize(@value : T)
{% if !T.union? && (T == Char || T < Int::Primitive || T < Enum) %}
# Support integer types, enum types, or char (because it's represented as an integer)
{% elsif T < Reference || (T.union? && T.union_types.all? { |t| t == Nil || t < Reference }) %}
{% elsif T.union_types.all? { |t| t == Nil || t < Reference } && T != Nil %}
# Support reference types, or union types with only nil or reference types
{% else %}
{{ raise "Can only create Atomic with primitive integer types, reference types or nilable reference types, not #{T}" }}
{% raise "Can only create Atomic with primitive integer types, reference types or nilable reference types, not #{T}" %}
{% end %}
end

Expand All @@ -21,6 +23,8 @@ struct Atomic(T)
# * if they are equal, sets the value to *new*, and returns `{old_value, true}`
# * if they are not equal the value remains the same, and returns `{old_value, false}`
#
# Reference types are compared by `#same?`, not `#==`.
#
# ```
# atomic = Atomic.new(1)
#
Expand All @@ -31,90 +35,97 @@ struct Atomic(T)
# atomic.get # => 3
# ```
def compare_and_set(cmp : T, new : T) : {T, Bool}
# Check if it's a nilable reference type
{% if T.union? && T.union_types.all? { |t| t == Nil || t < Reference } %}
# If so, use addresses because LLVM < 3.9 doesn't support cmpxchg with pointers
address, success = Ops.cmpxchg(pointerof(@value).as(LibC::SizeT*), LibC::SizeT.new(cmp.as(T).object_id), LibC::SizeT.new(new.as(T).object_id), :sequentially_consistent, :sequentially_consistent)
{address == 0 ? nil : Pointer(T).new(address).as(T), success}
# Check if it's a reference type
{% elsif T < Reference %}
# Use addresses again (but this can't return nil)
address, success = Ops.cmpxchg(pointerof(@value).as(LibC::SizeT*), LibC::SizeT.new(cmp.as(T).object_id), LibC::SizeT.new(new.as(T).object_id), :sequentially_consistent, :sequentially_consistent)
{Pointer(T).new(address).as(T), success}
{% else %}
# Otherwise, this is an integer type
Ops.cmpxchg(pointerof(@value), cmp, new, :sequentially_consistent, :sequentially_consistent)
{% end %}
Ops.cmpxchg(pointerof(@value), cmp.as(T), new.as(T), :sequentially_consistent, :sequentially_consistent)
end

# Performs `atomic_value &+= value`. Returns the old value.
#
# `T` cannot contain any reference types.
#
# ```
# atomic = Atomic.new(1)
# atomic.add(2) # => 1
# atomic.get # => 3
# ```
def add(value : T) : T
check_reference_type
Ops.atomicrmw(:add, pointerof(@value), value, :sequentially_consistent, false)
end

# Performs `atomic_value &-= value`. Returns the old value.
#
# `T` cannot contain any reference types.
#
# ```
# atomic = Atomic.new(9)
# atomic.sub(2) # => 9
# atomic.get # => 7
# ```
def sub(value : T) : T
check_reference_type
Ops.atomicrmw(:sub, pointerof(@value), value, :sequentially_consistent, false)
end

# Performs `atomic_value &= value`. Returns the old value.
#
# `T` cannot contain any reference types.
#
# ```
# atomic = Atomic.new(5)
# atomic.and(3) # => 5
# atomic.get # => 1
# ```
def and(value : T) : T
check_reference_type
Ops.atomicrmw(:and, pointerof(@value), value, :sequentially_consistent, false)
end

# Performs `atomic_value = ~(atomic_value & value)`. Returns the old value.
#
# `T` cannot contain any reference types.
#
# ```
# atomic = Atomic.new(5)
# atomic.nand(3) # => 5
# atomic.get # => -2
# ```
def nand(value : T) : T
check_reference_type
Ops.atomicrmw(:nand, pointerof(@value), value, :sequentially_consistent, false)
end

# Performs `atomic_value |= value`. Returns the old value.
#
# `T` cannot contain any reference types.
#
# ```
# atomic = Atomic.new(5)
# atomic.or(2) # => 5
# atomic.get # => 7
# ```
def or(value : T) : T
check_reference_type
Ops.atomicrmw(:or, pointerof(@value), value, :sequentially_consistent, false)
end

# Performs `atomic_value ^= value`. Returns the old value.
#
# `T` cannot contain any reference types.
#
# ```
# atomic = Atomic.new(5)
# atomic.xor(3) # => 5
# atomic.get # => 6
# ```
def xor(value : T) : T
check_reference_type
Ops.atomicrmw(:xor, pointerof(@value), value, :sequentially_consistent, false)
end

# Performs `atomic_value = {atomic_value, value}.max`. Returns the old value.
#
# `T` cannot contain any reference types.
#
# ```
# atomic = Atomic.new(5)
#
Expand All @@ -125,6 +136,7 @@ struct Atomic(T)
# atomic.get # => 10
# ```
def max(value : T)
check_reference_type
{% if T < Enum %}
if @value.value.is_a?(Int::Signed)
Ops.atomicrmw(:max, pointerof(@value), value, :sequentially_consistent, false)
Expand All @@ -140,6 +152,8 @@ struct Atomic(T)

# Performs `atomic_value = {atomic_value, value}.min`. Returns the old value.
#
# `T` cannot contain any reference types.
#
# ```
# atomic = Atomic.new(5)
#
Expand All @@ -150,6 +164,7 @@ struct Atomic(T)
# atomic.get # => 3
# ```
def min(value : T)
check_reference_type
{% if T < Enum %}
if @value.value.is_a?(Int::Signed)
Ops.atomicrmw(:min, pointerof(@value), value, :sequentially_consistent, false)
Expand All @@ -171,8 +186,8 @@ struct Atomic(T)
# atomic.get # => 10
# ```
def swap(value : T)
{% if T.union? && T.union_types.all? { |t| t == Nil || t < Reference } || T < Reference %}
address = Ops.atomicrmw(:xchg, pointerof(@value).as(LibC::SizeT*), LibC::SizeT.new(value.as(T).object_id), :sequentially_consistent, false)
{% if T.union_types.all? { |t| t == Nil || t < Reference } && T != Nil %}
address = Ops.atomicrmw(:xchg, pointerof(@value).as(LibC::SizeT*), LibC::SizeT.new(value.as(Void*).address), :sequentially_consistent, false)
Pointer(T).new(address).as(T)
{% else %}
Ops.atomicrmw(:xchg, pointerof(@value), value, :sequentially_consistent, false)
Expand Down Expand Up @@ -211,6 +226,12 @@ struct Atomic(T)
@value
end

private macro check_reference_type
{% if T.union_types.all? { |t| t == Nil || t < Reference } && T != Nil %}
{% raise "Cannot call `#{@type}##{@def.name}` as `#{T}` is a reference type" %}
{% end %}
end

# :nodoc:
module Ops
# Defines methods that directly map to LLVM instructions related to atomic operations.
Expand Down

0 comments on commit d2add86

Please sign in to comment.