Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug that caused Flux.params(x) call to not be cached (Closes issue #2040) #2048

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions src/functor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,14 @@ Possible values include:
"""
trainmode!(m, mode = true) = mode isa Bool ? testmode!(m, !mode) : testmode!(m, mode)

params!(p::Params, x::DenseArray{<:Number}, seen = IdSet()) = Functors.isleaf(x) && push!(p, x)
Copy link
Member

@mcabbott mcabbott Aug 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can a DenseArray{<:Number} ever not be a leaf?

Suggested change
params!(p::Params, x::DenseArray{<:Number}, seen = IdSet()) = Functors.isleaf(x) && push!(p, x)
params!(p::Params, x::DenseArray{<:Number}, seen) = push!(p, x)

Copy link
Member

@ToucheSir ToucheSir Aug 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically yes, Base.Experimental.Const is a pure wrapper type and subtypes DenseArray. I've seen it used in JuliaGPU libraries, but am unsure if those would ever come in contact with params.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's an interesting type. But I guess it will always be leaf-like, Functors should treat it as it would an SArray right?

More broadly if this method has a test for isleaf, then it has to do something with the other branch. And then it's the other method. I guess it could assert isleaf just to make sure you get an error if someone does something really weird.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it has to be a non-leaf for the same reason Transpose does: shared inner arrays.

RE the other branch, I thought the latest change addressed that but it appears I misremembered. Silently dropping an array instead of recursing is definitely not good.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess. Although transposing one of two shared arrays is common, but marking as Const only one of the two seems perverse.

Copy link
Member

@mcabbott mcabbott Aug 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me update my suggestion. I think this ought to be safe, and will at least throw an error should someone ever @functor Base.Experimental.Const (or its CUDA analogue):

Suggested change
params!(p::Params, x::DenseArray{<:Number}, seen = IdSet()) = Functors.isleaf(x) && push!(p, x)
function params!(p::Params, x::DenseArray{<:Number}, seen = IdSet())
# Fast path for the most common case, Array & CuArray. Solves issue 2040.
Functors.isleaf(x) || error("For efficiency, params believes every DenseArray of numbers is leaflike")
push!(p, x)
end

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code returns size.(Flux.params((x = view([1,2,3]pi, 1:2), y = transpose([4 5]pi)))) == [(1, 2)].

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code and what else?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code with the above suggestion.

Copy link
Member

@mcabbott mcabbott Aug 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. On this example, the suggestion changes nothing compared to the PR. It just moves the isleaf test to be an error not an ignore.

I think such a fast method should exist alongside the method which was here before the PR, which handles all cases (but has more branches). That should be correct. Whether it still solves 2040 I don't know.


function params!(p::Params, x, seen = IdSet())
if x isa AbstractArray{<:Number} && Functors.isleaf(x)
return push!(p, x)
elseif x in seen
nothing
else
push!(seen, x)
for child in trainable(x)
params!(p, child, seen)
end
x in seen && return

push!(seen, x)
for child in trainable(x)
params!(p, child, seen)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if I have a leaf type which isn't a DenseArray? The current behaviour is:

julia> using NamedDims, StaticArrays

julia> Flux.params((SA[2.2], 3:4.0, NamedDimsArray([5.0], :x)))
Params([[2.2], 3.0:1.0:4.0, NamedDimsArray([5.0], :x)])

What I meant with the DenseArray idea was that this method could be a short-cut for the common case, in addition to the existing method.

Of course the tests as always don't try very hard. But I do think that it ought to keep working with wrappers like NamedDims.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any wrappers already on the dependency chain which have this same behaviour outside of NamedDims?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe SubArray? ReshapedArray, SymTridiagonal ... for tests I guess you want something unlikely to be @functor-ed in the future.

julia> Flux.params(view([1,2,3]pi, 1:2))
Params([[3.141592653589793, 6.283185307179586]])

julia> ans[1] isa DenseArray
false

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SubArrays are a bit of a landmine IMO because they don't "cover" the entirety of the wrapped array. ReshapedArray makes sense though. Was it that or PermutedDimsArray that we found couldn't have its transform easily reversed?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC ReshapedArray was the tricky one, as its type doesn't have the shape.

end

Expand Down