Skip to content

Commit

Permalink
port partr multiq to julia (#44653)
Browse files Browse the repository at this point in the history
Direct translation, not necessarily fully idiomatic. In preparation for
future improvements.
  • Loading branch information
vtjnash authored Mar 23, 2022
1 parent 62e0729 commit 6366f40
Show file tree
Hide file tree
Showing 13 changed files with 229 additions and 233 deletions.
1 change: 1 addition & 0 deletions base/Base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@ include("condition.jl")
include("threads.jl")
include("lock.jl")
include("channels.jl")
include("partr.jl")
include("task.jl")
include("threads_overloads.jl")
include("weakkeydict.jl")
Expand Down
2 changes: 1 addition & 1 deletion base/boot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ primitive type Char <: AbstractChar 32 end
primitive type Int8 <: Signed 8 end
#primitive type UInt8 <: Unsigned 8 end
primitive type Int16 <: Signed 16 end
primitive type UInt16 <: Unsigned 16 end
#primitive type UInt16 <: Unsigned 16 end
#primitive type Int32 <: Signed 32 end
#primitive type UInt32 <: Unsigned 32 end
#primitive type Int64 <: Signed 64 end
Expand Down
166 changes: 166 additions & 0 deletions base/partr.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

module Partr

using ..Threads: SpinLock

# a task heap
mutable struct taskheap
const lock::SpinLock
const tasks::Vector{Task}
@atomic ntasks::Int32
@atomic priority::UInt16
taskheap() = new(SpinLock(), Vector{Task}(undef, tasks_per_heap), zero(Int32), typemax(UInt16))
end

# multiqueue parameters
const heap_d = UInt32(8)
const heap_c = UInt32(2)

# size of each heap
const tasks_per_heap = Int32(65536) # TODO: this should be smaller by default, but growable!

# the multiqueue's heaps
global heaps::Vector{taskheap}
global heap_p::UInt32 = 0

# unbias state for the RNG
global cong_unbias::UInt32 = 0


cong(max::UInt32, unbias::UInt32) = ccall(:jl_rand_ptls, UInt32, (UInt32, UInt32), max, unbias) + UInt32(1)

function unbias_cong(max::UInt32)
return typemax(UInt32) - ((typemax(UInt32) % max) + UInt32(1))
end


function multiq_init(nthreads)
global heap_p = heap_c * nthreads
global cong_unbias = unbias_cong(UInt32(heap_p))
global heaps = Vector{taskheap}(undef, heap_p)
for i = UInt32(1):heap_p
heaps[i] = taskheap()
end
nothing
end


function sift_up(heap::taskheap, idx::Int32)
while idx > Int32(1)
parent = (idx - Int32(2)) ÷ heap_d + Int32(1)
if heap.tasks[idx].priority < heap.tasks[parent].priority
t = heap.tasks[parent]
heap.tasks[parent] = heap.tasks[idx]
heap.tasks[idx] = t
idx = parent
else
break
end
end
end


function sift_down(heap::taskheap, idx::Int32)
if idx <= heap.ntasks
for child = (heap_d * idx - heap_d + Int32(2)):(heap_d * idx + Int32(1))
child > tasks_per_heap && break
if isassigned(heap.tasks, child) &&
heap.tasks[child].priority < heap.tasks[idx].priority
t = heap.tasks[idx]
heap.tasks[idx] = heap.tasks[child]
heap.tasks[child] = t
sift_down(heap, child)
end
end
end
end


function multiq_insert(task::Task, priority::UInt16)
task.priority = priority

rn = cong(heap_p, cong_unbias)
while !trylock(heaps[rn].lock)
rn = cong(heap_p, cong_unbias)
end

if heaps[rn].ntasks >= tasks_per_heap
unlock(heaps[rn].lock)
# multiq insertion failed, increase #tasks per heap
return false
end

ntasks = heaps[rn].ntasks + Int32(1)
@atomic :monotonic heaps[rn].ntasks = ntasks
heaps[rn].tasks[ntasks] = task
sift_up(heaps[rn], ntasks)
priority = heaps[rn].priority
if task.priority < priority
@atomic :monotonic heaps[rn].priority = task.priority
end
unlock(heaps[rn].lock)
return true
end


function multiq_deletemin()
local rn1, rn2
local prio1, prio2

@label retry
GC.safepoint()
for i = UInt32(1):heap_p
if i == heap_p
return nothing
end
rn1 = cong(heap_p, cong_unbias)
rn2 = cong(heap_p, cong_unbias)
prio1 = heaps[rn1].priority
prio2 = heaps[rn2].priority
if prio1 > prio2
prio1 = prio2
rn1 = rn2
elseif prio1 == prio2 && prio1 == typemax(UInt16)
continue
end
if trylock(heaps[rn1].lock)
if prio1 == heaps[rn1].priority
break
end
unlock(heaps[rn1].lock)
end
end

task = heaps[rn1].tasks[1]
tid = Threads.threadid()
if ccall(:jl_set_task_tid, Cint, (Any, Cint), task, tid-1) == 0
unlock(heaps[rn1].lock)
@goto retry
end
ntasks = heaps[rn1].ntasks
@atomic :monotonic heaps[rn1].ntasks = ntasks - Int32(1)
heaps[rn1].tasks[1] = heaps[rn1].tasks[ntasks]
Base._unsetindex!(heaps[rn1].tasks, Int(ntasks))
prio1 = typemax(UInt16)
if ntasks > 1
sift_down(heaps[rn1], Int32(1))
prio1 = heaps[rn1].tasks[1].priority
end
@atomic :monotonic heaps[rn1].priority = prio1
unlock(heaps[rn1].lock)

return task
end


function multiq_check_empty()
for i = UInt32(1):heap_p
if heaps[i].ntasks != 0
return false
end
end
return true
end

end
40 changes: 24 additions & 16 deletions base/task.jl
Original file line number Diff line number Diff line change
Expand Up @@ -711,12 +711,14 @@ const StickyWorkqueue = InvasiveLinkedListSynchronized{Task}
global const Workqueues = [StickyWorkqueue()]
global const Workqueue = Workqueues[1] # default work queue is thread 1
function __preinit_threads__()
if length(Workqueues) < Threads.nthreads()
resize!(Workqueues, Threads.nthreads())
for i = 2:length(Workqueues)
nt = Threads.nthreads()
if length(Workqueues) < nt
resize!(Workqueues, nt)
for i = 2:nt
Workqueues[i] = StickyWorkqueue()
end
end
Partr.multiq_init(nt)
nothing
end

Expand All @@ -741,7 +743,7 @@ function enq_work(t::Task)
end
push!(Workqueues[tid], t)
else
if ccall(:jl_enqueue_task, Cint, (Any,), t) != 0
if !Partr.multiq_insert(t, t.priority)
# if multiq is full, give to a random thread (TODO fix)
if tid == 0
tid = mod(time_ns() % Int, Threads.nthreads()) + 1
Expand Down Expand Up @@ -907,24 +909,30 @@ function ensure_rescheduled(othertask::Task)
end

function trypoptask(W::StickyWorkqueue)
isempty(W) && return
t = popfirst!(W)
if t._state !== task_state_runnable
# assume this somehow got queued twice,
# probably broken now, but try discarding this switch and keep going
# can't throw here, because it's probably not the fault of the caller to wait
# and don't want to use print() here, because that may try to incur a task switch
ccall(:jl_safe_printf, Cvoid, (Ptr{UInt8}, Int32...),
"\nWARNING: Workqueue inconsistency detected: popfirst!(Workqueue).state != :runnable\n")
return
while !isempty(W)
t = popfirst!(W)
if t._state !== task_state_runnable
# assume this somehow got queued twice,
# probably broken now, but try discarding this switch and keep going
# can't throw here, because it's probably not the fault of the caller to wait
# and don't want to use print() here, because that may try to incur a task switch
ccall(:jl_safe_printf, Cvoid, (Ptr{UInt8}, Int32...),
"\nWARNING: Workqueue inconsistency detected: popfirst!(Workqueue).state != :runnable\n")
continue
end
return t
end
return t
return Partr.multiq_deletemin()
end

function checktaskempty()
return Partr.multiq_check_empty()
end

@noinline function poptask(W::StickyWorkqueue)
task = trypoptask(W)
if !(task isa Task)
task = ccall(:jl_task_get_next, Ref{Task}, (Any, Any), trypoptask, W)
task = ccall(:jl_task_get_next, Ref{Task}, (Any, Any, Any), trypoptask, W, checktaskempty)
end
set_next_task(task)
nothing
Expand Down
5 changes: 3 additions & 2 deletions src/builtins.c
Original file line number Diff line number Diff line change
Expand Up @@ -2021,10 +2021,11 @@ void jl_init_primitives(void) JL_GC_DISABLED

add_builtin("Bool", (jl_value_t*)jl_bool_type);
add_builtin("UInt8", (jl_value_t*)jl_uint8_type);
add_builtin("Int32", (jl_value_t*)jl_int32_type);
add_builtin("Int64", (jl_value_t*)jl_int64_type);
add_builtin("UInt16", (jl_value_t*)jl_uint16_type);
add_builtin("UInt32", (jl_value_t*)jl_uint32_type);
add_builtin("UInt64", (jl_value_t*)jl_uint64_type);
add_builtin("Int32", (jl_value_t*)jl_int32_type);
add_builtin("Int64", (jl_value_t*)jl_int64_type);
#ifdef _P64
add_builtin("Int", (jl_value_t*)jl_int64_type);
#else
Expand Down
4 changes: 0 additions & 4 deletions src/gc.c
Original file line number Diff line number Diff line change
Expand Up @@ -2824,7 +2824,6 @@ static void jl_gc_queue_thread_local(jl_gc_mark_cache_t *gc_cache, jl_gc_mark_sp
gc_mark_queue_obj(gc_cache, sp, ptls2->previous_exception);
}

void jl_gc_mark_enqueued_tasks(jl_gc_mark_cache_t *gc_cache, jl_gc_mark_sp_t *sp);
extern jl_value_t *cmpswap_names JL_GLOBALLY_ROOTED;

// mark the initial root set
Expand All @@ -2833,9 +2832,6 @@ static void mark_roots(jl_gc_mark_cache_t *gc_cache, jl_gc_mark_sp_t *sp)
// modules
gc_mark_queue_obj(gc_cache, sp, jl_main_module);

// tasks
jl_gc_mark_enqueued_tasks(gc_cache, sp);

// invisible builtin values
if (jl_an_empty_vec_any != NULL)
gc_mark_queue_obj(gc_cache, sp, jl_an_empty_vec_any);
Expand Down
6 changes: 3 additions & 3 deletions src/init.c
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,6 @@ static void post_boot_hooks(void)
jl_char_type = (jl_datatype_t*)core("Char");
jl_int8_type = (jl_datatype_t*)core("Int8");
jl_int16_type = (jl_datatype_t*)core("Int16");
jl_uint16_type = (jl_datatype_t*)core("UInt16");
jl_float16_type = (jl_datatype_t*)core("Float16");
jl_float32_type = (jl_datatype_t*)core("Float32");
jl_float64_type = (jl_datatype_t*)core("Float64");
Expand All @@ -792,10 +791,11 @@ static void post_boot_hooks(void)

jl_bool_type->super = jl_integer_type;
jl_uint8_type->super = jl_unsigned_type;
jl_int32_type->super = jl_signed_type;
jl_int64_type->super = jl_signed_type;
jl_uint16_type->super = jl_unsigned_type;
jl_uint32_type->super = jl_unsigned_type;
jl_uint64_type->super = jl_unsigned_type;
jl_int32_type->super = jl_signed_type;
jl_int64_type->super = jl_signed_type;

jl_errorexception_type = (jl_datatype_t*)core("ErrorException");
jl_stackovf_exception = jl_new_struct_uninit((jl_datatype_t*)core("StackOverflowError"));
Expand Down
1 change: 0 additions & 1 deletion src/jl_exported_funcs.inc
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@
XX(jl_egal__bits) \
XX(jl_egal__special) \
XX(jl_eh_restore_state) \
XX(jl_enqueue_task) \
XX(jl_enter_handler) \
XX(jl_enter_threaded_region) \
XX(jl_environ) \
Expand Down
14 changes: 9 additions & 5 deletions src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -2152,6 +2152,8 @@ void jl_init_types(void) JL_GC_DISABLED
jl_any_type, jl_emptysvec, 64);
jl_uint8_type = jl_new_primitivetype((jl_value_t*)jl_symbol("UInt8"), core,
jl_any_type, jl_emptysvec, 8);
jl_uint16_type = jl_new_primitivetype((jl_value_t*)jl_symbol("UInt16"), core,
jl_any_type, jl_emptysvec, 16);

jl_ssavalue_type = jl_new_datatype(jl_symbol("SSAValue"), core, jl_any_type, jl_emptysvec,
jl_perm_symsvec(1, "id"),
Expand Down Expand Up @@ -2516,7 +2518,7 @@ void jl_init_types(void) JL_GC_DISABLED
"inferred",
//"edges",
//"absolute_max",
"ipo_purity_bits", "purity_bits",
"ipo_purity_bits", "purity_bits",
"argescapes",
"isspecsig", "precompile", "invoke", "specptr", // function object decls
"relocatability"),
Expand Down Expand Up @@ -2610,7 +2612,7 @@ void jl_init_types(void) JL_GC_DISABLED
NULL,
jl_any_type,
jl_emptysvec,
jl_perm_symsvec(14,
jl_perm_symsvec(15,
"next",
"queue",
"storage",
Expand All @@ -2624,8 +2626,9 @@ void jl_init_types(void) JL_GC_DISABLED
"rngState3",
"_state",
"sticky",
"_isexception"),
jl_svec(14,
"_isexception",
"priority"),
jl_svec(15,
jl_any_type,
jl_any_type,
jl_any_type,
Expand All @@ -2639,7 +2642,8 @@ void jl_init_types(void) JL_GC_DISABLED
jl_uint64_type,
jl_uint8_type,
jl_bool_type,
jl_bool_type),
jl_bool_type,
jl_uint16_type),
jl_emptysvec,
0, 1, 6);
jl_value_t *listt = jl_new_struct(jl_uniontype_type, jl_task_type, jl_nothing_type);
Expand Down
4 changes: 2 additions & 2 deletions src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -1881,12 +1881,12 @@ typedef struct _jl_task_t {
_Atomic(uint8_t) _state;
uint8_t sticky; // record whether this Task can be migrated to a new thread
_Atomic(uint8_t) _isexception; // set if `result` is an exception to throw or that we exited with
// multiqueue priority
uint16_t priority;

// hidden state:
// id of owning thread - does not need to be defined until the task runs
_Atomic(int16_t) tid;
// multiqueue priority
int16_t prio;
// saved gc stack top for context switches
jl_gcframe_t *gcstack;
size_t world_age;
Expand Down
Loading

0 comments on commit 6366f40

Please sign in to comment.