Skip to content

Commit

Permalink
Merge #956
Browse files Browse the repository at this point in the history
956: fix differentiation of loopinfo exprs r=DhairyaLGandhi a=simeonschaub

addresses the first part of #897

Co-authored-by: Simeon Schaub <[email protected]>
  • Loading branch information
bors[bot] and simeonschaub authored Apr 30, 2021
2 parents 4e54eb7 + 3218e50 commit 58c1f36
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/compiler/reverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ function adjoint(pr::Primal)
end
elseif ex isa Core.PiNode
grads[ex.val] = grads[v]
elseif isexpr(ex, GlobalRef, :call, :isdefined, :inbounds, :meta)
elseif isexpr(ex, GlobalRef, :call, :isdefined, :inbounds, :meta, :loopinfo)
elseif isexpr(ex)
push!(rb, stmt(xcall(Base, :error, "Can't differentiate $(ex.head) expression"),
line = b[v].line))
Expand Down
3 changes: 3 additions & 0 deletions test/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,6 @@ end
ms = MyStruct(1, 2)
@test Zygote.gradient(sumall, ms) == ((a = 2, b = 2),)
end

# issue 897
@test gradient(x -> sum(norm, collect(eachcol(x))), ones(3, 400))[1] fill(0.5773502691896258, 3, 400)
2 changes: 1 addition & 1 deletion test/features.jl
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ function pow_simd(x, n)
return r
end

@test_broken gradient(pow_simd, 2, 3) == (12,nothing)
@test gradient(pow_simd, 2, 3) == (12,nothing)

@testset "tuple getindex" begin
@test gradient(x -> size(x)[2], ones(2,2,2)) == (nothing,)
Expand Down

0 comments on commit 58c1f36

Please sign in to comment.