diff --git a/test/transform/selecttransform.jl b/test/transform/selecttransform.jl index b0e3bfa81..a6d382462 100644 --- a/test/transform/selecttransform.jl +++ b/test/transform/selecttransform.jl @@ -104,17 +104,41 @@ end @testset "$(AD)" for AD in [:ReverseDiff] - @test_broken ga = gradient(AD, A) do a - testfunction(ta_row, a, 2) + @test_broken let + gx = gradient(AD, X) do x + testfunction(tx_row, x, 2) + end + ga = gradient(AD, A) do a + testfunction(ta_row, a, 2) + end + gx ≈ ga end - @test_broken ga = gradient(AD, A) do a - testfunction(ta_col, a, 1) + @test_broken let + gx = gradient(AD, X) do x + testfunction(tx_col, x, 1) + end + ga = gradient(AD, A) do a + testfunction(ta_col, a, 1) + end + gx ≈ ga end - @test_broken ga = gradient(AD, A) do a - testfunction(ta_row, a, B, 2) + @test_broken let + gx = gradient(AD, X) do x + testfunction(tx_row, x, Y, 2) + end + ga = gradient(AD, A) do a + testfunction(ta_row, a, B, 2) + end + gx ≈ ga end - @test_broken ga = gradient(AD, A) do a - testfunction(ta_col, a, C, 1) + @test_broken let + gx = gradient(AD, X) do x + testfunction(tx_col, x, Z, 1) + end + ga = gradient(AD, A) do a + testfunction(ta_col, a, C, 1) + end + gx ≈ ga end end