diff --git a/src/tracker/lib/array.jl b/src/tracker/lib/array.jl index 690b0e18ec..de97f5ae50 100644 --- a/src/tracker/lib/array.jl +++ b/src/tracker/lib/array.jl @@ -223,8 +223,11 @@ Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Int64}}) = track(reshape, xs, @grad reshape(xs, dims) = reshape(data(xs), dims), Δ -> (reshape(Δ, size(xs)),nothing) -Base.permutedims(xs::TrackedArray, dims) = track(permutedims, xs, dims) -@grad permutedims(xs, dims) = permutedims(data(xs), dims), Δ -> (permutedims(Δ, invperm(dims)),nothing) +Base.permutedims(xs::TrackedArray, perm) = track(permutedims, xs, perm) +@grad permutedims(xs, perm) = permutedims(data(xs), perm), Δ -> (permutedims(Δ, invperm(perm)),nothing) + +Base.PermutedDimsArray(xs::TrackedArray, perm) = track(PermutedDimsArray, xs, perm) +@grad PermutedDimsArray(xs, perm) = PermutedDimsArray(data(xs), perm), Δ -> (PermutedDimsArray(Δ, invperm(perm)),nothing) function _kron(mat1::AbstractMatrix,mat2::AbstractMatrix) m1, n1 = size(mat1) diff --git a/test/tracker.jl b/test/tracker.jl index 8a9bded17b..34c14afa6d 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -116,6 +116,7 @@ end end @test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6)) +@test gradtest(x -> PermutedDimsArray(x, [3,1,2]), rand(4,5,6)) @test gradtest(x -> repeat(x; inner=2), rand(5)) @test gradtest(x -> repeat(x; inner=2, outer=3), rand(5))