-
Notifications
You must be signed in to change notification settings - Fork 148
Handle slices in mlx_funcify_IncSubtensor
#1692
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 4 commits
bff3685
a63e759
119e7e6
116a1bd
4f7ae9f
9632ad6
c457a13
cd7a2d0
449c2df
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -1,5 +1,7 @@ | ||||||
| from copy import deepcopy | ||||||
|
|
||||||
| import numpy as np | ||||||
|
|
||||||
| from pytensor.link.mlx.dispatch.basic import mlx_funcify | ||||||
| from pytensor.tensor.subtensor import ( | ||||||
| AdvancedIncSubtensor, | ||||||
|
|
@@ -13,12 +15,119 @@ | |||||
| from pytensor.tensor.type_other import MakeSlice | ||||||
|
|
||||||
|
|
||||||
| def normalize_indices_for_mlx(ilist, idx_list): | ||||||
| """Convert indices to MLX-compatible format. | ||||||
|
|
||||||
| MLX has strict requirements for indexing: | ||||||
| - Integer indices must be Python int, not np.int64 or other NumPy integer types | ||||||
| - Slice components (start, stop, step) must be Python int or None, not np.int64 | ||||||
| - MLX arrays created from scalars need to be converted back to Python int | ||||||
| - Array indices for advanced indexing are handled separately | ||||||
|
|
||||||
| This function converts all integer-like indices and slice components to Python int | ||||||
| while preserving None values and passing through array indices unchanged. | ||||||
|
|
||||||
| Parameters | ||||||
| ---------- | ||||||
| ilist : tuple | ||||||
| Runtime index values to be passed to indices_from_subtensor | ||||||
| idx_list : tuple | ||||||
| Static index specification from the Op's idx_list attribute | ||||||
|
|
||||||
| Returns | ||||||
| ------- | ||||||
| tuple | ||||||
| Normalized indices compatible with MLX array indexing | ||||||
|
|
||||||
| Examples | ||||||
| -------- | ||||||
| >>> # Single np.int64 index converted to Python int | ||||||
| >>> normalize_indices_for_mlx((np.int64(1),), (True,)) | ||||||
| (1,) | ||||||
|
|
||||||
| >>> # Slice with np.int64 components | ||||||
| >>> indices = indices_from_subtensor( | ||||||
| ... (np.int64(0), np.int64(2)), (slice(None, None),) | ||||||
| ... ) | ||||||
| >>> # After normalization, slice components are Python int | ||||||
|
|
||||||
| Notes | ||||||
| ----- | ||||||
| This conversion is necessary because MLX's C++ indexing implementation | ||||||
| does not recognize NumPy scalar types, raising ValueError when encountered. | ||||||
| Additionally, mlx_typify converts NumPy scalars to MLX arrays, which also | ||||||
| need to be converted back to Python int for use in indexing operations. | ||||||
| Converting to Python int is zero-cost for Python int inputs and minimal | ||||||
| overhead for NumPy scalars and MLX scalar arrays. | ||||||
| """ | ||||||
| import mlx.core as mx | ||||||
|
|
||||||
| def normalize_element(element): | ||||||
| """Convert a single index element to MLX-compatible format.""" | ||||||
| if element is None: | ||||||
| # None is valid in slices (e.g., x[None:5] or x[:None]) | ||||||
| return None | ||||||
| elif isinstance(element, slice): | ||||||
| # Recursively normalize slice components | ||||||
| return slice( | ||||||
| normalize_element(element.start), | ||||||
| normalize_element(element.stop), | ||||||
| normalize_element(element.step), | ||||||
| ) | ||||||
| elif isinstance(element, mx.array): | ||||||
| # MLX arrays from mlx_typify need special handling | ||||||
| # If it's a 0-d array (scalar), convert to Python int/float | ||||||
| if element.ndim == 0: | ||||||
| # Extract the scalar value | ||||||
| item = element.item() | ||||||
| # Convert to Python int if it's an integer type | ||||||
| if element.dtype in ( | ||||||
| mx.int8, | ||||||
| mx.int16, | ||||||
| mx.int32, | ||||||
| mx.int64, | ||||||
| mx.uint8, | ||||||
| mx.uint16, | ||||||
| mx.uint32, | ||||||
| mx.uint64, | ||||||
| ): | ||||||
| return int(item) | ||||||
| else: | ||||||
| return float(item) | ||||||
| else: | ||||||
| # Multi-dimensional array for advanced indexing - pass through | ||||||
| return element | ||||||
| elif isinstance(element, (np.integer, np.floating)): | ||||||
| # Convert NumPy scalar to Python int/float | ||||||
| # This handles np.int64, np.int32, np.float64, etc. | ||||||
| return int(element) if isinstance(element, np.integer) else float(element) | ||||||
| elif isinstance(element, (int, float)): | ||||||
| # Python int/float are already compatible | ||||||
| return element | ||||||
| else: | ||||||
| # Pass through other types (arrays for advanced indexing, etc.) | ||||||
| return element | ||||||
|
|
||||||
| # Get indices from PyTensor's subtensor utility | ||||||
| raw_indices = indices_from_subtensor(ilist, idx_list) | ||||||
|
|
||||||
| # Normalize each index element | ||||||
| normalized = tuple(normalize_element(idx) for idx in raw_indices) | ||||||
|
|
||||||
| return normalized | ||||||
|
|
||||||
|
|
||||||
| @mlx_funcify.register(Subtensor) | ||||||
| def mlx_funcify_Subtensor(op, node, **kwargs): | ||||||
| """MLX implementation of Subtensor operation. | ||||||
|
|
||||||
| Uses normalize_indices_for_mlx to ensure all indices are compatible with MLX. | ||||||
| """ | ||||||
| idx_list = getattr(op, "idx_list", None) | ||||||
|
||||||
| idx_list = getattr(op, "idx_list", None) | |
| idx_list = op.idx_list |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not a thing for Advanced indexing
| idx_list = getattr(op, "idx_list", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
Uh oh!
There was an error while loading. Please reload this page.