Skip to content

Commit

Permalink
Merge pull request #1959 from FluxML/cl/oh
Browse files Browse the repository at this point in the history
onehotbatch with CuArray
  • Loading branch information
CarloLucibello authored May 7, 2022
2 parents 12bad50 + 96cc8bc commit 25457f5
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/onehot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,8 @@ julia> reshape(1:15, 3, 5) * oh # this matrix multiplication is done efficientl
"""
onehotbatch(data, labels, default...) = _onehotbatch(data, length(labels) < 32 ? Tuple(labels) : labels, default...)

_onehotbatch(data::CuArray, labels) = _onehotbatch(data |> cpu, labels) |> gpu

function _onehotbatch(data, labels)
indices = UInt32[something(_findval(i, labels), 0) for i in data]
if 0 in indices
Expand Down
7 changes: 7 additions & 0 deletions test/cuda/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ end

gA = rand(3, 2) |> gpu;
@test gradient(A -> sum(A * y), gA)[1] isa CuArray

# construct from CuArray
x = [1, 3, 2]
y = Flux.onehotbatch(x, 0:3)
y2 = Flux.onehotbatch(x |> gpu, 0:3)
@test y2.indices isa CuArray
@test y2 |> cpu == y
end

@testset "onecold gpu" begin
Expand Down

0 comments on commit 25457f5

Please sign in to comment.