Skip to content

Commit

Permalink
conceptualise the interface for julia native ops
Browse files Browse the repository at this point in the history
  • Loading branch information
vchuravy committed Dec 12, 2015
1 parent 0110435 commit 22c1ecb
Showing 1 changed file with 38 additions and 18 deletions.
56 changes: 38 additions & 18 deletions src/nativeops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,21 +41,55 @@ immutable NativeOpInfo
p_list_outputs :: Ptr{Void}
p_list_arguments :: Ptr{Void}

function NativeOpInfo(forward :: Function, backwards :: Function, infer_shape :: Function, list_outputs :: Function, list_arguments :: Function)
function NativeOpInfo(op :: Operator, forward, backward)
p_is, p_loa, p_la = pointer_from_objref(op)

c_wrapper_fb = cfunction(_wrapper_fb, Void, (Cint, Ptr{Ptr{Cfloat}}, Ptr{Cint}, Ptr{Ptr{Cuint}}, Ptr{Cint}, Ptr{Void}))
c_wrapper_infer = cfunction(_wrapper_infer, Void, (Cint, Ptr{Cint}, Ptr{Ptr{Cuint}}, Ptr{Void}))
const c_wrapper_list = cfunction(_wrapper_list, Void, (Ptr{Ptr{Ptr{Cchar}}}, Ptr{Void}))

p_f = pointer_from_objref(forward)
p_b = pointer_from_objref(backwards)
p_is = pointer_from_objref(infer_shape)
p_lo = pointer_from_objref(list_outputs)
p_la = pointer_from_objref(list_arguments)
new(c_wrapper_fb, c_wrapper_fb, c_wrapper_infer, c_wrapper_list,
c_wrapper_list, p_f, p_b, p_is, p_lo, p_la)
end
end

###
# Infer and list are called in sync.
###
function _wrapper_infer(size :: Cint, ndims :: Ptr{Cint}, shapes :: Ptr{Ptr{Cuint}}, _op :: Ptr{Void})
op = unsafe_pointer_to_objref(_op) :: Operator

n_in = length(list_arguments(op))
n_out = length(list_outputs(op))
@assert size == n_in + n_out

shapes = [[tensor_shapes[i][j] for j in 1:tensor_dims[i]] for i in 1:n_in]]

ishape, oshape = infer_shape(op, shapes)
@assert length(ishape) == n_in
@assert length(oshape) == n_out

rshape = cat(ishape, oshape)
unsafe_store!(shapes, rshapes)
return nothing
end

function _wrapper_list_arguments(data :: Ptr{Ptr{Cstring}}, _op :: Ptr{Void})
op = unsafe_pointer_to_objref(_op) :: Operator
arguments = list_arguments(op)
unsafe_store!(data, arguments)
return nothing
end

function _wrapper_list_outputs(data :: Ptr{Ptr{Cstring}}, _op :: Ptr{Void})
op = unsafe_pointer_to_objref(_op) :: Operator
outputs = list_outputs(op)
unsafe_store!(data, outputs)
return nothing
end

##
# Forward and backward can be called from different threads in libmxnet and
# so we need to take special care in handling these callbacks correctly in
Expand Down Expand Up @@ -87,20 +121,6 @@ function _wrapper_fb(size :: Cint, data :: Ptr{Ptr{Cfloat}}, ndims :: Ptr{Cint},
nothing
end

###
# Infer and list are called in sync.
###
function _wrapper_infer(size :: Cint, ndims :: Ptr{Cint}, shapes :: Ptr{Ptr{Cuint}}, jf :: Ptr{Void})
entry = unsafe_pointer_to_objref(jf) :: Function
entry(size, ndims, shapes)
return nothing
end

function _wrapper_list(data :: Ptr{Ptr{Cstring}}, jf :: Ptr{Void})
entry = unsafe_pointer_to_objref(jf) :: Function
entry(data)
return nothing
end

create_info() = NativeOpInfo(fb_entry, fb_entry, infer_entry, list_entry, list_entry)
# pstring = bytestring("0x", hex(reinterpret(UInt, pointer_from_objref(info))))
Expand Down

0 comments on commit 22c1ecb

Please sign in to comment.