diff --git a/src/utils.jl b/src/utils.jl index 4f4e440f87..347ff20241 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -9,6 +9,7 @@ const AArray = AbstractArray initn(dims...) = randn(dims...)/100 unsqueeze(xs, dim = 1) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...)) +squeeze(xs, dim = 1) = Base.squeeze(xs, dim) stack(xs, dim = 1) = cat(dim, unsqueeze.(xs, dim)...) unstack(xs, dim = 1) = [slicedim(xs, dim, i) for i = 1:size(xs, dim)] diff --git a/test/recurrent.jl b/test/recurrent.jl index 7610ea1300..3f04d6c48d 100644 --- a/test/recurrent.jl +++ b/test/recurrent.jl @@ -13,5 +13,5 @@ end _, ys = apply(unroll1(r).model, xs, (r.y.x,)) @test ys[1] == tanh(xs[1] * r.Wxy.x .+ r.y.x * r.Wyy.x .+ r.by.x) ru = unroll(r, 3) - ru(batchone(Seq(squeeze.(xs, 1))))[1] == squeeze.(ys, 1) + ru(batchone(Seq(squeeze.(xs))))[1] == squeeze.(ys) end diff --git a/test/runtests.jl b/test/runtests.jl index e110725ee3..4a86c6e504 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,5 @@ using Flux, DataFlow, MacroTools, Base.Test -using Flux: graph, Param, unsqueeze +using Flux: graph, Param, squeeze, unsqueeze using DataFlow: Line, Frame macro mxonly(ex)