Skip to content

Commit 0bbbb9d

Browse files
committed
Remove WireValue from messaging path
This moves WireValue off of the path that any message from the client makes to the workers. This is a necessary step before replacing the way the streams store values for RValues to python objects so that we can remove direct torch dependency. Differential Revision: [D87826891](https://our.internmc.facebook.com/intern/diff/D87826891/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D87826891/)! ghstack-source-id: 325483805 Pull Request resolved: #1991
1 parent a404c21 commit 0bbbb9d

File tree

5 files changed

+268
-415
lines changed

5 files changed

+268
-415
lines changed

monarch_extension/src/convert.rs

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ use hyperactor::ActorId;
1313
use monarch_hyperactor::ndslice::PySlice;
1414
use monarch_hyperactor::proc::PyActorId;
1515
use monarch_messages::controller::Seq;
16-
use monarch_messages::wire_value::func_call_args_to_wire_values;
1716
use monarch_messages::worker;
17+
use monarch_messages::worker::ArgsKwargs;
1818
use monarch_messages::worker::CallFunctionParams;
1919
use monarch_messages::worker::Cloudpickle;
2020
use monarch_messages::worker::Factory;
@@ -220,17 +220,16 @@ fn create_map(py: Python) -> HashMap<u64, FnType> {
220220
});
221221
m.insert(key("CallFunction"), |p| {
222222
let function = p.parseFunction("function")?;
223-
let args = p.parse("args")?;
224-
let kwargs = p.parse("kwargs")?;
223+
let args: Bound<'_, PyTuple> = p.parse("args")?;
224+
let kwargs: Bound<'_, PyDict> = p.parse("kwargs")?;
225225

226-
let (args, kwargs) = func_call_args_to_wire_values(Some(&function), &args, &kwargs)?;
226+
let args_kwargs = ArgsKwargs::from_python(args.into_any(), kwargs.into_any())?;
227227
Ok(WorkerMessage::CallFunction(CallFunctionParams {
228228
seq: p.parseSeq("ident")?,
229229
results: p.parseFlatReferences("result")?,
230230
mutates: p.parseRefList("mutates")?,
231231
function,
232-
args,
233-
kwargs,
232+
args_kwargs,
234233
stream: p.parseStreamRef("stream")?,
235234
remote_process_groups: p.parseRefList("remote_process_groups")?,
236235
}))
@@ -340,14 +339,13 @@ fn create_map(py: Python) -> HashMap<u64, FnType> {
340339
"SendValue with no function must have exactly one argument and no keyword arguments",
341340
));
342341
}
343-
let (args, kwargs) = func_call_args_to_wire_values(function.as_ref(), &args, &kwargs)?;
342+
let args_kwargs = ArgsKwargs::from_python(args.into_any(), kwargs.into_any())?;
344343
Ok(WorkerMessage::SendValue {
345344
seq: p.parseSeq("ident")?,
346345
destination: p.parseOptionalRef("destination")?,
347346
mutates: p.parseRefList("mutates")?,
348347
function,
349-
args,
350-
kwargs,
348+
args_kwargs,
351349
stream: p.parseStreamRef("stream")?,
352350
})
353351
});

monarch_messages/src/wire_value.rs

Lines changed: 8 additions & 260 deletions
Original file line numberDiff line numberDiff line change
@@ -6,50 +6,27 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
use std::collections::HashMap;
10-
119
use derive_more::From;
12-
use derive_more::TryInto;
13-
use enum_as_inner::EnumAsInner;
1410
use hyperactor::Named;
1511
use monarch_types::PickledPyObject;
16-
use monarch_types::TryIntoPyObjectUnsafe;
1712
use pyo3::IntoPyObjectExt;
18-
use pyo3::exceptions::PyValueError;
1913
use pyo3::prelude::*;
20-
use pyo3::types::PyBool;
21-
use pyo3::types::PyDict;
22-
use pyo3::types::PyFloat;
23-
use pyo3::types::PyList;
2414
use pyo3::types::PyNone;
25-
use pyo3::types::PyString;
26-
use pyo3::types::PyTuple;
2715
use serde::Deserialize;
2816
use serde::Serialize;
2917
use torch_sys::Device;
3018
use torch_sys::Layout;
3119
use torch_sys::MemoryFormat;
32-
use torch_sys::OpaqueIValue;
3320
use torch_sys::ScalarType;
3421

3522
use crate::worker::Ref;
36-
use crate::worker::ResolvableFunction;
3723

3824
/// A value used as an input to CallFunction.
3925
// TODO, this is basically the same as RValue, but with TensorIndices swapped
4026
// out for refs. And IValue is the same as RValue, but with real tensors and
4127
// C++ types. I wonder if there is a nicer way to express this relationship.
4228
// TODO extend this to support other types of values, like bytes, dicts etc.
43-
#[derive(
44-
Serialize,
45-
Deserialize,
46-
Debug,
47-
Clone,
48-
TryInto,
49-
Named,
50-
From,
51-
EnumAsInner
52-
)]
29+
#[derive(Serialize, Deserialize, Debug, Clone, Named)]
5330
pub enum WireValue {
5431
// Make sure boolean goes ealier than int as bool is a subclass of int.
5532
// Otherwise, bool will be converted to int.
@@ -68,81 +45,20 @@ pub enum WireValue {
6845
// empty enum variants.
6946
None(()),
7047
PyObject(PickledPyObject),
71-
// It is ok to just have IValue without an alias tracking cell as we just use
72-
// WireValue as a way to serialize and send args to workers. We dont mutate the
73-
// IValue and use the opaque wrapper to make accessing the IValue directly
74-
// an unsafe op.
75-
IValue(torch_sys::OpaqueIValue),
7648
}
7749

7850
impl FromPyObject<'_> for WireValue {
7951
fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
80-
if let Ok(ref_) = Ref::from_py_object(obj) {
81-
Ok(WireValue::Ref(ref_))
82-
} else if let Ok(list) = obj.downcast::<PyList>() {
83-
let len = list.len();
84-
if len == 0 {
85-
// TODO: This is done for now as this seems to be the most common case for empty lists
86-
// in torch ops but we should use the op schema to do this correctly.
87-
return Ok(WireValue::IntList(vec![]));
88-
}
89-
90-
// SAFETY: We know it is within bounds
91-
let item = unsafe { list.get_item_unchecked(0) };
92-
let len = list.len();
93-
if let Ok(int) = item.extract::<i64>() {
94-
let mut int_list = Vec::with_capacity(len);
95-
int_list.push(int);
96-
for item in list.iter().skip(1) {
97-
int_list.push(item.extract::<i64>().map_err(|_| {
98-
PyValueError::new_err(format!(
99-
"Expected homogeneous list of ints got: {:?}",
100-
list
101-
))
102-
})?);
103-
}
104-
return Ok(WireValue::IntList(int_list));
105-
}
106-
if let Ok(ref_) = Ref::from_py_object(&item) {
107-
let mut ref_list = Vec::with_capacity(len);
108-
ref_list.push(ref_);
109-
for item in list.iter().skip(1) {
110-
ref_list.push(Ref::from_py_object(&item).map_err(|_| {
111-
PyValueError::new_err(format!(
112-
"Expected homogeneous list of ints got: {:?}",
113-
list
114-
))
115-
})?);
116-
}
117-
return Ok(WireValue::RefList(ref_list));
118-
}
119-
Ok(WireValue::PyObject(PickledPyObject::pickle(obj)?))
120-
} else if obj.is_none() {
121-
Ok(WireValue::None(()))
122-
} else if let Ok(bool_) = obj.downcast::<PyBool>() {
123-
Ok(WireValue::Bool(bool_.is_true()))
124-
} else if let Ok(int) = obj.extract::<i64>() {
125-
Ok(WireValue::Int(int))
126-
} else if let Ok(double) = obj.downcast::<PyFloat>() {
127-
Ok(WireValue::Double(double.value()))
128-
} else if let Ok(string) = obj.downcast::<PyString>() {
129-
Ok(WireValue::String(string.to_str()?.to_string()))
130-
} else if let Ok(device) = obj.extract::<Device>() {
131-
Ok(WireValue::Device(device))
132-
} else if let Ok(layout) = obj.extract::<Layout>() {
133-
Ok(WireValue::Layout(layout))
134-
} else if let Ok(scalar_type) = obj.extract::<ScalarType>() {
135-
Ok(WireValue::ScalarType(scalar_type))
136-
} else if let Ok(memory_format) = obj.extract::<MemoryFormat>() {
137-
Ok(WireValue::MemoryFormat(memory_format))
138-
} else {
139-
Ok(WireValue::PyObject(PickledPyObject::pickle(obj)?))
140-
}
52+
Ok(WireValue::PyObject(PickledPyObject::pickle(obj)?))
14153
}
14254
}
14355

144-
impl<'py> TryIntoPyObjectUnsafe<'py, PyAny> for WireValue {
145-
unsafe fn try_to_object_unsafe(self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
56+
impl<'py> IntoPyObject<'py> for WireValue {
57+
type Target = PyAny;
58+
type Output = Bound<'py, PyAny>;
59+
type Error = PyErr;
60+
61+
fn into_pyobject(self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
14662
match self {
14763
WireValue::Ref(ref_) => ref_.into_bound_py_any(py),
14864
WireValue::RefList(ref_list) => ref_list.clone().into_bound_py_any(py),
@@ -157,10 +73,6 @@ impl<'py> TryIntoPyObjectUnsafe<'py, PyAny> for WireValue {
15773
WireValue::MemoryFormat(val) => val.into_bound_py_any(py),
15874
WireValue::None(()) => PyNone::get(py).into_bound_py_any(py),
15975
WireValue::PyObject(val) => val.unpickle(py),
160-
// SAFETY: WireValue is only used for serde between client and worker.
161-
// This function is used to access the args / kwargs of a function call
162-
// on the client side only.
163-
WireValue::IValue(val) => unsafe { val.try_to_object_unsafe(py) },
16476
}
16577
}
16678
}
@@ -170,167 +82,3 @@ impl From<PyObject> for WireValue {
17082
Python::with_gil(|py| WireValue::PyObject(PickledPyObject::pickle(obj.bind(py)).unwrap()))
17183
}
17284
}
173-
174-
pub fn func_call_args_to_wire_values(
175-
_func: Option<&ResolvableFunction>,
176-
args: &Bound<'_, PyTuple>,
177-
kwargs: &Bound<'_, PyDict>,
178-
) -> PyResult<(Vec<WireValue>, HashMap<String, WireValue>)> {
179-
python_func_args_to_wire_value(args, kwargs)
180-
}
181-
182-
fn python_func_args_to_wire_value(
183-
args: &Bound<'_, PyTuple>,
184-
kwargs: &Bound<'_, PyDict>,
185-
) -> PyResult<(Vec<WireValue>, HashMap<String, WireValue>)> {
186-
let args = args
187-
.iter()
188-
.map(|arg| Ok(WireValue::PyObject(PickledPyObject::pickle(&arg)?)))
189-
.collect::<PyResult<_>>()?;
190-
let kwargs = kwargs
191-
.iter()
192-
.map(|(k, v)| {
193-
Ok((
194-
k.extract::<String>()?,
195-
WireValue::PyObject(PickledPyObject::pickle(&v)?),
196-
))
197-
})
198-
.collect::<Result<HashMap<_, _>, PyErr>>()?;
199-
Ok((args, kwargs))
200-
}
201-
202-
#[cfg(test)]
203-
mod tests {
204-
use std::assert_matches::assert_matches;
205-
206-
use anyhow::Result;
207-
use anyhow::bail;
208-
use paste::paste;
209-
use pyo3::Python;
210-
use pyo3::ffi::c_str;
211-
use pyo3::types::PyDict;
212-
use torch_sys::DeviceType;
213-
use torch_sys::ScalarType;
214-
215-
use super::*;
216-
use crate::worker::Ref;
217-
218-
const MOCK_REFERNCABLE_MODULE: &std::ffi::CStr = c_str!(
219-
r#"
220-
class Referencable:
221-
def __init__(self, ref: int):
222-
self.ref = ref
223-
224-
def __monarch_ref__(self):
225-
return self.ref
226-
"#
227-
);
228-
229-
fn setup() -> Result<()> {
230-
pyo3::prepare_freethreaded_python();
231-
// We need to load torch to initialize some internal structures used by
232-
// the FFI funcs we use to convert ivalues to/from py objects.
233-
Python::with_gil(|py| py.run(c_str!("import torch"), None, None))?;
234-
Ok(())
235-
}
236-
237-
fn create_py_object() -> PyObject {
238-
pyo3::prepare_freethreaded_python();
239-
Python::with_gil(|py| {
240-
let dict = PyDict::new(py);
241-
dict.set_item("foo", "bar").unwrap();
242-
dict.into_any().clone().unbind()
243-
})
244-
}
245-
246-
macro_rules! generate_wire_value_from_py_tests {
247-
($($kind:ident, $input:expr);* $(;)?) => {
248-
paste! {
249-
$(
250-
#[test]
251-
fn [<test_wire_value_from_py_$kind:snake:lower>]() -> Result<()> {
252-
setup()?;
253-
Python::with_gil(|py| {
254-
let actual = $input.into_pyobject(py)?.extract::<WireValue>()?;
255-
assert_matches!(actual, WireValue::$kind(_));
256-
anyhow::Ok(())
257-
})
258-
}
259-
)*
260-
261-
#[test]
262-
fn test_wire_value_from_py_none() -> Result<()> {
263-
setup()?;
264-
Python::with_gil(|py| {
265-
let obj = PyNone::get(py).into_pyobject(py)?;
266-
let actual = obj.extract::<WireValue>()?;
267-
assert_matches!(actual, WireValue::None(_));
268-
anyhow::Ok(())
269-
})
270-
}
271-
272-
#[test]
273-
fn test_wire_value_from_py_empty_list() -> Result<()> {
274-
setup()?;
275-
Python::with_gil(|py| {
276-
let obj: PyObject = PyList::empty(py).into_any().unbind();
277-
let actual = obj.extract::<WireValue>(py)?;
278-
match actual {
279-
WireValue::IntList(list) if list.len() == 0 => (),
280-
_ => bail!("Expected empty list to be converted to empty int list"),
281-
}
282-
anyhow::Ok(())
283-
})
284-
}
285-
286-
#[test]
287-
fn test_wire_value_from_py_referencable_class() -> Result<()> {
288-
setup()?;
289-
Python::with_gil(|py| {
290-
let referencable = PyModule::from_code(
291-
py,
292-
MOCK_REFERNCABLE_MODULE,
293-
c_str!("referencable.py"),
294-
c_str!("referencable"),
295-
)?;
296-
let ref_ = referencable.getattr("Referencable")?.call1((1,))?.unbind();
297-
let actual = ref_.extract::<WireValue>(py)?;
298-
assert_matches!(actual, WireValue::Ref(Ref { id: 1 }));
299-
anyhow::Ok(())
300-
})
301-
}
302-
303-
#[test]
304-
fn test_wire_value_from_py_roundtrip_was_exhaustive() {
305-
let val = WireValue::Int(0);
306-
match val {
307-
$(WireValue::$kind(_) => (),)*
308-
WireValue::None(_) => (),
309-
// Can't test from py here as PyObject behaves as catch all for conversion from PY.
310-
// We will manually convert torch ops args to IValue respecting the schema so its
311-
// not super important to have this.
312-
WireValue::IValue(_) => (),
313-
}
314-
}
315-
}
316-
}
317-
}
318-
319-
// Generate exhaustive roundtrip tests for all IValue kind.
320-
// If you got a "non-exhaustive patterns" error here, you need to add a new
321-
// test entry for your IValue kind!
322-
generate_wire_value_from_py_tests! {
323-
Bool, false;
324-
Double, 1.23f64;
325-
Int, 123i64;
326-
IntList, vec![1i64];
327-
Ref, Ref::from(1);
328-
RefList, vec![Ref::from(1), Ref::from(2)];
329-
String, "foobar".to_owned();
330-
Device, Device::new(DeviceType::CPU);
331-
Layout, Layout(2);
332-
ScalarType, ScalarType(3);
333-
MemoryFormat, MemoryFormat(1);
334-
PyObject, create_py_object();
335-
}
336-
}

0 commit comments

Comments
 (0)