Skip to content

Commit

Permalink
Redesign address stability trait
Browse files Browse the repository at this point in the history
  • Loading branch information
bernhardmgruber committed Jul 26, 2024
1 parent 4dba0dd commit 445a3aa
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 34 deletions.
25 changes: 16 additions & 9 deletions thrust/testing/address_stability.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,29 @@ struct Overloaded
}
};

struct Addable
{
_CCCL_HOST_DEVICE friend auto operator+(const Addable&, const Addable&) -> Addable
{
return Addable{};
}
};

void TestAddressStability()
{
using thrust::can_copy_arguments;
using thrust::is_input_address_oblivious;

static_assert(is_input_address_oblivious<thrust::plus<int>>::value, "");
static_assert(!is_input_address_oblivious<thrust::plus<>>::value, "");
static_assert(!is_input_address_oblivious<thrust::plus<MyPlus>>::value, "");
static_assert(is_input_address_oblivious<thrust::plus<int>, int, int>::value, "");
static_assert(is_input_address_oblivious<thrust::plus<>, int, int>::value, "");
static_assert(!is_input_address_oblivious<thrust::plus<MyPlus>, int, int>::value, ""); // TODO should be fine

// TODO(bgruber): Overloaded when called with (int, int) is oblivious, but (float, float) isn't
// static_assert(!is_input_address_oblivious<Overloaded>::value, "");
static_assert(!is_input_address_oblivious<Overloaded, int, int>::value, ""); // TODO should be fine
static_assert(!is_input_address_oblivious<Overloaded, float, float>::value, "");

static_assert(can_copy_arguments<thrust::plus<int>, int*, int*>::value, "");
// TODO(bgruber): with some effort we can make this one work:
// static_assert(can_copy_arguments<thrust::plus<>, int*, int*>::value, "");
static_assert(!can_copy_arguments<thrust::plus<MyPlus>, MyPlus*, MyPlus*>::value, "");
static_assert(!can_copy_arguments<thrust::plus<>, MyPlus*, MyPlus*>::value, "");
static_assert(can_copy_arguments<thrust::plus<>, int*, int*>::value, "");
static_assert(!can_copy_arguments<thrust::plus<Addable>, Addable*, Addable*>::value, "");
static_assert(!can_copy_arguments<thrust::plus<>, Addable*, Addable*>::value, "");
}
DECLARE_UNITTEST(TestAddressStability);
63 changes: 40 additions & 23 deletions thrust/thrust/address_stability.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,29 +19,46 @@

THRUST_NAMESPACE_BEGIN

// TODO(bgruber): we may need to include the parameter types in the check, since the call operator of the functor could
// be overloaded.
// TODO(bgruber): bikeshed name, e.g., allow_copied_parameter
/// Trait telling whether a function object relies on the memory address of the input arguments. The nested value is
/// true when the addres of the inputs do not matter.
template <typename F, typename SFINAE = void>
struct is_input_address_oblivious : std::false_type
namespace detail
{
// need a separate implementation trait because we SFINAE with a type parameter before the variadic pack
template <typename F, typename SFINAE, typename... Args>
struct is_input_address_oblivious_impl : std::false_type
{};

template <typename F>
struct is_input_address_oblivious<F, ::cuda::std::void_t<decltype(F::is_input_address_oblivious)>>
template <typename F, typename... Args>
struct is_input_address_oblivious_impl<F, ::cuda::std::void_t<decltype(F::is_input_address_oblivious)>, Args...>
{
static constexpr bool value = F::is_input_address_oblivious;
};

#define MARK_INPUT_ADDRESS_OBLIVIOUS(functor) \
template <typename T> \
struct is_input_address_oblivious<functor<T>> \
{ \
/*we know what thrust::plus<T> etc. do if T is not a type that could have a weird operatorX() */ \
static constexpr bool value = \
!::cuda::std::is_class<T>::value && !::cuda::std::is_enum<T>::value && !::cuda::std::is_void<T>::value; \
}
template <typename T>
struct has_builtin_operators
: ::cuda::std::bool_constant<!::cuda::std::is_class<T>::value && !::cuda::std::is_enum<T>::value
&& !::cuda::std::is_void<T>::value>
{};
} // namespace detail

// TODO(bgruber): we may need to include the parameter types in the check, since the call operator of the functor could
// be overloaded.
// TODO(bgruber): bikeshed name, e.g., allow_copied_parameter
/// Trait telling whether a function object relies on the memory address of the input arguments when called with the
/// given set of types. The nested value is true when the addres of the inputs do not matter.
template <typename F, typename... Args>
using is_input_address_oblivious = detail::is_input_address_oblivious_impl<F, void, Args...>;

#define MARK_INPUT_ADDRESS_OBLIVIOUS(functor) \
/*we know what thrust::plus<T> etc. do if T is not a type that could have a weird operatorX() */ \
template <typename T, typename... Args> \
struct detail::is_input_address_oblivious_impl<functor<T>, void, Args...> \
{ \
static constexpr bool value = detail::has_builtin_operators<T>::value; \
}; \
/*we know what thrust::plus<void> etc. do if T is not a type that could have a weird operatorX() */ \
template <typename... Args> \
struct detail::is_input_address_oblivious_impl<functor<void>, void, Args...> \
: ::cuda::std::conjunction<detail::has_builtin_operators<Args>...> \
{};

// TODO(bgruber): move those close to where the functors are defined
MARK_INPUT_ADDRESS_OBLIVIOUS(thrust::plus);
Expand All @@ -51,8 +68,9 @@ MARK_INPUT_ADDRESS_OBLIVIOUS(thrust::negate);

#undef MARK_INPUT_ADDRESS_OBLIVIOUS

template <typename F>
struct is_input_address_oblivious<detail::not_fun_t<F>> : is_input_address_oblivious<F>
template <typename F, typename... Args>
struct detail::is_input_address_oblivious_impl<detail::not_fun_t<F>, void, Args...>
: is_input_address_oblivious<F, Args...>
{};

namespace detail
Expand All @@ -74,16 +92,15 @@ _CCCL_HOST_DEVICE constexpr auto mark_input_address_oblivious(F f) -> detail::in
return detail::input_address_oblivious_wrapper<F>{std::move(f)};
}

// TODO(bgruber): should we take the iterator reference directly instead?
template <typename TransformOp, typename... Its>
struct can_copy_arguments
{
// TODO(bgruber): add detection whether user takes arguments by value, similar to how cub::DeviceFor does it
static constexpr bool value =
::cuda::std::conjunction<::cuda::std::is_trivially_copyable<iterator_value_t<Its>>...>::value
&& is_input_address_oblivious<TransformOp>::value;
// TODO(bgruber): if TransformOp is a transparent functor (e.g. thrust::plus<void>) we should retry the check with
// thrust::plus<iterator_value_t<Its>>
&& is_input_address_oblivious<TransformOp, iterator_value_t<Its>...>::value; // TODO(bgruber): is
// iterator_value_t<Its> correct? Why
// not iterator_reference_t<Its>?
};

#if _CCCL_STD_VER >= 2014
Expand Down
4 changes: 2 additions & 2 deletions thrust/thrust/zip_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,8 @@ class zip_function
mutable Function func;
};

template <typename F>
struct is_input_address_oblivious<zip_function<F>, void> : is_input_address_oblivious<F>
template <typename F, typename... Args>
struct detail::is_input_address_oblivious_impl<zip_function<F>, void, Args...> : is_input_address_oblivious<F, Args...>
{};

/*! \p make_zip_function creates a \p zip_function from a function object.
Expand Down

0 comments on commit 445a3aa

Please sign in to comment.