Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TrackedArrays and Scalars with a Chain #179

Closed
ZacCranko opened this issue Feb 20, 2018 · 6 comments
Closed

TrackedArrays and Scalars with a Chain #179

ZacCranko opened this issue Feb 20, 2018 · 6 comments

Comments

@ZacCranko
Copy link

Is this expected behaviour?

m = Chain(Dense(5, 1), σ)
data = rand(5, 10)
m(data)
ERROR: DimensionMismatch("multiplicative identity defined only for square matrices")
Stacktrace:
 [1] _one(::Float64, ::TrackedArray{…,Array{Float64,2}}) at ./array.jl:357
 [2] σ(::TrackedArray{…,Array{Float64,2}}) at /Users/zcranko/.julia/v0.6/NNlib/src/activation.jl:25
 [3] mapfoldl_impl(::Base.#identity, ::Flux.##65#66, ::Array{Float64,2}, ::Array{Any,1}, ::Int64) at ./reduce.jl:46
 [4] (::Flux.Chain)(::Array{Float64,2}) at /Users/zcranko/.julia/v0.6/Flux/src/layers/basic.jl:31

To get this to work we have to write:

m = Chain(Dense(5,1), x -> σ.(x))

which is potentially a little counter intuitive.

Now you might be tempted to reply "you should just use m = Chain(Dense(5, 1, σ))".

But if one has a sigmoid activated dense layer Dense(5, 1, σ), it's really complicated to pop the sigmoid layer and extract the features if your model is parameterised in that way. Cf.

model    = Chain(Dense(5, 1), σ)
features = model[1:end-1]  
@MikeInnes
Copy link
Member

Yeah, that's correct, see also #108. This is a fundamental design choice that Base Julia has made, so there's not much we can really do about it:

julia> tanh([1,2,3])
WARNING: tanh(x::AbstractArray{T}) where T <: Number is deprecated, use tanh.(x) instead.
Stacktrace:
 [1] depwarn(::String, ::Symbol) at ./deprecated.jl:70
 [2] tanh(::Array{Int64,1}) at ./deprecated.jl:57
 [3] eval(::Module, ::Any) at ./boot.jl:235
 [4] eval_user_input(::Any, ::Base.REPL.REPLBackend) at ./REPL.jl:66
 [5] macro expansion at ./REPL.jl:97 [inlined]
 [6] (::Base.REPL.##1#2{Base.REPL.REPLBackend})() at ./event.jl:73
while loading no file, in expression starting on line 0
3-element Array{Float64,1}:
 0.761594
 0.964028
 0.995055

One way to make it more convenient would be to write something like broadcasted(f) in place of x -> f.(x).

Although it's somewhat unusual among DL frameworks, closures are a pattern we use quite often in Flux (e.g. Chain(x -> reshape(x, 784, :), Dense(784, 10)). So we probably just need to make sure it's documented carefully.

@ZacCranko
Copy link
Author

ZacCranko commented Feb 21, 2018

broadcasted(f) would be at least cleaner.

Up to you of course, but there might be reason to go against the Julia design choice, since you could say that a Chain operates on a single datum, and Flux should work out if you've fed it multiple data or a single datum and broadcast appropriately.

@MikeInnes
Copy link
Member

One possibility would be that Chain itself does something here, but I'm not sure what that'd be.

It's not great to just special case a bunch of possible activation functions, and especially not to override Base's versions (e.g. for tanh).

@ZacCranko
Copy link
Author

Yeah I don't know either. Possibly this might be solved nicely with traits?

@ZacCranko
Copy link
Author

Let data be some N dimension array. Would it be consistent to define m.(data) where the "broadcasting" occurs over the last dimension of data? Then m(data) would throw a dimension mismatch unless fed an N-1 dimension array?

@MikeInnes
Copy link
Member

FluxML/NNlib.jl#98 will give a nicer error message here. I don't think there's much we can do beyond that. I guess if someone wants to PR broadcasted or similar then that'd be welcome.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants