Skip to content

Commit

Permalink
Merge pull request #1685 from CliMA/ck/new_repro
Browse files Browse the repository at this point in the history
Update gpu prog edmf repro
  • Loading branch information
charleskawczynski authored Apr 25, 2024
2 parents 0c30ae5 + 22d5f56 commit 1c7d7b5
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 7 deletions.
1 change: 1 addition & 0 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,7 @@ steps:
- label: "Unit: bidiag matrix row example (GPU)"
key: gpu_compat_bidiag_matrix_row
command: "julia --color=yes --project=test test/MatrixFields/gpu_compat_bidiag_matrix_row.jl"
soft_fail: true
agents:
slurm_gpus: 1

Expand Down
70 changes: 63 additions & 7 deletions test/MatrixFields/gpu_compat_bidiag_matrix_row.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,19 @@ end
import .TestUtilities as TU

import ClimaCore: Spaces, Geometry, Operators, Fields, MatrixFields
using LinearAlgebra: Adjoint
import StaticArrays: SArray
import ClimaCore.Geometry: AxisTensor, CovariantAxis, ContravariantAxis
using ClimaCore.MatrixFields:
BandMatrixRow,
DiagonalMatrixRow,
BidiagonalMatrixRow,
TridiagonalMatrixRow,
MultiplyColumnwiseBandMatrixField,
const C3 = Geometry.Covariant3Vector
FT = Float64
const CT3 = Geometry.Contravariant3Vector
GFT = Float64
const ᶠgradᵥ = Operators.GradientC2F(
bottom = Operators.SetGradient(C3(0)),
top = Operators.SetGradient(C3(0)),
Expand All @@ -32,27 +38,77 @@ const ᶠgradᵥ_matrix = MatrixFields.operator_matrix(ᶠgradᵥ)
device = ClimaComms.device()
context = ClimaComms.context(device)
cspace =
TU.CenterExtrudedFiniteDifferenceSpace(FT; zelem = 25, helem = 10, context)
TU.CenterExtrudedFiniteDifferenceSpace(GFT; zelem = 25, helem = 10, context)
fspace = Spaces.FaceExtrudedFiniteDifferenceSpace(cspace)
@info "device = $device"

∂ᶠu₃ʲ_err_∂ᶠu₃ʲ_type = BandMatrixRow{
-1,
3,
AxisTensor{
GFT,
2,
Tuple{CovariantAxis{(3,)}, ContravariantAxis{(3,)}},
SArray{Tuple{1, 1}, GFT, 2, 1},
},
}

f = (;
ᶠtridiagonal_matrix_c3 = Fields.Field(TridiagonalMatrixRow{C3{FT}}, fspace),
∂ᶠu₃ʲ_err_∂ᶠu₃ʲ = Fields.Field(∂ᶠu₃ʲ_err_∂ᶠu₃ʲ_type, fspace),
ᶠtridiagonal_matrix_c3 = Fields.Field(
TridiagonalMatrixRow{C3{GFT}},
fspace,
),
ᶠu₃ = Fields.Field(C3{GFT}, fspace),
adj_u₃ = Fields.Field(DiagonalMatrixRow{Adjoint{GFT, CT3{GFT}}}, fspace),
)
c = (;
ᶜu₃ʲ = Fields.Field(C3{GFT}, cspace),
bdmr_l = Fields.Field(BidiagonalMatrixRow{GFT}, cspace),
bdmr_r = Fields.Field(BidiagonalMatrixRow{GFT}, cspace),
bdmr = Fields.Field(BidiagonalMatrixRow{GFT}, cspace),
)

const ᶜleft_bias = Operators.LeftBiasedF2C()
const ᶜright_bias = Operators.RightBiasedF2C()
const ᶜleft_bias_matrix = MatrixFields.operator_matrix(ᶜleft_bias)
const ᶜright_bias_matrix = MatrixFields.operator_matrix(ᶜright_bias)

one_C3xACT3(::Type{_FT}) where {_FT} = C3(_FT(1)) * CT3(_FT(1))'
get_I_u₃(::Type{_FT}) where {_FT} = DiagonalMatrixRow(one_C3xACT3(_FT))

conv(::Type{_FT}, ᶜbias_matrix) where {_FT} =
convert(BidiagonalMatrixRow{_FT}, ᶜbias_matrix)
function foo(f)
(; ᶠtridiagonal_matrix_c3) = f
function foo(c, f)
(; ᶠtridiagonal_matrix_c3, ᶠu₃, ∂ᶠu₃ʲ_err_∂ᶠu₃ʲ, adj_u₃) = f
(; ᶜu₃ʲ, bdmr_l, bdmr_r, bdmr) = c
space = axes(ᶠtridiagonal_matrix_c3)
FT = Spaces.undertype(space)
@. ᶠtridiagonal_matrix_c3 = ᶠgradᵥ_matrix() conv(FT, ᶜleft_bias_matrix())
I_u₃ = get_I_u₃(FT)
dtγ = FT(1)

@. ∂ᶠu₃ʲ_err_∂ᶠu₃ʲ =
dtγ * ᶠtridiagonal_matrix_c3 DiagonalMatrixRow(adjoint(CT3(ᶠu₃))) -
(I_u₃,)

@. ∂ᶠu₃ʲ_err_∂ᶠu₃ʲ = dtγ * ᶠtridiagonal_matrix_c3 adj_u₃ - (I_u₃,)

# Fails on gpu
@. ᶠtridiagonal_matrix_c3 =
-(ᶠgradᵥ_matrix()) ifelse(
ᶜu₃ʲ.components.data.:1 > 0,
convert(BidiagonalMatrixRow{FT}, ᶜleft_bias_matrix()),
convert(BidiagonalMatrixRow{FT}, ᶜright_bias_matrix()),
)

# However, this can be decomposed into simpler broadcast
# expressions that will run on gpus:
@. bdmr_l = convert(BidiagonalMatrixRow{FT}, ᶜleft_bias_matrix())
@. bdmr_r = convert(BidiagonalMatrixRow{FT}, ᶜright_bias_matrix())
@. bdmr = ifelse(ᶜu₃ʲ.components.data.:1 > 0, bdmr_l, bdmr_r)
@. ᶠtridiagonal_matrix_c3 = -(ᶠgradᵥ_matrix()) bdmr

return nothing
end

foo(f)
foo(c, f)

0 comments on commit 1c7d7b5

Please sign in to comment.