-
Hi, using Lux, LuxCUDA
const gpu = gpu_device()
x = rand(rng, 5, 3)
x_gpu = x |> gpu
x_gpu_f64 = f64(x_gpu) |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Not yet, see https://github.com/LuxDL/MLDataDevices.jl/issues/61. But if you pass in |
Beta Was this translation helpful? Give feedback.
Not yet, see https://github.com/LuxDL/MLDataDevices.jl/issues/61.
But if you pass in
MLDataDevices.default_device_rng(dev)
to the initialization functions, it will directly generate the parameters on GPU https://lux.csail.mit.edu/stable/api/Building_Blocks/WeightInitializers#Supported-RNG-Types-WeightInit. Combine it with setting the precision to Float64 to theinit_*
functions and you will get the desired behavior. A bit inconvenient currently but we should fix it...