Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Jun 9, 2024
1 parent 7a95e64 commit 5187e13
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 22 deletions.
4 changes: 2 additions & 2 deletions ext/cuda/fill.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
cartesian_index(::AbstractData, inds) = CartesianIndex(inds)

function knl_fill_flat!(dest::AbstractData, val)
n = size(dest)
n = DataLayouts.universal_size(dest)
inds = kernel_indexes(n)
if valid_range(inds, n)
I = cartesian_index(dest, inds)
Expand All @@ -11,7 +11,7 @@ function knl_fill_flat!(dest::AbstractData, val)
end

function cuda_fill!(dest::AbstractData, val)
(Nij, Nij, Nf, Nv, Nh) = size(dest)
(_, _, _, Nv, Nh) = size(dest)
if Nv > 0 && Nh > 0
auto_launch!(knl_fill_flat!, (dest, val), dest; auto = true)
end
Expand Down
36 changes: 16 additions & 20 deletions src/DataLayouts/DataLayouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ abstract type AbstractData{S} end

Base.size(data::AbstractData, i::Integer) = size(data)[i]

function universal_size(data::AbstractData)
s = size(data)
return (s[1], s[2], ncomponents(data), s[4], s[5])
end

function Base.show(io::IO, data::AbstractData)
indent_width = 2
(rows, cols) = displaysize(io)
Expand Down Expand Up @@ -270,8 +275,7 @@ Base.copy(data::IJFH{S, Nij}) where {S, Nij} = IJFH{S, Nij}(copy(parent(data)))
function Base.size(data::IJFH{S, Nij}) where {S, Nij}
Nv = 1
Nh = size(parent(data), 4)
Nf = ncomponents(data)
(Nij, Nij, Nf, Nv, Nh)
(Nij, Nij, 1, Nv, Nh)
end

function Base.fill!(data::IJFH, val)
Expand Down Expand Up @@ -445,8 +449,7 @@ Base.copy(data::IFH{S, Ni}) where {S, Ni} = IFH{S, Ni}(copy(parent(data)))
function Base.size(data::IFH{S, Ni}) where {S, Ni}
Nv = 1
Nh = size(parent(data), 3)
Nf = ncomponents(data)
(Ni, 1, Nf, Nv, Nh)
(Ni, 1, 1, Nv, Nh)
end

function Base.fill!(data::IFH, val)
Expand Down Expand Up @@ -641,8 +644,7 @@ end
@inbounds slab[I[1], I[2]] = val
end

Base.size(data::DataSlab2D{S, Nij}) where {S, Nij} =
(Nij, Nij, ncomponents(data), 1, 1)
Base.size(data::DataSlab2D{S, Nij}) where {S, Nij} = (Nij, Nij, 1, 1, 1)
Base.axes(data::DataSlab2D{S, Nij}) where {S, Nij} = (SOneTo(Nij), SOneTo(Nij))

@inline function slab(data::DataSlab2D, h)
Expand Down Expand Up @@ -694,8 +696,7 @@ function replace_basetype(data::IJF{S, Nij}, ::Type{T}) where {S, Nij, T}
end

function Base.size(data::IJF{S, Nij}) where {S, Nij}
Nf = ncomponents(data)
return (Nij, Nij, Nf, 1, 1)
return (Nij, Nij, 1, 1, 1)
end
function Base.fill!(data::IJF{S, Nij}, val) where {S, Nij}
@inbounds for j in 1:Nij, i in 1:Nij
Expand Down Expand Up @@ -781,8 +782,7 @@ end
end

function Base.size(data::DataSlab1D{<:Any, Ni}) where {Ni}
Nf = ncomponents(data)
return (Ni, 1, Nf, 1, 1)
return (Ni, 1, 1, 1, 1)
end
Base.axes(::DataSlab1D{S, Ni}) where {S, Ni} = (SOneTo(Ni),)
Base.lastindex(::DataSlab1D{S, Ni}) where {S, Ni} = Ni
Expand Down Expand Up @@ -894,7 +894,7 @@ end
# ======================

Base.length(data::DataColumn) = size(parent(data), 1)
Base.size(data::DataColumn) = (1, 1, ncomponents(data), length(data), 1)
Base.size(data::DataColumn) = (1, 1, 1, length(data), 1)

"""
VF{S, A} <: DataColumn{S, Nv}
Expand Down Expand Up @@ -939,7 +939,7 @@ end

Base.copy(data::VF{S, Nv}) where {S, Nv} = VF{S, Nv}(copy(parent(data)))
Base.lastindex(data::VF) = length(data)
Base.size(data::VF{S, Nv}) where {S, Nv} = (1, 1, ncomponents(data), Nv, 1)
Base.size(data::VF{S, Nv}) where {S, Nv} = (1, 1, 1, Nv, 1)

nlevels(::VF{S, Nv}) where {S, Nv} = Nv

Expand Down Expand Up @@ -1059,8 +1059,7 @@ end

function Base.size(data::VIJFH{<:Any, Nv, Nij}) where {Nv, Nij}
Nh = size(parent(data), 5)
Nf = ncomponents(data)
return (Nij, Nij, Nf, Nv, Nh)
return (Nij, Nij, 1, Nv, Nh)
end

function Base.length(data::VIJFH)
Expand Down Expand Up @@ -1226,8 +1225,7 @@ Base.copy(data::VIFH{S, Nv, Ni}) where {S, Nv, Ni} =

function Base.size(data::VIFH{<:Any, Nv, Ni}) where {Nv, Ni}
Nh = size(parent(data), 4)
Nf = ncomponents(data)
return (Ni, 1, Nf, Nv, Nh)
return (Ni, 1, 1, Nv, Nh)
end

function Base.length(data::VIFH)
Expand Down Expand Up @@ -1376,8 +1374,7 @@ Base.copy(data::IH1JH2{S, Nij}) where {S, Nij} =
function Base.size(data::IH1JH2{S, Nij}) where {S, Nij}
Nv = 1
Nh = div(length(parent(data)), Nij * Nij)
Nv = ncomponents(data)
(Nij, Nij, Nf, Nv, Nh)
(Nij, Nij, 1, Nv, Nh)
end

Base.length(data::IH1JH2{S, Nij}) where {S, Nij} =
Expand Down Expand Up @@ -1422,8 +1419,7 @@ Base.copy(data::IV1JH2{S, n1, Ni}) where {S, n1, Ni} =

function Base.size(data::IV1JH2{S, n1, Ni}) where {S, n1, Ni}
Nh = div(size(parent(data), 2), Ni)
Nf = ncomponents(data)
(Ni, 1, Nf, n1, Nh)
(Ni, 1, 1, n1, Nh)
end

Base.length(data::IV1JH2{S, n1, Ni}) where {S, n1, Ni} =
Expand Down

0 comments on commit 5187e13

Please sign in to comment.