Skip to content

Commit

Permalink
split up Flatten layer to use the flatten function
Browse files Browse the repository at this point in the history
  • Loading branch information
gartangh committed Dec 5, 2019
1 parent 093b75b commit f8766d3
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 5 deletions.
16 changes: 12 additions & 4 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -392,12 +392,20 @@ Flattening layer.
Transforms (w,h,c,b)-shaped input into (w*h*c,b)-shaped output,
by linearizing all values for each element in the batch.
"""
struct Flatten end
struct Flatten{F}
σ::F
function Flatten::F = identity) where {F}
return new{F}(σ)
end
end

function (f::Flatten)(x)
return reshape(x, :, size(x)[end])
function (f::Flatten)(x::AbstractArray)
σ = f.σ
σ(flatten(x))
end

function Base.show(io::IO, f::Flatten)
print(io, "Flatten()")
print(io, "Flatten(")
f.σ == identity || print(io, f.σ)
print(io, ")")
end
10 changes: 10 additions & 0 deletions src/layers/stateless.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,13 @@ function normalise(x::AbstractArray; dims=1)
σ′ = std(x, dims = dims, mean = μ′, corrected=false)
return (x .- μ′) ./ σ′
end

"""
flatten(x::AbstractArray)
Transforms (w,h,c,b)-shaped input into (w*h*c,b)-shaped output,
by linearizing all values for each element in the batch.
"""
function flatten(x::AbstractArray)
return reshape(x, :, size(x)[end])
end
9 changes: 8 additions & 1 deletion test/layers/stateless.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using Test
using Flux: onehotbatch, mse, crossentropy, logitcrossentropy,
σ, binarycrossentropy, logitbinarycrossentropy
σ, binarycrossentropy, logitbinarycrossentropy, flatten

const ϵ = 1e-7

Expand Down Expand Up @@ -62,3 +62,10 @@ const ϵ = 1e-7
end
end
end

@testset "helpers" begin
@testset "flatten" begin
x = randn(Float32, 10, 10, 3, 2)
@test size(flatten(x)) == (300, 2)
end
end

0 comments on commit f8766d3

Please sign in to comment.