diff --git a/Project.toml b/Project.toml index 733f91345..6d5a50658 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "KernelFunctions" uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392" -version = "0.10.4" +version = "0.10.5" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/utils.jl b/src/utils.jl index 33a085c56..d17e4273c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -76,6 +76,8 @@ Base.getindex(D::ColVecs, i::CartesianIndex{1}) = view(D.X, :, i) Base.getindex(D::ColVecs, i) = ColVecs(view(D.X, :, i)) Base.setindex!(D::ColVecs, v::AbstractVector, i) = setindex!(D.X, v, :, i) +Base.vcat(a::ColVecs, b::ColVecs) = ColVecs(hcat(a.X, b.X)) + dim(x::ColVecs) = size(x.X, 1) pairwise(d::PreMetric, x::ColVecs) = Distances.pairwise(d, x.X; dims=2) @@ -144,6 +146,8 @@ Base.getindex(D::RowVecs, i::CartesianIndex{1}) = view(D.X, i, :) Base.getindex(D::RowVecs, i) = RowVecs(view(D.X, i, :)) Base.setindex!(D::RowVecs, v::AbstractVector, i) = setindex!(D.X, v, i, :) +Base.vcat(a::RowVecs, b::RowVecs) = RowVecs(vcat(a.X, b.X)) + dim(x::RowVecs) = size(x.X, 2) pairwise(d::PreMetric, x::RowVecs) = Distances.pairwise(d, x.X; dims=1) diff --git a/test/utils.jl b/test/utils.jl index 8bdd16330..a671f5b09 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -33,6 +33,8 @@ pairwise(SqEuclidean(), X; dims=2) @test KernelFunctions.pairwise(SqEuclidean(), DX, DY) ≈ pairwise(SqEuclidean(), X, Y; dims=2) + @test vcat(DX, DY) isa ColVecs + @test vcat(DX, DY).X == hcat(X, Y) K = zeros(N, N) KernelFunctions.pairwise!(K, SqEuclidean(), DX) @test K ≈ pairwise(SqEuclidean(), X; dims=2) @@ -72,6 +74,8 @@ pairwise(SqEuclidean(), X; dims=1) @test KernelFunctions.pairwise(SqEuclidean(), DX, DY) ≈ pairwise(SqEuclidean(), X, Y; dims=1) + @test vcat(DX, DY) isa RowVecs + @test vcat(DX, DY).X == vcat(X, Y) K = zeros(D, D) KernelFunctions.pairwise!(K, SqEuclidean(), DX) @test K ≈ pairwise(SqEuclidean(), X; dims=1)