Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP static dispatch for kwargs and structural fields #16580

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions base/boot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ export
# key types
Any, DataType, Vararg, ANY, NTuple,
Tuple, Type, TypeConstructor, TypeName, TypeVar, Union, Void,
SimpleVector, AbstractArray, DenseArray,
SimpleVector, AbstractArray, DenseArray, Struct,
# special objects
Function, LambdaInfo, Method, MethodTable, TypeMapEntry, TypeMapLevel,
Module, Symbol, Task, Array, WeakRef, VecElement,
Expand Down Expand Up @@ -332,15 +332,28 @@ end
atdoc = (str, expr) -> Expr(:escape, expr)
atdoc!(λ) = global atdoc = λ

immutable KwKeys{names}
end
function structdiff
end
function structmerge
end
function fieldname
end
function structadd
end
Struct(;args...) = args
module TopModule
# this defines the types that lowering expects to be defined in a (top) module
# that are usually inherited from Core, but could be defined custom for a module
using Core: Box, IntrinsicFunction, Builtin,
using Core: Box, IntrinsicFunction, Builtin, Struct, KwKeys,
arrayref, arrayset, arraysize,
_expr, _apply, typeassert, apply_type, svec, kwfunc
export Box, IntrinsicFunction, Builtin,
_expr, _apply, typeassert, apply_type, svec, kwfunc, fieldname,
struct, structdiff, structmerge, structadd
export Box, IntrinsicFunction, Builtin, Struct, KwKeys,
arrayref, arrayset, arraysize,
_expr, _apply, typeassert, apply_type, svec, kwfunc
_expr, _apply, typeassert, apply_type, svec, kwfunc, fieldname,
struct, structdiff, structmerge, structadd
end
using .TopModule
ccall(:jl_set_istopmod, Void, (Bool,), true)
94 changes: 94 additions & 0 deletions base/essentials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,97 @@ const (:) = Colon()
# For passing constants through type inference
immutable Val{T}
end

function sortedmerge(x,y)
ix,nx = 0, length(x)
iy,ny = 0, length(y)
z = Array(Symbol, nx+ny)
iz = 1
while ix+iy < nx+ny
c = if ix < nx && (iy < ny && x[ix+1] <= y[iy+1] || iy == ny)
ix += 1
x[ix]
else
iy += 1
y[iy]
end
if iz == 1 || z[iz-1] != c
z[iz] = c
iz += 1
end
end
resize!(z, iz-1)
(z...,)
end

@generated function Core.structmerge{xn,xT,yn,yT}(x::Core.Struct{xn,xT},y::Core.Struct{yn,yT})
names = sortedmerge(xn,yn)
fields = map(names) do name
if findfirst(xn, name) > 0
:(getfield(x, $(Expr(:quote, name))))
else
:(getfield(y, $(Expr(:quote, name))))
end
end
quote
Core.struct($names, $(fields...))
end
end

function Core.structmerge{xn,xT}(x::Core.Struct{xn,xT}, y)
kvs = collect(y)
sort!(kvs, 1, length(kvs), Base.Sort.InsertionSort, Base.Order.By(x -> x[1]))
names = sortedmerge(xn,map(first,kvs))
n = length(names)
values = Array(Any, n)
for (name,val) in y
idx = findfirst(names, name)
if idx > 0
values[idx] = val
end
end
for i = 1:n
if isdefined(x,names[i])
values[i] = getfield(x,names[i])
end
end
Core.struct(names, values...)
end
function Core.structadd{n,T}(x::Core.Struct{n,T}, y)
Core.structmerge(x, (y,))
end
function sorteddiff(x,y)
ix,xn = 1,length(x)
iy,yn = 1,length(y)
z = Array(Symbol, xn)
iz = 1
while ix <= xn
v = x[ix]
if iy > yn || v < y[iy]
z[iz] = v
iz += 1
ix += 1
elseif v === y[iy]
ix += 1
iy += 1
else
iy += 1
end
end
resize!(z, iz-1)
(z...,)
end

@generated function Core.structdiff{xn,xT,yn}(x::Core.Struct{xn,xT},::Core.KwKeys{yn})
names = sorteddiff(xn,yn)
fields = map(names) do name
:(getfield(x, $(Expr(:quote, name))))
end
quote
Core.struct($names, $(fields...))
end
end

start(s::Core.Struct) = 1
done(s::Core.Struct, i) = i > nfields(s)
next(s::Core.Struct, i) = ((fieldname(typeof(s),i) => getfield(s,i)), i+1)
76 changes: 72 additions & 4 deletions base/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,42 @@ add_tfunc(is, 2, 2,
return Bool
end
end)
add_tfunc(isdefined, 1, IInf, (args...)->Bool)
function isdefined_tfunc(args...)
a1 = widenconst(args[1])
if isType(a1)
a1 = typeof(a1.parameters[1])
if a1 === TypeVar
return Bool
end
end
if isleaftype(a1)
if a1 <: Array # TODO
elseif a1 === Module # TODO
elseif length(args) == 2 && isa(args[2],Const)
n = nfields(a1)
val = args[2].val
idx::Int = 0
if isa(val, Symbol)
for i=1:n
if fieldname(a1, i) === val
idx = i
break
end
end
elseif isa(val, Int)
idx = val::Int
end

if 1 <= idx <= a1.ninitialized
return Const(true)
elseif idx <= 0 || idx > n
return Const(false)
end
end
end
Bool
end
add_tfunc(isdefined, 1, IInf, isdefined_tfunc)
add_tfunc(Core.sizeof, 1, 1, x->Int)
add_tfunc(nfields, 1, 1, x->(isa(x,Const) ? Const(nfields(x.val)) :
isType(x) && isleaftype(x.parameters[1]) ? Const(nfields(x.parameters[1])) :
Expand Down Expand Up @@ -469,7 +504,11 @@ function getfield_tfunc(s0::ANY, name)
end
end
end
snames = s.name.names
snames = isdefined(s,:names) ? s.names : s.name.names
if s <: Core.Struct && !isleaftype(s)
# TODO
return Any,false
end
for i=1:length(snames)
if is(snames[i],fld)
R = s.types[i]
Expand Down Expand Up @@ -642,6 +681,28 @@ function builtin_tfunction(f::ANY, argtypes::Array{Any,1}, sv::InferenceState)
end
end
return Const(tuple(map(a->a.val, argtypes)...))
elseif is(f,Core.struct)
if length(argtypes) >= 1
if isa(argtypes[1],Const)
names = argtypes[1].val
values = argtypes[2:end]
has_dup = false
n = nfields(names)
for i=2:n
name = getfield(names, i)
for j=1:i-1
if name == getfield(names, j)
has_dup = true
break
end
end
end
if !has_dup && length(values) == nfields(names)
tup = limit_tuple_type(argtypes_to_type(values))
return isleaftype(tup) ? Core.Struct{names, tup} : Core.Struct{names,TypeVar(:_,tup)}
end
end
end
elseif is(f,svec)
return SimpleVector
elseif is(f,arrayset)
Expand Down Expand Up @@ -2165,7 +2226,7 @@ end
const _pure_builtins = Any[tuple, svec, fieldtype, apply_type, is, isa, typeof]

# known effect-free calls (might not be affect-free)
const _pure_builtins_volatile = Any[getfield, arrayref]
const _pure_builtins_volatile = Any[getfield, arrayref, isdefined]

function is_pure_builtin(f::ANY)
if contains_is(_pure_builtins, f)
Expand Down Expand Up @@ -2315,6 +2376,13 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference
return NF
end

if (is(f, Core.struct) &&
e.typ <: Core.Struct && isleaftype(e.typ))
new_e = Expr(:new, e.typ, argexprs[3:end]...)
new_e.typ = e.typ
return (new_e, ())
end

atype = argtypes_to_type(atypes)
if length(atype.parameters) - 1 > MAX_TUPLETYPE_LEN
atype = limit_tuple_type(atype)
Expand Down Expand Up @@ -3258,7 +3326,7 @@ function gotoifnot_elim_pass!(linfo::LambdaInfo, sv::InferenceState)
# doesn't recognize the error for strictly non-Bool condition)
if isa(val, Bool)
# in case there's side effects... (like raising `UndefVarError`)
body[i - 1] = cond
body[i - 1] = effect_free(cond, sv, true) ? nothing : cond
if val === false
insert!(body, i, GotoNode(expr.args[2]))
i += 1
Expand Down
2 changes: 1 addition & 1 deletion base/methodshow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ function arg_decl_parts(m::Method)
end

function kwarg_decl(sig::ANY, kwtype::DataType)
sig = Tuple{kwtype, Array, sig.parameters...}
sig = Tuple{kwtype, Struct, sig.parameters...}
kwli = ccall(:jl_methtable_lookup, Any, (Any, Any), kwtype.name.mt, sig)
if kwli !== nothing
kwli = kwli::Method
Expand Down
8 changes: 4 additions & 4 deletions base/multi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -805,7 +805,7 @@ end
function remotecall(f, w::Worker, args...; kwargs...)
rr = Future(w)
#println("$(myid()) asking for $rr")
send_msg(w, CallMsg{:call}(f, args, kwargs, remoteref_id(rr)))
send_msg(w, CallMsg{:call}(f, args, [kwargs...], remoteref_id(rr)))
rr
end

Expand All @@ -823,7 +823,7 @@ function remotecall_fetch(f, w::Worker, args...; kwargs...)
oid = RRID()
rv = lookup_ref(oid)
rv.waitingfor = w.id
send_msg(w, CallMsg{:call_fetch}(f, args, kwargs, oid))
send_msg(w, CallMsg{:call_fetch}(f, args, [kwargs...], oid))
v = take!(rv)
delete!(PGRP.refs, oid)
isa(v, RemoteException) ? throw(v) : v
Expand All @@ -840,7 +840,7 @@ function remotecall_wait(f, w::Worker, args...; kwargs...)
rv = lookup_ref(prid)
rv.waitingfor = w.id
rr = Future(w)
send_msg(w, CallWaitMsg(f, args, kwargs, remoteref_id(rr), prid))
send_msg(w, CallWaitMsg(f, args, [kwargs...], remoteref_id(rr), prid))
v = fetch(rv.c)
delete!(PGRP.refs, prid)
isa(v, RemoteException) && throw(v)
Expand All @@ -860,7 +860,7 @@ function remote_do(f, w::LocalProcess, args...; kwargs...)
end

function remote_do(f, w::Worker, args...; kwargs...)
send_msg(w, RemoteDoMsg(f, args, kwargs))
send_msg(w, RemoteDoMsg(f, args, [kwargs...]))
nothing
end

Expand Down
4 changes: 2 additions & 2 deletions base/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ end

Get the name of field `i` of a `DataType`.
"""
fieldname(t::DataType, i::Integer) = t.name.names[i]::Symbol
fieldname{T<:Tuple}(t::Type{T}, i::Integer) = i < 1 || i > nfields(t) ? throw(BoundsError(t, i)) : Int(i)
Core.fieldname(t::DataType, i::Integer) = isdefined(t,:names) ? t.names[i]::Symbol : t.name.names[i]::Symbol
Core.fieldname{T<:Tuple}(t::Type{T}, i::Integer) = i < 1 || i > nfields(t) ? throw(BoundsError(t, i)) : Int(i)

"""
fieldnames(x::DataType)
Expand Down
5 changes: 4 additions & 1 deletion src/alloc.c
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ jl_value_t *jl_stackovf_exception;
#ifdef SEGV_EXCEPTION
jl_value_t *jl_segv_exception;
#endif
jl_typename_t *jl_struct_typename;
jl_datatype_t *jl_struct_type;
JL_DLLEXPORT jl_value_t *jl_diverror_exception;
JL_DLLEXPORT jl_value_t *jl_domain_exception;
JL_DLLEXPORT jl_value_t *jl_overflow_exception;
Expand Down Expand Up @@ -160,7 +162,7 @@ void jl_assign_bits(void *dest, jl_value_t *bits)

JL_DLLEXPORT int jl_field_index(jl_datatype_t *t, jl_sym_t *fld, int err)
{
jl_svec_t *fn = t->name->names;
jl_svec_t *fn = jl_field_names(t);
for(size_t i=0; i < jl_svec_len(fn); i++) {
if (jl_svecref(fn,i) == (jl_value_t*)fld) {
return (int)i;
Expand Down Expand Up @@ -829,6 +831,7 @@ JL_DLLEXPORT jl_datatype_t *jl_new_uninitialized_datatype(size_t nfields, int8_t
// corruption otherwise.
t->fielddesc_type = fielddesc_type;
t->nfields = nfields;
t->names = NULL;
t->haspadding = 0;
t->pointerfree = 0;
t->depth = 0;
Expand Down
2 changes: 1 addition & 1 deletion src/builtin_proto.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ DECLARE_BUILTIN(fieldtype); DECLARE_BUILTIN(arrayref);
DECLARE_BUILTIN(arrayset); DECLARE_BUILTIN(arraysize);
DECLARE_BUILTIN(apply_type); DECLARE_BUILTIN(applicable);
DECLARE_BUILTIN(invoke); DECLARE_BUILTIN(_expr);

DECLARE_BUILTIN(struct);
#ifdef __cplusplus
}
#endif
Expand Down
Loading