From 9c43cb75a334f63dafb75143bb8e8adb6289d3f0 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 1 Dec 2022 09:55:32 -0500 Subject: [PATCH] batch -> stack (#26) --- Project.toml | 3 ++- src/OneHotArrays.jl | 2 +- src/array.jl | 11 ++++++++++- test/array.jl | 7 +++++++ test/runtests.jl | 1 + 5 files changed, 21 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index 3e2f41a..e26dae3 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.2.1" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" @@ -14,8 +15,8 @@ NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Adapt = "3.0" CUDA = "3.8" ChainRulesCore = "1.13" +Compat = "4.2" GPUArraysCore = "0.1.0" -MLUtils = "0.2, 0.3" NNlib = "0.8" Zygote = "0.6.35" julia = "1.6" diff --git a/src/OneHotArrays.jl b/src/OneHotArrays.jl index 0ded073..f067e52 100644 --- a/src/OneHotArrays.jl +++ b/src/OneHotArrays.jl @@ -4,7 +4,7 @@ using Adapt using ChainRulesCore using GPUArraysCore using LinearAlgebra -using MLUtils +using Compat: Compat using NNlib export onehot, onehotbatch, onecold, diff --git a/src/array.jl b/src/array.jl index ca1580e..6350eb9 100644 --- a/src/array.jl +++ b/src/array.jl @@ -129,7 +129,16 @@ Base.hcat(x::OneHotMatrix, xs::OneHotMatrix...) = Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(reduce(vcat, _indices.(xs); init = _indices(x)), _nlabels(x, xs...)) -MLUtils.batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(_indices.(xs), _nlabels(xs...)) +if isdefined(Base, :stack) + import Base: _stack +else + import Compat: _stack +end +function _stack(::Colon, xs::AbstractArray{<:OneHotArray}) + n = _nlabels(first(xs)) + all(x -> _nlabels(x)==n, xs) || throw(DimensionMismatch("The number of labels are not the same for all one-hot arrays.")) + OneHotArray(Compat.stack(_indices, xs), n) +end Adapt.adapt_structure(T, x::OneHotArray) = OneHotArray(adapt(T, _indices(x)), x.nlabels) diff --git a/test/array.jl b/test/array.jl index 76ef2b1..70247a1 100644 --- a/test/array.jl +++ b/test/array.jl @@ -64,6 +64,13 @@ end @test cat(oa, oa; dims = 3) isa OneHotArray @test cat(oa, oa; dims = 1) == cat(collect(oa), collect(oa); dims = 1) + # stack + @test stack([ov, ov]) == hcat(ov, ov) + @test stack([ov, ov, ov]) isa OneHotMatrix + @test stack([om, om]) == cat(om, om; dims = 3) + @test stack([om, om]) isa OneHotArray + @test stack([oa, oa, oa, oa]) isa OneHotArray + # proper error handling of inconsistent sizes @test_throws DimensionMismatch hcat(ov, ov2) @test_throws DimensionMismatch hcat(om, om2) diff --git a/test/runtests.jl b/test/runtests.jl index 94453e2..11077fb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,6 @@ using OneHotArrays using Test +using Compat: stack @testset "OneHotArray" begin include("array.jl")