66 * LICENSE file in the root directory of this source tree.
77 */
88
9- use std:: collections:: HashMap ;
10-
119use derive_more:: From ;
12- use derive_more:: TryInto ;
13- use enum_as_inner:: EnumAsInner ;
1410use hyperactor:: Named ;
1511use monarch_types:: PickledPyObject ;
16- use monarch_types:: TryIntoPyObjectUnsafe ;
1712use pyo3:: IntoPyObjectExt ;
18- use pyo3:: exceptions:: PyValueError ;
1913use pyo3:: prelude:: * ;
20- use pyo3:: types:: PyBool ;
21- use pyo3:: types:: PyDict ;
22- use pyo3:: types:: PyFloat ;
23- use pyo3:: types:: PyList ;
2414use pyo3:: types:: PyNone ;
25- use pyo3:: types:: PyString ;
26- use pyo3:: types:: PyTuple ;
2715use serde:: Deserialize ;
2816use serde:: Serialize ;
2917use torch_sys:: Device ;
3018use torch_sys:: Layout ;
3119use torch_sys:: MemoryFormat ;
32- use torch_sys:: OpaqueIValue ;
3320use torch_sys:: ScalarType ;
3421
3522use crate :: worker:: Ref ;
36- use crate :: worker:: ResolvableFunction ;
3723
3824/// A value used as an input to CallFunction.
3925// TODO, this is basically the same as RValue, but with TensorIndices swapped
4026// out for refs. And IValue is the same as RValue, but with real tensors and
4127// C++ types. I wonder if there is a nicer way to express this relationship.
4228// TODO extend this to support other types of values, like bytes, dicts etc.
43- #[ derive(
44- Serialize ,
45- Deserialize ,
46- Debug ,
47- Clone ,
48- TryInto ,
49- Named ,
50- From ,
51- EnumAsInner
52- ) ]
29+ #[ derive( Serialize , Deserialize , Debug , Clone , Named ) ]
5330pub enum WireValue {
5431 // Make sure boolean goes ealier than int as bool is a subclass of int.
5532 // Otherwise, bool will be converted to int.
@@ -68,81 +45,20 @@ pub enum WireValue {
6845 // empty enum variants.
6946 None ( ( ) ) ,
7047 PyObject ( PickledPyObject ) ,
71- // It is ok to just have IValue without an alias tracking cell as we just use
72- // WireValue as a way to serialize and send args to workers. We dont mutate the
73- // IValue and use the opaque wrapper to make accessing the IValue directly
74- // an unsafe op.
75- IValue ( torch_sys:: OpaqueIValue ) ,
7648}
7749
7850impl FromPyObject < ' _ > for WireValue {
7951 fn extract_bound ( obj : & Bound < ' _ , PyAny > ) -> PyResult < Self > {
80- if let Ok ( ref_) = Ref :: from_py_object ( obj) {
81- Ok ( WireValue :: Ref ( ref_) )
82- } else if let Ok ( list) = obj. downcast :: < PyList > ( ) {
83- let len = list. len ( ) ;
84- if len == 0 {
85- // TODO: This is done for now as this seems to be the most common case for empty lists
86- // in torch ops but we should use the op schema to do this correctly.
87- return Ok ( WireValue :: IntList ( vec ! [ ] ) ) ;
88- }
89-
90- // SAFETY: We know it is within bounds
91- let item = unsafe { list. get_item_unchecked ( 0 ) } ;
92- let len = list. len ( ) ;
93- if let Ok ( int) = item. extract :: < i64 > ( ) {
94- let mut int_list = Vec :: with_capacity ( len) ;
95- int_list. push ( int) ;
96- for item in list. iter ( ) . skip ( 1 ) {
97- int_list. push ( item. extract :: < i64 > ( ) . map_err ( |_| {
98- PyValueError :: new_err ( format ! (
99- "Expected homogeneous list of ints got: {:?}" ,
100- list
101- ) )
102- } ) ?) ;
103- }
104- return Ok ( WireValue :: IntList ( int_list) ) ;
105- }
106- if let Ok ( ref_) = Ref :: from_py_object ( & item) {
107- let mut ref_list = Vec :: with_capacity ( len) ;
108- ref_list. push ( ref_) ;
109- for item in list. iter ( ) . skip ( 1 ) {
110- ref_list. push ( Ref :: from_py_object ( & item) . map_err ( |_| {
111- PyValueError :: new_err ( format ! (
112- "Expected homogeneous list of ints got: {:?}" ,
113- list
114- ) )
115- } ) ?) ;
116- }
117- return Ok ( WireValue :: RefList ( ref_list) ) ;
118- }
119- Ok ( WireValue :: PyObject ( PickledPyObject :: pickle ( obj) ?) )
120- } else if obj. is_none ( ) {
121- Ok ( WireValue :: None ( ( ) ) )
122- } else if let Ok ( bool_) = obj. downcast :: < PyBool > ( ) {
123- Ok ( WireValue :: Bool ( bool_. is_true ( ) ) )
124- } else if let Ok ( int) = obj. extract :: < i64 > ( ) {
125- Ok ( WireValue :: Int ( int) )
126- } else if let Ok ( double) = obj. downcast :: < PyFloat > ( ) {
127- Ok ( WireValue :: Double ( double. value ( ) ) )
128- } else if let Ok ( string) = obj. downcast :: < PyString > ( ) {
129- Ok ( WireValue :: String ( string. to_str ( ) ?. to_string ( ) ) )
130- } else if let Ok ( device) = obj. extract :: < Device > ( ) {
131- Ok ( WireValue :: Device ( device) )
132- } else if let Ok ( layout) = obj. extract :: < Layout > ( ) {
133- Ok ( WireValue :: Layout ( layout) )
134- } else if let Ok ( scalar_type) = obj. extract :: < ScalarType > ( ) {
135- Ok ( WireValue :: ScalarType ( scalar_type) )
136- } else if let Ok ( memory_format) = obj. extract :: < MemoryFormat > ( ) {
137- Ok ( WireValue :: MemoryFormat ( memory_format) )
138- } else {
139- Ok ( WireValue :: PyObject ( PickledPyObject :: pickle ( obj) ?) )
140- }
52+ Ok ( WireValue :: PyObject ( PickledPyObject :: pickle ( obj) ?) )
14153 }
14254}
14355
144- impl < ' py > TryIntoPyObjectUnsafe < ' py , PyAny > for WireValue {
145- unsafe fn try_to_object_unsafe ( self , py : Python < ' py > ) -> PyResult < Bound < ' py , PyAny > > {
56+ impl < ' py > IntoPyObject < ' py > for WireValue {
57+ type Target = PyAny ;
58+ type Output = Bound < ' py , PyAny > ;
59+ type Error = PyErr ;
60+
61+ fn into_pyobject ( self , py : Python < ' py > ) -> PyResult < Bound < ' py , PyAny > > {
14662 match self {
14763 WireValue :: Ref ( ref_) => ref_. into_bound_py_any ( py) ,
14864 WireValue :: RefList ( ref_list) => ref_list. clone ( ) . into_bound_py_any ( py) ,
@@ -157,10 +73,6 @@ impl<'py> TryIntoPyObjectUnsafe<'py, PyAny> for WireValue {
15773 WireValue :: MemoryFormat ( val) => val. into_bound_py_any ( py) ,
15874 WireValue :: None ( ( ) ) => PyNone :: get ( py) . into_bound_py_any ( py) ,
15975 WireValue :: PyObject ( val) => val. unpickle ( py) ,
160- // SAFETY: WireValue is only used for serde between client and worker.
161- // This function is used to access the args / kwargs of a function call
162- // on the client side only.
163- WireValue :: IValue ( val) => unsafe { val. try_to_object_unsafe ( py) } ,
16476 }
16577 }
16678}
@@ -170,167 +82,3 @@ impl From<PyObject> for WireValue {
17082 Python :: with_gil ( |py| WireValue :: PyObject ( PickledPyObject :: pickle ( obj. bind ( py) ) . unwrap ( ) ) )
17183 }
17284}
173-
174- pub fn func_call_args_to_wire_values (
175- _func : Option < & ResolvableFunction > ,
176- args : & Bound < ' _ , PyTuple > ,
177- kwargs : & Bound < ' _ , PyDict > ,
178- ) -> PyResult < ( Vec < WireValue > , HashMap < String , WireValue > ) > {
179- python_func_args_to_wire_value ( args, kwargs)
180- }
181-
182- fn python_func_args_to_wire_value (
183- args : & Bound < ' _ , PyTuple > ,
184- kwargs : & Bound < ' _ , PyDict > ,
185- ) -> PyResult < ( Vec < WireValue > , HashMap < String , WireValue > ) > {
186- let args = args
187- . iter ( )
188- . map ( |arg| Ok ( WireValue :: PyObject ( PickledPyObject :: pickle ( & arg) ?) ) )
189- . collect :: < PyResult < _ > > ( ) ?;
190- let kwargs = kwargs
191- . iter ( )
192- . map ( |( k, v) | {
193- Ok ( (
194- k. extract :: < String > ( ) ?,
195- WireValue :: PyObject ( PickledPyObject :: pickle ( & v) ?) ,
196- ) )
197- } )
198- . collect :: < Result < HashMap < _ , _ > , PyErr > > ( ) ?;
199- Ok ( ( args, kwargs) )
200- }
201-
202- #[ cfg( test) ]
203- mod tests {
204- use std:: assert_matches:: assert_matches;
205-
206- use anyhow:: Result ;
207- use anyhow:: bail;
208- use paste:: paste;
209- use pyo3:: Python ;
210- use pyo3:: ffi:: c_str;
211- use pyo3:: types:: PyDict ;
212- use torch_sys:: DeviceType ;
213- use torch_sys:: ScalarType ;
214-
215- use super :: * ;
216- use crate :: worker:: Ref ;
217-
218- const MOCK_REFERNCABLE_MODULE : & std:: ffi:: CStr = c_str ! (
219- r#"
220- class Referencable:
221- def __init__(self, ref: int):
222- self.ref = ref
223-
224- def __monarch_ref__(self):
225- return self.ref
226- "#
227- ) ;
228-
229- fn setup ( ) -> Result < ( ) > {
230- pyo3:: prepare_freethreaded_python ( ) ;
231- // We need to load torch to initialize some internal structures used by
232- // the FFI funcs we use to convert ivalues to/from py objects.
233- Python :: with_gil ( |py| py. run ( c_str ! ( "import torch" ) , None , None ) ) ?;
234- Ok ( ( ) )
235- }
236-
237- fn create_py_object ( ) -> PyObject {
238- pyo3:: prepare_freethreaded_python ( ) ;
239- Python :: with_gil ( |py| {
240- let dict = PyDict :: new ( py) ;
241- dict. set_item ( "foo" , "bar" ) . unwrap ( ) ;
242- dict. into_any ( ) . clone ( ) . unbind ( )
243- } )
244- }
245-
246- macro_rules! generate_wire_value_from_py_tests {
247- ( $( $kind: ident, $input: expr) ;* $( ; ) ?) => {
248- paste! {
249- $(
250- #[ test]
251- fn [ <test_wire_value_from_py_$kind: snake: lower>] ( ) -> Result <( ) > {
252- setup( ) ?;
253- Python :: with_gil( |py| {
254- let actual = $input. into_pyobject( py) ?. extract:: <WireValue >( ) ?;
255- assert_matches!( actual, WireValue :: $kind( _) ) ;
256- anyhow:: Ok ( ( ) )
257- } )
258- }
259- ) *
260-
261- #[ test]
262- fn test_wire_value_from_py_none( ) -> Result <( ) > {
263- setup( ) ?;
264- Python :: with_gil( |py| {
265- let obj = PyNone :: get( py) . into_pyobject( py) ?;
266- let actual = obj. extract:: <WireValue >( ) ?;
267- assert_matches!( actual, WireValue :: None ( _) ) ;
268- anyhow:: Ok ( ( ) )
269- } )
270- }
271-
272- #[ test]
273- fn test_wire_value_from_py_empty_list( ) -> Result <( ) > {
274- setup( ) ?;
275- Python :: with_gil( |py| {
276- let obj: PyObject = PyList :: empty( py) . into_any( ) . unbind( ) ;
277- let actual = obj. extract:: <WireValue >( py) ?;
278- match actual {
279- WireValue :: IntList ( list) if list. len( ) == 0 => ( ) ,
280- _ => bail!( "Expected empty list to be converted to empty int list" ) ,
281- }
282- anyhow:: Ok ( ( ) )
283- } )
284- }
285-
286- #[ test]
287- fn test_wire_value_from_py_referencable_class( ) -> Result <( ) > {
288- setup( ) ?;
289- Python :: with_gil( |py| {
290- let referencable = PyModule :: from_code(
291- py,
292- MOCK_REFERNCABLE_MODULE ,
293- c_str!( "referencable.py" ) ,
294- c_str!( "referencable" ) ,
295- ) ?;
296- let ref_ = referencable. getattr( "Referencable" ) ?. call1( ( 1 , ) ) ?. unbind( ) ;
297- let actual = ref_. extract:: <WireValue >( py) ?;
298- assert_matches!( actual, WireValue :: Ref ( Ref { id: 1 } ) ) ;
299- anyhow:: Ok ( ( ) )
300- } )
301- }
302-
303- #[ test]
304- fn test_wire_value_from_py_roundtrip_was_exhaustive( ) {
305- let val = WireValue :: Int ( 0 ) ;
306- match val {
307- $( WireValue :: $kind( _) => ( ) , ) *
308- WireValue :: None ( _) => ( ) ,
309- // Can't test from py here as PyObject behaves as catch all for conversion from PY.
310- // We will manually convert torch ops args to IValue respecting the schema so its
311- // not super important to have this.
312- WireValue :: IValue ( _) => ( ) ,
313- }
314- }
315- }
316- }
317- }
318-
319- // Generate exhaustive roundtrip tests for all IValue kind.
320- // If you got a "non-exhaustive patterns" error here, you need to add a new
321- // test entry for your IValue kind!
322- generate_wire_value_from_py_tests ! {
323- Bool , false ;
324- Double , 1.23f64 ;
325- Int , 123i64 ;
326- IntList , vec![ 1i64 ] ;
327- Ref , Ref :: from( 1 ) ;
328- RefList , vec![ Ref :: from( 1 ) , Ref :: from( 2 ) ] ;
329- String , "foobar" . to_owned( ) ;
330- Device , Device :: new( DeviceType :: CPU ) ;
331- Layout , Layout ( 2 ) ;
332- ScalarType , ScalarType ( 3 ) ;
333- MemoryFormat , MemoryFormat ( 1 ) ;
334- PyObject , create_py_object( ) ;
335- }
336- }
0 commit comments