Skip to content

Commit

Permalink
CPU <-> GPU covariant derivatives are now identical (towards tensor ops)
Browse files Browse the repository at this point in the history
	modified:   ext/cuda/operators_sem_shmem.jl
	modified:   ext/cuda/operators_spectral_element.jl

	modified:   ext/cuda/operators_integral.jl
	modified:   src/Operators/integrals.jl
	modified:   test/Operators/integrals.jl

Updated integral ops

wip fix columnmapreduce for gpu

	modified:   ext/cuda/operators_sem_shmem.jl
	modified:   ext/cuda/operators_spectral_element.jl
  • Loading branch information
Akshay Sridhar committed Jun 17, 2024
1 parent ba3f8d4 commit ea77cd9
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 20 deletions.
16 changes: 8 additions & 8 deletions ext/cuda/operators_integral.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ using CUDA: @cuda

function column_integral_definite!(
::ClimaComms.CUDADevice,
∫field::Fields.Field,
ᶜfield::Fields.Field,
∫field::Fields.SpectralElementField,
ᶜfield::Fields.ExtrudedFiniteDifferenceField,
)
space = axes(∫field)
Ni, Nj, _, _, Nh = size(Fields.field_values(∫field))
Expand Down Expand Up @@ -62,8 +62,8 @@ end

function column_integral_indefinite!(
::ClimaComms.CUDADevice,
ᶠ∫field::Fields.Field,
ᶜfield::Fields.Field,
ᶠ∫field::Fields.FaceExtrudedFiniteDifferenceField,
ᶜfield::Fields.CenterExtrudedFiniteDifferenceField,
)
Ni, Nj, _, _, Nh = size(Fields.field_values(ᶠ∫field))
nthreads, nblocks = _configure_threadblock(Ni * Nj * Nh)
Expand All @@ -81,15 +81,15 @@ function column_mapreduce_device!(
::ClimaComms.CUDADevice,
fn::F,
op::O,
reduced_field::Fields.Field,
fields::Fields.Field...,
reduced_field::Fields.SpectralElementField2D,
fields...,
) where {F, O}
Ni, Nj, _, _, Nh = size(Fields.field_values(reduced_field))
nthreads, nblocks = _configure_threadblock(Ni * Nj * Nh)
kernel! = if first(fields) isa Fields.ExtrudedFiniteDifferenceField
column_mapreduce_kernel_extruded!
else
column_mapreduce_kernel!
# column_mapreduce_kernel!
end
args = (
fn,
Expand All @@ -111,7 +111,7 @@ end
function column_mapreduce_kernel_extruded!(
fn::F,
op::O,
reduced_field,
reduced_field,
fields...,
) where {F, O}
idx = threadIdx().x + (blockIdx().x - 1) * blockDim().x
Expand Down
9 changes: 8 additions & 1 deletion ext/cuda/operators_sem_shmem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,15 +130,22 @@ Base.@propagate_inbounds function operator_fill_shmem!(
if RT <: Geometry.Covariant12Vector
(v,) = input
v[i, j, vt] = arg
elseif typeof(arg) <: Geometry.UVVector
elseif typeof(arg) <: Geometry.UVVector
# TODO classify based on returntype
v₁, v₂ = input
v₁[i, j, vt] = Geometry.LocalVector(arg, local_geometry).u
v₂[i, j, vt] = Geometry.LocalVector(arg, local_geometry).v
#v₁[i, j, vt] = Geometry.contravariant1(arg, local_geometry)
#v₂[i, j, vt] = Geometry.contravariant2(arg, local_geometry)
elseif typeof(arg) <: Geometry.UVWVector
# TODO classify based on returntype
v₁, v₂, v₃ = input
v₁[i, j, vt] = Geometry.LocalVector(arg, local_geometry).u
v₂[i, j, vt] = Geometry.LocalVector(arg, local_geometry).v
v₃[i, j, vt] = Geometry.LocalVector(arg, local_geometry).w
#v₁[i, j, vt] = Geometry.contravariant1(arg, local_geometry)
#v₂[i, j, vt] = Geometry.contravariant2(arg, local_geometry)
#v₃[i, j, vt] = Geometry.contravariant3(arg, local_geometry)
end
end

Expand Down
6 changes: 6 additions & 0 deletions ext/cuda/operators_spectral_element.jl
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,9 @@ Base.@propagate_inbounds function operator_evaluate(
return Geometry.AxisTensor((Geometry.Covariant12Axis(), Geometry.UVAxis()),
(∂f₁∂ξ₁, ∂f₂∂ξ₁,
∂f₁∂ξ₂, ∂f₂∂ξ₂))
#return Geometry.AxisTensor((Geometry.Covariant12Axis(), Geometry.Contravariant12Axis()),
# (∂f₁∂ξ₁, ∂f₂∂ξ₁,
# ∂f₁∂ξ₂, ∂f₂∂ξ₂))
else
v₁, v₂, v₃ = input
∂f₁∂ξ₁ = D[i, 1] v₁[1, j, vt]
Expand All @@ -312,6 +315,9 @@ Base.@propagate_inbounds function operator_evaluate(
return Geometry.AxisTensor((Geometry.Covariant12Axis(), Geometry.UVWAxis()),
(∂f₁∂ξ₁, ∂f₂∂ξ₁, ∂f₃∂ξ₁,
∂f₁∂ξ₂, ∂f₂∂ξ₂, ∂f₃∂ξ₂))
#return Geometry.AxisTensor((Geometry.Covariant12Axis(), Geometry.Contravariant123Axis()),
# (∂f₁∂ξ₁, ∂f₂∂ξ₁, ∂f₃∂ξ₁,
# ∂f₁∂ξ₂, ∂f₂∂ξ₂, ∂f₃∂ξ₂))
end
end

Expand Down
22 changes: 11 additions & 11 deletions src/Operators/integrals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ function column_integral_definite_kernel!(
end

column_integral_definite!(
::ClimaComms.AbstractCPUDevice,
::ClimaComms.AbstractDevice,
∫field::Fields.SpectralElementField,
ᶜfield::Fields.ExtrudedFiniteDifferenceField,
) =
Expand All @@ -33,7 +33,7 @@ column_integral_definite!(
end

column_integral_definite!(
::ClimaComms.AbstractCPUDevice,
::ClimaComms.AbstractDevice,
∫field::Fields.PointField,
ᶜfield::Fields.FiniteDifferenceField,
) = _column_integral_definite!(∫field, ᶜfield)
Expand Down Expand Up @@ -84,7 +84,7 @@ column_integral_indefinite_kernel!(
) = _column_integral_indefinite!(ᶠ∫field, ᶜfield)

column_integral_indefinite!(
::ClimaComms.AbstractCPUDevice,
::ClimaComms.AbstractDevice,
ᶠ∫field::Fields.FaceExtrudedFiniteDifferenceField,
ᶜfield::Fields.CenterExtrudedFiniteDifferenceField,
) =
Expand All @@ -93,7 +93,7 @@ column_integral_indefinite!(
end

column_integral_indefinite!(
::ClimaComms.AbstractCPUDevice,
::ClimaComms.AbstractDevice,
ᶠ∫field::Fields.FaceFiniteDifferenceField,
ᶜfield::Fields.CenterFiniteDifferenceField,
) = _column_integral_indefinite!(ᶠ∫field, ᶜfield)
Expand Down Expand Up @@ -145,14 +145,14 @@ column_integral_indefinite!(

column_integral_indefinite!(
f::Function,
::ClimaComms.AbstractCPUDevice,
::ClimaComms.AbstractDevice,
ᶠ∫field,
args...,
) = column_integral_indefinite_cpu!(f, ᶠ∫field, args...)

column_integral_indefinite!(
f::Function,
::ClimaComms.AbstractCPUDevice,
::ClimaComms.AbstractDevice,
ᶠ∫field::Fields.FaceExtrudedFiniteDifferenceField,
args...,
) =
Expand Down Expand Up @@ -215,7 +215,7 @@ column_mapreduce!(
fn::F,
op::O,
reduced_field::Fields.Field,
fields::Fields.Field...,
fields...,
) where {F, O} = column_mapreduce_device!(
ClimaComms.device(reduced_field),
fn,
Expand All @@ -228,23 +228,23 @@ column_mapreduce_kernel!(fn::F, op::O, reduced_field, fields...) where {F, O} =
_column_mapreduce!(fn, op, reduced_field, fields...)

column_mapreduce_device!(
::ClimaComms.AbstractCPUDevice,
::ClimaComms.AbstractDevice,
fn::F,
op::O,
reduced_field::Fields.SpectralElementField,
fields::Fields.ExtrudedFiniteDifferenceField...,
fields...,
) where {F, O} =
Fields.bycolumn(axes(reduced_field)) do colidx
column_fields = map(field -> field[colidx], fields)
_column_mapreduce!(fn, op, reduced_field[colidx], column_fields...)
end

column_mapreduce_device!(
::ClimaComms.AbstractCPUDevice,
::ClimaComms.AbstractDevice,
fn::F,
op::O,
reduced_field::Fields.PointField,
fields::Fields.FiniteDifferenceField...,
fields...,
) where {F, O} = _column_mapreduce!(fn, op, reduced_field, fields...)

function _column_mapreduce!(fn::F, op::O, reduced_field, fields...) where {F, O}
Expand Down

0 comments on commit ea77cd9

Please sign in to comment.