Skip to content

Commit

Permalink
Merge #1616
Browse files Browse the repository at this point in the history
1616: Warn on reconstruct length mismatch r=CarloLucibello a=ToucheSir

Ref. #1601. This is kept as a plain warning for backwards compat, but perhaps we want to consider it a bugfix and error/depwarn instead?

### PR Checklist

- [x] Tests are added
- [ ] Entry in NEWS.md
- [ ] Documentation, if applicable
- [ ] API changes require approval from a committer (different from the author, if applicable)


Co-authored-by: Brian Chen <[email protected]>
  • Loading branch information
bors[bot] and ToucheSir authored Jun 17, 2021
2 parents 335286a + 0f9e672 commit 27c4c77
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 2 deletions.
12 changes: 10 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -610,16 +610,24 @@ end

function _restructure(m, xs)
i = 0
fmap(m) do x
= fmap(m) do x
x isa AbstractArray || return x
x = reshape(xs[i.+(1:length(x))], size(x))
i += length(x)
return x
end
length(xs) == i || @warn "Expected $(i) params, got $(length(xs))"
return
end

@adjoint function _restructure(m, xs)
_restructure(m, xs), dm -> (nothing,destructure(dm)[1])
m̄, numel = _restructure(m, xs), length(xs)
function _restructure_pullback(dm)
xs′ = destructure(dm)[1]
numel == length(xs′) || @warn "Expected $(numel) params, got $(length(xs′))"
return (nothing, xs′)
end
return m̄, _restructure_pullback
end

"""
Expand Down
9 changes: 9 additions & 0 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,15 @@ end
p, re = destructure(m)
testdense(re(p), bt)
end

@testset "restructure in gradient" begin
x = rand(Float32, 3, 1)
m = dm(zeros)
∇m = gradient(m -> sum(m(x)), m)[1]
p, re = destructure(m)
∇p = gradient-> sum(re(θ)(x)), p)[1]
@test ∇p destructure(∇m)[1]
end
end
end

Expand Down

0 comments on commit 27c4c77

Please sign in to comment.