-
Notifications
You must be signed in to change notification settings - Fork 150
Handle slices in mlx_funcify_IncSubtensor
#1692
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 8 commits
bff3685
a63e759
119e7e6
116a1bd
4f7ae9f
9632ad6
c457a13
cd7a2d0
449c2df
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -1,5 +1,8 @@ | ||||||
| from copy import deepcopy | ||||||
|
|
||||||
| import mlx.core as mx | ||||||
| import numpy as np | ||||||
|
|
||||||
| from pytensor.link.mlx.dispatch.basic import mlx_funcify | ||||||
| from pytensor.tensor.subtensor import ( | ||||||
| AdvancedIncSubtensor, | ||||||
|
|
@@ -13,12 +16,51 @@ | |||||
| from pytensor.tensor.type_other import MakeSlice | ||||||
|
|
||||||
|
|
||||||
| def normalize_indices_for_mlx(ilist, idx_list): | ||||||
| """Convert numpy integers to Python integers for MLX indexing. | ||||||
|
|
||||||
| MLX requires index values to be Python int, not np.int64 or other NumPy types. | ||||||
| """ | ||||||
|
|
||||||
| def to_int(value, element): | ||||||
| """Convert value to Python int with helpful error message.""" | ||||||
| try: | ||||||
| return int(value) | ||||||
| except (TypeError, ValueError) as e: | ||||||
| raise TypeError( | ||||||
| "MLX backend does not support symbolic indices. " | ||||||
| "Index values must be concrete (constant) integers, not symbolic variables. " | ||||||
| f"Got: {element}" | ||||||
| ) from e | ||||||
|
|
||||||
| def normalize_element(element): | ||||||
| if element is None: | ||||||
| return None | ||||||
| elif isinstance(element, slice): | ||||||
| return slice( | ||||||
| normalize_element(element.start), | ||||||
| normalize_element(element.stop), | ||||||
| normalize_element(element.step), | ||||||
| ) | ||||||
| elif isinstance(element, mx.array) and element.ndim == 0: | ||||||
| return to_int(element.item(), element) | ||||||
| elif isinstance(element, np.integer): | ||||||
| return to_int(element, element) | ||||||
| else: | ||||||
| return element | ||||||
|
|
||||||
| indices = indices_from_subtensor(ilist, idx_list) | ||||||
|
||||||
| return tuple(normalize_element(idx) for idx in indices) | ||||||
|
|
||||||
|
|
||||||
| @mlx_funcify.register(Subtensor) | ||||||
| def mlx_funcify_Subtensor(op, node, **kwargs): | ||||||
| """MLX implementation of Subtensor.""" | ||||||
| idx_list = getattr(op, "idx_list", None) | ||||||
|
||||||
| idx_list = getattr(op, "idx_list", None) | |
| idx_list = op.idx_list |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not a thing for Advanced indexing
| idx_list = getattr(op, "idx_list", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
ricardoV94 marked this conversation as resolved.
Show resolved
Hide resolved
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -119,6 +119,19 @@ def test_mlx_IncSubtensor_increment(): | |
| assert not out_pt.owner.op.set_instead_of_inc | ||
| compare_mlx_and_py([], [out_pt], []) | ||
|
|
||
| # Increment slice | ||
| out_pt = pt_subtensor.inc_subtensor(x_pt[:, :, 2:], st_pt) | ||
| compare_mlx_and_py([], [out_pt], []) | ||
|
|
||
| out_pt = pt_subtensor.inc_subtensor(x_pt[:, :, -3:], st_pt) | ||
| compare_mlx_and_py([], [out_pt], []) | ||
|
|
||
| out_pt = pt_subtensor.inc_subtensor(x_pt[::2, ::2, ::2], st_pt) | ||
| compare_mlx_and_py([], [out_pt], []) | ||
|
|
||
| out_pt = pt_subtensor.inc_subtensor(x_pt[:, :, :], st_pt) | ||
| compare_mlx_and_py([], [out_pt], []) | ||
|
|
||
|
|
||
| def test_mlx_AdvancedIncSubtensor_set(): | ||
| """Test advanced set operations using AdvancedIncSubtensor.""" | ||
|
|
@@ -232,9 +245,12 @@ def test_mlx_subtensor_edge_cases(): | |
| compare_mlx_and_py([], [out_pt], []) | ||
|
|
||
|
|
||
| @pytest.mark.xfail(reason="MLX indexing with tuples not yet supported") | ||
| def test_mlx_subtensor_with_variables(): | ||
| """Test subtensor operations with PyTensor variables as inputs.""" | ||
| """Test subtensor operations with PyTensor variables as inputs. | ||
|
|
||
| This test now works thanks to the fix for np.int64 indexing, which also | ||
| handles the conversion of MLX scalar arrays in slice components. | ||
| """ | ||
| # Test with variable arrays (not constants) | ||
| x_pt = pt.matrix("x", dtype="float32") | ||
| y_pt = pt.vector("y", dtype="float32") | ||
|
|
@@ -245,3 +261,150 @@ def test_mlx_subtensor_with_variables(): | |
| # Set operation with variables | ||
| out_pt = pt_subtensor.set_subtensor(x_pt[0, :2], y_pt) | ||
| compare_mlx_and_py([x_pt, y_pt], [out_pt], [x_np, y_np]) | ||
|
|
||
|
|
||
| def test_mlx_subtensor_with_numpy_int64(): | ||
| """Test Subtensor operations with np.int64 indices. | ||
|
|
||
| This tests the fix for MLX's strict requirement that indices must be | ||
| Python int, not np.int64 or other NumPy integer types. | ||
| """ | ||
| # Test data | ||
| x_np = np.arange(12, dtype=np.float32).reshape((3, 4)) | ||
| x_pt = pt.constant(x_np) | ||
|
|
||
| # Single np.int64 index - this was failing before the fix | ||
| idx = np.int64(1) | ||
| out_pt = x_pt[idx] | ||
| compare_mlx_and_py([], [out_pt], []) | ||
|
|
||
| # Multiple np.int64 indices | ||
| out_pt = x_pt[np.int64(1), np.int64(2)] | ||
| compare_mlx_and_py([], [out_pt], []) | ||
|
|
||
| # Negative np.int64 index | ||
| out_pt = x_pt[np.int64(-1)] | ||
| compare_mlx_and_py([], [out_pt], []) | ||
|
|
||
| # Mixed Python int and np.int64 | ||
| out_pt = x_pt[1, np.int64(2)] | ||
| compare_mlx_and_py([], [out_pt], []) | ||
|
|
||
|
|
||
| def test_mlx_subtensor_slices_with_numpy_int64(): | ||
| """Test Subtensor with slices containing np.int64 components. | ||
|
|
||
| This tests that slice start/stop/step values can be np.int64. | ||
| """ | ||
| x_np = np.arange(20, dtype=np.float32) | ||
| x_pt = pt.constant(x_np) | ||
|
|
||
| # Slice with np.int64 start | ||
| out_pt = x_pt[np.int64(2) :] | ||
| compare_mlx_and_py([], [out_pt], []) | ||
|
|
||
| # Slice with np.int64 stop | ||
| out_pt = x_pt[: np.int64(5)] | ||
| compare_mlx_and_py([], [out_pt], []) | ||
|
|
||
| # Slice with np.int64 start and stop | ||
| out_pt = x_pt[np.int64(2) : np.int64(8)] | ||
| compare_mlx_and_py([], [out_pt], []) | ||
|
|
||
| # Slice with np.int64 step | ||
| out_pt = x_pt[:: np.int64(2)] | ||
| compare_mlx_and_py([], [out_pt], []) | ||
|
|
||
| # Slice with all np.int64 components | ||
| out_pt = x_pt[np.int64(1) : np.int64(10) : np.int64(2)] | ||
| compare_mlx_and_py([], [out_pt], []) | ||
|
|
||
| # Negative np.int64 in slice | ||
| out_pt = x_pt[np.int64(-5) :] | ||
| compare_mlx_and_py([], [out_pt], []) | ||
|
|
||
|
|
||
| def test_mlx_incsubtensor_with_numpy_int64(): | ||
| """Test IncSubtensor (set/inc) with np.int64 indices and slices. | ||
|
|
||
| This is the main test for the reported issue with inc_subtensor. | ||
| """ | ||
| # Test data | ||
| x_np = np.arange(12, dtype=np.float32).reshape((3, 4)) | ||
| x_pt = pt.constant(x_np) | ||
| y_pt = pt.as_tensor_variable(np.array(10.0, dtype=np.float32)) | ||
|
|
||
| # Set with np.int64 index | ||
| out_pt = pt_subtensor.set_subtensor(x_pt[np.int64(1), np.int64(2)], y_pt) | ||
| compare_mlx_and_py([], [out_pt], []) | ||
|
|
||
| # Increment with np.int64 index | ||
| out_pt = pt_subtensor.inc_subtensor(x_pt[np.int64(1), np.int64(2)], y_pt) | ||
| compare_mlx_and_py([], [out_pt], []) | ||
|
|
||
| # Set with slice containing np.int64 - THE ORIGINAL FAILING CASE | ||
| out_pt = pt_subtensor.set_subtensor(x_pt[:, : np.int64(2)], y_pt) | ||
| compare_mlx_and_py([], [out_pt], []) | ||
|
|
||
| # Increment with slice containing np.int64 - THE ORIGINAL FAILING CASE | ||
| out_pt = pt_subtensor.inc_subtensor(x_pt[:, : np.int64(2)], y_pt) | ||
| compare_mlx_and_py([], [out_pt], []) | ||
|
|
||
| # Complex slice with np.int64 | ||
| y2_pt = pt.as_tensor_variable(np.ones((2, 2), dtype=np.float32)) | ||
| out_pt = pt_subtensor.inc_subtensor( | ||
| x_pt[np.int64(0) : np.int64(2), np.int64(1) : np.int64(3)], y2_pt | ||
| ) | ||
| compare_mlx_and_py([], [out_pt], []) | ||
|
|
||
|
|
||
| def test_mlx_incsubtensor_original_issue(): | ||
| """Test the exact example from the issue report. | ||
|
|
||
| This was failing with: ValueError: Slice indices must be integers or None. | ||
| """ | ||
| x_np = np.arange(9, dtype=np.float64).reshape((3, 3)) | ||
| x_pt = pt.constant(x_np, dtype="float64") | ||
|
|
||
| # The exact failing case from the issue | ||
| out_pt = pt_subtensor.inc_subtensor(x_pt[:, :2], 10) | ||
| compare_mlx_and_py([], [out_pt], []) | ||
|
|
||
| # Verify it also works with set_subtensor | ||
| out_pt = pt_subtensor.set_subtensor(x_pt[:, :2], 10) | ||
| compare_mlx_and_py([], [out_pt], []) | ||
|
|
||
|
|
||
| def test_mlx_advanced_subtensor_with_numpy_int64(): | ||
| """Test AdvancedSubtensor with np.int64 in mixed indexing.""" | ||
| x_np = np.arange(24, dtype=np.float32).reshape((3, 4, 2)) | ||
| x_pt = pt.constant(x_np) | ||
|
|
||
| # Advanced indexing with list, but other dimensions use np.int64 | ||
| # Note: This creates AdvancedSubtensor, not basic Subtensor | ||
| out_pt = x_pt[[0, 2], np.int64(1)] | ||
| compare_mlx_and_py([], [out_pt], []) | ||
|
|
||
| # Mixed advanced and basic indexing with np.int64 in slice | ||
| out_pt = x_pt[[0, 2], np.int64(1) : np.int64(3)] | ||
| compare_mlx_and_py([], [out_pt], []) | ||
|
|
||
|
|
||
| def test_mlx_advanced_incsubtensor_with_numpy_int64(): | ||
| """Test AdvancedIncSubtensor with np.int64.""" | ||
| x_np = np.arange(15, dtype=np.float32).reshape((5, 3)) | ||
| x_pt = pt.constant(x_np) | ||
|
|
||
| # Value to set/increment | ||
| y_pt = pt.as_tensor_variable( | ||
| np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32) | ||
| ) | ||
|
|
||
| # Advanced indexing set with array indices | ||
| indices = [np.int64(0), np.int64(2)] | ||
|
||
| out_pt = pt_subtensor.set_subtensor(x_pt[indices], y_pt) | ||
| compare_mlx_and_py([], [out_pt], []) | ||
|
|
||
| # Advanced indexing increment | ||
| out_pt = pt_subtensor.inc_subtensor(x_pt[indices], y_pt) | ||
| compare_mlx_and_py([], [out_pt], []) | ||
Uh oh!
There was an error while loading. Please reload this page.