Skip to content

Commit

Permalink
port partr multiq to julia
Browse files Browse the repository at this point in the history
Direct translation, not necessarily fully idiomatic. In preparation for
future improvements.
  • Loading branch information
vtjnash committed Mar 17, 2022
1 parent b9b2a3c commit 837e3d3
Show file tree
Hide file tree
Showing 13 changed files with 227 additions and 233 deletions.
1 change: 1 addition & 0 deletions base/Base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@ include("condition.jl")
include("threads.jl")
include("lock.jl")
include("channels.jl")
include("partr.jl")
include("task.jl")
include("threads_overloads.jl")
include("weakkeydict.jl")
Expand Down
2 changes: 1 addition & 1 deletion base/boot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ primitive type Char <: AbstractChar 32 end
primitive type Int8 <: Signed 8 end
#primitive type UInt8 <: Unsigned 8 end
primitive type Int16 <: Signed 16 end
primitive type UInt16 <: Unsigned 16 end
#primitive type UInt16 <: Unsigned 16 end
#primitive type Int32 <: Signed 32 end
#primitive type UInt32 <: Unsigned 32 end
#primitive type Int64 <: Signed 64 end
Expand Down
164 changes: 164 additions & 0 deletions base/partr.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

module Partr

# a task heap
mutable struct taskheap
const lock::ReentrantLock
const tasks::Vector{Task}
@atomic ntasks::Int32
@atomic priority::UInt16
taskheap() = new(ReentrantLock(), Vector{Task}(undef, tasks_per_heap), zero(Int32), typemax(UInt16))
end

# multiqueue parameters
const heap_d = UInt32(8)
const heap_c = UInt32(2)

# size of each heap
const tasks_per_heap = Int32(65536) # TODO: this should be smaller by default, but growable!

# the multiqueue's heaps
global heaps::Vector{taskheap}
global heap_p::UInt32 = 0

# unbias state for the RNG
global cong_unbias::UInt32 = 0


cong(max::UInt32, unbias::UInt32) = ccall(:jl_rand_ptls, UInt32, (UInt32, UInt32), max, unbias) + UInt32(1)

function unbias_cong(max::UInt32)
return typemax(UInt32) - ((typemax(UInt32) % max) + UInt32(1))
end


function multiq_init(nthreads)
global heap_p = heap_c * nthreads
global cong_unbias = unbias_cong(UInt32(heap_p))
global heaps = Vector{taskheap}(undef, heap_p)
for i = UInt32(1):heap_p
heaps[i] = taskheap()
end
nothing
end


function sift_up(heap::taskheap, idx::Int32)
while idx > Int32(1)
parent = (idx - Int32(2)) ÷ heap_d + Int32(1)
if heap.tasks[idx].priority < heap.tasks[parent].priority
t = heap.tasks[parent]
heap.tasks[parent] = heap.tasks[idx]
heap.tasks[idx] = t
idx = parent
else
break
end
end
end


function sift_down(heap::taskheap, idx::Int32)
if idx <= heap.ntasks
for child = (heap_d * idx - heap_d + Int32(2)):(heap_d * idx + Int32(1))
child > tasks_per_heap && break
if isassigned(heap.tasks, child) &&
heap.tasks[child].priority < heap.tasks[idx].priority
t = heap.tasks[idx]
heap.tasks[idx] = heap.tasks[child]
heap.tasks[child] = t
sift_down(heap, child)
end
end
end
end


function multiq_insert(task::Task, priority::UInt16)
task.priority = priority

rn = cong(heap_p, cong_unbias)
while !trylock(heaps[rn].lock)
rn = cong(heap_p, cong_unbias)
end

if heaps[rn].ntasks >= tasks_per_heap
unlock(heaps[rn].lock)
# multiq insertion failed, increase #tasks per heap
return false
end

ntasks = heaps[rn].ntasks + Int32(1)
@atomic :monotonic heaps[rn].ntasks = ntasks
heaps[rn].tasks[ntasks] = task
sift_up(heaps[rn], ntasks)
priority = heaps[rn].priority
if task.priority < priority
@atomic :monotonic heaps[rn].priority = task.priority
end
unlock(heaps[rn].lock)
return true
end


function multiq_deletemin()
local rn1, rn2
local prio1, prio2

@label retry
GC.safepoint()
for i = UInt32(1):heap_p
if i == heap_p
return nothing
end
rn1 = cong(heap_p, cong_unbias)
rn2 = cong(heap_p, cong_unbias)
prio1 = heaps[rn1].priority
prio2 = heaps[rn2].priority
if prio1 > prio2
prio1 = prio2
rn1 = rn2
elseif prio1 == prio2 && prio1 == typemax(UInt16)
continue
end
if trylock(heaps[rn1].lock)
if prio1 == heaps[rn1].priority
break
end
unlock(heaps[rn1].lock)
end
end

task = heaps[rn1].tasks[1]
tid = Threads.threadid()
if ccall(:jl_set_task_tid, Cint, (Any, Cint), task, tid-1) == 0
unlock(heaps[rn1].lock)
@goto retry
end
ntasks = heaps[rn1].ntasks
@atomic :monotonic heaps[rn1].ntasks = ntasks - Int32(1)
heaps[rn1].tasks[1] = heaps[rn1].tasks[ntasks]
Base._unsetindex!(heaps[rn1].tasks, Int(ntasks))
prio1 = typemax(UInt16)
if ntasks > 1
sift_down(heaps[rn1], Int32(1))
prio1 = heaps[rn1].tasks[1].priority
end
@atomic :monotonic heaps[rn1].priority = prio1
unlock(heaps[rn1].lock)

return task
end


function multiq_check_empty()
for i = UInt32(1):heap_p
if heaps[i].ntasks != 0
return false
end
end
return true
end

end
40 changes: 24 additions & 16 deletions base/task.jl
Original file line number Diff line number Diff line change
Expand Up @@ -668,12 +668,14 @@ const StickyWorkqueue = InvasiveLinkedListSynchronized{Task}
global const Workqueues = [StickyWorkqueue()]
global const Workqueue = Workqueues[1] # default work queue is thread 1
function __preinit_threads__()
if length(Workqueues) < Threads.nthreads()
resize!(Workqueues, Threads.nthreads())
for i = 2:length(Workqueues)
nt = Threads.nthreads()
if length(Workqueues) < nt
resize!(Workqueues, nt)
for i = 2:nt
Workqueues[i] = StickyWorkqueue()
end
end
Partr.multiq_init(nt)
nothing
end

Expand All @@ -698,7 +700,7 @@ function enq_work(t::Task)
end
push!(Workqueues[tid], t)
else
if ccall(:jl_enqueue_task, Cint, (Any,), t) != 0
if !Partr.multiq_insert(t, t.priority)
# if multiq is full, give to a random thread (TODO fix)
if tid == 0
tid = mod(time_ns() % Int, Threads.nthreads()) + 1
Expand Down Expand Up @@ -864,24 +866,30 @@ function ensure_rescheduled(othertask::Task)
end

function trypoptask(W::StickyWorkqueue)
isempty(W) && return
t = popfirst!(W)
if t._state !== task_state_runnable
# assume this somehow got queued twice,
# probably broken now, but try discarding this switch and keep going
# can't throw here, because it's probably not the fault of the caller to wait
# and don't want to use print() here, because that may try to incur a task switch
ccall(:jl_safe_printf, Cvoid, (Ptr{UInt8}, Int32...),
"\nWARNING: Workqueue inconsistency detected: popfirst!(Workqueue).state != :runnable\n")
return
while !isempty(W)
t = popfirst!(W)
if t._state !== task_state_runnable
# assume this somehow got queued twice,
# probably broken now, but try discarding this switch and keep going
# can't throw here, because it's probably not the fault of the caller to wait
# and don't want to use print() here, because that may try to incur a task switch
ccall(:jl_safe_printf, Cvoid, (Ptr{UInt8}, Int32...),
"\nWARNING: Workqueue inconsistency detected: popfirst!(Workqueue).state != :runnable\n")
continue
end
return t
end
return t
return Partr.multiq_deletemin()
end

function checktaskempty()
return Partr.multiq_check_empty()
end

@noinline function poptask(W::StickyWorkqueue)
task = trypoptask(W)
if !(task isa Task)
task = ccall(:jl_task_get_next, Ref{Task}, (Any, Any), trypoptask, W)
task = ccall(:jl_task_get_next, Ref{Task}, (Any, Any, Any), trypoptask, W, checktaskempty)
end
set_next_task(task)
nothing
Expand Down
5 changes: 3 additions & 2 deletions src/builtins.c
Original file line number Diff line number Diff line change
Expand Up @@ -2021,10 +2021,11 @@ void jl_init_primitives(void) JL_GC_DISABLED

add_builtin("Bool", (jl_value_t*)jl_bool_type);
add_builtin("UInt8", (jl_value_t*)jl_uint8_type);
add_builtin("Int32", (jl_value_t*)jl_int32_type);
add_builtin("Int64", (jl_value_t*)jl_int64_type);
add_builtin("UInt16", (jl_value_t*)jl_uint16_type);
add_builtin("UInt32", (jl_value_t*)jl_uint32_type);
add_builtin("UInt64", (jl_value_t*)jl_uint64_type);
add_builtin("Int32", (jl_value_t*)jl_int32_type);
add_builtin("Int64", (jl_value_t*)jl_int64_type);
#ifdef _P64
add_builtin("Int", (jl_value_t*)jl_int64_type);
#else
Expand Down
4 changes: 0 additions & 4 deletions src/gc.c
Original file line number Diff line number Diff line change
Expand Up @@ -2820,7 +2820,6 @@ static void jl_gc_queue_thread_local(jl_gc_mark_cache_t *gc_cache, jl_gc_mark_sp
gc_mark_queue_obj(gc_cache, sp, ptls2->previous_exception);
}

void jl_gc_mark_enqueued_tasks(jl_gc_mark_cache_t *gc_cache, jl_gc_mark_sp_t *sp);
extern jl_value_t *cmpswap_names JL_GLOBALLY_ROOTED;

// mark the initial root set
Expand All @@ -2829,9 +2828,6 @@ static void mark_roots(jl_gc_mark_cache_t *gc_cache, jl_gc_mark_sp_t *sp)
// modules
gc_mark_queue_obj(gc_cache, sp, jl_main_module);

// tasks
jl_gc_mark_enqueued_tasks(gc_cache, sp);

// invisible builtin values
if (jl_an_empty_vec_any != NULL)
gc_mark_queue_obj(gc_cache, sp, jl_an_empty_vec_any);
Expand Down
6 changes: 3 additions & 3 deletions src/init.c
Original file line number Diff line number Diff line change
Expand Up @@ -773,7 +773,6 @@ static void post_boot_hooks(void)
jl_char_type = (jl_datatype_t*)core("Char");
jl_int8_type = (jl_datatype_t*)core("Int8");
jl_int16_type = (jl_datatype_t*)core("Int16");
jl_uint16_type = (jl_datatype_t*)core("UInt16");
jl_float16_type = (jl_datatype_t*)core("Float16");
jl_float32_type = (jl_datatype_t*)core("Float32");
jl_float64_type = (jl_datatype_t*)core("Float64");
Expand All @@ -785,10 +784,11 @@ static void post_boot_hooks(void)

jl_bool_type->super = jl_integer_type;
jl_uint8_type->super = jl_unsigned_type;
jl_int32_type->super = jl_signed_type;
jl_int64_type->super = jl_signed_type;
jl_uint16_type->super = jl_unsigned_type;
jl_uint32_type->super = jl_unsigned_type;
jl_uint64_type->super = jl_unsigned_type;
jl_int32_type->super = jl_signed_type;
jl_int64_type->super = jl_signed_type;

jl_errorexception_type = (jl_datatype_t*)core("ErrorException");
jl_stackovf_exception = jl_new_struct_uninit((jl_datatype_t*)core("StackOverflowError"));
Expand Down
1 change: 0 additions & 1 deletion src/jl_exported_funcs.inc
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@
XX(jl_egal__bits) \
XX(jl_egal__special) \
XX(jl_eh_restore_state) \
XX(jl_enqueue_task) \
XX(jl_enter_handler) \
XX(jl_enter_threaded_region) \
XX(jl_environ) \
Expand Down
14 changes: 9 additions & 5 deletions src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -2144,6 +2144,8 @@ void jl_init_types(void) JL_GC_DISABLED
jl_any_type, jl_emptysvec, 64);
jl_uint8_type = jl_new_primitivetype((jl_value_t*)jl_symbol("UInt8"), core,
jl_any_type, jl_emptysvec, 8);
jl_uint16_type = jl_new_primitivetype((jl_value_t*)jl_symbol("UInt16"), core,
jl_any_type, jl_emptysvec, 16);

jl_ssavalue_type = jl_new_datatype(jl_symbol("SSAValue"), core, jl_any_type, jl_emptysvec,
jl_perm_symsvec(1, "id"),
Expand Down Expand Up @@ -2508,7 +2510,7 @@ void jl_init_types(void) JL_GC_DISABLED
"inferred",
//"edges",
//"absolute_max",
"ipo_purity_bits", "purity_bits",
"ipo_purity_bits", "purity_bits",
"argescapes",
"isspecsig", "precompile", "invoke", "specptr", // function object decls
"relocatability"),
Expand Down Expand Up @@ -2602,7 +2604,7 @@ void jl_init_types(void) JL_GC_DISABLED
NULL,
jl_any_type,
jl_emptysvec,
jl_perm_symsvec(14,
jl_perm_symsvec(15,
"next",
"queue",
"storage",
Expand All @@ -2616,8 +2618,9 @@ void jl_init_types(void) JL_GC_DISABLED
"rngState3",
"_state",
"sticky",
"_isexception"),
jl_svec(14,
"_isexception",
"priority"),
jl_svec(15,
jl_any_type,
jl_any_type,
jl_any_type,
Expand All @@ -2631,7 +2634,8 @@ void jl_init_types(void) JL_GC_DISABLED
jl_uint64_type,
jl_uint8_type,
jl_bool_type,
jl_bool_type),
jl_bool_type,
jl_uint16_type),
jl_emptysvec,
0, 1, 6);
jl_value_t *listt = jl_new_struct(jl_uniontype_type, jl_task_type, jl_nothing_type);
Expand Down
4 changes: 2 additions & 2 deletions src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -1881,12 +1881,12 @@ typedef struct _jl_task_t {
_Atomic(uint8_t) _state;
uint8_t sticky; // record whether this Task can be migrated to a new thread
_Atomic(uint8_t) _isexception; // set if `result` is an exception to throw or that we exited with
// multiqueue priority
uint16_t priority;

// hidden state:
// id of owning thread - does not need to be defined until the task runs
_Atomic(int16_t) tid;
// multiqueue priority
int16_t prio;
// saved gc stack top for context switches
jl_gcframe_t *gcstack;
size_t world_age;
Expand Down
Loading

0 comments on commit 837e3d3

Please sign in to comment.