Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SoA range checking: adds inter operability between range checked and non range checked #41928

Merged
merged 4 commits into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions DataFormats/SoATemplate/interface/SoACommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@
#define _VALUE_TYPE_EIGEN_COLUMN 2

/* The size type need to be "hardcoded" in the template parameters for classes serialized by ROOT */
/* In practice, using a typedef as a template parameter to the Layout or its ViewTemplateFreeParams member
* declaration fails ROOT dictionary generation. */
#define CMS_SOA_BYTE_SIZE_TYPE std::size_t

namespace cms::soa {
Expand Down Expand Up @@ -132,6 +134,8 @@ namespace cms::soa {
return reinterpret_cast<intptr_t>(addr) % alignment;
}

TupleOrPointerType tupleOrPointer() { return addr_; }

public:
// scalar or column
ValueType const* addr_ = nullptr;
Expand Down Expand Up @@ -166,6 +170,8 @@ namespace cms::soa {
return reinterpret_cast<intptr_t>(addr) % alignment;
}

TupleOrPointerType tupleOrPointer() { return {addr_, stride_}; }

public:
// address and stride
ScalarType const* addr_ = nullptr;
Expand Down Expand Up @@ -201,6 +207,8 @@ namespace cms::soa {
return reinterpret_cast<intptr_t>(addr) % alignment;
}

TupleOrPointerType tupleOrPointer() { return addr_; }

public:
// scalar or column
ValueType* addr_ = nullptr;
Expand Down Expand Up @@ -234,6 +242,8 @@ namespace cms::soa {
return reinterpret_cast<intptr_t>(addr) % alignment;
}

TupleOrPointerType tupleOrPointer() { return {addr_, stride_}; }

public:
// address and stride
ScalarType* addr_ = nullptr;
Expand Down
57 changes: 53 additions & 4 deletions DataFormats/SoATemplate/interface/SoAView.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,15 @@ namespace cms::soa {
#define _DECLARE_VIEW_MEMBER_LIST(R, DATA, LAYOUT_MEMBER_NAME) \
BOOST_PP_EXPAND(_DECLARE_VIEW_MEMBER_LIST_IMPL LAYOUT_MEMBER_NAME)

/**
* Generator of view member list.
*/
#define _DECLARE_VIEW_OTHER_MEMBER_LIST_IMPL(LAYOUT, MEMBER, NAME) \
(const_cast_SoAParametersImpl(other.BOOST_PP_CAT(NAME, Parameters_)).tupleOrPointer())

#define _DECLARE_VIEW_OTHER_MEMBER_LIST(R, DATA, LAYOUT_MEMBER_NAME) \
BOOST_PP_EXPAND(_DECLARE_VIEW_OTHER_MEMBER_LIST_IMPL LAYOUT_MEMBER_NAME)

/**
* Generator of member initializer for copy constructor.
*/
Expand Down Expand Up @@ -390,7 +399,7 @@ namespace cms::soa {
template RestrictQualifier<restrictQualify>::ParamReturnType \
LOCAL_NAME(size_type _soa_impl_index) { \
if constexpr (rangeChecking == cms::soa::RangeChecking::enabled) { \
if (_soa_impl_index >= base_type::elements_) \
if (_soa_impl_index >= base_type::elements_ or _soa_impl_index < 0) \
SOA_THROW_OUT_OF_RANGE("Out of range index in mutable " #LOCAL_NAME "(size_type index)") \
} \
return typename cms::soa::SoAAccessors<typename BOOST_PP_CAT(Metadata::TypeOf_, LOCAL_NAME)>:: \
Expand Down Expand Up @@ -428,7 +437,7 @@ namespace cms::soa {
template RestrictQualifier<restrictQualify>::ParamReturnType \
LOCAL_NAME(size_type _soa_impl_index) const { \
if constexpr (rangeChecking == cms::soa::RangeChecking::enabled) { \
if (_soa_impl_index >= elements_) \
if (_soa_impl_index >= elements_ or _soa_impl_index < 0) \
SOA_THROW_OUT_OF_RANGE("Out of range index in const " #LOCAL_NAME "(size_type index)") \
} \
return typename cms::soa::SoAAccessors<typename BOOST_PP_CAT(Metadata::TypeOf_, LOCAL_NAME)>:: \
Expand Down Expand Up @@ -535,6 +544,9 @@ namespace cms::soa {
template <cms::soa::SoAColumnType COLUMN_TYPE, class C> \
using SoAConstValueWithConf = cms::soa::SoAConstValue<COLUMN_TYPE, C, conditionalAlignment, restrictQualify>; \
\
template <CMS_SOA_BYTE_SIZE_TYPE, bool, bool, bool> \
friend struct VIEW; \
\
/** \
* Helper/friend class allowing SoA introspection. \
*/ \
Expand Down Expand Up @@ -582,6 +594,23 @@ namespace cms::soa {
VIEW(VIEW const&) = default; \
VIEW& operator=(VIEW const&) = default; \
\
/* Copy constructor for other parameters */ \
template <CMS_SOA_BYTE_SIZE_TYPE OTHER_VIEW_ALIGNMENT, \
bool OTHER_VIEW_ALIGNMENT_ENFORCEMENT, \
bool OTHER_RESTRICT_QUALIFY, \
bool OTHER_RANGE_CHECKING> \
VIEW(VIEW<OTHER_VIEW_ALIGNMENT, OTHER_VIEW_ALIGNMENT_ENFORCEMENT, OTHER_RESTRICT_QUALIFY, \
OTHER_RANGE_CHECKING> const& other): base_type{other.elements_, \
_ITERATE_ON_ALL_COMMA(_DECLARE_VIEW_OTHER_MEMBER_LIST, BOOST_PP_EMPTY(), VALUE_LIST) \
} {} \
/* Copy operator for other parameters */ \
template <CMS_SOA_BYTE_SIZE_TYPE OTHER_VIEW_ALIGNMENT, \
bool OTHER_VIEW_ALIGNMENT_ENFORCEMENT, \
bool OTHER_RESTRICT_QUALIFY, \
bool OTHER_RANGE_CHECKING> \
VIEW& operator=(VIEW<OTHER_VIEW_ALIGNMENT, OTHER_VIEW_ALIGNMENT_ENFORCEMENT, OTHER_RESTRICT_QUALIFY, \
OTHER_RANGE_CHECKING> const& other) { static_cast<base_type>(*this) = static_cast<base_type>(other); } \
\
/* Movable */ \
VIEW(VIEW &&) = default; \
VIEW& operator=(VIEW &&) = default; \
Expand Down Expand Up @@ -620,7 +649,7 @@ namespace cms::soa {
SOA_HOST_DEVICE SOA_INLINE \
element operator[](size_type _soa_impl_index) { \
if constexpr (rangeChecking == cms::soa::RangeChecking::enabled) { \
if (_soa_impl_index >= base_type::elements_) \
if (_soa_impl_index >= base_type::elements_ or _soa_impl_index < 0) \
SOA_THROW_OUT_OF_RANGE("Out of range index in " #VIEW "::operator[]") \
} \
return element{_soa_impl_index, _ITERATE_ON_ALL_COMMA(_DECLARE_VIEW_ELEMENT_CONSTR_CALL, ~, VALUE_LIST)}; \
Expand Down Expand Up @@ -673,6 +702,9 @@ namespace cms::soa {
template <CMS_SOA_BYTE_SIZE_TYPE, bool, bool, bool> \
friend struct VIEW; \
\
template <CMS_SOA_BYTE_SIZE_TYPE, bool, bool, bool> \
friend struct CONST_VIEW; \
\
/* For CUDA applications, we align to the 128 bytes of the cache lines. \
* See https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#global-memory-3-0 this is still valid \
* up to compute capability 8.X. \
Expand Down Expand Up @@ -739,6 +771,23 @@ namespace cms::soa {
CONST_VIEW(CONST_VIEW const&) = default; \
CONST_VIEW& operator=(CONST_VIEW const&) = default; \
\
/* Copy constructor for other parameters */ \
template <CMS_SOA_BYTE_SIZE_TYPE OTHER_VIEW_ALIGNMENT, \
bool OTHER_VIEW_ALIGNMENT_ENFORCEMENT, \
bool OTHER_RESTRICT_QUALIFY, \
bool OTHER_RANGE_CHECKING> \
CONST_VIEW(CONST_VIEW<OTHER_VIEW_ALIGNMENT, OTHER_VIEW_ALIGNMENT_ENFORCEMENT, OTHER_RESTRICT_QUALIFY, \
OTHER_RANGE_CHECKING> const& other): CONST_VIEW{other.elements_, \
_ITERATE_ON_ALL_COMMA(_DECLARE_VIEW_OTHER_MEMBER_LIST, BOOST_PP_EMPTY(), VALUE_LIST) \
} {} \
/* Copy operator for other parameters */ \
template <CMS_SOA_BYTE_SIZE_TYPE OTHER_VIEW_ALIGNMENT, \
bool OTHER_VIEW_ALIGNMENT_ENFORCEMENT, \
bool OTHER_RESTRICT_QUALIFY, \
bool OTHER_RANGE_CHECKING> \
CONST_VIEW& operator=(CONST_VIEW<OTHER_VIEW_ALIGNMENT, OTHER_VIEW_ALIGNMENT_ENFORCEMENT, OTHER_RESTRICT_QUALIFY, \
OTHER_RANGE_CHECKING> const& other) { *this = other; } \
\
/* Movable */ \
CONST_VIEW(CONST_VIEW &&) = default; \
CONST_VIEW& operator=(CONST_VIEW &&) = default; \
Expand All @@ -761,7 +810,7 @@ namespace cms::soa {
SOA_HOST_DEVICE SOA_INLINE \
const_element operator[](size_type _soa_impl_index) const { \
if constexpr (rangeChecking == cms::soa::RangeChecking::enabled) { \
if (_soa_impl_index >= elements_) \
if (_soa_impl_index >= elements_ or _soa_impl_index < 0) \
SOA_THROW_OUT_OF_RANGE("Out of range index in " #CONST_VIEW "::operator[]") \
} \
return const_element{ \
Expand Down
71 changes: 50 additions & 21 deletions DataFormats/SoATemplate/test/SoALayoutAndView_t.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ GENERATE_SOA_LAYOUT(SoAHostDeviceLayoutTemplate,

using SoAHostDeviceLayout = SoAHostDeviceLayoutTemplate<>;
using SoAHostDeviceView = SoAHostDeviceLayout::View;
using SoAHostDeviceRangeCheckingView =
SoAHostDeviceLayout::ViewTemplate<cms::soa::RestrictQualify::enabled, cms::soa::RangeChecking::enabled>;
using SoAHostDeviceConstView = SoAHostDeviceLayout::ConstView;

GENERATE_SOA_LAYOUT(SoADeviceOnlyLayoutTemplate,
Expand Down Expand Up @@ -126,6 +128,12 @@ int main(void) {
cudaCheck(cudaMallocHost(&h_buf, hostDeviceSize));
SoAHostDeviceLayout h_soahdLayout(h_buf, numElements);
SoAHostDeviceView h_soahd(h_soahdLayout);

// Validation of range checking variants initialization
SoAHostDeviceRangeCheckingView h_soahdrc(h_soahdLayout);
[[maybe_unused]] SoAHostDeviceRangeCheckingView h_soahdrc2 = h_soahdLayout;
SoAHostDeviceRangeCheckingView h_soahdrc3{h_soahd};
[[maybe_unused]] SoAHostDeviceRangeCheckingView h_soahdrc4 = h_soahd;
SoAHostDeviceConstView h_soahd_c(h_soahdLayout);

// Alocate buffer, stores and views on the device (single, shared buffer).
Expand Down Expand Up @@ -248,29 +256,50 @@ int main(void) {
}
}

// Validation of range checking
try {
// Get a view like the default, except for range checking
SoAHostDeviceLayout::ViewTemplate<SoAHostDeviceView::restrictQualify, cms::soa::RangeChecking::enabled>
soa1viewRangeChecking(h_soahdLayout);
// This should throw an exception
[[maybe_unused]] auto si = soa1viewRangeChecking[soa1viewRangeChecking.metadata().size()];
std::cout << "Fail: expected range-check exception (operator[]) not caught on the host." << std::endl;
assert(false);
} catch (const std::out_of_range&) {
std::cout << "Pass: expected range-check exception (operator[]) successfully caught on the host." << std::endl;
{
// Get a view like the default, except for range checking (direct initialization from layout)
SoAHostDeviceRangeCheckingView soa1viewRangeChecking(h_soahdLayout);
try {
[[maybe_unused]] auto si = soa1viewRangeChecking[soa1viewRangeChecking.metadata().size()];
std::cout << "Fail: expected range-check exception (view-level index access) not caught on the host (overflow)."
<< std::endl;
assert(false);
} catch (const std::out_of_range&) {
}
try {
[[maybe_unused]] auto si = soa1viewRangeChecking[-1];
std::cout << "Fail: expected range-check exception (view-level index access) not caught on the host (underflow)."
<< std::endl;
assert(false);
} catch (const std::out_of_range&) {
}
[[maybe_unused]] auto si = soa1viewRangeChecking[soa1viewRangeChecking.metadata().size() - 1];
[[maybe_unused]] auto si2 = soa1viewRangeChecking[0];
std::cout << "Pass: expected range-check exceptions (view-level index access) successfully caught on the host "
"(layout initialization)."
<< std::endl;
}

try {
// Get a view like the default, except for range checking
SoAHostDeviceLayout::ViewTemplate<SoAHostDeviceView::restrictQualify, cms::soa::RangeChecking::enabled>
soa1viewRangeChecking(h_soahdLayout);
// This should throw an exception
[[maybe_unused]] auto si = soa1viewRangeChecking[soa1viewRangeChecking.metadata().size()];
std::cout << "Fail: expected range-check exception (view-level index access) not caught on the host." << std::endl;
assert(false);
} catch (const std::out_of_range&) {
std::cout << "Pass: expected range-check exception (view-level index access) successfully caught on the host."
{
// Validation of view initialized range checking view initialization
try {
[[maybe_unused]] auto si = h_soahdrc3[h_soahdrc3.metadata().size()];
std::cout << "Fail: expected range-check exception (view-level index access) not caught on the host (overflow)."
<< std::endl;
assert(false);
} catch (const std::out_of_range&) {
}
try {
[[maybe_unused]] auto si = h_soahdrc3[-1];
std::cout << "Fail: expected range-check exception (view-level index access) not caught on the host (underflow)."
<< std::endl;
assert(false);
} catch (const std::out_of_range&) {
}
[[maybe_unused]] auto si = h_soahdrc3[h_soahdrc3.metadata().size() - 1];
[[maybe_unused]] auto si2 = h_soahdrc3[0];
std::cout << "Pass: expected range-check exceptions (view-level index access) successfully caught on the host "
"(view initialization)."
<< std::endl;
}

Expand Down
Loading