Skip to content

Commit

Permalink
Merge #1632
Browse files Browse the repository at this point in the history
1632: Add WIP docstrings to CPU and GPU r=darsnack a=logankilpatrick

We could surely use more details, but ideally we would want to document these functions and what they do + usage.


Co-authored-by: Logan Kilpatrick <[email protected]>
Co-authored-by: Dhairya Gandhi <[email protected]>
  • Loading branch information
3 people authored Aug 22, 2021
2 parents 2032733 + 3413c16 commit 48f6ae9
Showing 1 changed file with 48 additions and 0 deletions.
48 changes: 48 additions & 0 deletions src/functor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,60 @@ function loadparams!(m, xs)
end

# CPU/GPU movement conveniences
"""
cpu(m)
Moves `m` onto the CPU.
This utility uses [`@functor`](@ref) to properly move structures to the CPU.
```julia-repl
julia> m = Dense(1,2)
Dense(1, 2)
julia> m_gpu = gpu(m)
Dense(1, 2)
julia> typeof(m_gpu.W)
CuArray{Float32, 2}
julia> m_cpu = cpu(m_gpu)
Dense(1, 2)
julia> typeof(m_cpu.W)
Matrix{Float32}
```
"""
cpu(m) = fmap(x -> adapt(Array, x), m)

_isbitsarray(::AbstractArray{<:Number}) = true
_isbitsarray(::AbstractArray{T}) where T = isbitstype(T)
_isbitsarray(x) = false

"""
gpu(x)
Moves `m` to the current GPU device, if available. It is a no-op otherwise.
See the [CUDA.jl docs](https://juliagpu.github.io/CUDA.jl/stable/usage/multigpu/)
to help identify the current device.
This works for functions and
any struct with [`@functor`](@ref) defined.
```julia-repl
julia> m = Dense(1,2)
Dense(1, 2)
julia> typeof(m.W)
Matrix{Float32}
julia> m_gpu = gpu(m)
Dense(1, 2)
julia> typeof(m_gpu.W) # notice the type of the array changed to a CuArray
CuArray{Float32, 2}
```
"""
gpu(x) = use_cuda[] ? fmap(CUDA.cu, x; exclude = _isbitsarray) : x

# Precision
Expand Down

0 comments on commit 48f6ae9

Please sign in to comment.