@@ -23,6 +23,97 @@ namespace pybind11
2323{
2424 namespace detail
2525 {
26+ template <typename T, xt::layout_type L>
27+ struct pybind_array_getter_impl
28+ {
29+ static auto run (handle src)
30+ {
31+ return array_t <T, array::c_style | array::forcecast>::ensure (src);
32+ }
33+ };
34+
35+ template <typename T>
36+ struct pybind_array_getter_impl <T, xt::layout_type::column_major>
37+ {
38+ static auto run (handle src)
39+ {
40+ return array_t <T, array::f_style | array::forcecast>::ensure (src);
41+ }
42+ };
43+
44+ template <class T >
45+ struct pybind_array_getter
46+ {
47+ };
48+
49+ template <class T , xt::layout_type L>
50+ struct pybind_array_getter <xt::xarray<T, L>>
51+ {
52+ static auto run (handle src)
53+ {
54+ return pybind_array_getter_impl<T, L>::run (src);
55+ }
56+ };
57+
58+ template <class T , std::size_t N, xt::layout_type L>
59+ struct pybind_array_getter <xt::xtensor<T, N, L>>
60+ {
61+ static auto run (handle src)
62+ {
63+ return pybind_array_getter_impl<T, L>::run (src);
64+ }
65+ };
66+
67+ template <class CT , class S , xt::layout_type L, class FST >
68+ struct pybind_array_getter <xt::xstrided_view<CT, S, L, FST>>
69+ {
70+ static auto run (handle /* src*/ )
71+ {
72+ return false ;
73+ }
74+ };
75+
76+ template <class EC , xt::layout_type L, class SC , class Tag >
77+ struct pybind_array_getter <xt::xarray_adaptor<EC, L, SC, Tag>>
78+ {
79+ static auto run (handle src)
80+ {
81+ auto buf = pybind_array_getter_impl<EC, L>::run (src);
82+ return buf;
83+ }
84+ };
85+
86+ template <class EC , std::size_t N, xt::layout_type L, class Tag >
87+ struct pybind_array_getter <xt::xtensor_adaptor<EC, N, L, Tag>>
88+ {
89+ static auto run (handle /* src*/ )
90+ {
91+ return false ;
92+ }
93+ };
94+
95+
96+ template <class T >
97+ struct pybind_array_dim_checker
98+ {
99+ template <class B >
100+ static bool run (const B& buf)
101+ {
102+ return true ;
103+ }
104+ };
105+
106+ template <class T , std::size_t N, xt::layout_type L>
107+ struct pybind_array_dim_checker <xt::xtensor<T, N, L>>
108+ {
109+ template <class B >
110+ static bool run (const B& buf)
111+ {
112+ return buf.ndim () == N;
113+ }
114+ };
115+
116+
26117 // Casts a strided expression type to numpy array.If given a base,
27118 // the numpy array references the src data, otherwise it'll make a copy.
28119 // The writeable attributes lets you specify writeable flag for the array.
@@ -74,10 +165,6 @@ namespace pybind11
74165 template <class Type >
75166 struct xtensor_type_caster_base
76167 {
77- bool load (handle /* src*/ , bool )
78- {
79- return false ;
80- }
81168
82169 private:
83170
@@ -106,6 +193,36 @@ namespace pybind11
106193
107194 public:
108195
196+ PYBIND11_TYPE_CASTER (Type, _(" numpy.ndarray[" ) + npy_format_descriptor<typename Type::value_type>::name + _(" ]" ));
197+
198+ bool load (handle src, bool convert)
199+ {
200+ using T = typename Type::value_type;
201+
202+ if (!convert && !array_t <T>::check_ (src))
203+ {
204+ return false ;
205+ }
206+
207+ auto buf = pybind_array_getter<Type>::run (src);
208+
209+ if (!buf)
210+ {
211+ return false ;
212+ }
213+ if (!pybind_array_dim_checker<Type>::run (buf))
214+ {
215+ return false ;
216+ }
217+
218+ std::vector<size_t > shape (buf.ndim ());
219+ std::copy (buf.shape (), buf.shape () + buf.ndim (), shape.begin ());
220+ value = Type::from_shape (shape);
221+ std::copy (buf.data (), buf.data () + buf.size (), value.data ());
222+
223+ return true ;
224+ }
225+
109226 // Normal returned non-reference, non-const value:
110227 static handle cast (Type&& src, return_value_policy /* policy */ , handle parent)
111228 {
@@ -151,18 +268,6 @@ namespace pybind11
151268 {
152269 return cast_impl (src, policy, parent);
153270 }
154-
155- #ifdef PYBIND11_DESCR // The macro is removed from pybind11 since 2.3
156- static PYBIND11_DESCR name ()
157- {
158- return _ (" xt::xtensor" );
159- }
160- #else
161- static constexpr auto name = _(" xt::xtensor" );
162- #endif
163-
164- template <typename T>
165- using cast_op_type = cast_op_type<T>;
166271 };
167272 }
168273}
0 commit comments