From 96cc8bcafba0abe073e5174bb446a4078f1ee4ed Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Fri, 6 May 2022 12:30:39 +0200 Subject: [PATCH] onehotbatch with CuArray --- src/onehot.jl | 2 ++ test/cuda/cuda.jl | 7 +++++++ 2 files changed, 9 insertions(+) diff --git a/src/onehot.jl b/src/onehot.jl index 5c553db7c3..42d263aa44 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -185,6 +185,8 @@ julia> reshape(1:15, 3, 5) * oh # this matrix multiplication is done efficientl """ onehotbatch(data, labels, default...) = _onehotbatch(data, length(labels) < 32 ? Tuple(labels) : labels, default...) +_onehotbatch(data::CuArray, labels) = _onehotbatch(data |> cpu, labels) |> gpu + function _onehotbatch(data, labels) indices = UInt32[something(_findval(i, labels), 0) for i in data] if 0 in indices diff --git a/test/cuda/cuda.jl b/test/cuda/cuda.jl index 47a2368e3d..6dad7cfa4b 100644 --- a/test/cuda/cuda.jl +++ b/test/cuda/cuda.jl @@ -45,6 +45,13 @@ end gA = rand(3, 2) |> gpu; @test gradient(A -> sum(A * y), gA)[1] isa CuArray + + # construct from CuArray + x = [1, 3, 2] + y = Flux.onehotbatch(x, 0:3) + y2 = Flux.onehotbatch(x |> gpu, 0:3) + @test y2.indices isa CuArray + @test y2 |> cpu == y end @testset "onecold gpu" begin