Skip to content

Commit a404c21

Browse files
committed
Remove torch-op special call path in tensor engine
This is part of the work to remove a direct libtorch dependency. We now just route through python. A few months ago we tested that this is not slower because of so much wrapping/unwrapping to python objects anyway. We can re-enable faster execution in the future if needed. Differential Revision: [D87807456](https://our.internmc.facebook.com/intern/diff/D87807456/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D87807456/)! ghstack-source-id: 325483803 Pull Request resolved: #1990
1 parent c30a724 commit a404c21

File tree

5 files changed

+30
-256
lines changed

5 files changed

+30
-256
lines changed

monarch_extension/src/convert.rs

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -330,21 +330,6 @@ fn create_map(py: Python) -> HashMap<u64, FnType> {
330330
to_stream: p.parseStreamRef("to_stream")?,
331331
})
332332
});
333-
m.insert(key("CreatePipe"), |p| {
334-
let function = p.parseFunction("function")?;
335-
let args = p.parse("args")?;
336-
let kwargs = p.parse("kwargs")?;
337-
let (args, kwargs) = func_call_args_to_wire_values(Some(&function), &args, &kwargs)?;
338-
Ok(WorkerMessage::CreatePipe {
339-
result: p.parseRef("result")?,
340-
key: p.parse("key")?,
341-
function,
342-
max_messages: p.parse("max_messages")?,
343-
mesh: p.parseRef("device_mesh")?,
344-
args,
345-
kwargs,
346-
})
347-
});
348333
m.insert(key("SendValue"), |p| {
349334
let function = p.parseOptionalFunction("function")?;
350335
let args: Bound<'_, PyTuple> = p.parse("args")?;

monarch_messages/src/wire_value.rs

Lines changed: 2 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -171,110 +171,12 @@ impl From<PyObject> for WireValue {
171171
}
172172
}
173173

174-
impl WireValue {
175-
fn from_pyobject_with_torch_op_arg_type(
176-
obj: Bound<'_, PyAny>,
177-
type_: &torch_sys::call_op::TypePtr,
178-
num_elements: i32,
179-
allow_nums_as_tensors: bool,
180-
) -> PyResult<Self> {
181-
if type_.is_tensor() || type_.is_optional_tensor() {
182-
if type_.is_optional_tensor() && obj.is_none() {
183-
return Ok(WireValue::None(()));
184-
} else if let Ok(ref_) = Ref::from_py_object(&obj) {
185-
return Ok(WireValue::Ref(ref_));
186-
}
187-
}
188-
if type_.is_tensor_list() || type_.is_optional_tensor_list() {
189-
if type_.is_optional_tensor_list() && obj.is_none() {
190-
return Ok(WireValue::None(()));
191-
}
192-
let list = obj.downcast::<PyList>()?;
193-
let len = list.len();
194-
if len == 0 {
195-
return Ok(WireValue::RefList(vec![]));
196-
}
197-
// SAFETY: We know it is within bounds
198-
let item = unsafe { list.get_item_unchecked(0) };
199-
if let Ok(ref_) = Ref::from_py_object(&item) {
200-
let mut ref_list = Vec::with_capacity(len);
201-
ref_list.push(ref_);
202-
for item in list.iter().skip(1) {
203-
ref_list.push(Ref::from_py_object(&item).map_err(|_| {
204-
PyValueError::new_err(format!(
205-
"Expected homogeneous list of refs got: {:?}",
206-
list
207-
))
208-
})?);
209-
}
210-
return Ok(WireValue::RefList(ref_list));
211-
}
212-
}
213-
OpaqueIValue::from_py_object_with_type(obj, type_, num_elements, allow_nums_as_tensors)
214-
.map(WireValue::IValue)
215-
}
216-
}
217-
218174
pub fn func_call_args_to_wire_values(
219-
func: Option<&ResolvableFunction>,
220-
args: &Bound<'_, PyTuple>,
221-
kwargs: &Bound<'_, PyDict>,
222-
) -> PyResult<(Vec<WireValue>, HashMap<String, WireValue>)> {
223-
if let Some((op, overload)) = func.and_then(|func| func.as_torch_op()) {
224-
torch_op_args_to_wire_values(&op, &overload, args, kwargs)
225-
} else {
226-
python_func_args_to_wire_value(args, kwargs)
227-
}
228-
}
229-
230-
fn torch_op_args_to_wire_values(
231-
op: &str,
232-
overload: &str,
175+
_func: Option<&ResolvableFunction>,
233176
args: &Bound<'_, PyTuple>,
234177
kwargs: &Bound<'_, PyDict>,
235178
) -> PyResult<(Vec<WireValue>, HashMap<String, WireValue>)> {
236-
let args_info = torch_sys::call_op::get_schema_args_info(op, overload).map_err(|err| {
237-
PyValueError::new_err(format!(
238-
"Failed to get the operator schema for {}::{}: {}",
239-
op, overload, err
240-
))
241-
})?;
242-
243-
let args = args
244-
.iter()
245-
.zip(&args_info)
246-
.map(|(arg, arg_info)| {
247-
WireValue::from_pyobject_with_torch_op_arg_type(
248-
arg,
249-
arg_info.type_,
250-
arg_info.num_elements,
251-
arg_info.allows_number_as_tensor,
252-
)
253-
})
254-
.collect::<Result<Vec<_>, _>>()?;
255-
let kwargs = kwargs
256-
.iter()
257-
.map(|(k, v)| {
258-
let key = k.extract::<String>()?;
259-
let arg_info = args_info
260-
.iter()
261-
.find(|arg_info| arg_info.name == key)
262-
.ok_or_else(|| {
263-
PyValueError::new_err(format!(
264-
"Torch op {}::{} does not support kwarg {}",
265-
op, overload, key
266-
))
267-
})?;
268-
let val = WireValue::from_pyobject_with_torch_op_arg_type(
269-
v,
270-
arg_info.type_,
271-
arg_info.num_elements,
272-
arg_info.allows_number_as_tensor,
273-
)?;
274-
Ok((key, val))
275-
})
276-
.collect::<Result<HashMap<_, _>, PyErr>>()?;
277-
Ok((args, kwargs))
179+
python_func_args_to_wire_value(args, kwargs)
278180
}
279181

280182
fn python_func_args_to_wire_value(

monarch_messages/src/worker.rs

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -340,21 +340,6 @@ impl ResolvableFunction {
340340
}
341341
}
342342

343-
pub fn as_torch_op<'a>(&'a self) -> Option<(String, String)> {
344-
match self {
345-
Self::FunctionPath(func) => match func.path.split(".").collect::<Vec<_>>().as_slice() {
346-
["torch", "ops", namespace, op_name, "default"] => {
347-
Some((format!("{}::{}", namespace, op_name), String::new()))
348-
}
349-
["torch", "ops", namespace, op_name, overload] => {
350-
Some((format!("{}::{}", namespace, op_name), overload.to_string()))
351-
}
352-
_ => None,
353-
},
354-
_ => None,
355-
}
356-
}
357-
358343
/// For testing: this is a special remote function path that induces a panic
359344
/// when called.
360345
pub fn panic_if_requested(&self) {
@@ -367,13 +352,6 @@ impl ResolvableFunction {
367352
_ => (),
368353
}
369354
}
370-
371-
pub fn supports_pytree_args(&self) -> bool {
372-
match self {
373-
Self::Cloudpickle(_) => true,
374-
Self::FunctionPath(_) => self.as_torch_op().is_none(),
375-
}
376-
}
377355
}
378356

379357
impl<T: Into<String>> From<T> for ResolvableFunction {
@@ -800,16 +778,6 @@ pub enum WorkerMessage {
800778
to_stream: StreamRef,
801779
},
802780

803-
CreatePipe {
804-
result: Ref,
805-
key: String,
806-
function: ResolvableFunction,
807-
max_messages: i64,
808-
mesh: Ref,
809-
args: Vec<WireValue>,
810-
kwargs: HashMap<String, WireValue>,
811-
},
812-
813781
SendValue {
814782
seq: Seq,
815783
/// Pipe to send value to. If `None`, value is sent to controller.

monarch_tensor_worker/src/lib.rs

Lines changed: 6 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@ use monarch_messages::worker::StreamRef;
7979
use monarch_messages::worker::WorkerMessage;
8080
use monarch_messages::worker::WorkerMessageHandler;
8181
use monarch_messages::worker::WorkerParams;
82-
use monarch_types::PyTree;
8382
use ndslice::Slice;
8483
use pyo3::Python;
8584
use pyo3::types::PyAnyMethods;
@@ -92,7 +91,6 @@ use stream::StreamParams;
9291
use torch_sys::CudaDevice;
9392
use torch_sys::DeviceIndex;
9493
use torch_sys::Layout;
95-
use torch_sys::RValue;
9694
use torch_sys::ScalarType;
9795
use torch_sys::TensorCell;
9896
use torch_sys::factory_zeros;
@@ -383,14 +381,11 @@ impl WorkerMessageHandler for WorkerActor {
383381
self.maybe_add_stream_to_recording(cx, params.stream)
384382
.await?;
385383

386-
let device_meshes = if params.function.as_torch_op().is_some() {
387-
HashMap::new()
388-
} else {
389-
self.device_meshes
390-
.iter()
391-
.map(|(k, v)| (k.clone(), v.0.clone()))
392-
.collect()
393-
};
384+
let device_meshes = self
385+
.device_meshes
386+
.iter()
387+
.map(|(k, v)| (k.clone(), v.0.clone()))
388+
.collect();
394389

395390
let mut remote_process_groups = HashMap::new();
396391
for remote_process_group_ref in &params.remote_process_groups {
@@ -638,22 +633,6 @@ impl WorkerMessageHandler for WorkerActor {
638633
Ok(())
639634
}
640635

641-
async fn create_pipe(
642-
&mut self,
643-
_cx: &hyperactor::Context<Self>,
644-
_result: Ref,
645-
// TODO(agallagher): This is used in the python impl to name the socket
646-
// path to use for comms, but we don't currently use a named socket.
647-
_key: String,
648-
_function: ResolvableFunction,
649-
_max_messages: i64,
650-
_device_mesh: Ref,
651-
_args: Vec<WireValue>,
652-
_kwargs: HashMap<String, WireValue>,
653-
) -> Result<()> {
654-
panic!("create_pipe is no longer implemented")
655-
}
656-
657636
async fn send_tensor(
658637
&mut self,
659638
cx: &hyperactor::Context<Self>,
@@ -772,7 +751,7 @@ impl WorkerMessageHandler for WorkerActor {
772751
// Resolve the stream.
773752
let stream = self.try_get_stream(stream)?;
774753

775-
let device_meshes = if function.as_ref().is_none_or(|f| f.as_torch_op().is_some()) {
754+
let device_meshes = if function.is_none() {
776755
HashMap::new()
777756
} else {
778757
self.device_meshes

monarch_tensor_worker/src/stream.rs

Lines changed: 22 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ use monarch_types::PyTree;
5454
use monarch_types::SerializablePyErr;
5555
use monarch_types::TryIntoPyObjectUnsafe;
5656
use pyo3::prelude::*;
57-
use pyo3::types::PyTuple;
5857
use tokio::runtime::Handle;
5958
use tokio::sync::Mutex;
6059
use tokio::task::JoinHandle;
@@ -740,34 +739,6 @@ impl StreamActor {
740739
Ok(())
741740
}
742741

743-
fn call_torch_op(
744-
&self,
745-
op: String,
746-
overload: String,
747-
args: Vec<WireValue>,
748-
kwargs: HashMap<String, WireValue>,
749-
) -> Result<Vec<RValue>, CallFunctionError> {
750-
let args = args
751-
.into_iter()
752-
.map(|arg| self.wire_to_rvalue(arg))
753-
.collect::<Result<Vec<_>, _>>()?;
754-
let kwargs = kwargs
755-
.into_iter()
756-
.map(|(k, v)| self.wire_to_rvalue(v).map(|rvalue| (k, rvalue)))
757-
.collect::<Result<HashMap<_, _>, CallFunctionError>>()?;
758-
759-
let results = torch_sys::call_op::call_op(op, overload, &args, &kwargs, true)?;
760-
761-
// Handle the case where the op returns nothing and convert it to a list of None.
762-
// This is to ensure handle results does not error out as the client will call
763-
// such a function with expected results of size 1.
764-
Ok(if results.is_empty() {
765-
vec![RValue::None]
766-
} else {
767-
results
768-
})
769-
}
770-
771742
fn call_python_fn<'py>(
772743
&mut self,
773744
py: Python<'py>,
@@ -1118,21 +1089,17 @@ impl StreamMessageHandler for StreamActor {
11181089
params.results,
11191090
&params.mutates,
11201091
async |self| {
1121-
tokio::task::block_in_place(|| match params.function.as_torch_op() {
1122-
Some((op, overload)) => {
1123-
self.call_torch_op(op, overload, params.args, params.kwargs)
1124-
}
1125-
_ => self
1126-
.call_python_fn_pytree(
1127-
cx,
1128-
params.function,
1129-
params.args,
1130-
params.kwargs,
1131-
&params.mutates,
1132-
device_meshes,
1133-
remote_process_groups,
1134-
)
1135-
.map(|results| results.into_leaves()),
1092+
tokio::task::block_in_place(|| {
1093+
self.call_python_fn_pytree(
1094+
cx,
1095+
params.function,
1096+
params.args,
1097+
params.kwargs,
1098+
&params.mutates,
1099+
device_meshes,
1100+
remote_process_groups,
1101+
)
1102+
.map(|results| results.into_leaves())
11361103
})
11371104
},
11381105
)
@@ -1562,44 +1529,17 @@ impl StreamMessageHandler for StreamActor {
15621529
}
15631530
let result = if let Some(function) = function {
15641531
// If a function was provided, use that to resolve the value.
1565-
match function.as_torch_op() {
1566-
Some((op, overload)) => {
1567-
self.call_torch_op(op, overload, args, kwargs)
1568-
.map(|rvalues| {
1569-
if rvalues.len() == 1 {
1570-
Ok(rvalues[0].clone().into())
1571-
} else {
1572-
// TODO: Replace with native pytrees when possible
1573-
Python::with_gil(|py| {
1574-
Ok((|| {
1575-
let py_rvalues = rvalues
1576-
.into_iter()
1577-
// SAFETY: This inherits the unsafety of `try_to_object_unsafe`.
1578-
.map(|rvalue| unsafe {
1579-
rvalue.try_to_object_unsafe(py)
1580-
})
1581-
.collect::<Result<Vec<_>, _>>()?;
1582-
PyTuple::new(py, &py_rvalues)?.extract::<PyTree<RValue>>()
1583-
})()
1584-
.map_err(SerializablePyErr::from_fn(py))?)
1585-
})
1586-
}
1587-
})?
1588-
}
1589-
// Use block-in-place to allow nested callbacks to re-enter the
1590-
// runtime to run async code.
1591-
_ => tokio::task::block_in_place(|| {
1592-
self.call_python_fn_pytree(
1593-
cx,
1594-
function,
1595-
args,
1596-
kwargs,
1597-
&mutates,
1598-
device_meshes,
1599-
HashMap::new(),
1600-
)
1601-
}),
1602-
}
1532+
tokio::task::block_in_place(|| {
1533+
self.call_python_fn_pytree(
1534+
cx,
1535+
function,
1536+
args,
1537+
kwargs,
1538+
&mutates,
1539+
device_meshes,
1540+
HashMap::new(),
1541+
)
1542+
})
16031543
} else {
16041544
// If there's no function provided, there should be exactly one arg
16051545
// and no kwargs.

0 commit comments

Comments
 (0)