Skip to content

Commit

Permalink
create new tasks in the parent world
Browse files Browse the repository at this point in the history
N.B. This means serialized tasks will discard this stateful information
and pick up new/different information.

Closes #35690
Closes #41332
  • Loading branch information
vtjnash committed Jul 13, 2021
1 parent 2c91d7f commit ab9156c
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 15 deletions.
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ New language features
Language changes
----------------

* Newly created Task objects (`@spawn`, `@async`, etc.) now adopt the world-age for methods from their parent
Task upon creation, instead of using the global latest world at start. This is done to enable inference to
eventually optimize these calls. Places that wish for the old behavior may use `Base.invokelatest`. ([#41449])

Compiler/Runtime improvements
-----------------------------
Expand Down
5 changes: 3 additions & 2 deletions base/docs/basedocs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1507,8 +1507,9 @@ DomainError
"""
Task(func)
Create a `Task` (i.e. coroutine) to execute the given function `func` (which must be
callable with no arguments). The task exits when this function returns.
Create a `Task` (i.e. coroutine) to execute the given function `func` (which
must be callable with no arguments). The task exits when this function returns.
The task will run in the "world age" from the parent at construction when [`schedule`](@ref)d.
# Examples
```jldoctest
Expand Down
3 changes: 1 addition & 2 deletions src/task.c
Original file line number Diff line number Diff line change
Expand Up @@ -750,7 +750,7 @@ JL_DLLEXPORT jl_task_t *jl_new_task(jl_function_t *start, jl_value_t *completion
t->prio = -1;
t->tid = t->copy_stack ? ct->tid : -1; // copy_stacks are always pinned since they can't be moved
t->ptls = NULL;
t->world_age = 0;
t->world_age = ct->world_age;

#ifdef COPY_STACKS
if (!t->copy_stack) {
Expand Down Expand Up @@ -876,7 +876,6 @@ CFI_NORETURN
jl_sigint_safepoint(ptls);
}
JL_TIMING(ROOT);
ct->world_age = jl_world_counter;
res = jl_apply(&ct->start, 1);
}
JL_CATCH {
Expand Down
10 changes: 5 additions & 5 deletions stdlib/Distributed/src/process_messages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ function showerror(io::IO, re::RemoteException)
showerror(io, re.captured)
end

function run_work_thunk(thunk, print_error)
function run_work_thunk(thunk::Function, print_error::Bool)
local result
try
result = thunk()
Expand Down Expand Up @@ -271,11 +271,11 @@ function process_hdr(s, validate_cookie)
end

function handle_msg(msg::CallMsg{:call}, header, r_stream, w_stream, version)
schedule_call(header.response_oid, ()->msg.f(msg.args...; msg.kwargs...))
schedule_call(header.response_oid, ()->invokelatest(msg.f, msg.args...; msg.kwargs...))
end
function handle_msg(msg::CallMsg{:call_fetch}, header, r_stream, w_stream, version)
errormonitor(@async begin
v = run_work_thunk(()->msg.f(msg.args...; msg.kwargs...), false)
v = run_work_thunk(()->invokelatest(msg.f, msg.args...; msg.kwargs...), false)
if isa(v, SyncTake)
try
deliver_result(w_stream, :call_fetch, header.notify_oid, v.v)
Expand All @@ -291,14 +291,14 @@ end

function handle_msg(msg::CallWaitMsg, header, r_stream, w_stream, version)
errormonitor(@async begin
rv = schedule_call(header.response_oid, ()->msg.f(msg.args...; msg.kwargs...))
rv = schedule_call(header.response_oid, ()->invokelatest(msg.f, msg.args...; msg.kwargs...))
deliver_result(w_stream, :call_wait, header.notify_oid, fetch(rv.c))
nothing
end)
end

function handle_msg(msg::RemoteDoMsg, header, r_stream, w_stream, version)
errormonitor(@async run_work_thunk(()->msg.f(msg.args...; msg.kwargs...), true))
errormonitor(@async run_work_thunk(()->invokelatest(msg.f, msg.args...; msg.kwargs...), true))
end

function handle_msg(msg::ResultMsg, header, r_stream, w_stream, version)
Expand Down
5 changes: 1 addition & 4 deletions stdlib/Distributed/src/remotecall.jl
Original file line number Diff line number Diff line change
Expand Up @@ -348,10 +348,7 @@ end
# make a thunk to call f on args in a way that simulates what would happen if
# the function were sent elsewhere
function local_remotecall_thunk(f, args, kwargs)
if isempty(args) && isempty(kwargs)
return f
end
return ()->f(args...; kwargs...)
return ()->invokelatest(f, args...; kwargs...)
end

function remotecall(f, w::LocalProcess, args...; kwargs...)
Expand Down
20 changes: 18 additions & 2 deletions test/worlds.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,24 @@ end

g265() = [f265(x) for x in 1:3.]
wc265 = get_world_counter()
f265(::Any) = 1.0
@test wc265 + 1 == get_world_counter()
wc265_41332a = Task(tls_world_age)
@test tls_world_age() == wc265
(function ()
global wc265_41332b = Task(tls_world_age)
@eval f265(::Any) = 1.0
global wc265_41332c = Base.invokelatest(Task, tls_world_age)
global wc265_41332d = Task(tls_world_age)
nothing
end)()
@test wc265 + 2 == get_world_counter() == tls_world_age()
schedule(wc265_41332a)
schedule(wc265_41332b)
schedule(wc265_41332c)
schedule(wc265_41332d)
@test wc265 == fetch(wc265_41332a)
@test wc265 + 1 == fetch(wc265_41332b)
@test wc265 + 2 == fetch(wc265_41332c)
@test wc265 + 1 == fetch(wc265_41332d)
chnls, tasks = Base.channeled_tasks(2, wfunc)
t265 = tasks[1]

Expand Down

0 comments on commit ab9156c

Please sign in to comment.