Skip to content

Commit

Permalink
Merge #1584
Browse files Browse the repository at this point in the history
1584: fixes #1583 r=DhairyaLGandhi a=DhairyaLGandhi

Is the classic `Union{}` dispatch piracy gotcha

Co-authored-by: Dhairya Gandhi <[email protected]>
  • Loading branch information
bors[bot] and DhairyaLGandhi authored Apr 25, 2021
2 parents 509b21a + bdde17d commit f0fc291
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions src/onehot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ OneHotArray(indices::T, L::Integer) where {T<:Integer} = OneHotArray{T, L, 0, T}
OneHotArray(indices::AbstractArray{T, N}, L::Integer) where {T, N} = OneHotArray{T, L, N, typeof(indices)}(indices)

_indices(x::OneHotArray) = x.indices
_indices(x::Base.ReshapedArray{<:Any, <:Any, <:OneHotArray}) =
_indices(x::Base.ReshapedArray{<: Any, <: Any, <: OneHotArray}) =
reshape(parent(x).indices, x.dims[2:end])

const OneHotVector{T, L} = OneHotArray{T, L, 0, 1, T}
Expand Down Expand Up @@ -64,22 +64,22 @@ Base.getindex(x::OneHotArray, I::CartesianIndex{N}) where N = x[I[1], Tuple(I)[2
_onehot_bool_type(x::OneHotLike{<:Any, <:Any, <:Any, N, <:Union{Integer, AbstractArray}}) where N = Array{Bool, N}
_onehot_bool_type(x::OneHotLike{<:Any, <:Any, <:Any, N, <:CuArray}) where N = CuArray{Bool, N}

function Base.cat(xs::OneHotLike{<:Any, L}...; dims::Int) where L
if isone(dims) || any(x -> !_isonehot(x), xs)
return cat(map(x -> convert(_onehot_bool_type(x), x), xs)...; dims = dims)
function Base.cat(x::OneHotLike{<:Any, L}, xs::OneHotLike{<:Any, L}...; dims::Int) where L
if isone(dims) || any(x -> !_isonehot(x), (x, xs...))
return cat(map(x -> convert(_onehot_bool_type(x), x), (x, xs...))...; dims = dims)
else
return OneHotArray(cat(_indices.(xs)...; dims = dims - 1), L)
return OneHotArray(cat(_indices(x), _indices.(xs)...; dims = dims - 1), L)
end
end

Base.hcat(xs::OneHotLike...) = cat(xs...; dims = 2)
Base.vcat(xs::OneHotLike...) = cat(xs...; dims = 1)
Base.hcat(x::OneHotLike, xs::OneHotLike...) = cat(x, xs...; dims = 2)
Base.vcat(x::OneHotLike, xs::OneHotLike...) = cat(x, xs...; dims = 1)

batch(xs::AbstractArray{<:OneHotVector{<:Any, L}}) where L = OneHotArray(_indices.(xs), L)

Adapt.adapt_structure(T, x::OneHotArray{<:Any, L}) where L = OneHotArray(adapt(T, _indices(x)), L)

Base.BroadcastStyle(::Type{<:OneHotArray{<:Any, <:Any, <:Any, N, <:CuArray}}) where N = CUDA.CuArrayStyle{N}()
Base.BroadcastStyle(::Type{<:OneHotArray{<: Any, <: Any, <: Any, N, <: CuArray}}) where N = CUDA.CuArrayStyle{N}()

Base.argmax(x::OneHotLike; dims = Colon()) =
(_isonehot(x) && dims == 1) ?
Expand Down

0 comments on commit f0fc291

Please sign in to comment.