Skip to content

Commit

Permalink
batch -> stack (#26)
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott authored Dec 1, 2022
1 parent 3059520 commit 9c43cb7
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 3 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "0.2.1"
[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Expand All @@ -14,8 +15,8 @@ NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Adapt = "3.0"
CUDA = "3.8"
ChainRulesCore = "1.13"
Compat = "4.2"
GPUArraysCore = "0.1.0"
MLUtils = "0.2, 0.3"
NNlib = "0.8"
Zygote = "0.6.35"
julia = "1.6"
Expand Down
2 changes: 1 addition & 1 deletion src/OneHotArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using Adapt
using ChainRulesCore
using GPUArraysCore
using LinearAlgebra
using MLUtils
using Compat: Compat
using NNlib

export onehot, onehotbatch, onecold,
Expand Down
11 changes: 10 additions & 1 deletion src/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,16 @@ Base.hcat(x::OneHotMatrix, xs::OneHotMatrix...) =
Base.hcat(x::OneHotVector, xs::OneHotVector...) =
OneHotMatrix(reduce(vcat, _indices.(xs); init = _indices(x)), _nlabels(x, xs...))

MLUtils.batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(_indices.(xs), _nlabels(xs...))
if isdefined(Base, :stack)
import Base: _stack
else
import Compat: _stack
end
function _stack(::Colon, xs::AbstractArray{<:OneHotArray})
n = _nlabels(first(xs))
all(x -> _nlabels(x)==n, xs) || throw(DimensionMismatch("The number of labels are not the same for all one-hot arrays."))
OneHotArray(Compat.stack(_indices, xs), n)
end

Adapt.adapt_structure(T, x::OneHotArray) = OneHotArray(adapt(T, _indices(x)), x.nlabels)

Expand Down
7 changes: 7 additions & 0 deletions test/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ end
@test cat(oa, oa; dims = 3) isa OneHotArray
@test cat(oa, oa; dims = 1) == cat(collect(oa), collect(oa); dims = 1)

# stack
@test stack([ov, ov]) == hcat(ov, ov)
@test stack([ov, ov, ov]) isa OneHotMatrix
@test stack([om, om]) == cat(om, om; dims = 3)
@test stack([om, om]) isa OneHotArray
@test stack([oa, oa, oa, oa]) isa OneHotArray

# proper error handling of inconsistent sizes
@test_throws DimensionMismatch hcat(ov, ov2)
@test_throws DimensionMismatch hcat(om, om2)
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using OneHotArrays
using Test
using Compat: stack

@testset "OneHotArray" begin
include("array.jl")
Expand Down

2 comments on commit 9c43cb7

@mcabbott
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/73253

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.1 -m "<description of version>" 9c43cb75a334f63dafb75143bb8e8adb6289d3f0
git push origin v0.2.1

Please sign in to comment.