Skip to content

Commit

Permalink
Intrepid2: fix for trilinos#12037; resolves a test failure on certain…
Browse files Browse the repository at this point in the history
… CUDA platforms. (PR trilinos#12047)
  • Loading branch information
CamelliaDPG authored and JacobDomagala committed Aug 4, 2023
1 parent 266f63d commit ab18e35
Showing 1 changed file with 43 additions and 8 deletions.
51 changes: 43 additions & 8 deletions packages/intrepid2/src/Shared/Intrepid2_DataCombiners.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,42 @@ namespace Intrepid2 {
}
};

//! functor definition for the constant-data case.
template<class BinaryOperator, class ThisUnderlyingViewType, class AUnderlyingViewType, class BUnderlyingViewType>
struct InPlaceCombinationFunctorConstantCase
{
private:
ThisUnderlyingViewType this_underlying_;
AUnderlyingViewType A_underlying_;
BUnderlyingViewType B_underlying_;
BinaryOperator binaryOperator_;
public:
InPlaceCombinationFunctorConstantCase(ThisUnderlyingViewType this_underlying,
AUnderlyingViewType A_underlying,
BUnderlyingViewType B_underlying,
BinaryOperator binaryOperator)
:
this_underlying_(this_underlying),
A_underlying_(A_underlying),
B_underlying_(B_underlying),
binaryOperator_(binaryOperator)
{
INTREPID2_TEST_FOR_EXCEPTION(this_underlying.extent(0) != 1,std::invalid_argument,"all views for InPlaceCombinationFunctorConstantCase should have rank 1 and extent 1");
INTREPID2_TEST_FOR_EXCEPTION(A_underlying.extent(0) != 1,std::invalid_argument,"all views for InPlaceCombinationFunctorConstantCase should have rank 1 and extent 1");
INTREPID2_TEST_FOR_EXCEPTION(B_underlying.extent(0) != 1,std::invalid_argument,"all views for InPlaceCombinationFunctorConstantCase should have rank 1 and extent 1");
}

KOKKOS_INLINE_FUNCTION
void operator()(const int arg0) const
{
auto & result = this_underlying_(0);
const auto & A_val = A_underlying_(0);
const auto & B_val = B_underlying_(0);

result = binaryOperator_(A_val,B_val);
}
};

//! For use with Data object into which a value will be stored. We use passThroughBlockDiagonalArgs = true for storeInPlaceCombination().
template<bool passThroughBlockDiagonalArgs>
struct FullArgExtractorWritableData
Expand Down Expand Up @@ -489,17 +525,16 @@ namespace Intrepid2 {
{
// constant data
Kokkos::RangePolicy<ExecutionSpace> policy(ExecutionSpace(),0,1); // just 1 entry

auto this_underlying = thisData.template getUnderlyingView<1>();
auto A_underlying = A.template getUnderlyingView<1>();
auto B_underlying = B.template getUnderlyingView<1>();
Kokkos::parallel_for("compute in-place", policy,
KOKKOS_LAMBDA (const int &i0) {
auto & result = this_underlying(0);
const auto & A_val = A_underlying(0);
const auto & B_val = B_underlying(0);

result = binaryOperator(A_val,B_val);
});

using ConstantCaseFunctor = InPlaceCombinationFunctorConstantCase<decltype(binaryOperator), decltype(this_underlying),
decltype(A_underlying), decltype(B_underlying)>;

ConstantCaseFunctor functor(this_underlying, A_underlying, B_underlying, binaryOperator);
Kokkos::parallel_for("compute in-place", policy,functor);
}
else
{
Expand Down

0 comments on commit ab18e35

Please sign in to comment.