You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When I try to differentiate OMEinsum, the following error were thrown. The code is shown bellow, where the variable code is a mutable, callable object. Since it does not carry gradient, I suppose it should work. But it does not.
using Enzyme.EnzymeRules, OMEinsum, Enzyme
using OMEinsum: get_size_dict!
function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(einsum!)}, ::Type,
code::Const, xs::Duplicated, ys::Duplicated, sx::Const, sy::Const, size_dict::Const)
@assert sx.val ==1&& sy.val ==0"Only α = 1 and β = 0 is supported, got: $sx, $sy"# Compute primalif EnzymeRules.needs_primal(config)
primal = func.val(code.val, xs.val, ys.val, sx.val, sy.val, size_dict.val)
else
primal =nothingend# Save x in tape if x will be overwrittenif EnzymeRules.overwritten(config)[3]
tape =copy(xs.val)
else
tape =nothingend
shadow = ys.dval
return EnzymeRules.AugmentedReturn(primal, shadow, tape)
endfunction EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(einsum!)}, dret::Type{<:Annotation}, tape,
code::Const,
xs::Duplicated, ys::Duplicated, sx::Const, sy::Const, size_dict::Const)
xval = EnzymeRules.overwritten(config)[3] ? tape : xs.val
for i=1:length(xs.val)
xs.dval[i] .+= OMEinsum.einsum_grad(OMEinsum.getixs(code.val),
xval, OMEinsum.getiy(code.val), size_dict.val, conj(ys.dval), i)
endreturn (nothing, nothing, nothing, nothing, nothing, nothing)
endusing Test
@testset"EnzymeExt bp check 2"begin
A, B, C =randn(2, 3), randn(3, 4), randn(4, 2)
code =optimize_code(ein"ij, jk, ki->", uniformsize(ein"ij, jk, ki->", 2), TreeSA())
cost0 =code(A, B, C)[]
gA =zero(A); gB =zero(B); gC =zero(C);
f(code, a, b, c) =code(a, b, c)[]
Enzyme.autodiff(set_runtime_activity(Reverse), f, Active, Const(code), Duplicated(A, gA), Duplicated(B, gB), Duplicated(C, gC))
cost, mg = OMEinsum.cost_and_gradient(code, (A, B, C))
@test cost[] ≈ cost0
@testall(gA .≈ mg[1])
@testall(gB .≈ mg[2])
@testall(gC .≈ mg[3])
end
Good point. I tried, but the log exceeded the length limit that GitHub allow. Is there any way that I can circumvent this limit? Or which part do you want to see?
When I try to differentiate OMEinsum, the following error were thrown. The code is shown bellow, where the variable
code
is a mutable, callable object. Since it does not carry gradient, I suppose it should work. But it does not.The error message is:
The text was updated successfully, but these errors were encountered: