From 035e20a42def36a3ccbc866ac3fe4d51b001825d Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 29 Dec 2023 19:01:14 +0530 Subject: [PATCH] test: add test for broadcast autodiff --- test/adjoints.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/adjoints.jl b/test/adjoints.jl index a6abb445..e06035af 100644 --- a/test/adjoints.jl +++ b/test/adjoints.jl @@ -37,6 +37,11 @@ function loss6(x) sum(abs2, Array(_prob.u0)) end +function loss7(x) + _x = VectorOfArray([x .* i for i in 1:5]) + return sum(abs2, x .- 1) +end + x = float.(6:10) loss(x) @test Zygote.gradient(loss, x)[1] == ForwardDiff.gradient(loss, x) @@ -45,3 +50,4 @@ loss(x) @test Zygote.gradient(loss4, x)[1] == ForwardDiff.gradient(loss4, x) @test Zygote.gradient(loss5, x)[1] == ForwardDiff.gradient(loss5, x) @test Zygote.gradient(loss6, x)[1] == ForwardDiff.gradient(loss6, x) +@test Zygote.gradient(loss7, x)[1] == ForwardDiff.gradient(loss7, x)