From 39217df1662bea050df391dcda2ec83b9d260589 Mon Sep 17 00:00:00 2001 From: ayush1999 Date: Mon, 4 Jun 2018 16:48:36 +0530 Subject: [PATCH] Inception support added --- src/functional.jl | 4 ++++ src/read.jl | 4 +++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/functional.jl b/src/functional.jl index c5a0b72..a88ac4c 100644 --- a/src/functional.jl +++ b/src/functional.jl @@ -32,6 +32,10 @@ function graphify(a::Array{Any, 1}, structure_file, weight_file, ip) op_activation = ops[:Dense](ele)[2] inputs = ele.input_nodes[1][1][1] res[ele.fields["name"]] = vcall(op_activation, vcall(op_dense, res[inputs])) + elseif ele.layer_type == :GlobalAveragePooling2D + inputs = ele.input_nodes[1][1][1] + res[ele.fields["name"]] = vcall(x -> reshape(x, size(x)[3], size(x)[4]), + vcall(ops[:GlobalAveragePooling2D](ele), res[inputs])) else inputs = ele.input_nodes[1][1][1] res[ele.fields["name"]] = vcall(ops[ele.layer_type](ele), res[inputs]) diff --git a/src/read.jl b/src/read.jl index 16ed533..81cc430 100644 --- a/src/read.jl +++ b/src/read.jl @@ -20,10 +20,12 @@ function weights(file="weights.h5") for ele2 in keys(weight[ele][ele*"_1"]) weight[ele][ele*"_1"][ele2] = convert(Array{Float64, N} where N, weight[ele][ele*"_1"][ele2]) end - else + elseif haskey(weight[ele], ele) for ele2 in keys(weight[ele][ele]) weight[ele][ele][ele2] = convert(Array{Float64, N} where N, weight[ele][ele][ele2]) end + elseif ele == "conv1" + weight[ele]["conv"][ele]["conv"] = convert(Array{Float64, N} where N, weight[ele]["conv"][ele]["conv"]["kernel:0"]) end end end