@@ -51,7 +51,7 @@ namespace xt
5151
5252 template <class ... C>
5353 using common_value_type_t = typename common_value_type<C...>::type;
54-
54+
5555 /* *******************
5656 * common_size_type *
5757 ********************/
@@ -89,7 +89,7 @@ namespace xt
8989
9090 template <class ... Args>
9191 using common_difference_type_t = typename common_difference_type<Args...>::type;
92-
92+
9393 /* *****************
9494 * temporary_type *
9595 ******************/
@@ -104,7 +104,7 @@ namespace xt
104104 };
105105
106106#if defined(__GNUC__) && (__GNUC__ > 6)
107- #if __cplusplus == 201703L
107+ #if __cplusplus == 201703L
108108 template <template <class , std::size_t , class , bool > class S , class X , std::size_t N, class A , bool Init>
109109 struct xtype_for_shape <S<X, N, A, Init>>
110110 {
@@ -128,15 +128,34 @@ namespace xt
128128 using type = xtensor_fixed<T, xshape<X...>, L>;
129129 };
130130 }
131-
132- template <class T , class S , layout_type L>
131+
132+ template <class Tag , class T >
133+ struct temporary_type_from_tag ;
134+
135+ template <class T >
136+ struct temporary_type_from_tag <xtensor_expression_tag, T>
137+ {
138+ using I = std::decay_t <T>;
139+ using shape_type = typename I::shape_type;
140+ using value_type = typename I::value_type;
141+ static constexpr layout_type static_layout = XTENSOR_DEFAULT_LAYOUT;
142+ using type = typename detail::xtype_for_shape<shape_type>::template type<value_type, static_layout>;
143+ };
144+
145+ template <class T , class = void >
133146 struct temporary_type
134147 {
135- using type = typename detail::xtype_for_shape<S>::template type<T, L>;
148+ using type = typename temporary_type_from_tag<xexpression_tag_t <T>, T>::type;
149+ };
150+
151+ template <class T >
152+ struct temporary_type <T, void_t <typename std::decay_t <T>::temporary_type>>
153+ {
154+ using type = typename std::decay_t <T>::temporary_type;
136155 };
137156
138- template <class T , class S , layout_type L >
139- using temporary_type_t = typename temporary_type<T, S, L >::type;
157+ template <class T >
158+ using temporary_type_t = typename temporary_type<T>::type;
140159
141160 /* *********************
142161 * common_tensor_type *
@@ -148,9 +167,9 @@ namespace xt
148167 struct common_tensor_type_impl
149168 {
150169 static constexpr layout_type static_layout = compute_layout(std::decay_t <C>::static_layout...);
151- using type = temporary_type_t < common_value_type_t <C...>,
152- promote_shape_t <typename C::shape_type...>,
153- static_layout>;
170+ using value_type = common_value_type_t <C...>;
171+ using shape_type = promote_shape_t <typename C::shape_type...>;
172+ using type = typename xtype_for_shape<shape_type>:: template type<value_type, static_layout>;
154173 };
155174 }
156175
0 commit comments