Skip to content

Commit

Permalink
Inception support added
Browse files Browse the repository at this point in the history
  • Loading branch information
ayush1999 committed Jun 4, 2018
1 parent b0f913c commit 39217df
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
4 changes: 4 additions & 0 deletions src/functional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ function graphify(a::Array{Any, 1}, structure_file, weight_file, ip)
op_activation = ops[:Dense](ele)[2]
inputs = ele.input_nodes[1][1][1]
res[ele.fields["name"]] = vcall(op_activation, vcall(op_dense, res[inputs]))
elseif ele.layer_type == :GlobalAveragePooling2D
inputs = ele.input_nodes[1][1][1]
res[ele.fields["name"]] = vcall(x -> reshape(x, size(x)[3], size(x)[4]),
vcall(ops[:GlobalAveragePooling2D](ele), res[inputs]))
else
inputs = ele.input_nodes[1][1][1]
res[ele.fields["name"]] = vcall(ops[ele.layer_type](ele), res[inputs])
Expand Down
4 changes: 3 additions & 1 deletion src/read.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ function weights(file="weights.h5")
for ele2 in keys(weight[ele][ele*"_1"])
weight[ele][ele*"_1"][ele2] = convert(Array{Float64, N} where N, weight[ele][ele*"_1"][ele2])
end
else
elseif haskey(weight[ele], ele)
for ele2 in keys(weight[ele][ele])
weight[ele][ele][ele2] = convert(Array{Float64, N} where N, weight[ele][ele][ele2])
end
elseif ele == "conv1"
weight[ele]["conv"][ele]["conv"] = convert(Array{Float64, N} where N, weight[ele]["conv"][ele]["conv"]["kernel:0"])
end
end
end
Expand Down

0 comments on commit 39217df

Please sign in to comment.