Skip to content

Commit

Permalink
Implement StructureInterpolation2D for 3-dimensional fields with extr…
Browse files Browse the repository at this point in the history
…a variables)
  • Loading branch information
wdeconinck committed Nov 15, 2024
1 parent e4f6b87 commit 91c6430
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -472,15 +472,24 @@ void StructuredInterpolation2D<Kernel>::do_execute( const FieldSet& src_fields,
if ( datatype.kind() == array::DataType::KIND_REAL64 && rank == 1 ) {
execute_impl<double, 1>( *kernel_, src_fields, tgt_fields );
}
if ( datatype.kind() == array::DataType::KIND_REAL32 && rank == 1 ) {
else if ( datatype.kind() == array::DataType::KIND_REAL32 && rank == 1 ) {
execute_impl<float, 1>( *kernel_, src_fields, tgt_fields );
}
if ( datatype.kind() == array::DataType::KIND_REAL64 && rank == 2 ) {
else if ( datatype.kind() == array::DataType::KIND_REAL64 && rank == 2 ) {
execute_impl<double, 2>( *kernel_, src_fields, tgt_fields );
}
if ( datatype.kind() == array::DataType::KIND_REAL32 && rank == 2 ) {
else if ( datatype.kind() == array::DataType::KIND_REAL32 && rank == 2 ) {
execute_impl<float, 2>( *kernel_, src_fields, tgt_fields );
}
else if ( datatype.kind() == array::DataType::KIND_REAL64 && rank == 3 ) {
execute_impl<double, 3>( *kernel_, src_fields, tgt_fields );
}
else if ( datatype.kind() == array::DataType::KIND_REAL32 && rank == 3 ) {
execute_impl<float, 3>( *kernel_, src_fields, tgt_fields );
}
else {
ATLAS_NOTIMPLEMENTED;
}

tgt_fields.set_dirty();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,39 @@ class CubicHorizontalKernel {
}
}

template <typename stencil_t, typename weights_t, typename Value, int Rank>
typename std::enable_if<(Rank == 3), void>::type interpolate(const stencil_t& stencil, const weights_t& weights,
const array::ArrayView<const Value, Rank>& input,
array::ArrayView<Value, Rank>& output, idx_t r) const {
std::array<std::array<idx_t, stencil_width()>, stencil_width()> index;
const auto& weights_j = weights.weights_j;
const idx_t Nk = output.shape(1);
const idx_t Nl = output.shape(2);

for (idx_t k = 0; k < Nk; ++k) {
for (idx_t l = 0; l < Nl; ++l) {
output(r, k, l) = 0.;
}
}
for (idx_t j = 0; j < stencil_width(); ++j) {
const auto& weights_i = weights.weights_i[j];
for (idx_t i = 0; i < stencil_width(); ++i) {
idx_t n = src_.index(stencil.i(i, j), stencil.j(j));
Value w = static_cast<Value>(weights_i[i] * weights_j[j]);
for (idx_t k = 0; k < Nk; ++k) {
for (idx_t l = 0; l < Nl; ++l) {
output(r, k, l) += w * input(n, k, l);
}
}
index[j][i] = n;
}
}

if (limiter_) {
Limiter::limit(index, input, output, r);
}
}

template <typename array_t>
typename array_t::value_type operator()(double x, double y, const array_t& input) const {
Stencil stencil;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,39 @@ class CubicHorizontalLimiter {
}
}
}

template <typename Value, int Rank>
static typename std::enable_if<(Rank == 3), void>::type limit(const std::array<std::array<idx_t, 4>, 4>& index,
const array::ArrayView<const Value, Rank>& input,
array::ArrayView<Value, Rank>& output, idx_t r) {
// Limit output to max/min of values in stencil marked by '*'
// x x x x
// x *-----* x
// / P |
// x *------ * x
// x x x x
for (idx_t k = 0; k < output.shape(1); ++k) {
for (idx_t l = 0; l < output.shape(2); ++l) {
Value maxval = std::numeric_limits<Value>::lowest();
Value minval = std::numeric_limits<Value>::max();
for (idx_t j = 1; j < 3; ++j) {
for (idx_t i = 1; i < 3; ++i) {
idx_t n = index[j][i];
Value val = input(n, k, l);
maxval = std::max(maxval, val);
minval = std::min(minval, val);
}
}
if (output(r, k, l) < minval) {
output(r, k, l) = minval;
}
else if (output(r, k, l) > maxval) {
output(r, k, l) = maxval;
}
}
}
}

};

} // namespace method
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,32 @@ class LinearHorizontalKernel {
}
}

template <typename stencil_t, typename weights_t, typename Value, int Rank>
typename std::enable_if<(Rank == 3), void>::type interpolate(const stencil_t& stencil, const weights_t& weights,
const array::ArrayView<const Value, Rank>& input,
array::ArrayView<Value, Rank>& output, idx_t r) const {
const auto& weights_j = weights.weights_j;
const idx_t Nk = output.shape(1);
const idx_t Nl = output.shape(2);
for (idx_t k = 0; k < Nk; ++k) {
for (idx_t l = 0; l < Nl; ++l) {
output(r, k, l) = 0.;
}
}
for (idx_t j = 0; j < stencil_width(); ++j) {
const auto& weights_i = weights.weights_i[j];
for (idx_t i = 0; i < stencil_width(); ++i) {
idx_t n = src_.index(stencil.i(i, j), stencil.j(j));
Value w = static_cast<Value>(weights_i[i] * weights_j[j]);
for (idx_t k = 0; k < Nk; ++k) {
for (idx_t l = 0; l < Nl; ++l) {
output(r, k, l) += w * input(n, k, l);
}
}
}
}
}

template <typename array_t>
typename array_t::value_type operator()(double x, double y, const array_t& input) const {
Stencil stencil;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,57 @@ class QuasiCubicHorizontalKernel {
}
}


template <typename stencil_t, typename weights_t, typename Value, int Rank>
typename std::enable_if<(Rank == 3), void>::type interpolate(const stencil_t& stencil, const weights_t& weights,
const array::ArrayView<const Value, Rank>& input,
array::ArrayView<Value, Rank>& output, idx_t r) const {
std::array<std::array<idx_t, stencil_width()>, stencil_width()> index;
const auto& weights_j = weights.weights_j;
const idx_t Nk = output.shape(1);
const idx_t Nl = output.shape(2);
for (idx_t k = 0; k < Nk; ++k) {
for (idx_t l = 0; l < Nl; ++l) {
output(r, k, l) = 0.;
}
}

// LINEAR for outer rows ( j = {0,3} )
for (idx_t j = 0; j < 4; j += 3) {
const auto& weights_i = weights.weights_i[j];
for (idx_t i = 1; i < 3; ++i) { // i = {1,2}
idx_t n = src_.index(stencil.i(i, j), stencil.j(j));
Value w = weights_i[i] * weights_j[j];
for (idx_t k = 0; k < Nk; ++k) {
for (idx_t l = 0; l < Nl; ++l) {
output(r, k, l) += w * input(n, k, l);
}
}
index[j][i] = n;
}
}
// CUBIC for inner rows ( j = {1,2} )
for (idx_t j = 1; j < 3; ++j) {
const auto& weights_i = weights.weights_i[j];
for (idx_t i = 0; i < stencil_width(); ++i) {
idx_t n = src_.index(stencil.i(i, j), stencil.j(j));
Value w = weights_i[i] * weights_j[j];
for (idx_t k = 0; k < Nk; ++k) {
for (idx_t l = 0; l < Nl; ++l) {
output(r, k, l) += w * input(n, k, l);
}
}
index[j][i] = n;
}
}

if (limiter_) {
Limiter::limit(index, input, output, r);
}
}



template <typename array_t>
typename array_t::value_type operator()(double x, double y, const array_t& input) const {
Stencil stencil;
Expand Down

0 comments on commit 91c6430

Please sign in to comment.