Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow threadsafe access to buffer of type inference profiling trees #47615

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 90 additions & 57 deletions base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ being used for this purpose alone.
"""
module Timings

using Core.Compiler: -, +, :, Vector, length, first, empty!, push!, pop!, @inline,
@inbounds, copy, backtrace
using Core.Compiler: -, +, :, >, Vector, length, first, empty!, push!, pop!, @inline,
@inbounds, copy, backtrace, IdDict, Task, Ref, get!

# What we record for any given frame we infer during type inference.
struct InferenceFrameInfo
Expand All @@ -47,13 +47,16 @@ end

_typeinf_identifier(frame::InferenceFrameInfo) = frame

_typeinf_frame_linfo(frame::Core.Compiler.InferenceState) = frame.linfo
_typeinf_frame_linfo(frame::InferenceFrameInfo) = frame.mi

"""
Core.Compiler.Timing(mi_info, start_time, ...)
Core.Compiler.Timings.Timing(mi_info, start_time, ...)

Internal type containing the timing result for running type inference on a single
MethodInstance.
"""
struct Timing
mutable struct Timing
mi_info::InferenceFrameInfo
start_time::UInt64
cur_start_time::UInt64
Expand All @@ -66,24 +69,70 @@ Timing(mi_info, start_time) = Timing(mi_info, start_time, start_time, UInt64(0),

_time_ns() = ccall(:jl_hrtime, UInt64, ()) # Re-implemented here because Base not yet available.

# We keep a stack of the Timings for each of the MethodInstances currently being timed.
"""
Core.Compiler.Timings.clear_and_fetch_timings()

Return, then clear, the previously recorded type inference timings.

This fetches a vector of all of the type inference timings that have _finished_ as of this call. Note
that there may be concurrent invocations of inference that are still running in another thread, but
which haven't yet been added to this buffer. Those can be fetched in a future call.
"""
function clear_and_fetch_timings()
# Pass in the type, since the C code doesn't know about our Timing struct.
ccall(:jl_typeinf_profiling_clear_and_fetch, Any, (Any, Any,),
_finished_timings, Vector{Timing})::Vector{Timing}
end

function finish_timing_profile(timing::Timing)
ccall(:jl_typeinf_profiling_push_timing, Cvoid, (Any, Any,), _finished_timings, timing)
end

# DO NOT ACCESS DIRECTLY. This vector should only be accessed through the
# functions above. It is a buffer that lives in the Julia module only to be *rooted*
# for GC, but all accesses to the vector must go through C code, in order to be
# thread safe.
const _finished_timings = Timing[]

# We store a profiling stack for *each Task* as a task-local-storage variable, _timings.
# This is a stack of the Timings for each of the MethodInstances currently being timed.
# Since type inference currently operates via a depth-first search (during abstract
# evaluation), this vector operates like a call stack. The last node in _timings is the
# node currently being inferred, and its parent is directly before it, etc.
# Each Timing also contains its own vector for all of its children, so that the tree
# call structure through type inference is recorded. (It's recorded as a tree, not a graph,
# because we create a new node for duplicates.)
const _timings = Timing[]
# You will see this accessed below as `task_local_storage(:_timings)`

# ------- Task Local Storage -------
# Reimplementation of Task Local Storage, since these functions aren't available yet
# at this stage of bootstrapping.
current_task() = ccall(:jl_get_current_task, Ref{Task}, ())
task_local_storage() = get_task_tls(current_task())
function get_task_tls(t::Task)
if t.storage === nothing
t.storage = IdDict()
end
return (t.storage)::IdDict{Any,Any}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vchuravy / @pchintalapudi / @kpamnany: I finally got around to testing this, and indeed it doesn't work. :(

It turns out that Base.IdDict and Core.Compiler.IdDict are not the same thing. 😢

Does anyone know why that is? Why doesn't Base just reexport the IdDict from Core.Compiler?
I.e. here:

include("iddict.jl")

Why doesn't it just reexport the IdDict from Core.Compiler?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BUT anyway, the error we encounter is that we're not allowed to set the Task's .storage field to a Core.Compiler.IdDict; it's expected to be a Base.IdDict.

So is there just no way to use task-local storage from Core.Compiler?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly a more robust solution here, which might also last longer if we parallelize inference, would be to store this on the AbstractInterpreter or InferenceState objects? Does anyone know if that's plausible? We essentially want some object that lives for the lifetime of the inference invocation, and is local to this invocation, where we can store the stack of timers.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or we'd have to thread it through every function call, but that seems like possibly a nonstarter.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So is there just no way to use task-local storage from Core.Compiler?

Nobody needed this before it seems :)

Does anyone know why that is? Why doesn't Base just reexport the IdDict from Core.Compiler?

Likely to isolate the compiler code from Base.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can do:

if isdefined(Core, :Main)
  Core.Main.Base.get_task_tls(current_task())
end

to get around the bootstrapping issue.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is because the compiler is forbidden from accessing TLS. Since we hijack the current task to run codegen, it could corrupt the running task if it does anything to interact with it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(instead, you could perhaps replace the TLS on entry with one from Core.Compiler, and restore it on return)

end
# -------

tls_timings() = get!(task_local_storage(), :_timings, Vector{Timing}())

# ROOT() is an empty function used as the top-level Timing node to measure all time spent
# *not* in type inference during a given recording trace. It is used as a "dummy" node.
function ROOT() end
const ROOTmi = Core.Compiler.specialize_method(
first(Core.Compiler.methods(ROOT)), Tuple{typeof(ROOT)}, Core.svec())
NHDaly marked this conversation as resolved.
Show resolved Hide resolved

"""
Core.Compiler.reset_timings()
NHDaly marked this conversation as resolved.
Show resolved Hide resolved

Empty out the previously recorded type inference timings (`Core.Compiler._timings`), and
start the ROOT() timer again. `ROOT()` measures all time spent _outside_ inference.

!!! info
This function is deprecated as of Julia 1.9; use [`clear_and_fetch_timings`](@ref) instead.
"""
function reset_timings()
empty!(_timings)
Expand All @@ -93,7 +142,6 @@ function reset_timings()
_time_ns()))
return nothing
end
reset_timings()

# (This is split into a function so that it can be called both in this module, at the top
# of `enter_new_timer()`, and once at the Very End of the operation, by whoever started
Expand All @@ -105,44 +153,32 @@ reset_timings()
parent_timer = _timings[end]
accum_time = stop_time - parent_timer.cur_start_time

# Add in accum_time ("modify" the immutable struct)
# Add in accum_time
@inbounds begin
_timings[end] = Timing(
parent_timer.mi_info,
parent_timer.start_time,
parent_timer.cur_start_time,
parent_timer.time + accum_time,
parent_timer.children,
parent_timer.bt,
)
_timings[end].time += accum_time
NHDaly marked this conversation as resolved.
Show resolved Hide resolved
end
return nothing
end

@inline function enter_new_timer(frame)
_timings = tls_timings()

# Very first thing, stop the active timer: get the current time and add in the
# time since it was last started to its aggregate exclusive time.
close_current_timer()

mi_info = _typeinf_identifier(frame)
if length(_timings) > 0
close_current_timer()
end

# Start the new timer right before returning
mi_info = _typeinf_identifier(frame)
push!(_timings, Timing(mi_info, UInt64(0)))
len = length(_timings)
new_timer = @inbounds _timings[len]
new_timer = @inbounds _timings[end]

# Set the current time _after_ appending the node, to try to exclude the
# overhead from measurement.
start = _time_ns()

@inbounds begin
_timings[len] = Timing(
new_timer.mi_info,
start,
start,
new_timer.time,
new_timer.children,
)
end
new_timer.start_time = start
new_timer.cur_start_time = start

return nothing
end
Expand All @@ -151,46 +187,43 @@ end
# assert that indeed we are always returning to a parent after finishing all of its
# children (that is, asserting that inference proceeds via depth-first-search).
@inline function exit_current_timer(_expected_frame_)
_timings = tls_timings()

# Finish the new timer
stop_time = _time_ns()

expected_mi_info = _typeinf_identifier(_expected_frame_)
expected_linfo = _typeinf_frame_linfo(_expected_frame_)

# Grab the new timer again because it might have been modified in _timings
# (since it's an immutable struct)
# And remove it from the current timings stack
new_timer = pop!(_timings)
Core.Compiler.@assert new_timer.mi_info.mi === expected_mi_info.mi
Core.Compiler.@assert new_timer.mi_info.mi === expected_linfo

# Prepare to unwind one level of the stack and record in the parent
parent_timer = _timings[end]
# check for two cases: normal case & backcompat case
is_profile_root_normal = length(_timings) === 0
is_profile_root_backcompat = length(_timings) === 1 && _timings[1] === ROOTmi
is_profile_root = is_profile_root_normal || is_profile_root_backcompat

accum_time = stop_time - new_timer.cur_start_time
# Add in accum_time ("modify" the immutable struct)
new_timer = Timing(
new_timer.mi_info,
new_timer.start_time,
new_timer.cur_start_time,
new_timer.time + accum_time,
new_timer.children,
parent_timer.mi_info.mi === ROOTmi ? backtrace() : nothing,
)
# Record the final timing with the original parent timer
push!(parent_timer.children, new_timer)

# And finally restart the parent timer:
len = length(_timings)
@inbounds begin
_timings[len] = Timing(
parent_timer.mi_info,
parent_timer.start_time,
_time_ns(),
parent_timer.time,
parent_timer.children,
parent_timer.bt,
)
new_timer.time += accum_time
if is_profile_root
new_timer.bt = backtrace()
end

# Prepare to unwind one level of the stack and record in the parent
if is_profile_root
finish_timing_profile(new_timer)
else
parent_timer = _timings[end]

# Record the final timing with the original parent timer
push!(parent_timer.children, new_timer)

# And finally restart the parent timer:
parent_timer.cur_start_time = _time_ns()
end
return nothing
end

Expand Down
1 change: 1 addition & 0 deletions doc/src/devdocs/locks.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ The following is a leaf lock (level 2), and only acquires level 1 locks (safepoi
> * Module->lock
> * JLDebuginfoPlugin::PluginMutex
> * newly_inferred_mutex
> * typeinf_profiling_lock

The following is a level 3 lock, which can only acquire level 1 or level 2 locks internally:

Expand Down
3 changes: 2 additions & 1 deletion src/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ endif
SRCS := \
jltypes gf typemap smallintset ast builtins module interpreter symbol \
dlload sys init task array staticdata toplevel jl_uv datatype \
simplevector runtime_intrinsics precompile jloptions \
simplevector runtime_intrinsics precompile jloptions inference-profiling \
threading partr stackwalk gc gc-debug gc-pages gc-stacks gc-alloc-profiler method \
jlapi signal-handling safepoint timing subtype rtutils gc-heap-snapshot \
crc32c APInt-C processor ircode opaque_closure codegen-stubs coverage runtime_ccall
Expand Down Expand Up @@ -300,6 +300,7 @@ $(BUILDDIR)/gc-pages.o $(BUILDDIR)/gc-pages.dbg.obj: $(SRCDIR)/gc.h
$(BUILDDIR)/gc.o $(BUILDDIR)/gc.dbg.obj: $(SRCDIR)/gc.h $(SRCDIR)/gc-heap-snapshot.h $(SRCDIR)/gc-alloc-profiler.h
$(BUILDDIR)/gc-heap-snapshot.o $(BUILDDIR)/gc-heap-snapshot.dbg.obj: $(SRCDIR)/gc.h $(SRCDIR)/gc-heap-snapshot.h
$(BUILDDIR)/gc-alloc-profiler.o $(BUILDDIR)/gc-alloc-profiler.dbg.obj: $(SRCDIR)/gc.h $(SRCDIR)/gc-alloc-profiler.h
$(BUILDDIR)/inference-profiling.o $(BUILDDIR)/inference-profiling.dbg.obj: $(SRCDIR)/gc.h
$(BUILDDIR)/init.o $(BUILDDIR)/init.dbg.obj: $(SRCDIR)/builtin_proto.h
$(BUILDDIR)/interpreter.o $(BUILDDIR)/interpreter.dbg.obj: $(SRCDIR)/builtin_proto.h
$(BUILDDIR)/jitlayers.o $(BUILDDIR)/jitlayers.dbg.obj: $(SRCDIR)/jitlayers.h $(SRCDIR)/llvm-codegen-shared.h
Expand Down
44 changes: 44 additions & 0 deletions src/inference-profiling.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Implementation for type inference profiling

#include "julia.h"
#include "julia_internal.h"

jl_mutex_t typeinf_profiling_lock;

// == exported interface ==

JL_DLLEXPORT jl_array_t* jl_typeinf_profiling_clear_and_fetch(
jl_array_t *inference_profiling_results_array,
jl_value_t *array_timing_type
)
{
JL_LOCK(&typeinf_profiling_lock);

size_t len = jl_array_len(inference_profiling_results_array);

jl_array_t *out = jl_alloc_array_1d(array_timing_type, len);
JL_GC_PUSH1(&out);

memcpy(out->data, inference_profiling_results_array->data, len * sizeof(void*));

jl_array_del_end(inference_profiling_results_array, len);

JL_UNLOCK(&typeinf_profiling_lock);

JL_GC_POP();
return out;
}

JL_DLLEXPORT void jl_typeinf_profiling_push_timing(
jl_array_t *inference_profiling_results_array,
jl_value_t *timing
)
{
JL_LOCK(&typeinf_profiling_lock);

jl_array_grow_end(inference_profiling_results_array, 1);
size_t len = jl_array_len(inference_profiling_results_array);
jl_array_ptr_set(inference_profiling_results_array, len - 1, timing);

JL_UNLOCK(&typeinf_profiling_lock);
}
2 changes: 2 additions & 0 deletions src/jl_exported_funcs.inc
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,8 @@
XX(jl_typeinf_lock_end) \
XX(jl_typeinf_timing_begin) \
XX(jl_typeinf_timing_end) \
XX(jl_typeinf_profiling_clear_and_fetch) \
XX(jl_typeinf_profiling_push_timing) \
XX(jl_typename_str) \
XX(jl_typeof_str) \
XX(jl_types_equal) \
Expand Down
29 changes: 16 additions & 13 deletions test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3804,15 +3804,19 @@ f37532(T, x) = (Core.bitcast(Ptr{T}, x); x)
# Helper functions for Core.Compiler.Timings. These are normally accessed via a package -
# usually (SnoopCompileCore).
function time_inference(f)
Core.Compiler.Timings.reset_timings()
Core.Compiler.__set_measure_typeinf(true)
f()
Core.Compiler.__set_measure_typeinf(false)
Core.Compiler.Timings.close_current_timer()
return Core.Compiler.Timings._timings[1]
return Core.Compiler.Timings.clear_and_fetch_timings()
end
function depth(t::Core.Compiler.Timings.Timing)
maximum(depth.(t.children), init=0) + 1
function max_depth(t::Vector{Core.Compiler.Timings.Timing})
maximum(max_depth.(t), init=0) + 1
end
function max_depth(t::Core.Compiler.Timings.Timing)
maximum(max_depth.(t.children), init=0) + 1
end
function flatten_times(t::Vector{Core.Compiler.Timings.Timing})
collect(Iterators.flatten(flatten_times(el) for el in t))
end
function flatten_times(t::Core.Compiler.Timings.Timing)
collect(Iterators.flatten([(t.time => t.mi_info,), flatten_times.(t.children)...]))
Expand All @@ -3829,14 +3833,14 @@ end
timing1 = time_inference() do
@eval M1.g(2, 3.0)
end
@test occursin(r"Core.Compiler.Timings.Timing\(InferenceFrameInfo for Core.Compiler.Timings.ROOT\(\)\) with \d+ children", sprint(show, timing1))
@test timing1 isa Vector{Core.Compiler.Timings.Timing}
# The last two functions to be inferred should be `i` and `i2`, inferred at runtime with
# their concrete types.
@test sort([mi_info.mi.def.name for (time,mi_info) in flatten_times(timing1)[end-1:end]]) == [:i, :i2]
@test all(child->isa(child.bt, Vector), timing1.children)
@test all(child->child.bt===nothing, timing1.children[1].children)
@test all(child->isa(child.bt, Vector), timing1)
@test all(child->child.bt===nothing, timing1[1].children)
# Test the stacktrace
@test isa(stacktrace(timing1.children[1].bt), Vector{Base.StackTraces.StackFrame})
@test isa(stacktrace(timing1[1].bt), Vector{Base.StackTraces.StackFrame})
# Test that inference has cached some of the Method Instances
timing2 = time_inference() do
@eval M1.g(2, 3.0)
Expand All @@ -3857,16 +3861,16 @@ end
end
end
end
@test occursin("thunk from $(@__MODULE__) starting at $(@__FILE__):$((@__LINE__) - 5)", string(timingmod.children))
@test occursin("thunk from $(@__MODULE__) starting at $(@__FILE__):$((@__LINE__) - 5)", string(timingmod))
# END LINE NUMBER SENSITIVITY

# Recursive function
@eval module _Recursive f(n::Integer) = n == 0 ? 0 : f(n-1) + 1 end
timing = time_inference() do
@eval _Recursive.f(Base.inferencebarrier(5))
end
@test 2 <= depth(timing) <= 3 # root -> f (-> +)
@test 2 <= length(flatten_times(timing)) <= 3 # root, f, +
@test 1 <= max_depth(timing) <= 2 # f (-> +)
@test 1 <= length(flatten_times(timing)) <= 2 # f, +

# Functions inferred with multiple constants
@eval module C
Expand Down Expand Up @@ -3894,7 +3898,6 @@ end
@test !isempty(ft)
str = sprint(show, ft)
@test occursin("InferenceFrameInfo for /(1::$Int, ::$Int)", str) # inference constants
@test occursin("InferenceFrameInfo for Core.Compiler.Timings.ROOT()", str) # qualified
# loopc has internal slots, check constant printing in this case
sel = filter(ti -> ti.second.mi.def.name === :loopc, ft)
ifi = sel[end].second
Expand Down