-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Compatibility with Tracker from Flux for AD? #570
Comments
I can take a closer look at this after Wednesday. |
OK, so this appears to be because: julia> Tracker.track(identity, rand(3)) |> typeof |> supertype
AbstractArray{Float64,1} while julia> Tracker.track(identity, rand(3)) |> eltype
Tracker.TrackedReal{Float64} So in a sense, I also see that Tracker.jl is on its way out. I'm not familiar with Flux's status quo; do you know why they're replacing Tracker.jl and what they're replacing it with? |
After #571, I think the next issue is with using Tracker
using StaticArrays
x = Tracker.track(identity, 1.0)
y = [Tracker.track(identity, 2.0), Tracker.track(identity, 2.0)]
@show typeof(x * y) which results in TrackedArray{…,Array{Tracker.TrackedReal{Float64},1}} i.e., a But Tracker is being replaced with Zygote of course. It might make sense to switch to try the |
Also wanted to point to https://github.com/tkoolen/RigidBodyDynamicsDiff.jl, which, though experimental, has moderately optimized gradients w.r.t. |
Thank you for your detailed reply! using RigidBodyDynamics
using RigidBodySim
using DifferentialEquations
using Zygote
urdf = joinpath(dirname(pathof(RigidBodySim)), "..", "test", "urdf", "Acrobot.urdf")
mechanism = parse_urdf(Float32, urdf)
remove_fixed_tree_joints!(mechanism);
state = MechanismState(mechanism)
shoulder, elbow = joints(mechanism)
# Set the initial state
configuration(state, shoulder) .= 0.3
configuration(state, elbow) .= 0.4
velocity(state, shoulder) .= 1.
velocity(state, elbow) .= 2.;
mutable struct ControllerParams
param::Float32
end
p_test = ControllerParams(5)
function create_controller(p::ControllerParams)
function control!(τ, t, state)
view(τ, velocity_range(state, shoulder)) .= p.param * sin(t)
view(τ, velocity_range(state, elbow)) .= -configuration(state, shoulder)
end
return control!
end
function simulate_full(p::ControllerParams)
cc! = create_controller(p)
open_loop_dynamics = Dynamics(mechanism, cc!);
problem = ODEProblem(open_loop_dynamics, state, (0., 1))
sol = solve(problem, Tsit5(), abs_tol = 1e-7, dt = 0.05)
println("Done")
return sol[end][1]
end
# Check if we can get the gradient of a function using a ControllerParams
@show gradient(p -> p.param * 2, p_test)
# Check if we can simulate 1 sec of Dynamics
@show simulate_full(p_test)
# Check if we can get the gradient of the Dynamic with the respect of a ControllerParams
# -> Breaks
@show gradient(simulate_full, p_test) Stack trace: What I am trying to do is basically learn a parameterized controller with gradient descent. |
I actually thought it was a bug with Atom and I tried to execute the above script using the CLI but it's the same error
|
(away from a computer) I actually think that particular error is just due to the |
I actually removed the println and got the exact same error. |
Is it correct that Zygote and RigidBodyDynamics are not yet compatible? In my simple test, I fail because Zygote does not yet support mutating arrays. |
Hey there,
I am trying to use Flux models as torque controllers and eventually backpropagate from some sort of loss to the parameters of those models but it seems that RigidBodyDynamics.jl breaks Tracker.
I know that Flux and Tracker work with DifferentialEquations.jl (DiffEqFlux.jl as an example).
Edit: The reason why I am using
Tracker.TrackedReal{Float64}
as the type for MechanismState is because I need those parameters to record the operations being executed on them (just like the ForwardDiff.jl tutorial for this package)Here is a minimal repro:
Stack trace:
The text was updated successfully, but these errors were encountered: