Skip to content

Commit

Permalink
Fixed issues in TBB/OpenMP backends on MSVC.
Browse files Browse the repository at this point in the history
- Define `iterator_category` for testing iterator.
- Add missing include to shuffle unit test
- Specialize wrapped_function for void return types
  - MSVC is not a fan of the pattern "return static_cast<void>(expr);"
- Replace buggy SFINAE check with static_assert
  - SFINAE expression was evaluating wrong type
    (should have been Tuple, not the first element type of Tuple)
  - Rather than fix SFINAE expression, switch to a static_assert for
    the better diagnostic
- Replace deprecated `tbb/tbb_thread.h` with `<thread>`.
- Fix overcounting of initial value in tbb scans.
  - Apparently reverse_join may be called before operator()
- Fix partial sum value type to support edge case from unit test:

```
     // float -> float with plus<int> operator (int accumulator)
     thrust::inclusive_scan(float_input.begin(), float_input.end(), float_output.begin(), thrust::plus<int>());
     ASSERT_EQUAL(float_output[0],  1.5);
     ASSERT_EQUAL(float_output[1],  3.0);
     ASSERT_EQUAL(float_output[2],  6.0);
     ASSERT_EQUAL(float_output[3], 10.0);
```

- Use `thrust::advance` instead of `+=` for generic iterators.
- Wrap the OMP flags in -Xcompiler for NVCC
- Extend ASSERT_STATIC_ASSERT skip for HOST=OMP, too
- Add missing header caught by tbb.cuda configs.
  • Loading branch information
alliepiper committed May 29, 2020
1 parent 83eff97 commit 6b1f456
Show file tree
Hide file tree
Showing 11 changed files with 144 additions and 80 deletions.
10 changes: 10 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,9 @@ if ("MSVC" STREQUAL "${CMAKE_CXX_COMPILER_ID}")
# object files:
append_option_if_available("/bigobj" THRUST_CXX_WARNINGS)

# "Oh right, this is Visual Studio."
add_compile_definitions("NOMINMAX")

set(THRUST_TREAT_FILE_AS_CXX "/TP")
else ()
append_option_if_available("-Werror" THRUST_CXX_WARNINGS)
Expand Down Expand Up @@ -679,6 +682,13 @@ foreach (THRUST_EXAMPLE_SOURCE IN LISTS THRUST_EXAMPLES)
endif ()
endif ()

if ("MSVC" STREQUAL "${CMAKE_CXX_COMPILER_ID}")
# Some examples use unsafe APIs (e.g. fopen) that MSVC will complain about
# unless this is set:
set_target_properties(${THRUST_EXAMPLE}
PROPERTIES COMPILE_DEFINITIONS "_CRT_SECURE_NO_WARNINGS")
endif()

add_test(NAME ${THRUST_EXAMPLE}
COMMAND ${CMAKE_COMMAND}
-DTHRUST_EXAMPLE=${THRUST_EXAMPLE}
Expand Down
3 changes: 3 additions & 0 deletions testing/copy.cu
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,8 @@ struct only_set_when_expected_it
__host__ __device__ only_set_when_expected_it operator*() const { return *this; }
template<typename Difference>
__host__ __device__ only_set_when_expected_it operator+(Difference) const { return *this; }
template<typename Difference>
__host__ __device__ only_set_when_expected_it operator+=(Difference) const { return *this; }
template<typename Index>
__host__ __device__ only_set_when_expected_it operator[](Index) const { return *this; }

Expand All @@ -744,6 +746,7 @@ struct iterator_traits<only_set_when_expected_it>
{
typedef long long value_type;
typedef only_set_when_expected_it reference;
typedef thrust::random_access_device_iterator_tag iterator_category;
};
}

Expand Down
1 change: 1 addition & 0 deletions testing/shuffle.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#if THRUST_CPP_DIALECT >= 2011
#include <thrust/random.h>
#include <thrust/sequence.h>
#include <thrust/shuffle.h>
#include <thrust/sort.h>
#include <unittest/unittest.h>
Expand Down
2 changes: 1 addition & 1 deletion testing/unittest_static_assert.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ struct static_assertion
template<typename V>
void TestStaticAssertAssert()
{
#if THRUST_DEVICE_SYSTEM != THRUST_DEVICE_SYSTEM_OMP
#if THRUST_DEVICE_SYSTEM != THRUST_DEVICE_SYSTEM_OMP && THRUST_HOST_SYSTEM != THRUST_HOST_SYSTEM_OMP
V test(10);
ASSERT_STATIC_ASSERT(thrust::generate(test.begin(), test.end(), static_assertion<int>()));
#endif
Expand Down
11 changes: 11 additions & 0 deletions thrust/cmake/thrust-config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,16 @@ macro(_thrust_find_TBB required)
endif()
endmacro()

# Wrap the OpenMP flags for CUDA targets
function(thrust_fixup_omp_target omp_target)
get_target_property(opts ${omp_target} INTERFACE_COMPILE_OPTIONS)
if (opts MATCHES "\\$<\\$<COMPILE_LANGUAGE:CXX>:([^>]*)>")
target_compile_options(${omp_target} INTERFACE
$<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler=${CMAKE_MATCH_1}>
)
endif()
endfunction()

# This must be a macro instead of a function to ensure that backends passed to
# find_package(Thrust COMPONENTS [...]) have their full configuration loaded
# into the current scope. This provides at least some remedy for CMake issue
Expand All @@ -568,6 +578,7 @@ macro(_thrust_find_OMP required)
)

if (TARGET OpenMP::OpenMP_CXX)
thrust_fixup_omp_target(OpenMP::OpenMP_CXX)
thrust_set_OMP_target(OpenMP::OpenMP_CXX)
else()
thrust_debug("OpenMP::OpenMP_CXX not found!" internal)
Expand Down
129 changes: 91 additions & 38 deletions thrust/detail/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,80 +24,133 @@ namespace thrust
namespace detail
{


template<typename Function, typename Result>
struct wrapped_function
template <typename Function, typename Result>
struct wrapped_function
{
// mutable because Function::operator() might be const
mutable Function m_f;

inline __host__ __device__
wrapped_function()
: m_f()
: m_f()
{}

inline __host__ __device__
wrapped_function(const Function &f)
: m_f(f)
wrapped_function(const Function& f)
: m_f(f)
{}

__thrust_exec_check_disable__
template<typename Argument>
template <typename Argument>
inline __host__ __device__
Result operator()(Argument &x) const
Result operator()(Argument& x) const
{
// we static cast to Result to handle void Result without error
// in case Function's result is non-void
return static_cast<Result>(m_f(thrust::raw_reference_cast(x)));
return m_f(thrust::raw_reference_cast(x));
}

__thrust_exec_check_disable__
template<typename Argument>
inline __host__ __device__ Result operator()(const Argument &x) const
template <typename Argument>
inline __host__ __device__
Result operator()(const Argument& x) const
{
// we static cast to Result to handle void Result without error
// in case Function's result is non-void
return static_cast<Result>(m_f(thrust::raw_reference_cast(x)));
return m_f(thrust::raw_reference_cast(x));
}

__thrust_exec_check_disable__
template<typename Argument1, typename Argument2>
inline __host__ __device__ Result operator()(Argument1 &x, Argument2 &y) const
template <typename Argument1, typename Argument2>
inline __host__ __device__
Result operator()(Argument1& x, Argument2& y) const
{
// we static cast to Result to handle void Result without error
// in case Function's result is non-void
return static_cast<Result>(m_f(thrust::raw_reference_cast(x), thrust::raw_reference_cast(y)));
return m_f(thrust::raw_reference_cast(x), thrust::raw_reference_cast(y));
}

__thrust_exec_check_disable__
template<typename Argument1, typename Argument2>
inline __host__ __device__ Result operator()(const Argument1 &x, Argument2 &y) const
template <typename Argument1, typename Argument2>
inline __host__ __device__
Result operator()(const Argument1& x, Argument2& y) const
{
// we static cast to Result to handle void Result without error
// in case Function's result is non-void
return static_cast<Result>(m_f(thrust::raw_reference_cast(x), thrust::raw_reference_cast(y)));
return m_f(thrust::raw_reference_cast(x), thrust::raw_reference_cast(y));
}

__thrust_exec_check_disable__
template<typename Argument1, typename Argument2>
inline __host__ __device__ Result operator()(const Argument1 &x, const Argument2 &y) const
template <typename Argument1, typename Argument2>
inline __host__ __device__
Result operator()(const Argument1& x, const Argument2& y) const
{
// we static cast to Result to handle void Result without error
// in case Function's result is non-void
return static_cast<Result>(m_f(thrust::raw_reference_cast(x), thrust::raw_reference_cast(y)));
return m_f(thrust::raw_reference_cast(x), thrust::raw_reference_cast(y));
}

__thrust_exec_check_disable__
template<typename Argument1, typename Argument2>
inline __host__ __device__ Result operator()(Argument1 &x, const Argument2 &y) const
template <typename Argument1, typename Argument2>
inline __host__ __device__
Result operator()(Argument1& x, const Argument2& y) const
{
// we static cast to Result to handle void Result without error
// in case Function's result is non-void
return static_cast<Result>(m_f(thrust::raw_reference_cast(x), thrust::raw_reference_cast(y)));
return m_f(thrust::raw_reference_cast(x), thrust::raw_reference_cast(y));
}
}; // end wrapped_function

// Specialize for void return types:
template <typename Function>
struct wrapped_function<Function, void>
{
// mutable because Function::operator() might be const
mutable Function m_f;
inline __host__ __device__
wrapped_function()
: m_f()
{}

inline __host__ __device__
wrapped_function(const Function& f)
: m_f(f)
{}

__thrust_exec_check_disable__
template <typename Argument>
inline __host__ __device__
void operator()(Argument& x) const
{
m_f(thrust::raw_reference_cast(x));
}

} // end detail
} // end thrust
__thrust_exec_check_disable__
template <typename Argument>
inline __host__ __device__
void operator()(const Argument& x) const
{
m_f(thrust::raw_reference_cast(x));
}

__thrust_exec_check_disable__
template <typename Argument1, typename Argument2>
inline __host__ __device__
void operator()(Argument1& x, Argument2& y) const
{
m_f(thrust::raw_reference_cast(x), thrust::raw_reference_cast(y));
}

__thrust_exec_check_disable__
template <typename Argument1, typename Argument2>
inline __host__ __device__
void operator()(const Argument1& x, Argument2& y) const
{
m_f(thrust::raw_reference_cast(x), thrust::raw_reference_cast(y));
}
__thrust_exec_check_disable__
template <typename Argument1, typename Argument2>
inline __host__ __device__
void operator()(const Argument1& x, const Argument2& y) const
{
m_f(thrust::raw_reference_cast(x), thrust::raw_reference_cast(y));
}
__thrust_exec_check_disable__
template <typename Argument1, typename Argument2>
inline __host__ __device__
void operator()(Argument1& x, const Argument2& y) const
{
m_f(thrust::raw_reference_cast(x), thrust::raw_reference_cast(y));
}
}; // end wrapped_function

} // namespace detail
} // namespace thrust
9 changes: 5 additions & 4 deletions thrust/detail/internal_functional.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

#include <thrust/tuple.h>
#include <thrust/iterator/iterator_traits.h>
#include <thrust/detail/static_assert.h>
#include <thrust/detail/type_traits.h>
#include <thrust/iterator/detail/tuple_of_iterator_references.h>
#include <thrust/detail/raw_reference_cast.h>
Expand Down Expand Up @@ -317,11 +318,11 @@ template<typename UnaryFunction>
__thrust_exec_check_disable__
template<typename Tuple>
inline __host__ __device__
typename enable_if_non_const_reference_or_tuple_of_iterator_references<
typename thrust::tuple_element<1,Tuple>::type
>::type
operator()(Tuple t)
void operator()(Tuple t)
{
THRUST_STATIC_ASSERT_MSG(is_non_const_reference<Tuple>::value ||
is_tuple_of_iterator_references<Tuple>::value,
"Expected a non-const reference or thrust::detail::tuple_of_iterator_references");
thrust::get<1>(t) = f(thrust::get<0>(t));
}
};
Expand Down
5 changes: 3 additions & 2 deletions thrust/iterator/detail/zip_iterator_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#pragma once

#include <thrust/advance.h>
#include <thrust/iterator/iterator_traits.h>
#include <thrust/iterator/iterator_facade.h>
#include <thrust/iterator/iterator_categories.h>
Expand Down Expand Up @@ -45,12 +46,12 @@ class advance_iterator
public:
inline __host__ __device__
advance_iterator(DiffType step) : m_step(step) {}

__thrust_exec_check_disable__
template<typename Iterator>
inline __host__ __device__
void operator()(Iterator& it) const
{ it += m_step; }
{ thrust::advance(it, m_step); }

private:
DiffType m_step;
Expand Down
5 changes: 3 additions & 2 deletions thrust/system/tbb/detail/reduce_by_key.inl
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@
#include <thrust/detail/range/tail_flags.h>
#include <tbb/blocked_range.h>
#include <tbb/parallel_for.h>
#include <tbb/tbb_thread.h>

#include <cassert>
#include <thread>


namespace thrust
Expand Down Expand Up @@ -281,7 +282,7 @@ template<typename DerivedPolicy, typename Iterator1, typename Iterator2, typenam
}

// count the number of processors
const unsigned int p = thrust::max<unsigned int>(1u, ::tbb::tbb_thread::hardware_concurrency());
const unsigned int p = thrust::max<unsigned int>(1u, std::thread::hardware_concurrency());

// generate O(P) intervals of sequential work
// XXX oversubscribing is a tuning opportunity
Expand Down
Loading

0 comments on commit 6b1f456

Please sign in to comment.