You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
using UUIDs
using Preferences
if Sys.isapple()
flux_uuid =UUID("587475ba-b771-5e3f-ad9e-33799f191a9c")
set_preferences!(flux_uuid, "gpu_backend"=>"Metal")
using Metal
elseusing CUDA, cuDNN
CUDA.allowscalar(false)
endusing Test
using Metal
using Flux
using Statistics
using Flux: _greek_ascii_depwarn, ofeltype
using Flux.Losses: huber_loss, _check_sizes, mse
using Zygote
X = Flux.gpu(Float32[1,1])
Y = Flux.gpu(Float32[1,1])
function_huber_metric(abs_error, δ)
#TODO: remove ignore_derivatives when Zygote can handle this function with CuArrays
temp = Zygote.ignore_derivatives(abs_error .< δ)
x =ofeltype(abs_error, 0.5)
((abs_error * abs_error) * temp) * x + δ * (abs_error - x * δ) * (1- temp)
endfunctionhuber_loss_alternate(ŷ, y; agg = mean, delta::Real=1, δ =nothing)
delta_tmp =_greek_ascii_depwarn(δ => delta, :huber_loss, "δ"=>"delta")
δ =ofeltype(ŷ, delta_tmp)
_check_sizes(ŷ, y)
abs_error =abs.(ŷ .- y)
agg(_huber_metric.(abs_error, δ))
end
Flux.gradient(X, Y) do a,b # Worksmse(a,b)
end
Flux.gradient(X, Y) do a,b # Workshuber_loss_alternate(a,b)
end
Flux.gradient(X, Y) do a,b # Failshuber_loss(a,b)
end
I think it's related to JuliaGPU/GPUArrays.jl#484
MWE:
The text was updated successfully, but these errors were encountered: