Skip to content

Commit

Permalink
Merge #1619
Browse files Browse the repository at this point in the history
1619: Forward map(f, ::OneHotLike) to broadcast r=darsnack a=darsnack

Fixes #958 by forwarding `Base.map(f, ::OneHotLike)` to `Base.broadcast`.

### PR Checklist

- [x] Tests are added
- [x] ~~Entry in NEWS.md~~
- [x] ~~Documentation, if applicable~~
- [ ] API changes require approval from a committer (different from the author, if applicable)


Co-authored-by: Kyle Daruwalla <[email protected]>
  • Loading branch information
bors[bot] and darsnack authored Jun 15, 2021
2 parents 108cbc8 + 373728b commit 7e00a0e
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/onehot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ Adapt.adapt_structure(T, x::OneHotArray{<:Any, L}) where L = OneHotArray(adapt(T

Base.BroadcastStyle(::Type{<:OneHotArray{<: Any, <: Any, <: Any, N, <: CuArray}}) where N = CUDA.CuArrayStyle{N}()

Base.map(f, x::OneHotLike) = Base.broadcast(f, x)

Base.argmax(x::OneHotLike; dims = Colon()) =
(_isonehot(x) && dims == 1) ?
reshape(CartesianIndex.(_indices(x), CartesianIndices(_indices(x))), 1, size(_indices(x))...) :
Expand Down
6 changes: 6 additions & 0 deletions test/cuda/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ end
@test Flux.onecold(y, l) == ['a', 'a', 'a']
end

@testset "onehot forward map to broadcast" begin
oa = OneHotArray(rand(1:10, 5, 5), 10) |> gpu
@test all(map(identity, oa) .== oa)
@test all(map(x -> 2 * x, oa) .== 2 .* oa)
end

@testset "restructure gpu" begin
dudt = Dense(1,1) |> gpu
p,re = Flux.destructure(dudt)
Expand Down
5 changes: 5 additions & 0 deletions test/onehot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,4 +127,9 @@ end
@test argmax(oa; dims = 1) == argmax(convert(Array{Bool}, oa); dims = 1)
@test argmax(oa; dims = 3) == argmax(convert(Array{Bool}, oa); dims = 3)
end

@testset "Forward map to broadcast" begin
@test map(identity, oa) == oa
@test map(x -> 2 * x, oa) == 2 .* oa
end
end

0 comments on commit 7e00a0e

Please sign in to comment.