Skip to content

Commit 22105f3

Browse files
trevorsm7alliepiper
authored andcommitted
Add transform_input_output_iterator
1 parent 1c926c1 commit 22105f3

7 files changed

+501
-6
lines changed
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
#include <thrust/device_vector.h>
2+
#include <thrust/functional.h>
3+
#include <thrust/gather.h>
4+
#include <thrust/iterator/transform_input_output_iterator.h>
5+
#include <thrust/sequence.h>
6+
#include <iostream>
7+
8+
// Base 2 fixed point
9+
class ScaledInteger
10+
{
11+
int value_;
12+
int scale_;
13+
14+
public:
15+
__host__ __device__
16+
ScaledInteger(int value, int scale): value_{value}, scale_{scale} {}
17+
18+
__host__ __device__
19+
int value() const { return value_; }
20+
21+
__host__ __device__
22+
ScaledInteger rescale(int scale) const
23+
{
24+
int shift = scale - scale_;
25+
int result = shift < 0 ? value_ << (-shift) : value_ >> shift;
26+
return ScaledInteger{result, scale};
27+
}
28+
29+
__host__ __device__
30+
friend ScaledInteger operator+(ScaledInteger a, ScaledInteger b)
31+
{
32+
// Rescale inputs to the lesser of the two scales
33+
if (b.scale_ < a.scale_)
34+
a = a.rescale(b.scale_);
35+
else if (a.scale_ < b.scale_)
36+
b = b.rescale(a.scale_);
37+
return ScaledInteger{a.value_ + b.value_, a.scale_};
38+
}
39+
};
40+
41+
struct ValueToScaledInteger
42+
{
43+
int scale;
44+
45+
__host__ __device__
46+
ScaledInteger operator()(const int& value) const
47+
{
48+
return ScaledInteger{value, scale};
49+
}
50+
};
51+
52+
struct ScaledIntegerToValue
53+
{
54+
int scale;
55+
56+
__host__ __device__
57+
int operator()(const ScaledInteger& scaled) const
58+
{
59+
return scaled.rescale(scale).value();
60+
}
61+
};
62+
63+
int main(void)
64+
{
65+
const size_t size = 4;
66+
thrust::device_vector<int> A(size);
67+
thrust::device_vector<int> B(size);
68+
thrust::device_vector<int> C(size);
69+
70+
thrust::sequence(A.begin(), A.end(), 1);
71+
thrust::sequence(B.begin(), B.end(), 5);
72+
73+
const int A_scale = 16; // Values in A are left shifted by 16
74+
const int B_scale = 8; // Values in B are left shifted by 8
75+
const int C_scale = 4; // Values in C are left shifted by 4
76+
77+
auto A_begin = thrust::make_transform_input_output_iterator(A.begin(),
78+
ValueToScaledInteger{A_scale}, ScaledIntegerToValue{A_scale});
79+
auto A_end = thrust::make_transform_input_output_iterator(A.end(),
80+
ValueToScaledInteger{A_scale}, ScaledIntegerToValue{A_scale});
81+
auto B_begin = thrust::make_transform_input_output_iterator(B.begin(),
82+
ValueToScaledInteger{B_scale}, ScaledIntegerToValue{B_scale});
83+
auto C_begin = thrust::make_transform_input_output_iterator(C.begin(),
84+
ValueToScaledInteger{C_scale}, ScaledIntegerToValue{C_scale});
85+
86+
// Sum A and B as ScaledIntegers, storing the scaled result in C
87+
thrust::transform(A_begin, A_end, B_begin, C_begin, thrust::plus<ScaledInteger>{});
88+
89+
thrust::host_vector<int> A_h(A);
90+
thrust::host_vector<int> B_h(B);
91+
thrust::host_vector<int> C_h(C);
92+
93+
std::cout << std::hex;
94+
95+
std::cout << "Expected [ ";
96+
for (size_t i = 0; i < size; i++) {
97+
const int expected = ((A_h[i] << A_scale) + (B_h[i] << B_scale)) >> C_scale;
98+
std::cout << expected << " ";
99+
}
100+
std::cout << "] \n";
101+
102+
std::cout << "Result [ ";
103+
for (size_t i = 0; i < size; i++) {
104+
std::cout << C_h[i] << " ";
105+
}
106+
std::cout << "] \n";
107+
108+
return 0;
109+
}
110+
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
CHECK: Expected [ 1050 2060 3070 4080 ]
2+
CHECK-NEXT: Result [ 1050 2060 3070 4080 ]
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
#include <unittest/unittest.h>
2+
#include <thrust/iterator/transform_input_output_iterator.h>
3+
4+
#include <thrust/copy.h>
5+
#include <thrust/reduce.h>
6+
#include <thrust/functional.h>
7+
#include <thrust/sequence.h>
8+
#include <thrust/iterator/counting_iterator.h>
9+
10+
template <class Vector>
11+
void TestTransformInputOutputIterator(void)
12+
{
13+
typedef typename Vector::value_type T;
14+
15+
typedef thrust::negate<T> InputFunction;
16+
typedef thrust::square<T> OutputFunction;
17+
typedef typename Vector::iterator Iterator;
18+
19+
Vector input(4);
20+
Vector squared(4);
21+
Vector negated(4);
22+
23+
// initialize input
24+
thrust::sequence(input.begin(), input.end(), 1);
25+
26+
// construct transform_iterator
27+
thrust::transform_input_output_iterator<InputFunction, OutputFunction, Iterator>
28+
transform_iter(squared.begin(), InputFunction(), OutputFunction());
29+
30+
// transform_iter writes squared value
31+
thrust::copy(input.begin(), input.end(), transform_iter);
32+
33+
Vector gold_squared(4);
34+
gold_squared[0] = 1;
35+
gold_squared[1] = 4;
36+
gold_squared[2] = 9;
37+
gold_squared[3] = 16;
38+
39+
ASSERT_EQUAL(squared, gold_squared);
40+
41+
// negated value read from transform_iter
42+
thrust::copy_n(transform_iter, squared.size(), negated.begin());
43+
44+
Vector gold_negated(4);
45+
gold_negated[0] = -1;
46+
gold_negated[1] = -4;
47+
gold_negated[2] = -9;
48+
gold_negated[3] = -16;
49+
50+
ASSERT_EQUAL(negated, gold_negated);
51+
52+
}
53+
DECLARE_VECTOR_UNITTEST(TestTransformInputOutputIterator);
54+
55+
template <class Vector>
56+
void TestMakeTransformInputOutputIterator(void)
57+
{
58+
typedef typename Vector::value_type T;
59+
60+
typedef thrust::negate<T> InputFunction;
61+
typedef thrust::square<T> OutputFunction;
62+
63+
Vector input(4);
64+
Vector negated(4);
65+
Vector squared(4);
66+
67+
// initialize input
68+
thrust::sequence(input.begin(), input.end(), 1);
69+
70+
// negated value read from transform iterator
71+
thrust::copy_n(thrust::make_transform_input_output_iterator(input.begin(), InputFunction(), OutputFunction()),
72+
input.size(), negated.begin());
73+
74+
Vector gold_negated(4);
75+
gold_negated[0] = -1;
76+
gold_negated[1] = -2;
77+
gold_negated[2] = -3;
78+
gold_negated[3] = -4;
79+
80+
ASSERT_EQUAL(negated, gold_negated);
81+
82+
// squared value writen by transform iterator
83+
thrust::copy(negated.begin(), negated.end(),
84+
thrust::make_transform_input_output_iterator(squared.begin(), InputFunction(), OutputFunction()));
85+
86+
Vector gold_squared(4);
87+
gold_squared[0] = 1;
88+
gold_squared[1] = 4;
89+
gold_squared[2] = 9;
90+
gold_squared[3] = 16;
91+
92+
ASSERT_EQUAL(squared, gold_squared);
93+
94+
}
95+
DECLARE_VECTOR_UNITTEST(TestMakeTransformInputOutputIterator);
96+
97+
template <typename T>
98+
struct TestTransformInputOutputIteratorScan
99+
{
100+
void operator()(const size_t n)
101+
{
102+
thrust::host_vector<T> h_data = unittest::random_samples<T>(n);
103+
thrust::device_vector<T> d_data = h_data;
104+
105+
thrust::host_vector<T> h_result(n);
106+
thrust::device_vector<T> d_result(n);
107+
108+
// run on host (uses forward iterator negate)
109+
thrust::inclusive_scan(thrust::make_transform_input_output_iterator(h_data.begin(), thrust::negate<T>(), thrust::identity<T>()),
110+
thrust::make_transform_input_output_iterator(h_data.end(), thrust::negate<T>(), thrust::identity<T>()),
111+
h_result.begin());
112+
// run on device (uses reverse iterator negate)
113+
thrust::inclusive_scan(d_data.begin(), d_data.end(),
114+
thrust::make_transform_input_output_iterator(
115+
d_result.begin(), thrust::square<T>(), thrust::negate<T>()));
116+
117+
118+
ASSERT_EQUAL(h_result, d_result);
119+
}
120+
};
121+
VariableUnitTest<TestTransformInputOutputIteratorScan, IntegralTypes> TestTransformInputOutputIteratorScanInstance;
122+
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
/*
2+
* Copyright 2020 NVIDIA Corporation
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include <thrust/iterator/iterator_adaptor.h>
18+
19+
namespace thrust
20+
{
21+
22+
template <typename InputFunction, typename OutputFunction, typename Iterator>
23+
class transform_input_output_iterator;
24+
25+
namespace detail
26+
{
27+
28+
// Proxy reference that invokes InputFunction when reading from and
29+
// OutputFunction when writing to the dereferenced iterator
30+
template <typename InputFunction, typename OutputFunction, typename Iterator>
31+
class transform_input_output_iterator_proxy
32+
{
33+
using Value = typename std::result_of<InputFunction(typename thrust::iterator_value<Iterator>::type)>::type;
34+
35+
public:
36+
__host__ __device__
37+
transform_input_output_iterator_proxy(const Iterator& io, InputFunction input_function, OutputFunction output_function)
38+
: io(io), input_function(input_function), output_function(output_function)
39+
{
40+
}
41+
42+
transform_input_output_iterator_proxy(const transform_input_output_iterator_proxy&) = default;
43+
44+
__thrust_exec_check_disable__
45+
__host__ __device__
46+
operator Value const() const
47+
{
48+
return input_function(*io);
49+
}
50+
51+
__thrust_exec_check_disable__
52+
template <typename T>
53+
__host__ __device__
54+
transform_input_output_iterator_proxy operator=(const T& x)
55+
{
56+
*io = output_function(x);
57+
return *this;
58+
}
59+
60+
__thrust_exec_check_disable__
61+
__host__ __device__
62+
transform_input_output_iterator_proxy operator=(const transform_input_output_iterator_proxy& x)
63+
{
64+
*io = output_function(x);
65+
return *this;
66+
}
67+
68+
private:
69+
Iterator io;
70+
InputFunction input_function;
71+
OutputFunction output_function;
72+
};
73+
74+
// Compute the iterator_adaptor instantiation to be used for transform_input_output_iterator
75+
template <typename InputFunction, typename OutputFunction, typename Iterator>
76+
struct transform_input_output_iterator_base
77+
{
78+
typedef thrust::iterator_adaptor
79+
<
80+
transform_input_output_iterator<InputFunction, OutputFunction, Iterator>
81+
, Iterator
82+
, typename std::result_of<InputFunction(typename thrust::iterator_value<Iterator>::type)>::type
83+
, thrust::use_default
84+
, thrust::use_default
85+
, transform_input_output_iterator_proxy<InputFunction, OutputFunction, Iterator>
86+
> type;
87+
};
88+
89+
// Register transform_input_output_iterator_proxy with 'is_proxy_reference' from
90+
// type_traits to enable its use with algorithms.
91+
template <typename InputFunction, typename OutputFunction, typename Iterator>
92+
struct is_proxy_reference<
93+
transform_input_output_iterator_proxy<InputFunction, OutputFunction, Iterator> >
94+
: public thrust::detail::true_type {};
95+
96+
} // end detail
97+
} // end thrust
98+

thrust/iterator/detail/transform_output_iterator.inl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,13 @@
2020
namespace thrust
2121
{
2222

23-
template <typename OutputIterator, typename UnaryFunction>
23+
template <typename UnaryFunction, typename OutputIterator>
2424
class transform_output_iterator;
2525

2626
namespace detail
2727
{
2828

29-
// Proxy reference that uses Unary Functiont o transform the rhs of assigment
29+
// Proxy reference that uses Unary Function to transform the rhs of assigment
3030
// operator before writing the result to OutputIterator
3131
template <typename UnaryFunction, typename OutputIterator>
3232
class transform_output_iterator_proxy
@@ -66,11 +66,11 @@ struct transform_output_iterator_base
6666
> type;
6767
};
6868

69-
// Register trasnform_output_iterator_proxy with 'is_proxy_reference' from
69+
// Register transform_output_iterator_proxy with 'is_proxy_reference' from
7070
// type_traits to enable its use with algorithms.
71-
template <class OutputIterator, class UnaryFunction>
71+
template <class UnaryFunction, class OutputIterator>
7272
struct is_proxy_reference<
73-
transform_output_iterator_proxy<OutputIterator, UnaryFunction> >
73+
transform_output_iterator_proxy<UnaryFunction, OutputIterator> >
7474
: public thrust::detail::true_type {};
7575

7676
} // end detail

0 commit comments

Comments
 (0)