Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Try mutating Tapir #126

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ using DifferentiationInterface: AutoTapir
import DifferentiationInterface as DI
using Tapir: CoDual, build_rrule, value_and_pullback!!, zero_codual

DI.supports_mutation(::AutoTapir) = DI.MutationNotSupported()

function zero_sametype!!(x_target, x::Number)
return zero(x)
end
Expand All @@ -17,31 +15,7 @@ function zero_sametype!!(x_target, x::AbstractArray)
return x_sametype
end

function DI.value_and_pullback(f, ::AutoTapir, x, dy, rrule)
y = f(x)
dy_righttype = convert(typeof(y), dy)
_, (_, dx) = value_and_pullback!!(rrule, dy_righttype, f, x)
return y, dx
end

function DI.value_and_pullback!!(f, dx, ::AutoTapir, x, dy, rrule)
y = f(x)
dy_righttype = convert(typeof(y), dy)
dx_righttype = zero_sametype!!(dx, x)
new_y, (_, new_dx) = value_and_pullback!!(
rrule, dy_righttype, zero_codual(f), CoDual(x, dx_righttype)
)
return new_y, new_dx
end

for op in [:pushforward, :pullback, :derivative, :gradient, :jacobian]
prep_op = Symbol(:prepare_, op)
@eval function DI.$prep_op(f, backend::AutoTapir, x)
return build_rrule(f, x)
end
@eval function DI.$prep_op(rrule, f, backend::AutoTapir, x)
return rrule
end
end
include("allocating.jl")
include("mutating.jl")

end
26 changes: 26 additions & 0 deletions ext/DifferentiationInterfaceTapirExt/allocating.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
function DI.value_and_pullback(f, ::AutoTapir, x, dy, rrule)
y = f(x) # TODO: one call too many, just for the conversion
dy_righttype = convert(typeof(y), dy)
_, (_, dx) = value_and_pullback!!(rrule, dy_righttype, f, x)
return y, dx
end

function DI.value_and_pullback!!(f, dx, ::AutoTapir, x, dy, rrule)
y = f(x) # TODO: one call too many, just for the conversion
dy_righttype = convert(typeof(y), dy)
dx_righttype = zero_sametype!!(dx, x)
new_y, (_, new_dx) = value_and_pullback!!(
rrule, dy_righttype, zero_codual(f), CoDual(x, dx_righttype)
gdalle marked this conversation as resolved.
Show resolved Hide resolved
)
return new_y, new_dx
end

for op in [:pushforward, :pullback, :derivative, :gradient, :jacobian]
prep_op = Symbol(:prepare_, op)
@eval function DI.$prep_op(f, ::AutoTapir, x)
return build_rrule(f, x)
end
@eval function DI.$prep_op(rrule, f, ::AutoTapir, x)
return rrule
end
end
10 changes: 10 additions & 0 deletions ext/DifferentiationInterfaceTapirExt/mutating.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
function DI.value_and_pullback!!(f!, y, dx, ::AutoTapir, x, dy, extras::Nothing)
rrule = build_rrule(f!, y, x)
gdalle marked this conversation as resolved.
Show resolved Hide resolved
dy_righttype = convert(typeof(y), dy)
gdalle marked this conversation as resolved.
Show resolved Hide resolved
dx_righttype = zero_sametype!!(dx, x)
gdalle marked this conversation as resolved.
Show resolved Hide resolved
dz = nothing # f!(y, x) = nothing
gdalle marked this conversation as resolved.
Show resolved Hide resolved
_, (_, _, new_dx) = value_and_pullback!!(
rrule, dz, zero_codual(f!), CoDual(y, dy_righttype), CoDual(x, dx_righttype)
)
return y, new_dx
end
29 changes: 16 additions & 13 deletions test/first_order.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,19 @@ using Zygote: Zygote
##

all_backends = [
AutoChainRules(Zygote.ZygoteRuleConfig()),
AutoDiffractor(),
AutoEnzyme(Enzyme.Forward),
AutoEnzyme(Enzyme.Reverse),
AutoFastDifferentiation(),
AutoFiniteDiff(),
AutoFiniteDifferences(FiniteDifferences.central_fdm(3, 1)),
AutoForwardDiff(),
AutoPolyesterForwardDiff(; chunksize=2),
AutoReverseDiff(),
# AutoChainRules(Zygote.ZygoteRuleConfig()),
# AutoDiffractor(),
# AutoEnzyme(Enzyme.Forward),
# AutoEnzyme(Enzyme.Reverse),
# AutoFastDifferentiation(),
# AutoFiniteDiff(),
# AutoFiniteDifferences(FiniteDifferences.central_fdm(3, 1)),
# AutoForwardDiff(),
# AutoPolyesterForwardDiff(; chunksize=2),
# AutoReverseDiff(),
AutoTapir(),
AutoTracker(),
AutoZygote(),
# AutoTracker(),
# AutoZygote(),
]

##
Expand All @@ -39,5 +39,8 @@ for backend in all_backends
end

test_differentiation(
all_backends; second_order=false, logging=get(ENV, "CI", "false") == "false"
all_backends;
second_order=false,
logging=get(ENV, "CI", "false") == "false",
detailed=true,
);
6 changes: 4 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ include("test_imports.jl")
## Main tests

@testset verbose = true "DifferentiationInterface.jl" begin
#=
@testset verbose = true "Formal tests" begin
@testset "Aqua" begin
Aqua.test_all(
Expand All @@ -27,11 +28,11 @@ include("test_imports.jl")
@testset "Zero backends" begin
include("zero_backends.jl")
end

=#
@testset verbose = true "First order" begin
include("first_order.jl")
end

#=
@testset verbose = true "Second order" begin
include("second_order.jl")
end
Expand All @@ -49,4 +50,5 @@ include("test_imports.jl")
include("weird_arrays.jl")
end
end
=#
end;
Loading