This repository has been archived by the owner on Mar 21, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 757
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
1c926c1
commit 22105f3
Showing
7 changed files
with
501 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
#include <thrust/device_vector.h> | ||
#include <thrust/functional.h> | ||
#include <thrust/gather.h> | ||
#include <thrust/iterator/transform_input_output_iterator.h> | ||
#include <thrust/sequence.h> | ||
#include <iostream> | ||
|
||
// Base 2 fixed point | ||
class ScaledInteger | ||
{ | ||
int value_; | ||
int scale_; | ||
|
||
public: | ||
__host__ __device__ | ||
ScaledInteger(int value, int scale): value_{value}, scale_{scale} {} | ||
|
||
__host__ __device__ | ||
int value() const { return value_; } | ||
|
||
__host__ __device__ | ||
ScaledInteger rescale(int scale) const | ||
{ | ||
int shift = scale - scale_; | ||
int result = shift < 0 ? value_ << (-shift) : value_ >> shift; | ||
return ScaledInteger{result, scale}; | ||
} | ||
|
||
__host__ __device__ | ||
friend ScaledInteger operator+(ScaledInteger a, ScaledInteger b) | ||
{ | ||
// Rescale inputs to the lesser of the two scales | ||
if (b.scale_ < a.scale_) | ||
a = a.rescale(b.scale_); | ||
else if (a.scale_ < b.scale_) | ||
b = b.rescale(a.scale_); | ||
return ScaledInteger{a.value_ + b.value_, a.scale_}; | ||
} | ||
}; | ||
|
||
struct ValueToScaledInteger | ||
{ | ||
int scale; | ||
|
||
__host__ __device__ | ||
ScaledInteger operator()(const int& value) const | ||
{ | ||
return ScaledInteger{value, scale}; | ||
} | ||
}; | ||
|
||
struct ScaledIntegerToValue | ||
{ | ||
int scale; | ||
|
||
__host__ __device__ | ||
int operator()(const ScaledInteger& scaled) const | ||
{ | ||
return scaled.rescale(scale).value(); | ||
} | ||
}; | ||
|
||
int main(void) | ||
{ | ||
const size_t size = 4; | ||
thrust::device_vector<int> A(size); | ||
thrust::device_vector<int> B(size); | ||
thrust::device_vector<int> C(size); | ||
|
||
thrust::sequence(A.begin(), A.end(), 1); | ||
thrust::sequence(B.begin(), B.end(), 5); | ||
|
||
const int A_scale = 16; // Values in A are left shifted by 16 | ||
const int B_scale = 8; // Values in B are left shifted by 8 | ||
const int C_scale = 4; // Values in C are left shifted by 4 | ||
|
||
auto A_begin = thrust::make_transform_input_output_iterator(A.begin(), | ||
ValueToScaledInteger{A_scale}, ScaledIntegerToValue{A_scale}); | ||
auto A_end = thrust::make_transform_input_output_iterator(A.end(), | ||
ValueToScaledInteger{A_scale}, ScaledIntegerToValue{A_scale}); | ||
auto B_begin = thrust::make_transform_input_output_iterator(B.begin(), | ||
ValueToScaledInteger{B_scale}, ScaledIntegerToValue{B_scale}); | ||
auto C_begin = thrust::make_transform_input_output_iterator(C.begin(), | ||
ValueToScaledInteger{C_scale}, ScaledIntegerToValue{C_scale}); | ||
|
||
// Sum A and B as ScaledIntegers, storing the scaled result in C | ||
thrust::transform(A_begin, A_end, B_begin, C_begin, thrust::plus<ScaledInteger>{}); | ||
|
||
thrust::host_vector<int> A_h(A); | ||
thrust::host_vector<int> B_h(B); | ||
thrust::host_vector<int> C_h(C); | ||
|
||
std::cout << std::hex; | ||
|
||
std::cout << "Expected [ "; | ||
for (size_t i = 0; i < size; i++) { | ||
const int expected = ((A_h[i] << A_scale) + (B_h[i] << B_scale)) >> C_scale; | ||
std::cout << expected << " "; | ||
} | ||
std::cout << "] \n"; | ||
|
||
std::cout << "Result [ "; | ||
for (size_t i = 0; i < size; i++) { | ||
std::cout << C_h[i] << " "; | ||
} | ||
std::cout << "] \n"; | ||
|
||
return 0; | ||
} | ||
|
2 changes: 2 additions & 0 deletions
2
internal/test/thrust.example.transform_input_output_iterator.filecheck
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
CHECK: Expected [ 1050 2060 3070 4080 ] | ||
CHECK-NEXT: Result [ 1050 2060 3070 4080 ] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
#include <unittest/unittest.h> | ||
#include <thrust/iterator/transform_input_output_iterator.h> | ||
|
||
#include <thrust/copy.h> | ||
#include <thrust/reduce.h> | ||
#include <thrust/functional.h> | ||
#include <thrust/sequence.h> | ||
#include <thrust/iterator/counting_iterator.h> | ||
|
||
template <class Vector> | ||
void TestTransformInputOutputIterator(void) | ||
{ | ||
typedef typename Vector::value_type T; | ||
|
||
typedef thrust::negate<T> InputFunction; | ||
typedef thrust::square<T> OutputFunction; | ||
typedef typename Vector::iterator Iterator; | ||
|
||
Vector input(4); | ||
Vector squared(4); | ||
Vector negated(4); | ||
|
||
// initialize input | ||
thrust::sequence(input.begin(), input.end(), 1); | ||
|
||
// construct transform_iterator | ||
thrust::transform_input_output_iterator<InputFunction, OutputFunction, Iterator> | ||
transform_iter(squared.begin(), InputFunction(), OutputFunction()); | ||
|
||
// transform_iter writes squared value | ||
thrust::copy(input.begin(), input.end(), transform_iter); | ||
|
||
Vector gold_squared(4); | ||
gold_squared[0] = 1; | ||
gold_squared[1] = 4; | ||
gold_squared[2] = 9; | ||
gold_squared[3] = 16; | ||
|
||
ASSERT_EQUAL(squared, gold_squared); | ||
|
||
// negated value read from transform_iter | ||
thrust::copy_n(transform_iter, squared.size(), negated.begin()); | ||
|
||
Vector gold_negated(4); | ||
gold_negated[0] = -1; | ||
gold_negated[1] = -4; | ||
gold_negated[2] = -9; | ||
gold_negated[3] = -16; | ||
|
||
ASSERT_EQUAL(negated, gold_negated); | ||
|
||
} | ||
DECLARE_VECTOR_UNITTEST(TestTransformInputOutputIterator); | ||
|
||
template <class Vector> | ||
void TestMakeTransformInputOutputIterator(void) | ||
{ | ||
typedef typename Vector::value_type T; | ||
|
||
typedef thrust::negate<T> InputFunction; | ||
typedef thrust::square<T> OutputFunction; | ||
|
||
Vector input(4); | ||
Vector negated(4); | ||
Vector squared(4); | ||
|
||
// initialize input | ||
thrust::sequence(input.begin(), input.end(), 1); | ||
|
||
// negated value read from transform iterator | ||
thrust::copy_n(thrust::make_transform_input_output_iterator(input.begin(), InputFunction(), OutputFunction()), | ||
input.size(), negated.begin()); | ||
|
||
Vector gold_negated(4); | ||
gold_negated[0] = -1; | ||
gold_negated[1] = -2; | ||
gold_negated[2] = -3; | ||
gold_negated[3] = -4; | ||
|
||
ASSERT_EQUAL(negated, gold_negated); | ||
|
||
// squared value writen by transform iterator | ||
thrust::copy(negated.begin(), negated.end(), | ||
thrust::make_transform_input_output_iterator(squared.begin(), InputFunction(), OutputFunction())); | ||
|
||
Vector gold_squared(4); | ||
gold_squared[0] = 1; | ||
gold_squared[1] = 4; | ||
gold_squared[2] = 9; | ||
gold_squared[3] = 16; | ||
|
||
ASSERT_EQUAL(squared, gold_squared); | ||
|
||
} | ||
DECLARE_VECTOR_UNITTEST(TestMakeTransformInputOutputIterator); | ||
|
||
template <typename T> | ||
struct TestTransformInputOutputIteratorScan | ||
{ | ||
void operator()(const size_t n) | ||
{ | ||
thrust::host_vector<T> h_data = unittest::random_samples<T>(n); | ||
thrust::device_vector<T> d_data = h_data; | ||
|
||
thrust::host_vector<T> h_result(n); | ||
thrust::device_vector<T> d_result(n); | ||
|
||
// run on host (uses forward iterator negate) | ||
thrust::inclusive_scan(thrust::make_transform_input_output_iterator(h_data.begin(), thrust::negate<T>(), thrust::identity<T>()), | ||
thrust::make_transform_input_output_iterator(h_data.end(), thrust::negate<T>(), thrust::identity<T>()), | ||
h_result.begin()); | ||
// run on device (uses reverse iterator negate) | ||
thrust::inclusive_scan(d_data.begin(), d_data.end(), | ||
thrust::make_transform_input_output_iterator( | ||
d_result.begin(), thrust::square<T>(), thrust::negate<T>())); | ||
|
||
|
||
ASSERT_EQUAL(h_result, d_result); | ||
} | ||
}; | ||
VariableUnitTest<TestTransformInputOutputIteratorScan, IntegralTypes> TestTransformInputOutputIteratorScanInstance; | ||
|
98 changes: 98 additions & 0 deletions
98
thrust/iterator/detail/transform_input_output_iterator.inl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
/* | ||
* Copyright 2020 NVIDIA Corporation | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include <thrust/iterator/iterator_adaptor.h> | ||
|
||
namespace thrust | ||
{ | ||
|
||
template <typename InputFunction, typename OutputFunction, typename Iterator> | ||
class transform_input_output_iterator; | ||
|
||
namespace detail | ||
{ | ||
|
||
// Proxy reference that invokes InputFunction when reading from and | ||
// OutputFunction when writing to the dereferenced iterator | ||
template <typename InputFunction, typename OutputFunction, typename Iterator> | ||
class transform_input_output_iterator_proxy | ||
{ | ||
using Value = typename std::result_of<InputFunction(typename thrust::iterator_value<Iterator>::type)>::type; | ||
|
||
public: | ||
__host__ __device__ | ||
transform_input_output_iterator_proxy(const Iterator& io, InputFunction input_function, OutputFunction output_function) | ||
: io(io), input_function(input_function), output_function(output_function) | ||
{ | ||
} | ||
|
||
transform_input_output_iterator_proxy(const transform_input_output_iterator_proxy&) = default; | ||
|
||
__thrust_exec_check_disable__ | ||
__host__ __device__ | ||
operator Value const() const | ||
{ | ||
return input_function(*io); | ||
} | ||
|
||
__thrust_exec_check_disable__ | ||
template <typename T> | ||
__host__ __device__ | ||
transform_input_output_iterator_proxy operator=(const T& x) | ||
{ | ||
*io = output_function(x); | ||
return *this; | ||
} | ||
|
||
__thrust_exec_check_disable__ | ||
__host__ __device__ | ||
transform_input_output_iterator_proxy operator=(const transform_input_output_iterator_proxy& x) | ||
{ | ||
*io = output_function(x); | ||
return *this; | ||
} | ||
|
||
private: | ||
Iterator io; | ||
InputFunction input_function; | ||
OutputFunction output_function; | ||
}; | ||
|
||
// Compute the iterator_adaptor instantiation to be used for transform_input_output_iterator | ||
template <typename InputFunction, typename OutputFunction, typename Iterator> | ||
struct transform_input_output_iterator_base | ||
{ | ||
typedef thrust::iterator_adaptor | ||
< | ||
transform_input_output_iterator<InputFunction, OutputFunction, Iterator> | ||
, Iterator | ||
, typename std::result_of<InputFunction(typename thrust::iterator_value<Iterator>::type)>::type | ||
, thrust::use_default | ||
, thrust::use_default | ||
, transform_input_output_iterator_proxy<InputFunction, OutputFunction, Iterator> | ||
> type; | ||
}; | ||
|
||
// Register transform_input_output_iterator_proxy with 'is_proxy_reference' from | ||
// type_traits to enable its use with algorithms. | ||
template <typename InputFunction, typename OutputFunction, typename Iterator> | ||
struct is_proxy_reference< | ||
transform_input_output_iterator_proxy<InputFunction, OutputFunction, Iterator> > | ||
: public thrust::detail::true_type {}; | ||
|
||
} // end detail | ||
} // end thrust | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.