Skip to content

Commit e202479

Browse files
Add .pyi file
Signed-off-by: Goutam <goutam@anyscale.com>
1 parent d6a6229 commit e202479

File tree

4 files changed

+900
-20
lines changed

4 files changed

+900
-20
lines changed

python/ray/data/BUILD.bazel

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1635,6 +1635,20 @@ py_test(
16351635
],
16361636
)
16371637

1638+
py_test(
1639+
name = "test_namespace_expressions",
1640+
size = "medium",
1641+
srcs = ["tests/test_namespace_expressions.py"],
1642+
tags = [
1643+
"exclusive",
1644+
"team:data",
1645+
],
1646+
deps = [
1647+
":conftest",
1648+
"//:ray_lib",
1649+
],
1650+
)
1651+
16381652
py_test(
16391653
name = "test_context",
16401654
size = "small",

python/ray/data/expressions.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -549,7 +549,6 @@ def _add_methods_from_config(
549549
"len": _PyArrowMethodConfig(
550550
"list_value_length", DataType.int32(), docstring="Get the length of each list."
551551
),
552-
# Note: sort is manually defined below to use Literal type for order parameter
553552
"flatten": _PyArrowMethodConfig(
554553
"list_flatten", DataType(object), docstring="Flatten nested lists."
555554
),
@@ -630,6 +629,11 @@ def _add_methods_from_config(
630629
DataType.bool(),
631630
docstring="Check if strings contain only printable characters.",
632631
),
632+
"is_ascii": _PyArrowMethodConfig(
633+
"string_is_ascii",
634+
DataType.bool(),
635+
docstring="Check if strings contain only ASCII characters.",
636+
),
633637
# Searching (parameterized)
634638
"starts_with": _PyArrowMethodConfig(
635639
"starts_with",
@@ -667,6 +671,24 @@ def _add_methods_from_config(
667671
params=["pattern", "ignore_case"],
668672
docstring="Count occurrences of a substring.",
669673
),
674+
"find_regex": _PyArrowMethodConfig(
675+
"find_substring_regex",
676+
DataType.int32(),
677+
params=["pattern", "ignore_case"],
678+
docstring="Find the first occurrence matching a regex pattern.",
679+
),
680+
"count_regex": _PyArrowMethodConfig(
681+
"count_substring_regex",
682+
DataType.int32(),
683+
params=["pattern", "ignore_case"],
684+
docstring="Count occurrences matching a regex pattern.",
685+
),
686+
"match_regex": _PyArrowMethodConfig(
687+
"match_substring_regex",
688+
DataType.bool(),
689+
params=["pattern", "ignore_case"],
690+
docstring="Check if strings match a regex pattern.",
691+
),
670692
# Transformations
671693
"reverse": _PyArrowMethodConfig(
672694
"utf8_reverse", DataType.string(), docstring="Reverse each string."
@@ -754,25 +776,6 @@ def _list_slice(arr):
754776

755777
return _list_slice(self._expr)
756778

757-
def sort(
758-
self, order: Literal["ascending", "descending"] = "ascending"
759-
) -> "UDFExpr":
760-
"""Sort each list.
761-
762-
Args:
763-
order: Sort order, either "ascending" or "descending".
764-
765-
Returns:
766-
UDFExpr that sorts each list in the column.
767-
"""
768-
import pyarrow.compute as pc
769-
770-
@udf(return_dtype=DataType(object))
771-
def _list_sort(arr):
772-
return pc.list_sort(arr, order=order)
773-
774-
return _list_sort(self._expr)
775-
776779

777780
@dataclass
778781
class _StringNamespace:

python/ray/data/expressions.pyi

Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
"""Type stubs for ray.data.expressions module.
2+
3+
This file provides type hints for dynamically generated namespace methods
4+
to enable IDE autocomplete and type checking.
5+
"""
6+
7+
from __future__ import annotations
8+
9+
from enum import Enum
10+
from typing import Any, Callable, Dict, List, Literal, Union
11+
12+
from ray.data.block import BatchColumn
13+
from ray.data.datatype import DataType
14+
15+
# Re-export main classes
16+
class Operation(Enum):
17+
ADD: str
18+
SUB: str
19+
MUL: str
20+
DIV: str
21+
FLOORDIV: str
22+
GT: str
23+
LT: str
24+
GE: str
25+
LE: str
26+
EQ: str
27+
NE: str
28+
AND: str
29+
OR: str
30+
NOT: str
31+
IS_NULL: str
32+
IS_NOT_NULL: str
33+
IN: str
34+
NOT_IN: str
35+
36+
class Expr:
37+
data_type: DataType
38+
39+
@property
40+
def name(self) -> str | None: ...
41+
def structurally_equals(self, other: Any) -> bool: ...
42+
def to_pyarrow(self) -> Any: ...
43+
def __repr__(self) -> str: ...
44+
def _bin(self, other: Any, op: Operation) -> Expr: ...
45+
46+
# Arithmetic
47+
def __add__(self, other: Any) -> Expr: ...
48+
def __radd__(self, other: Any) -> Expr: ...
49+
def __sub__(self, other: Any) -> Expr: ...
50+
def __rsub__(self, other: Any) -> Expr: ...
51+
def __mul__(self, other: Any) -> Expr: ...
52+
def __rmul__(self, other: Any) -> Expr: ...
53+
def __truediv__(self, other: Any) -> Expr: ...
54+
def __rtruediv__(self, other: Any) -> Expr: ...
55+
def __floordiv__(self, other: Any) -> Expr: ...
56+
def __rfloordiv__(self, other: Any) -> Expr: ...
57+
58+
# Comparison
59+
def __gt__(self, other: Any) -> Expr: ...
60+
def __lt__(self, other: Any) -> Expr: ...
61+
def __ge__(self, other: Any) -> Expr: ...
62+
def __le__(self, other: Any) -> Expr: ...
63+
def __eq__(self, other: Any) -> Expr: ...
64+
def __ne__(self, other: Any) -> Expr: ...
65+
66+
# Boolean
67+
def __and__(self, other: Any) -> Expr: ...
68+
def __or__(self, other: Any) -> Expr: ...
69+
def __invert__(self) -> Expr: ...
70+
71+
# Predicates
72+
def is_null(self) -> Expr: ...
73+
def is_not_null(self) -> Expr: ...
74+
def is_in(self, values: Union[List[Any], Expr]) -> Expr: ...
75+
def not_in(self, values: Union[List[Any], Expr]) -> Expr: ...
76+
77+
def alias(self, name: str) -> Expr: ...
78+
79+
# Namespace accessors
80+
@property
81+
def list(self) -> _ListNamespace: ...
82+
@property
83+
def str(self) -> _StringNamespace: ...
84+
@property
85+
def struct(self) -> _StructNamespace: ...
86+
87+
def _unalias(self) -> Expr: ...
88+
89+
class ColumnExpr(Expr):
90+
_name: str
91+
@property
92+
def name(self) -> str: ...
93+
def _rename(self, name: str) -> AliasExpr: ...
94+
def structurally_equals(self, other: Any) -> bool: ...
95+
96+
class LiteralExpr(Expr):
97+
value: Any
98+
def structurally_equals(self, other: Any) -> bool: ...
99+
100+
class BinaryExpr(Expr):
101+
op: Operation
102+
left: Expr
103+
right: Expr
104+
def structurally_equals(self, other: Any) -> bool: ...
105+
106+
class UnaryExpr(Expr):
107+
op: Operation
108+
operand: Expr
109+
def structurally_equals(self, other: Any) -> bool: ...
110+
111+
class UDFExpr(Expr):
112+
fn: Callable[..., BatchColumn]
113+
args: List[Expr]
114+
kwargs: Dict[str, Expr]
115+
def structurally_equals(self, other: Any) -> bool: ...
116+
117+
class DownloadExpr(Expr):
118+
uri_column_name: str
119+
def structurally_equals(self, other: Any) -> bool: ...
120+
121+
class AliasExpr(Expr):
122+
expr: Expr
123+
_name: str
124+
_is_rename: bool
125+
@property
126+
def name(self) -> str: ...
127+
def alias(self, name: str) -> Expr: ...
128+
def _unalias(self) -> Expr: ...
129+
def structurally_equals(self, other: Any) -> bool: ...
130+
131+
class StarExpr(Expr):
132+
def structurally_equals(self, other: Any) -> bool: ...
133+
134+
class _ListNamespace:
135+
"""Namespace for list operations."""
136+
_expr: Expr
137+
138+
# Indexing and slicing
139+
def __getitem__(self, key: Union[int, slice]) -> UDFExpr: ...
140+
def get(self, index: int) -> UDFExpr: ...
141+
def slice(self, start: int = None, stop: int = None, step: int = None) -> UDFExpr: ...
142+
143+
# Auto-generated methods
144+
def len(self) -> UDFExpr: ...
145+
def flatten(self) -> UDFExpr: ...
146+
147+
class _StringNamespace:
148+
"""Namespace for string operations."""
149+
_expr: Expr
150+
151+
# Auto-generated length methods
152+
def len(self) -> UDFExpr: ...
153+
def byte_len(self) -> UDFExpr: ...
154+
155+
# Auto-generated case conversion
156+
def upper(self) -> UDFExpr: ...
157+
def lower(self) -> UDFExpr: ...
158+
def capitalize(self) -> UDFExpr: ...
159+
def title(self) -> UDFExpr: ...
160+
def swapcase(self) -> UDFExpr: ...
161+
162+
# Auto-generated predicates
163+
def is_alpha(self) -> UDFExpr: ...
164+
def is_alnum(self) -> UDFExpr: ...
165+
def is_digit(self) -> UDFExpr: ...
166+
def is_decimal(self) -> UDFExpr: ...
167+
def is_numeric(self) -> UDFExpr: ...
168+
def is_space(self) -> UDFExpr: ...
169+
def is_lower(self) -> UDFExpr: ...
170+
def is_upper(self) -> UDFExpr: ...
171+
def is_title(self) -> UDFExpr: ...
172+
def is_printable(self) -> UDFExpr: ...
173+
def is_ascii(self) -> UDFExpr: ...
174+
175+
# Auto-generated searching (parameterized)
176+
def starts_with(self, pattern: str, ignore_case: bool = False) -> UDFExpr: ...
177+
def ends_with(self, pattern: str, ignore_case: bool = False) -> UDFExpr: ...
178+
def contains(self, pattern: str, ignore_case: bool = False) -> UDFExpr: ...
179+
def match(self, pattern: str, ignore_case: bool = False) -> UDFExpr: ...
180+
def find(self, pattern: str, ignore_case: bool = False) -> UDFExpr: ...
181+
def count(self, pattern: str, ignore_case: bool = False) -> UDFExpr: ...
182+
def find_regex(self, pattern: str, ignore_case: bool = False) -> UDFExpr: ...
183+
def count_regex(self, pattern: str, ignore_case: bool = False) -> UDFExpr: ...
184+
def match_regex(self, pattern: str, ignore_case: bool = False) -> UDFExpr: ...
185+
186+
# Auto-generated transformations
187+
def reverse(self) -> UDFExpr: ...
188+
189+
# Manual methods (complex logic)
190+
def strip(self, characters: str = None) -> UDFExpr: ...
191+
def lstrip(self, characters: str = None) -> UDFExpr: ...
192+
def rstrip(self, characters: str = None) -> UDFExpr: ...
193+
def pad(
194+
self,
195+
width: int,
196+
fillchar: str = " ",
197+
side: Literal["left", "right", "both"] = "right",
198+
) -> UDFExpr: ...
199+
def center(self, width: int, fillchar: str = " ") -> UDFExpr: ...
200+
def slice(self, start: int, stop: int = None, step: int = 1) -> UDFExpr: ...
201+
def replace(
202+
self, pattern: str, replacement: str, max_replacements: int = None
203+
) -> UDFExpr: ...
204+
def replace_regex(
205+
self, pattern: str, replacement: str, max_replacements: int = None
206+
) -> UDFExpr: ...
207+
def replace_slice(self, start: int, stop: int, replacement: str) -> UDFExpr: ...
208+
def split(
209+
self, pattern: str, max_splits: int = None, reverse: bool = False
210+
) -> UDFExpr: ...
211+
def split_regex(
212+
self, pattern: str, max_splits: int = None, reverse: bool = False
213+
) -> UDFExpr: ...
214+
def split_whitespace(
215+
self, max_splits: int = None, reverse: bool = False
216+
) -> UDFExpr: ...
217+
def extract(self, pattern: str) -> UDFExpr: ...
218+
def repeat(self, n: int) -> UDFExpr: ...
219+
220+
class _StructNamespace:
221+
"""Namespace for struct operations."""
222+
_expr: Expr
223+
224+
def __getitem__(self, field_name: str) -> UDFExpr: ...
225+
def field(self, field_name: str) -> UDFExpr: ...
226+
227+
# ──────────────────────────────────────
228+
# Public API Functions
229+
# ──────────────────────────────────────
230+
231+
def col(name: str) -> ColumnExpr: ...
232+
def lit(value: Any) -> LiteralExpr: ...
233+
def star() -> StarExpr: ...
234+
def download(uri_column_name: str) -> DownloadExpr: ...
235+
def udf(return_dtype: DataType) -> Callable[..., Callable[..., UDFExpr]]: ...
236+
237+
__all__ = [
238+
"Operation",
239+
"Expr",
240+
"ColumnExpr",
241+
"LiteralExpr",
242+
"BinaryExpr",
243+
"UnaryExpr",
244+
"UDFExpr",
245+
"DownloadExpr",
246+
"AliasExpr",
247+
"StarExpr",
248+
"udf",
249+
"col",
250+
"lit",
251+
"download",
252+
"star",
253+
]

0 commit comments

Comments
 (0)