Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

summary simplification, test cases, and eager checks #28

Merged
merged 8 commits into from
Aug 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .github/workflows/permanent.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@ on:
push:
branches:
- 'master'
pull_request:
jobs:
document:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@latest
with:
version: '1.3'
- uses: julia-actions/julia-docdeploy@releases/v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
Expand Down
12 changes: 6 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
name = "JuliennedArrays"
uuid = "5cadff95-7770-533d-a838-a1bf817ee6e0"
authors = ["Brandon Taylor <[email protected]>"]
version = "0.2.2"
repo = "https://github.com/bramtayl/JuliennedArrays.jl.git"
version = "0.2.2"

[compat]
julia = "1"

[extras]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Documenter"]

[compat]
julia = "1"
test = ["Test"]
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
[deps]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
JuliennedArrays = "5cadff95-7770-533d-a838-a1bf817ee6e0"
2 changes: 1 addition & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using JuliennedArrays
using Documenter: deploydocs, makedocs

makedocs(sitename = "JuliennedArrays.jl", modules = [JuliennedArrays], doctest = false)
makedocs(sitename = "JuliennedArrays.jl", modules = [JuliennedArrays], doctest = true, strict=true)
deploydocs(repo = "github.com/bramtayl/JuliennedArrays.jl.git")
101 changes: 63 additions & 38 deletions src/JuliennedArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ module JuliennedArrays
import Base: axes, getindex, setindex!, size
using Base: @pure, tail

export Slices, Align
export True, False
Comment on lines +6 to +7
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I think these names are too common to be exported.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Who uses them?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Slice exists in Julia Base with a completely different meaning, I feel that's quite ambiguous to have Slices meaning the others. SliceView and AlignView are more accurate IMO.

True/False are too common names here and I think they should be either implementation details or preserved by Base. But apparently, people just invent their own without much effort:

https://github.com/SciML/Static.jl/blob/0f293e94fcbfbc812cdeaa796bd549b4ec2bc1ce/src/bool.jl#L2-L11

Copy link
Owner

@bramtayl bramtayl Aug 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd be fine with SlicesView and AlignView. I added the s because it's many slices. I think post JuliaLang/julia#40561 we might be able to just use true and false instead with no extra performance cost as long as they are constant, but someone would need to verify that.


@inline is_in(needle::Needle, straw1::Needle, straws...) where {Needle} = True()
@inline is_in(needle, straw1, straws...) = is_in(needle, straws...)
@inline is_in(needle) = False()
Expand Down Expand Up @@ -33,9 +36,6 @@ struct False <: TypedBool end
@inline not(::False) = True()
@inline not(::True) = False()

export True
export False

@inline getindex_unrolled(into::Tuple{}, switches::Tuple{}) = ()
@inline function getindex_unrolled(into, switches)
next = getindex_unrolled(tail(into), tail(switches))
Expand All @@ -54,9 +54,16 @@ end
first(old), setindex_unrolled(tail(old), new, tail(switches))...
end

###
# Slices
###
struct Slices{Item,Dimensions,Whole,Alongs} <: AbstractArray{Item,Dimensions}
whole::Whole
alongs::Alongs
function Slices{T,N,W,A}(whole::W, alongs::A) where {T,N,W,A}
# any(isequal(True()), alongs) || throw(DimensionMismatch("Expected to have at least one active slicing dimension."))
new{T,N,W,A}(whole, alongs)
end
end
@inline Slices{Item,Dimensions}(
whole::Whole,
Expand All @@ -69,17 +76,13 @@ end

@inline slice_index(slices, indices) =
setindex_unrolled(axes(slices.whole), indices, map(not, slices.alongs))
@inline getindex(slices::Slices, indices::Int...) =
@inline getindex(slices::Slices{T,N}, indices::Vararg{Int,N}) where {T,N} =
view(slices.whole, slice_index(slices, indices)...)
@inline setindex!(slices::Slices, value, indices::Int...) =
@inline setindex!(slices::Slices{T,N}, value, indices::Vararg{Int,N}) where {T,N} =
slices.whole[slice_index(slices, indices)...] = value

@inline axis_or_1(switch, axis) =
if untyped(switch)
axis
else
1
end
@inline axis_or_1(switch, axis) = untyped(switch) ? axis : 1
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tend not to like to use the ternary operator. It's inconvenient if you want to add more lines to the if statement

Copy link
Contributor Author

@johnnychen94 johnnychen94 Aug 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea here is to only use short-form function definitions when they can be fit into one line.

https://github.com/invenia/BlueStyle#method-definitions

I have to admit that the current codebase requires a little bit of extra effort to read even if I've been used to Julia for years. I didn't make style changes to every line of codes that I feel not so comfortable to read, because they're mostly unrelated to this PR. If you insist we can revert this one.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eh, I don't really care that much. Just the way I do things

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's up to you. IMHO the codebase should be clean and simplified so that others could jump in and make contributions without much mental efforts. If that codebase or code style is too opinioned then it's setting a high wall and blocks people.


"""
Slices(whole, alongs::TypedBool...)

Expand All @@ -93,7 +96,7 @@ julia> using JuliennedArrays
julia> whole = [1 2; 3 4];

julia> slices = Slices(whole, False(), True())
2-element Slices{SubArray{Int64,1,Array{Int64,2},Tuple{Int64,Base.OneTo{Int64}},true},1,Array{Int64,2},Tuple{False,True}}:
2-element Slices{SubArray{$Int, 1}, 1}:
[1, 2]
[3, 4]

Expand All @@ -103,7 +106,7 @@ true
julia> slices[1] = [2, 1];

julia> whole
2×2 Array{Int64,2}:
2×2 Matrix{$Int}:
2 1
3 4

Expand All @@ -115,13 +118,12 @@ julia> size(first(larger_slices))
(5,)
```
"""
Slices(whole::AbstractArray, alongs::TypedBool...) = Slices{
typeof(@inbounds view(whole, map(axis_or_1, alongs, axes(whole))...)),
length(getindex_unrolled(alongs, map(not, alongs))),
}(
whole,
alongs,
)
function Slices(whole::AbstractArray, alongs::TypedBool...)
# length(alongs) == ndims(whole) || throw(ArgumentError("$(length(alongs)) dimensions are specified, expected to be == $(ndims(whole))"))
x = @inbounds view(whole, map(axis_or_1, alongs, axes(whole))...)
N = length(getindex_unrolled(alongs, map(not, alongs)))
return Slices{typeof(x),N}(whole, alongs)
end

"""
Slices(whole, alongs::Int...)
Expand All @@ -132,7 +134,7 @@ Alternative syntax: `alongs` is which dimensions will be replaced with `:` when
julia> using JuliennedArrays

julia> input = reshape(1:8, 2, 2, 2)
2×2×2 reshape(::UnitRange{Int64}, 2, 2, 2) with eltype Int64:
2×2×2 reshape(::UnitRange{$Int}, 2, 2, 2) with eltype $Int:
[:, :, 1] =
1 3
2 4
Expand All @@ -142,26 +144,45 @@ julia> input = reshape(1:8, 2, 2, 2)
6 8

julia> s = Slices(input, 1, 3)
2-element Slices{SubArray{Int64,2,Base.ReshapedArray{Int64,3,UnitRange{Int64},Tuple{}},Tuple{Base.OneTo{Int64},Int64,Base.OneTo{Int64}},false},1,Base.ReshapedArray{Int64,3,UnitRange{Int64},Tuple{}},Tuple{True,False,True}}:
2-element Slices{SubArray{$Int, 2}, 1}:
[1 5; 2 6]
[3 7; 4 8]

julia> map(sum, s)
2-element Array{Int64,1}:
2-element Vector{$Int}:
14
22
```
"""
Slices(
whole::AbstractArray{Item,NumberOfDimensions},
alongs::Int...,
) where {Item,NumberOfDimensions} =
Slices(whole, in_unrolled(as_vals(alongs...), ntuple(Val, NumberOfDimensions)...)...)
export Slices
function Slices(whole::AbstractArray{T,N}, alongs::Int...) where {T,N}
# any(x->x>N, alongs) && throw(ArgumentError("All alongs values $(alongs) should be less than $(N)"))
Slices(whole, in_unrolled(as_vals(alongs...), ntuple(Val, N)...)...)
end

function Base.showarg(io::IO, ::Slices{T,N}, toplevel) where {T,N}
print(io, "Slices{", basetype(T), "{", eltype(T), ", ", ndims(T), "}, ", N, "}")
end

# This is expected to be added to Julia (maybe under a different name)
# Follow https://github.com/JuliaLang/julia/issues/35543 for progress
basetype(T::Type) = Base.typename(T).wrapper
basetype(T) = basetype(typeof(T))

###
# Align
###
struct Align{Item,Dimensions,Sliced,Alongs} <: AbstractArray{Item,Dimensions}
slices::Sliced
alongs::Alongs
function Align{T,N,S,A}(slices::S, alongs::A) where {T,N,S,A}
# TODO: run eager size check without introducing much overheads
# sz = @inbounds size(first(slices))
# all(x->sz==size(x), slices) || throw(ArgumentError("All sizes of slices should be the same."))
# length(alongs) == N || throw(DimensionMismatch("The total dimension $(N) is expected to be the sum of inner dimension $(length(sz)) and outer dimension $(length(alongs))"))
# inner_dimensions = mapreduce(isequal(True()), +, alongs)
# inner_dimensions == ndims(first(slices)) || throw(DimensionMismatch("Only $inner_dimensions inner dimensions are used, expected $(ndims(first(slices))) dimensions."))
new{T,N,S,A}(slices, alongs)
end
end
@inline Align{Item,Dimensions}(
slices::Sliced,
Expand All @@ -179,11 +200,11 @@ end
@inline split_indices(aligned, indices) =
getindex_unrolled(indices, map(not, aligned.alongs)),
getindex_unrolled(indices, aligned.alongs)
@inline function getindex(aligned::Align, indices::Int...)
@inline function getindex(aligned::Align{T,N}, indices::Vararg{Int,N}) where {T,N}
outer, inner = split_indices(aligned, indices)
aligned.slices[outer...][inner...]
end
@inline function setindex!(aligned::Align, value, indices::Int...)
@inline function setindex!(aligned::Align{T,N}, value, indices::Vararg{Int,N}) where {T,N}
outer, inner = split_indices(aligned, indices)
aligned.slices[outer...][inner...] = value
end
Expand All @@ -201,7 +222,7 @@ julia> using JuliennedArrays
julia> slices = [[1, 2], [3, 4]];

julia> aligned = Align(slices, False(), True())
2×2 Align{Int64,2,Array{Array{Int64,1},1},Tuple{False,True}}:
2×2 Align{$Int, 2} with eltype $Int:
1 2
3 4

Expand All @@ -211,7 +232,7 @@ true
julia> aligned[1, 1] = 0;

julia> slices
2-element Array{Array{Int64,1},1}:
2-element Vector{Vector{$Int}}:
[0, 2]
[3, 4]
```
Expand All @@ -221,7 +242,6 @@ julia> slices
alongs::TypedBool...,
) where {Item,InnerDimensions,OuterDimensions} =
Align{Item,OuterDimensions + InnerDimensions}(slices, alongs)
export Align

"""
Along(slices, alongs::Int...)
Expand All @@ -232,7 +252,7 @@ Alternative syntax: `alongs` is which dimensions will be taken up by the inner a
julia> using JuliennedArrays

julia> input = reshape(1:8, 2, 2, 2)
2×2×2 reshape(::UnitRange{Int64}, 2, 2, 2) with eltype Int64:
2×2×2 reshape(::UnitRange{$Int}, 2, 2, 2) with eltype $Int:
[:, :, 1] =
1 3
2 4
Expand All @@ -241,13 +261,13 @@ julia> input = reshape(1:8, 2, 2, 2)
5 7
6 8

julia> slices = collect(Slices(input, 1, 3))
2-element Array{SubArray{Int64,2,Base.ReshapedArray{Int64,3,UnitRange{Int64},Tuple{}},Tuple{Base.OneTo{Int64},Int64,Base.OneTo{Int64}},false},1}:
julia> slices = Slices(input, 1, 3)
2-element Slices{SubArray{$Int, 2}, 1}:
[1 5; 2 6]
[3 7; 4 8]

julia> Align(slices, 1, 3)
2×2×2 Align{Int64,3,Array{SubArray{Int64,2,Base.ReshapedArray{Int64,3,UnitRange{Int64},Tuple{}},Tuple{Base.OneTo{Int64},Int64,Base.OneTo{Int64}},false},1},Tuple{True,False,True}}:
2×2×2 Align{$Int, 3} with eltype $Int:
[:, :, 1] =
1 3
2 4
Expand All @@ -268,4 +288,9 @@ Align(
)...,
)

function Base.showarg(io::IO, ::Align{T,N}, toplevel) where {T,N}
print(io, "Align{", T, ", ", N, "}")
toplevel && print(io, " with eltype ", T)
end

end # module
59 changes: 56 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,57 @@
import JuliennedArrays
using Documenter: doctest
using JuliennedArrays
using Test

doctest(JuliennedArrays)
@testset "JuliennedArrays.jl" begin
@testset "Align" begin
Xs = [rand(2, 3) for _ in 1:4]
X = @inferred Align(Xs, True(), False(), True())
@test size(X) == (2, 4, 3)
@test permutedims(cat(Xs...; dims=3), (1, 3, 2)) == X
@test X[1] == Xs[1][1] # test linear indexing
@test X[1, 2, 3] == Xs[2][1, 3] # test cartesian indexing

Xs = reshape(Xs, 1, 4)
X = @inferred Align(Xs, True(), False(), True(), False())
@test size(X) == (2, 1, 3, 4)
@test permutedims(X, (1, 3, 2, 4))[:] == cat(Xs...; dims=3)[:]

# type is not inferrable for integer alongs
Xs = [rand(2, 3) for _ in 1:4]
RT = Base.return_types(Align, (typeof(Xs), Int, Int))[1]
@test !isconcretetype(RT)
@test Align(Xs, True(), False(), True()) == Align(Xs, 1, 3)

# @test_throws DimensionMismatch Align([rand(2, 3) for _ in 1:4], 1) # issue #25
# @test_throws MethodError Align(ones(2, 3, 4), 1, 2, 3)
end

@testset "Slice" begin
X = rand(2, 3, 4, 5)
Xs = @inferred Slices(X, True(), False(), False(), False())

Xs = Slices(X, 1)
@test size(Xs) == (3, 4, 5)
@test Xs[1, 1, 1] == X[:, 1, 1, 1]

Xs = Slices(X, 2)
@test size(Xs) == (2, 4, 5)
@test Xs[1, 1, 1] == X[1, :, 1, 1]

Xs = Slices(X, 1, 3)
@test size(Xs) == (3, 5)
@test Xs[1, 2] == X[:, 1, :, 2]
@test Align(Xs, 1, 3) == X # Slices is the inverse of Align
@test Xs[1] == X[:, 1, :, 1] # test linear indexing
@test Xs[1, 3] == X[:, 1, :, 3] # test cartesian indexing

# type is not inferrable for integer alongs
RT = Base.return_types(Slices, (typeof(X), Int, Int))[1]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I wrote it I took special care to make sure that this would be inferable if constant propagation kicks in

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless wrapped with Val, I don't think it's inferrable if the input is an integer whose value is not known at compile time.

This is the behavior on master branch:

julia> @inferred Slices(rand(2, 3), 2)
ERROR: return type Slices{SubArray{Float64, 1, Matrix{Float64}, Tuple{Int64, Base.OneTo{Int64}}, true}, 1, Matrix{Float64}, Tuple{False, True}} does not match inferred return type Slices{Item, Dimensions, Matrix{Float64}, Alongs} where {Item, Dimensions, Alongs}

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, but @inferred (x -> Slices(x, 2))(rand(2, 3)) or something like that should hopefully work

@test !isconcretetype(RT)
@test Slices(X, True(), False(), True(), False()) == Slices(X, 1, 3)

# X = rand(2, 3, 4, 5)
# @test_throws ArgumentError Slices(X, True())
# @test_throws ArgumentError Slices(X, True(), False(), False(), False(), False())
# @test_throws ArgumentError Slices(X, 5)
end
end