-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for masked column stencils / operators #298
Comments
Is the masking something like (with some boundary conditions)?
how is it well defined on the edges of the mask? are the masked ops only valid for interior elements? |
I think it's something like that. The thing is that this sort of operation is certainly well defined in the collocated case, where one-sided difference is used when we approach the mask boundaries (assuming masked regions have enough points to use a FD stencil). I think the problem with what your example is that Hmm, I thought that my example below gets around this is by interpolating the field to faces, first, but now that I'm thinking about it, I'm not sure. I just about finished prototyping something that I think will work with our existing operators, so I think we can close this. Going to discuss with @ilopezgp a bit on this. Here's the prototype
I think we can assume that we know what the mast is on cell faces and cell centers. # ------------------------------------------------------------------ boiler plate
import ClimaCore as CC
import ClimaCore.Operators as CCO
struct Grid{FT, CS, FS, SC, SF}
zmin::FT
zmax::FT
Δz::FT
Δzi::FT
nz::Int
cs::CS
fs::FS
zc::SC
zf::SF
function Grid(Δz::FT, nz::Int) where {FT <: AbstractFloat}
z₀, z₁ = FT(0), FT(nz * Δz)
domain = CC.Domains.IntervalDomain(
CC.Geometry.ZPoint{FT}(z₀),
CC.Geometry.ZPoint{FT}(z₁),
boundary_tags = (:bottom, :top),
)
mesh = CC.Meshes.IntervalMesh(domain, nelems = nz)
cs = CC.Spaces.CenterFiniteDifferenceSpace(mesh)
fs = CC.Spaces.FaceFiniteDifferenceSpace(cs)
zc = CC.Fields.coordinate_field(cs)
zf = CC.Fields.coordinate_field(fs)
#Set the inverse grid spacing
Δzi = 1 / Δz
zmin = minimum(parent(zf))
zmax = maximum(parent(zf))
CS = typeof(cs)
FS = typeof(fs)
SC = typeof(zc)
SF = typeof(zf)
return new{FT, CS, FS, SC, SF}(zmin, zmax, Δz, Δzi, nz, cs, fs, zc, zf)
end
end
struct Cent{I <: Integer}
i::I
end
kc_surface(grid::Grid) = Cent(1)
kf_surface(grid::Grid) = CCO.PlusHalf(1)
kc_top_of_atmos(grid::Grid) = Cent(grid.nz)
kf_top_of_atmos(grid::Grid) = CCO.PlusHalf(grid.nz + 1)
real_center_indices(grid::Grid) = Cent.((kc_surface(grid).i):(kc_top_of_atmos(grid).i))
real_face_indices(grid::Grid) = CCO.PlusHalf.((kf_surface(grid).i):(kf_top_of_atmos(grid).i))
Base.isless(zp1::CC.Geometry.ZPoint, zp2::CC.Geometry.ZPoint) = Base.isless(zp1.z, zp2.z)
Base.isless(zp::CC.Geometry.ZPoint, val::Number) = Base.isless(zp.z, val)
Base.isless(val::Number, zp::CC.Geometry.ZPoint) = Base.isless(val, zp.z)
Base.:+(zp::CC.Geometry.ZPoint, val) = Base.:+(zp.z, val)
Base.:+(val, zp::CC.Geometry.ZPoint) = Base.:+(val, zp.z)
Base.:/(val, zp::CC.Geometry.ZPoint) = Base.:/(val, zp.z)
Base.:+(zp1::CC.Geometry.ZPoint, zp2::CC.Geometry.ZPoint) = Base.:+(zp1.z, zp2.z)
Base.:*(zp1::CC.Geometry.ZPoint, zp2::CC.Geometry.ZPoint) = Base.:*(zp1.z, zp2.z)
Base.log(zp::CC.Geometry.ZPoint) = Base.log(zp.z)
Base.:-(zp::CC.Geometry.ZPoint) = Base.:-(zp.z)
Base.:-(zp::CC.Geometry.ZPoint, val) = Base.:-(zp.z, val)
Base.:-(val, zp::CC.Geometry.ZPoint) = Base.:-(val, zp.z)
Base.:-(zp1::CC.Geometry.ZPoint, zp2::CC.Geometry.ZPoint) = Base.:-(zp1.z, zp2.z)
Base.convert(::Type{Float64}, zp::CC.Geometry.ZPoint) = zp.z
Base.:+(h::Cent) = h
Base.:-(h::Cent) = Cent(-h.i - one(h.i))
Base.:+(i::Integer, h::Cent) = Cent(i + h.i)
Base.:+(h::Cent, i::Integer) = Cent(h.i + i)
Base.:+(h1::Cent, h2::Cent) = h1.i + h2.i + one(h1.i)
Base.:-(i::Integer, h::Cent) = Cent(i - h.i - one(h.i))
Base.:-(h::Cent, i::Integer) = Cent(h.i - i)
Base.:-(h1::Cent, h2::Cent) = h1.i - h2.i
Base.:<=(h1::Cent, h2::Cent) = h1.i <= h2.i
Base.:<(h1::Cent, h2::Cent) = h1.i < h2.i
Base.max(h1::Cent, h2::Cent) = Cent(max(h1.i, h2.i))
Base.min(h1::Cent, h2::Cent) = Cent(min(h1.i, h2.i))
Base.@propagate_inbounds Base.getindex(field::CC.Fields.FiniteDifferenceField, i::Integer) = Base.getproperty(field, i)
Base.@propagate_inbounds Base.getindex(field::CC.Fields.CenterFiniteDifferenceField, i::Cent) =
Base.getindex(CC.Fields.field_values(field), i.i)
Base.@propagate_inbounds Base.setindex!(field::CC.Fields.CenterFiniteDifferenceField, v, i::Cent) =
Base.setindex!(CC.Fields.field_values(field), v, i.i)
Base.@propagate_inbounds Base.getindex(field::CC.Fields.FaceFiniteDifferenceField, i::CCO.PlusHalf) =
Base.getindex(CC.Fields.field_values(field), i.i)
Base.@propagate_inbounds Base.setindex!(field::CC.Fields.FaceFiniteDifferenceField, v, i::CCO.PlusHalf) =
Base.setindex!(CC.Fields.field_values(field), v, i.i)
Base.@propagate_inbounds Base.getindex(field::CC.Fields.FaceFiniteDifferenceField, ::Cent) =
error("Attempting to getindex with a center index (Cent) into a Face field")
Base.@propagate_inbounds Base.getindex(field::CC.Fields.CenterFiniteDifferenceField, ::CCO.PlusHalf) =
error("Attempting to getindex with a face index (PlusHalf) into a Center field")
Base.@propagate_inbounds Base.setindex!(field::CC.Fields.FaceFiniteDifferenceField, v, ::Cent) =
error("Attempting to setindex with a center index (Cent) into a Face field")
Base.@propagate_inbounds Base.setindex!(field::CC.Fields.CenterFiniteDifferenceField, v, ::CCO.PlusHalf) =
error("Attempting to setindex with a face index (PlusHalf) into a Center field")
# TODO: deprecate, we should not overload getindex/setindex for ordinary arrays.
Base.@propagate_inbounds Base.getindex(arr::AbstractArray, i::Cent) = Base.getindex(arr, i.i)
Base.@propagate_inbounds Base.setindex!(arr::AbstractArray, v, i::Cent) = Base.setindex!(arr, v, i.i)
Base.@propagate_inbounds Base.getindex(arr::AbstractArray, i::CCO.PlusHalf) = Base.getindex(arr, i.i)
Base.@propagate_inbounds Base.setindex!(arr::AbstractArray, v, i::CCO.PlusHalf) = Base.setindex!(arr, v, i.i)
Base.@propagate_inbounds Base.getindex(arr::AbstractArray, i::Int, j::Cent) = Base.getindex(arr, i, j.i)
Base.@propagate_inbounds Base.setindex!(arr::AbstractArray, v, i::Int, j::Cent) = Base.setindex!(arr, v, i, j.i)
Base.@propagate_inbounds Base.getindex(arr::AbstractArray, i::Int, j::CCO.PlusHalf) = Base.getindex(arr, i, j.i)
Base.@propagate_inbounds Base.setindex!(arr::AbstractArray, v, i::Int, j::CCO.PlusHalf) = Base.setindex!(arr, v, i, j.i)
function FieldFromNamedTuple(space, nt::NamedTuple)
cmv(z) = nt
return cmv.(CC.Fields.coordinate_field(space))
end
# ------------------------------------------------------------------ END boiler plate
const FT = Float64;
N = 100;
grid = Grid(π/N, N);
face = FieldFromNamedTuple(grid.fs, (;θ = FT(0), m = FT(0), ∇θ = FT(0), ∇θ_default = FT(0)));
cent = FieldFromNamedTuple(grid.cs, (;θ = FT(0), m = FT(0), ∇θ = FT(0), ∇θ_default = FT(0)));
for k in real_center_indices(grid)
if π/5 ≤ grid.zc[k] ≤ π/2
cent.m[k] = 1
else
cent.m[k] = 0
end
end
for k in real_face_indices(grid)
if π/5 ≤ grid.zf[k] ≤ π/2
face.m[k] = 1
else
face.m[k] = 0
end
end
f_bcs = (; bottom = CCO.Extrapolate(), top = CCO.Extrapolate())
If = CCO.InterpolateC2F(;f_bcs...)
wvec = CC.Geometry.WVector
m∇ = CCO.DivergenceF2C()
zc = grid.zc
parent(cent.∇θ_default) .= 0
parent(face.∇θ_default) .= 0
@. cent.θ = sin(getproperty(zc, :z))
# interp to face (with BCs, required regardless)
@. face.θ = If(cent.θ)
@. cent.∇θ = m∇(wvec(face.θ)) * cent.m + (1 - cent.m) * cent.∇θ_default
import Plots
Plots.plot(vec(grid.zc), vec(cent.θ); label = "θ")
Plots.plot!(vec(grid.zc), vec(cent.m); label = "mask")
Plots.plot!(vec(grid.zc), vec(cent.∇θ); label = "∇θ") |
I spoke with @ilopezgp. My example doesn't address the problem. The root problem is because of the interpolation, which is not valid in non-existent subdomains. So, we basically need a |
This is a bit of an insane example, but this "just works" all by reusing existing pieces. That said, it's highly allocating and I'm sure that doing this properly would be more efficient. import ClimaCore as CC
import ClimaCore.Operators as CCO
struct SubMasks{N,T,M,R<:NTuple{N,T}}
mask::M
ranges::R
end
function SubMasks(mask)
ranges = UnitRange{Int64}[]
iter = 0
mask_offset = 0
mask_len = length(mask)
while true
iter > mask_len && break # safety net
mask_start = findfirst(i -> mask[i]==1 && i>mask_offset, 1:mask_len)
isnothing(mask_start) && break
mask_stop = findfirst( i -> mask[i]==0 && i>mask_start, 1:mask_len)
if isnothing(mask_stop)
push!(ranges, mask_start:mask_len)
break
else
push!(ranges, mask_start:(mask_stop-1))
end
mask_offset = mask_stop
iter+=1
end
ranges = Tuple(ranges)
M = typeof(mask)
N = length(ranges)
T = UnitRange{Int64}
R = typeof(ranges)
return SubMasks{N,T,M,R}(mask, ranges)
end
Base.eltype(::Type{SDM}) where {N,T,SDM<:SubMasks{N,T}} = T
Base.length(sm::SubMasks{N}) where {N} = N
Base.iterate(sm::SubMasks{N}, state = 1) where {N} =
state > N ? nothing : (sm.ranges[state], state + 1)
# mask = Bool[0, 0, 0, 1, 1, 1, 0, 0, 1, 1];
# for sm in SubMasks(mask)
# @show sm
# end
struct Grid{FT, CS, FS, SC, SF}
zmin::FT
zmax::FT
Δz::FT
Δzi::FT
nz::Int
cs::CS
fs::FS
zc::SC
zf::SF
function Grid(Δz::FT, nz::Int) where {FT <: AbstractFloat}
z₀, z₁ = FT(0), FT(nz * Δz)
domain = CC.Domains.IntervalDomain(
CC.Geometry.ZPoint{FT}(z₀),
CC.Geometry.ZPoint{FT}(z₁),
boundary_tags = (:bottom, :top),
)
mesh = CC.Meshes.IntervalMesh(domain, nelems = nz)
cs = CC.Spaces.CenterFiniteDifferenceSpace(mesh)
fs = CC.Spaces.FaceFiniteDifferenceSpace(cs)
zc = CC.Fields.coordinate_field(cs)
zf = CC.Fields.coordinate_field(fs)
#Set the inverse grid spacing
Δzi = 1 / Δz
zmin = minimum(parent(zf))
zmax = maximum(parent(zf))
CS = typeof(cs)
FS = typeof(fs)
SC = typeof(zc)
SF = typeof(zf)
return new{FT, CS, FS, SC, SF}(zmin, zmax, Δz, Δzi, nz, cs, fs, zc, zf)
end
end
struct Cent{I <: Integer}
i::I
end
kc_surface(grid::Grid) = Cent(1)
kf_surface(grid::Grid) = CCO.PlusHalf(1)
kc_top_of_atmos(grid::Grid) = Cent(grid.nz)
kf_top_of_atmos(grid::Grid) = CCO.PlusHalf(grid.nz + 1)
real_center_indices(grid::Grid) = Cent.((kc_surface(grid).i):(kc_top_of_atmos(grid).i))
real_face_indices(grid::Grid) = CCO.PlusHalf.((kf_surface(grid).i):(kf_top_of_atmos(grid).i))
Base.isless(zp1::CC.Geometry.ZPoint, zp2::CC.Geometry.ZPoint) = Base.isless(zp1.z, zp2.z)
Base.isless(zp::CC.Geometry.ZPoint, val::Number) = Base.isless(zp.z, val)
Base.isless(val::Number, zp::CC.Geometry.ZPoint) = Base.isless(val, zp.z)
Base.:+(zp::CC.Geometry.ZPoint, val) = Base.:+(zp.z, val)
Base.:+(val, zp::CC.Geometry.ZPoint) = Base.:+(val, zp.z)
Base.:/(val, zp::CC.Geometry.ZPoint) = Base.:/(val, zp.z)
Base.:+(zp1::CC.Geometry.ZPoint, zp2::CC.Geometry.ZPoint) = Base.:+(zp1.z, zp2.z)
Base.:*(zp1::CC.Geometry.ZPoint, zp2::CC.Geometry.ZPoint) = Base.:*(zp1.z, zp2.z)
Base.log(zp::CC.Geometry.ZPoint) = Base.log(zp.z)
Base.:-(zp::CC.Geometry.ZPoint) = Base.:-(zp.z)
Base.:-(zp::CC.Geometry.ZPoint, val) = Base.:-(zp.z, val)
Base.:-(val, zp::CC.Geometry.ZPoint) = Base.:-(val, zp.z)
Base.:-(zp1::CC.Geometry.ZPoint, zp2::CC.Geometry.ZPoint) = Base.:-(zp1.z, zp2.z)
Base.convert(::Type{Float64}, zp::CC.Geometry.ZPoint) = zp.z
Base.:+(h::Cent) = h
Base.:-(h::Cent) = Cent(-h.i - one(h.i))
Base.:+(i::Integer, h::Cent) = Cent(i + h.i)
Base.:+(h::Cent, i::Integer) = Cent(h.i + i)
Base.:+(h1::Cent, h2::Cent) = h1.i + h2.i + one(h1.i)
Base.:-(i::Integer, h::Cent) = Cent(i - h.i - one(h.i))
Base.:-(h::Cent, i::Integer) = Cent(h.i - i)
Base.:-(h1::Cent, h2::Cent) = h1.i - h2.i
Base.:<=(h1::Cent, h2::Cent) = h1.i <= h2.i
Base.:<(h1::Cent, h2::Cent) = h1.i < h2.i
Base.max(h1::Cent, h2::Cent) = Cent(max(h1.i, h2.i))
Base.min(h1::Cent, h2::Cent) = Cent(min(h1.i, h2.i))
Base.@propagate_inbounds Base.getindex(field::CC.Fields.FiniteDifferenceField, i::Integer) = Base.getproperty(field, i)
Base.@propagate_inbounds Base.getindex(field::CC.Fields.CenterFiniteDifferenceField, i::Cent) =
Base.getindex(CC.Fields.field_values(field), i.i)
Base.@propagate_inbounds Base.setindex!(field::CC.Fields.CenterFiniteDifferenceField, v, i::Cent) =
Base.setindex!(CC.Fields.field_values(field), v, i.i)
Base.@propagate_inbounds Base.getindex(field::CC.Fields.FaceFiniteDifferenceField, i::CCO.PlusHalf) =
Base.getindex(CC.Fields.field_values(field), i.i)
Base.@propagate_inbounds Base.setindex!(field::CC.Fields.FaceFiniteDifferenceField, v, i::CCO.PlusHalf) =
Base.setindex!(CC.Fields.field_values(field), v, i.i)
Base.@propagate_inbounds Base.getindex(field::CC.Fields.FaceFiniteDifferenceField, ::Cent) =
error("Attempting to getindex with a center index (Cent) into a Face field")
Base.@propagate_inbounds Base.getindex(field::CC.Fields.CenterFiniteDifferenceField, ::CCO.PlusHalf) =
error("Attempting to getindex with a face index (PlusHalf) into a Center field")
Base.@propagate_inbounds Base.setindex!(field::CC.Fields.FaceFiniteDifferenceField, v, ::Cent) =
error("Attempting to setindex with a center index (Cent) into a Face field")
Base.@propagate_inbounds Base.setindex!(field::CC.Fields.CenterFiniteDifferenceField, v, ::CCO.PlusHalf) =
error("Attempting to setindex with a face index (PlusHalf) into a Center field")
# TODO: deprecate, we should not overload getindex/setindex for ordinary arrays.
Base.@propagate_inbounds Base.getindex(arr::AbstractArray, i::Cent) = Base.getindex(arr, i.i)
Base.@propagate_inbounds Base.setindex!(arr::AbstractArray, v, i::Cent) = Base.setindex!(arr, v, i.i)
Base.@propagate_inbounds Base.getindex(arr::AbstractArray, i::CCO.PlusHalf) = Base.getindex(arr, i.i)
Base.@propagate_inbounds Base.setindex!(arr::AbstractArray, v, i::CCO.PlusHalf) = Base.setindex!(arr, v, i.i)
Base.@propagate_inbounds Base.getindex(arr::AbstractArray, i::Int, j::Cent) = Base.getindex(arr, i, j.i)
Base.@propagate_inbounds Base.setindex!(arr::AbstractArray, v, i::Int, j::Cent) = Base.setindex!(arr, v, i, j.i)
Base.@propagate_inbounds Base.getindex(arr::AbstractArray, i::Int, j::CCO.PlusHalf) = Base.getindex(arr, i, j.i)
Base.@propagate_inbounds Base.setindex!(arr::AbstractArray, v, i::Int, j::CCO.PlusHalf) = Base.setindex!(arr, v, i, j.i)
function FieldFromNamedTuple(space, nt::NamedTuple)
cmv(z) = nt
return cmv.(CC.Fields.coordinate_field(space))
end
const FT = Float64;
N = 20;
grid = Grid(π/N, N);
face = FieldFromNamedTuple(grid.fs, (;θ = FT(0), m = FT(0), ∇θ = FT(0), ∇θ_default = FT(0)));
cent = FieldFromNamedTuple(grid.cs, (;θ = FT(0), m = FT(0), ∇θ = FT(0), ∇θ_default = FT(0)));
f_bcs = (; bottom = CCO.Extrapolate(), top = CCO.Extrapolate())
If = CCO.InterpolateC2F(;f_bcs...)
wvec = CC.Geometry.WVector
div = CCO.DivergenceF2C()
zc = grid.zc
parent(cent.∇θ_default) .= 0
parent(face.∇θ_default) .= 0
@. cent.θ = sin(getproperty(zc, :z))
# interp to face (with BCs, required regardless)
@. face.θ = If(cent.θ)
function subdomain_field(field_in, space_in::CC.Spaces.CenterFiniteDifferenceSpace, sm)
FT = eltype(field_in)
fs_in = CC.Spaces.FaceFiniteDifferenceSpace(space_in)
zf_in = CC.Fields.coordinate_field(fs_in)
zc_in = CC.Fields.coordinate_field(fs_in)
nelems = length(vec(zc_in)[sm])
zmin = minimum(vec(zf_in)[sm])
zmax = maximum(vec(zf_in)[sm])
domain = CC.Domains.IntervalDomain(
CC.Geometry.ZPoint{FT}(zmin),
CC.Geometry.ZPoint{FT}(zmax);
boundary_tags = (:bottom, :top),
)
mesh = CC.Meshes.IntervalMesh(domain; nelems = nelems)
cs = CC.Spaces.CenterFiniteDifferenceSpace(mesh)
cf = CC.Fields.coordinate_field(cs)
sub_field = sin.(cf.z)
parent(sub_field) .= parent(field_in)[sm]
return sub_field
end
function masked_interpolate!(face, center, grid, mask)
If = CCO.InterpolateC2F(; bottom = CCO.Extrapolate(), top=CCO.Extrapolate())
for sm in SubMasks(vec(mask))
sub_field = subdomain_field(center, grid.cs, sm)
@show sub_field
parent(face)[sm.start:(sm.stop+1)] .= vec(If.(sub_field))
@show face
end
end
for k in real_center_indices(grid)
if π/5 ≤ grid.zc[k] ≤ π/2
cent.m[k] = 1
else
cent.θ[k] = 2.5
cent.m[k] = 0
end
end
for k in real_face_indices(grid)
if π/5 ≤ grid.zf[k] ≤ π/2
else
face.θ[k] = 2.5
end
end
masked_interpolate!(face.θ, cent.θ, grid, cent.m)
@. cent.∇θ = div(wvec(face.θ)) * cent.m + (1 - cent.m) * cent.∇θ_default
import Plots
Plots.plot(vec(grid.zc), vec(cent.θ); label = "θc", markershape = :circle)
Plots.plot!(vec(grid.zc), vec(cent.m); label = "mask", markershape = :circle)
Plots.plot!(vec(grid.zc), vec(cent.∇θ); label = "∇θ", markershape = :circle)
Plots.plot!(vec(grid.zf), vec(face.θ); label = "θf", markershape = :circle) |
After digging into this a bit more, we're able to achieve our existing behavior with the existing operators (TC.jl's 565) and masking the field (with some small modifications). So, I don't think we have any dire needs for new operators at this point. Closing for now. |
Some parts of the EDMF code require "masked operators", where a user provides a mask (in the form of a field?), and a field in order to, potentially, interpolate or compute gradients/divergences. Here is a sketch for illustration:
Perhaps the interface could be something like
From what I can see in the code, it looks like the only operator we actually use is
∇C2C
but, IIRC, the edmf team wanted to use this in several other places. It would be very helpful to specify exactly which operators we need support for to limit our scope:Also, the picture for the
∇C2C
seems clear to me, but I think the picture is a bit different for staggered derivatives (no one-sided differences are then needed).I think there are still some loose ends that need clarification (cc @ilopezgp). For example, what should happen if / when a single point, or only two points, in the mask is true? It seems like we can't guarantee any order of accuracy because we're potentially limited by the size of these masked subdomains.
I highlighted a section in bold above because, as I was writing this, I started to think that perhaps we can already compose these operators from existing pieces. I'm going to try prototyping something now, but still opening the issue for reference.
The text was updated successfully, but these errors were encountered: