Skip to content

Commit

Permalink
Merge pull request #576 from mcabbott/patch-1
Browse files Browse the repository at this point in the history
PermutedDimsArray
  • Loading branch information
MikeInnes authored Jan 29, 2019
2 parents dd95416 + 55a7359 commit 0469394
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
7 changes: 5 additions & 2 deletions src/tracker/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,11 @@ Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Int64}}) = track(reshape, xs,

@grad reshape(xs, dims) = reshape(data(xs), dims), Δ -> (reshape(Δ, size(xs)),nothing)

Base.permutedims(xs::TrackedArray, dims) = track(permutedims, xs, dims)
@grad permutedims(xs, dims) = permutedims(data(xs), dims), Δ -> (permutedims(Δ, invperm(dims)),nothing)
Base.permutedims(xs::TrackedArray, perm) = track(permutedims, xs, perm)
@grad permutedims(xs, perm) = permutedims(data(xs), perm), Δ -> (permutedims(Δ, invperm(perm)),nothing)

Base.PermutedDimsArray(xs::TrackedArray, perm) = track(PermutedDimsArray, xs, perm)
@grad PermutedDimsArray(xs, perm) = PermutedDimsArray(data(xs), perm), Δ -> (PermutedDimsArray(Δ, invperm(perm)),nothing)

function _kron(mat1::AbstractMatrix,mat2::AbstractMatrix)
m1, n1 = size(mat1)
Expand Down
1 change: 1 addition & 0 deletions test/tracker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ end
end

@test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6))
@test gradtest(x -> PermutedDimsArray(x, [3,1,2]), rand(4,5,6))

@test gradtest(x -> repeat(x; inner=2), rand(5))
@test gradtest(x -> repeat(x; inner=2, outer=3), rand(5))
Expand Down

0 comments on commit 0469394

Please sign in to comment.