Skip to content

Commit 2799672

Browse files
authored
Merge pull request #2223 from gouarin/improve_xeval
Improve xeval
2 parents af3ee1a + 73bd889 commit 2799672

File tree

5 files changed

+41
-42
lines changed

5 files changed

+41
-42
lines changed

include/xtensor/xbuilder.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ namespace xt
147147
template <class E>
148148
inline auto empty_like(const xexpression<E>& e)
149149
{
150-
using xtype = temporary_type_t<typename E::value_type, typename E::shape_type, E::static_layout>;
150+
using xtype = temporary_type_t<E>;
151151
auto res = xtype::from_shape(e.derived_cast().shape());
152152
return res;
153153
}
@@ -162,7 +162,7 @@ namespace xt
162162
template <class E>
163163
inline auto full_like(const xexpression<E>& e, typename E::value_type fill_value)
164164
{
165-
using xtype = temporary_type_t<typename E::value_type, typename E::shape_type, E::static_layout>;
165+
using xtype = temporary_type_t<E>;
166166
auto res = xtype::from_shape(e.derived_cast().shape());
167167
res.fill(fill_value);
168168
return res;

include/xtensor/xeval.hpp

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#ifndef XTENSOR_EVAL_HPP
1111
#define XTENSOR_EVAL_HPP
1212

13+
#include "xexpression_traits.hpp"
1314
#include "xtensor_forward.hpp"
1415
#include "xshape.hpp"
1516

@@ -40,28 +41,12 @@ namespace xt
4041
}
4142

4243
/// @cond DOXYGEN_INCLUDE_SFINAE
43-
template <class T, class I = std::decay_t<T>>
44-
inline auto eval(T&& t)
45-
-> std::enable_if_t<!detail::is_container<I>::value && detail::is_array<typename I::shape_type>::value && !detail::is_fixed<typename I::shape_type>::value, xtensor<typename I::value_type, std::tuple_size<typename I::shape_type>::value>>
46-
{
47-
return xtensor<typename I::value_type, std::tuple_size<typename I::shape_type>::value>(std::forward<T>(t));
48-
}
49-
50-
template <class T, class I = std::decay_t<T>>
51-
inline auto eval(T&& t)
52-
-> std::enable_if_t<!detail::is_container<I>::value && !detail::is_array<typename I::shape_type>::value && !detail::is_fixed<typename I::shape_type>::value, xt::xarray<typename I::value_type>>
53-
{
54-
return xarray<typename I::value_type>(std::forward<T>(t));
55-
}
56-
57-
template <class T, class I = std::decay_t<T>>
44+
template <class T>
5845
inline auto eval(T&& t)
59-
-> std::enable_if_t<!detail::is_container<I>::value && detail::is_fixed<typename I::shape_type>::value && !detail::is_array<typename I::shape_type>::value,
60-
xt::xtensor_fixed<typename I::value_type, typename I::shape_type>>
46+
-> std::enable_if_t<!detail::is_container<std::decay_t<T>>::value, temporary_type_t<T>>
6147
{
62-
return xtensor_fixed<typename I::value_type, typename I::shape_type>(std::forward<T>(t));
48+
return std::forward<T>(t);
6349
}
64-
/// @endcond
6550
}
6651

6752
#endif

include/xtensor/xexpression_traits.hpp

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ namespace xt
5151

5252
template <class... C>
5353
using common_value_type_t = typename common_value_type<C...>::type;
54-
54+
5555
/********************
5656
* common_size_type *
5757
********************/
@@ -89,7 +89,7 @@ namespace xt
8989

9090
template <class... Args>
9191
using common_difference_type_t = typename common_difference_type<Args...>::type;
92-
92+
9393
/******************
9494
* temporary_type *
9595
******************/
@@ -104,7 +104,7 @@ namespace xt
104104
};
105105

106106
#if defined(__GNUC__) && (__GNUC__ > 6)
107-
#if __cplusplus == 201703L
107+
#if __cplusplus == 201703L
108108
template <template <class, std::size_t, class, bool> class S, class X, std::size_t N, class A, bool Init>
109109
struct xtype_for_shape<S<X, N, A, Init>>
110110
{
@@ -128,15 +128,34 @@ namespace xt
128128
using type = xtensor_fixed<T, xshape<X...>, L>;
129129
};
130130
}
131-
132-
template <class T, class S, layout_type L>
131+
132+
template <class Tag, class T>
133+
struct temporary_type_from_tag;
134+
135+
template <class T>
136+
struct temporary_type_from_tag<xtensor_expression_tag, T>
137+
{
138+
using I = std::decay_t<T>;
139+
using shape_type = typename I::shape_type;
140+
using value_type = typename I::value_type;
141+
static constexpr layout_type static_layout = XTENSOR_DEFAULT_LAYOUT;
142+
using type = typename detail::xtype_for_shape<shape_type>::template type<value_type, static_layout>;
143+
};
144+
145+
template <class T, class = void>
133146
struct temporary_type
134147
{
135-
using type = typename detail::xtype_for_shape<S>::template type<T, L>;
148+
using type = typename temporary_type_from_tag<xexpression_tag_t<T>, T>::type;
149+
};
150+
151+
template <class T>
152+
struct temporary_type<T, void_t<typename std::decay_t<T>::temporary_type>>
153+
{
154+
using type = typename std::decay_t<T>::temporary_type;
136155
};
137156

138-
template <class T, class S, layout_type L>
139-
using temporary_type_t = typename temporary_type<T, S, L>::type;
157+
template <class T>
158+
using temporary_type_t = typename temporary_type<T>::type;
140159

141160
/**********************
142161
* common_tensor_type *
@@ -148,9 +167,9 @@ namespace xt
148167
struct common_tensor_type_impl
149168
{
150169
static constexpr layout_type static_layout = compute_layout(std::decay_t<C>::static_layout...);
151-
using type = temporary_type_t<common_value_type_t<C...>,
152-
promote_shape_t<typename C::shape_type...>,
153-
static_layout>;
170+
using value_type = common_value_type_t<C...>;
171+
using shape_type = promote_shape_t<typename C::shape_type...>;
172+
using type = typename xtype_for_shape<shape_type>::template type<value_type, static_layout>;
154173
};
155174
}
156175

include/xtensor/xpad.hpp

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,7 @@ namespace xt
8080
XTENSOR_ASSERT(detail::check_pad_width(pad_width, e.shape()));
8181

8282
using size_type = typename std::decay_t<E>::size_type;
83-
using value_type = typename std::decay_t<E>::value_type;
84-
using return_type = temporary_type_t<value_type,
85-
typename std::decay_t<E>::shape_type,
86-
std::decay_t<E>::static_layout>;
83+
using return_type = temporary_type_t<E>;
8784

8885
// place the original array in the center
8986

@@ -239,10 +236,8 @@ namespace xt
239236
inline auto tile(E&& e, const S& reps)
240237
{
241238
using size_type = typename std::decay_t<E>::size_type;
242-
using value_type = typename std::decay_t<E>::value_type;
243-
using return_type = temporary_type_t<value_type,
244-
typename std::decay_t<E>::shape_type,
245-
std::decay_t<E>::static_layout>;
239+
240+
using return_type = temporary_type_t<E>;
246241

247242
XTENSOR_ASSERT(e.shape().size() == reps.size());
248243

@@ -262,7 +257,7 @@ namespace xt
262257

263258
xt::xstrided_slice_vector svs(e.shape().size(), xt::all());
264259
xt::xstrided_slice_vector svt(e.shape().size(), xt::all());
265-
260+
266261
for (size_type axis = 0; axis < e.shape().size(); ++axis)
267262
{
268263
for (size_type i = 1; i < static_cast<size_type>(reps[axis]); ++i)

include/xtensor/xstrided_view.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ namespace xt
8181
using undecay_shape = S;
8282
using storage_getter = FST;
8383
using inner_storage_type = typename storage_getter::type;
84-
using temporary_type = temporary_type_t<typename xexpression_type::value_type, S, L>;
84+
using temporary_type = typename detail::xtype_for_shape<S>::template type<typename xexpression_type::value_type, L>;
8585
using storage_type = std::remove_reference_t<inner_storage_type>;
8686
static constexpr layout_type layout = L;
8787
};

0 commit comments

Comments
 (0)