Skip to content

Commit

Permalink
Fixed issues in TBB/OpenMP backends on MSVC.
Browse files Browse the repository at this point in the history
- CMake
  - Define `NOMINMAX` on msvc because aaaargh.
  - Define `_CRT_SECURE_NO_WARNINGS` for examples (fopen, etc warnings)
- Define `iterator_category` for testing iterator.
- Add missing include to shuffle unit test
- Specialize wrapped_function for void return types
  - MSVC is not a fan of the pattern "return static_cast<void>(expr);"
- Replace buggy SFINAE check with static_assert
  - SFINAE expression was evaluating wrong type
    (should have been Tuple, not the first element type of Tuple)
  - Rather than fix SFINAE expression, switch to a static_assert for
    the better diagnostic
- Replace deprecated `tbb/tbb_thread.h` with `<thread>`.
- Fix overcounting of initial value in tbb scans.
  - Apparently reverse_join may be called before operator()
- Fix partial sum value type to support edge case from unit test:

```
# testing/scan.cu:260:

     // float -> float with plus<int> operator (int accumulator)
     thrust::inclusive_scan(float_input.begin(), float_input.end(), float_output.begin(), thrust::plus<int>());
     ASSERT_EQUAL(float_output[0],  1.5);
     ASSERT_EQUAL(float_output[1],  3.0);
     ASSERT_EQUAL(float_output[2],  6.0);
     ASSERT_EQUAL(float_output[3], 10.0);
```
  • Loading branch information
alliepiper committed May 28, 2020
1 parent e478243 commit e779d2e
Show file tree
Hide file tree
Showing 7 changed files with 139 additions and 95 deletions.
10 changes: 10 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,9 @@ if ("MSVC" STREQUAL "${CMAKE_CXX_COMPILER_ID}")
# object files:
append_option_if_available("/bigobj" THRUST_CXX_WARNINGS)

# "Oh right, this is Visual Studio."
add_compile_definitions("NOMINMAX")

set(THRUST_TREAT_FILE_AS_CXX "/TP")
else ()
append_option_if_available("-Werror" THRUST_CXX_WARNINGS)
Expand Down Expand Up @@ -679,6 +682,13 @@ foreach (THRUST_EXAMPLE_SOURCE IN LISTS THRUST_EXAMPLES)
endif ()
endif ()

if ("MSVC" STREQUAL "${CMAKE_CXX_COMPILER_ID}")
# Some examples use unsafe APIs (e.g. fopen) that MSVC will complain about
# unless this is set:
set_target_properties(${THRUST_EXAMPLE}
PROPERTIES COMPILE_DEFINITIONS "_CRT_SECURE_NO_WARNINGS")
endif()

add_test(NAME ${THRUST_EXAMPLE}
COMMAND ${CMAKE_COMMAND}
-DTHRUST_EXAMPLE=${THRUST_EXAMPLE}
Expand Down
3 changes: 3 additions & 0 deletions testing/copy.cu
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,8 @@ struct only_set_when_expected_it
__host__ __device__ only_set_when_expected_it operator*() const { return *this; }
template<typename Difference>
__host__ __device__ only_set_when_expected_it operator+(Difference) const { return *this; }
template<typename Difference>
__host__ __device__ only_set_when_expected_it operator+=(Difference) const { return *this; }
template<typename Index>
__host__ __device__ only_set_when_expected_it operator[](Index) const { return *this; }

Expand All @@ -744,6 +746,7 @@ struct iterator_traits<only_set_when_expected_it>
{
typedef long long value_type;
typedef only_set_when_expected_it reference;
typedef thrust::random_access_device_iterator_tag iterator_category;
};
}

Expand Down
1 change: 1 addition & 0 deletions testing/shuffle.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#if THRUST_CPP_DIALECT >= 2011
#include <thrust/random.h>
#include <thrust/sequence.h>
#include <thrust/shuffle.h>
#include <thrust/sort.h>
#include <unittest/unittest.h>
Expand Down
129 changes: 91 additions & 38 deletions thrust/detail/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,80 +24,133 @@ namespace thrust
namespace detail
{


template<typename Function, typename Result>
struct wrapped_function
template <typename Function, typename Result>
struct wrapped_function
{
// mutable because Function::operator() might be const
mutable Function m_f;

inline __host__ __device__
wrapped_function()
: m_f()
: m_f()
{}

inline __host__ __device__
wrapped_function(const Function &f)
: m_f(f)
wrapped_function(const Function& f)
: m_f(f)
{}

__thrust_exec_check_disable__
template<typename Argument>
template <typename Argument>
inline __host__ __device__
Result operator()(Argument &x) const
Result operator()(Argument& x) const
{
// we static cast to Result to handle void Result without error
// in case Function's result is non-void
return static_cast<Result>(m_f(thrust::raw_reference_cast(x)));
return m_f(thrust::raw_reference_cast(x));
}

__thrust_exec_check_disable__
template<typename Argument>
inline __host__ __device__ Result operator()(const Argument &x) const
template <typename Argument>
inline __host__ __device__
Result operator()(const Argument& x) const
{
// we static cast to Result to handle void Result without error
// in case Function's result is non-void
return static_cast<Result>(m_f(thrust::raw_reference_cast(x)));
return m_f(thrust::raw_reference_cast(x));
}

__thrust_exec_check_disable__
template<typename Argument1, typename Argument2>
inline __host__ __device__ Result operator()(Argument1 &x, Argument2 &y) const
template <typename Argument1, typename Argument2>
inline __host__ __device__
Result operator()(Argument1& x, Argument2& y) const
{
// we static cast to Result to handle void Result without error
// in case Function's result is non-void
return static_cast<Result>(m_f(thrust::raw_reference_cast(x), thrust::raw_reference_cast(y)));
return m_f(thrust::raw_reference_cast(x), thrust::raw_reference_cast(y));
}

__thrust_exec_check_disable__
template<typename Argument1, typename Argument2>
inline __host__ __device__ Result operator()(const Argument1 &x, Argument2 &y) const
template <typename Argument1, typename Argument2>
inline __host__ __device__
Result operator()(const Argument1& x, Argument2& y) const
{
// we static cast to Result to handle void Result without error
// in case Function's result is non-void
return static_cast<Result>(m_f(thrust::raw_reference_cast(x), thrust::raw_reference_cast(y)));
return m_f(thrust::raw_reference_cast(x), thrust::raw_reference_cast(y));
}

__thrust_exec_check_disable__
template<typename Argument1, typename Argument2>
inline __host__ __device__ Result operator()(const Argument1 &x, const Argument2 &y) const
template <typename Argument1, typename Argument2>
inline __host__ __device__
Result operator()(const Argument1& x, const Argument2& y) const
{
// we static cast to Result to handle void Result without error
// in case Function's result is non-void
return static_cast<Result>(m_f(thrust::raw_reference_cast(x), thrust::raw_reference_cast(y)));
return m_f(thrust::raw_reference_cast(x), thrust::raw_reference_cast(y));
}

__thrust_exec_check_disable__
template<typename Argument1, typename Argument2>
inline __host__ __device__ Result operator()(Argument1 &x, const Argument2 &y) const
template <typename Argument1, typename Argument2>
inline __host__ __device__
Result operator()(Argument1& x, const Argument2& y) const
{
// we static cast to Result to handle void Result without error
// in case Function's result is non-void
return static_cast<Result>(m_f(thrust::raw_reference_cast(x), thrust::raw_reference_cast(y)));
return m_f(thrust::raw_reference_cast(x), thrust::raw_reference_cast(y));
}
}; // end wrapped_function

// Specialize for void return types:
template <typename Function>
struct wrapped_function<Function, void>
{
// mutable because Function::operator() might be const
mutable Function m_f;
inline __host__ __device__
wrapped_function()
: m_f()
{}

inline __host__ __device__
wrapped_function(const Function& f)
: m_f(f)
{}

__thrust_exec_check_disable__
template <typename Argument>
inline __host__ __device__
void operator()(Argument& x) const
{
m_f(thrust::raw_reference_cast(x));
}

} // end detail
} // end thrust
__thrust_exec_check_disable__
template <typename Argument>
inline __host__ __device__
void operator()(const Argument& x) const
{
m_f(thrust::raw_reference_cast(x));
}

__thrust_exec_check_disable__
template <typename Argument1, typename Argument2>
inline __host__ __device__
void operator()(Argument1& x, Argument2& y) const
{
m_f(thrust::raw_reference_cast(x), thrust::raw_reference_cast(y));
}

__thrust_exec_check_disable__
template <typename Argument1, typename Argument2>
inline __host__ __device__
void operator()(const Argument1& x, Argument2& y) const
{
m_f(thrust::raw_reference_cast(x), thrust::raw_reference_cast(y));
}
__thrust_exec_check_disable__
template <typename Argument1, typename Argument2>
inline __host__ __device__
void operator()(const Argument1& x, const Argument2& y) const
{
m_f(thrust::raw_reference_cast(x), thrust::raw_reference_cast(y));
}
__thrust_exec_check_disable__
template <typename Argument1, typename Argument2>
inline __host__ __device__
void operator()(Argument1& x, const Argument2& y) const
{
m_f(thrust::raw_reference_cast(x), thrust::raw_reference_cast(y));
}
}; // end wrapped_function

} // namespace detail
} // namespace thrust
9 changes: 5 additions & 4 deletions thrust/detail/internal_functional.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

#include <thrust/tuple.h>
#include <thrust/iterator/iterator_traits.h>
#include <thrust/detail/static_assert.h>
#include <thrust/detail/type_traits.h>
#include <thrust/iterator/detail/tuple_of_iterator_references.h>
#include <thrust/detail/raw_reference_cast.h>
Expand Down Expand Up @@ -317,11 +318,11 @@ template<typename UnaryFunction>
__thrust_exec_check_disable__
template<typename Tuple>
inline __host__ __device__
typename enable_if_non_const_reference_or_tuple_of_iterator_references<
typename thrust::tuple_element<1,Tuple>::type
>::type
operator()(Tuple t)
void operator()(Tuple t)
{
THRUST_STATIC_ASSERT_MSG(is_non_const_reference<Tuple>::value ||
is_tuple_of_iterator_references<Tuple>::value,
"Expected a non-const reference or thrust::detail::tuple_of_iterator_references");
thrust::get<1>(t) = f(thrust::get<0>(t));
}
};
Expand Down
5 changes: 3 additions & 2 deletions thrust/system/tbb/detail/reduce_by_key.inl
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@
#include <thrust/detail/range/tail_flags.h>
#include <tbb/blocked_range.h>
#include <tbb/parallel_for.h>
#include <tbb/tbb_thread.h>

#include <cassert>
#include <thread>


namespace thrust
Expand Down Expand Up @@ -281,7 +282,7 @@ template<typename DerivedPolicy, typename Iterator1, typename Iterator2, typenam
}

// count the number of processors
const unsigned int p = thrust::max<unsigned int>(1u, ::tbb::tbb_thread::hardware_concurrency());
const unsigned int p = thrust::max<unsigned int>(1u, std::thread::hardware_concurrency());

// generate O(P) intervals of sequential work
// XXX oversubscribing is a tuning opportunity
Expand Down
Loading

0 comments on commit e779d2e

Please sign in to comment.