diff --git a/python/ray/data/EXPECTATIONS_COMPARISON.md b/python/ray/data/EXPECTATIONS_COMPARISON.md new file mode 100644 index 000000000000..99d5e60db279 --- /dev/null +++ b/python/ray/data/EXPECTATIONS_COMPARISON.md @@ -0,0 +1,493 @@ +# Data Quality Framework Comparison: Ray Data Expectations vs Top 20 Frameworks + +## Executive Summary + +This document compares Ray Data's `expect()` API against 20 leading Python data quality and validation frameworks to identify best practices, missing features, and API improvements. + +## Framework Comparison Matrix + +### 1. Great Expectations +**API Pattern**: `expect_column_values_to_be_between()`, `expect_table_row_count_to_equal()` +**Key Features**: +- 300+ pre-built expectations +- JSON/YAML-based expectation definitions +- Data profiling and documentation +- Integration with Airflow, dbt, Prefect +- Expectation suites and data docs + +**API Example**: +```python +expectation_suite = ExpectationSuite(name="my_suite") +expectation_suite.add_expectation( + ExpectColumnValuesToBeBetween(column="age", min_value=0, max_value=120) +) +validator.validate(dataset, expectation_suite) +``` + +**Recommendations for Ray Data**: +- ✅ Add pre-built expectation helpers (e.g., `expect_column_min()`, `expect_column_max()`) +- ✅ Support expectation suites/groups +- ✅ Add data profiling capabilities +- ✅ Consider YAML/JSON serialization for expectations + +### 2. Pandera +**API Pattern**: Schema-based validation with decorators +**Key Features**: +- Type-safe schema definitions +- Runtime validation +- Integration with pandas, pyspark, dask +- Rich error messages + +**API Example**: +```python +schema = DataFrameSchema({ + "age": Column(Int, Check.in_range(0, 120)), + "email": Column(String, Check.str_matches(r".+@.+\..+")) +}) +schema.validate(df) +``` + +**Recommendations for Ray Data**: +- ✅ Add schema-based validation support +- ✅ Improve error messages with column-level details +- ✅ Support type checking beyond boolean predicates + +### 3. Pydantic +**API Pattern**: Model-based validation with type hints +**Key Features**: +- Type validation using Python type hints +- Automatic data coercion +- JSON schema generation +- Fast validation + +**API Example**: +```python +class User(BaseModel): + age: conint(ge=0, le=120) + email: EmailStr + name: constr(min_length=1) + +User(**data) # Validates automatically +``` + +**Recommendations for Ray Data**: +- ✅ Consider supporting Pydantic models for validation +- ✅ Add automatic type coercion +- ✅ Generate schemas from type hints + +### 4. Soda Core +**API Pattern**: YAML-based checks (SodaCL) +**Key Features**: +- Human-readable YAML syntax +- Data freshness checks +- Schema evolution tracking +- Integration with 18+ data sources + +**API Example**: +```yaml +checks for dataset: + - row_count > 0 + - missing_count(column: age) = 0 + - invalid_percent(column: email) < 5% +``` + +**Recommendations for Ray Data**: +- ✅ Add data freshness checks +- ✅ Support declarative YAML/JSON configuration +- ✅ Add schema evolution tracking + +### 5. cuallee +**API Pattern**: DataFrame-agnostic validation +**Key Features**: +- Multi-engine support (Spark, Pandas, Polars) +- High performance +- Simple API + +**API Example**: +```python +check = Check() +check.is_complete(df, "age") +check.is_between(df, "age", 0, 120) +``` + +**Recommendations for Ray Data**: +- ✅ Optimize for distributed execution (already good) +- ✅ Add more built-in checks +- ✅ Benchmark performance + +### 6. Cerberus +**API Pattern**: Schema-based validation +**Key Features**: +- JSON schema-like validation +- Custom validators +- Normalization + +**API Example**: +```python +schema = { + "age": {"type": "integer", "min": 0, "max": 120}, + "email": {"type": "string", "regex": r".+@.+\..+"} +} +v = Validator(schema) +v.validate(data) +``` + +**Recommendations for Ray Data**: +- ✅ Support JSON schema validation +- ✅ Add normalization/coercion options + +### 7. Voluptuous +**API Pattern**: Schema objects with validators +**Key Features**: +- Declarative schemas +- Custom validators +- Error messages + +**API Example**: +```python +schema = Schema({ + "age": All(int, Range(min=0, max=120)), + "email": Email() +}) +schema(data) +``` + +**Recommendations for Ray Data**: +- ✅ Add declarative schema support +- ✅ Improve composability of validators + +### 8. Marshmallow +**API Pattern**: Schema classes with fields +**Key Features**: +- Object serialization/deserialization +- Field validation +- Nested schemas + +**API Example**: +```python +class UserSchema(Schema): + age = fields.Int(validate=Range(min=0, max=120)) + email = fields.Email() + +UserSchema().validate(data) +``` + +**Recommendations for Ray Data**: +- ✅ Support nested schema validation +- ✅ Add serialization capabilities + +### 9. dbt (data tests) +**API Pattern**: SQL-based tests in YAML +**Key Features**: +- SQL-based validation +- Integration with data warehouses +- Test results in metadata + +**API Example**: +```yaml +models: + - name: users + tests: + - dbt_utils.expression_is_true: + expression: "age >= 0 AND age <= 120" +``` + +**Recommendations for Ray Data**: +- ✅ Support SQL-based expectations +- ✅ Add test result metadata/storage + +### 10. pytest-deadfixtures +**API Pattern**: pytest plugin for fixture validation +**Key Features**: +- Fixture lifecycle validation +- Test discovery + +**Recommendations for Ray Data**: +- ✅ Add pytest integration +- ✅ Support test discovery + +### 11. Hypothesis +**API Pattern**: Property-based testing +**Key Features**: +- Generative testing +- Shrinking failures +- Statistical validation + +**API Example**: +```python +@given(st.integers(min_value=0, max_value=120)) +def test_age_valid(age): + assert validate_age(age) +``` + +**Recommendations for Ray Data**: +- ✅ Add property-based testing support +- ✅ Generate test cases automatically + +### 12. PyExamine +**API Pattern**: Code quality analysis +**Key Features**: +- Multi-level analysis +- 49 distinct metrics +- High accuracy + +**Recommendations for Ray Data**: +- ✅ Add code quality metrics for validators +- ✅ Static analysis of expectation expressions + +### 13. Stream DaQ +**API Pattern**: Streaming data quality +**Key Features**: +- Window-based validation +- Dynamic constraint adaptation +- 30+ quality checks + +**API Example**: +```python +checker = StreamChecker(window_size=1000) +checker.add_check("age", Range(0, 120)) +checker.validate(stream) +``` + +**Recommendations for Ray Data**: +- ✅ Add streaming/windowed validation +- ✅ Support dynamic threshold adjustment +- ✅ Add more built-in checks + +### 14. PySAD +**API Pattern**: Anomaly detection +**Key Features**: +- Streaming anomaly detection +- Multiple algorithms +- Integration with scikit-learn + +**Recommendations for Ray Data**: +- ✅ Add anomaly detection expectations +- ✅ Support statistical validation + +### 15. Frictionless Data +**API Pattern**: Table Schema validation +**Key Features**: +- Table Schema specification +- Data package validation +- CSV/JSON validation + +**API Example**: +```python +schema = { + "fields": [ + {"name": "age", "type": "integer", "constraints": {"minimum": 0, "maximum": 120}} + ] +} +validate(data, schema=schema) +``` + +**Recommendations for Ray Data**: +- ✅ Support Table Schema format +- ✅ Add data package validation + +### 16. jsonschema +**API Pattern**: JSON Schema validation +**Key Features**: +- Standard JSON Schema support +- Draft version support +- Error reporting + +**API Example**: +```python +schema = { + "type": "object", + "properties": { + "age": {"type": "integer", "minimum": 0, "maximum": 120} + } +} +validate(data, schema) +``` + +**Recommendations for Ray Data**: +- ✅ Support JSON Schema validation +- ✅ Add standard schema format support + +### 17. validators +**API Pattern**: Simple validation functions +**Key Features**: +- Email, URL, IP validation +- Simple API +- Lightweight + +**API Example**: +```python +from validators import email, url +email("test@example.com") +url("https://example.com") +``` + +**Recommendations for Ray Data**: +- ✅ Add common validators (email, URL, etc.) +- ✅ Keep API simple and composable + +### 18. schema +**API Pattern**: Schema objects +**Key Features**: +- Simple schema definitions +- Error handling +- Type validation + +**API Example**: +```python +schema = Schema({ + "age": And(int, lambda n: 0 <= n <= 120), + "email": And(str, Use(str.lower)) +}) +schema.validate(data) +``` + +**Recommendations for Ray Data**: +- ✅ Improve error messages +- ✅ Add schema composition + +### 19. pyvalid +**API Pattern**: Function decorators +**Key Features**: +- Function argument validation +- Type checking +- Range validation + +**API Example**: +```python +@validate_args(age=Range(0, 120), email=Email()) +def process_user(age, email): + pass +``` + +**Recommendations for Ray Data**: +- ✅ Add decorator support for validators +- ✅ Support function argument validation + +### 20. dataclass-validator +**API Pattern**: Dataclass validation +**Key Features**: +- Type validation +- Custom validators +- Integration with dataclasses + +**API Example**: +```python +@dataclass +class User: + age: int = field(validator=Range(0, 120)) + email: str = field(validator=Email()) +``` + +**Recommendations for Ray Data**: +- ✅ Support dataclass validation +- ✅ Add type-based validation + +## Key Findings and Recommendations + +### Strengths of Current Ray Data Implementation + +1. ✅ **Distributed execution** - Excellent support for large-scale validation +2. ✅ **Expression-based API** - Clean, composable predicate expressions +3. ✅ **Quarantine workflow** - Returns passed/failed datasets (unique feature) +4. ✅ **Execution time expectations** - Novel feature not found in most frameworks +5. ✅ **Integration with Ray Data** - Native integration with Dataset operations + +### Missing Features Compared to Top Frameworks + +1. **Pre-built Expectations** (Great Expectations, cuallee) + - Add helpers like `expect_column_min()`, `expect_column_max()`, `expect_column_null_count()` + - Common validations: email, URL, regex, range, uniqueness + +2. **Schema-based Validation** (Pandera, Pydantic, Marshmallow) + - Support schema objects for type checking + - Automatic type coercion + - Nested schema validation + +3. **Expectation Suites/Groups** (Great Expectations) + - Group related expectations + - Run multiple expectations together + - Aggregate results + +4. **Data Profiling** (Great Expectations) + - Automatic data profiling + - Statistical summaries + - Data documentation + +5. **Better Error Messages** (Pandera, Pydantic) + - Column-level error details + - Row-level error context + - Sample failed values (partially implemented) + +6. **YAML/JSON Configuration** (Soda Core, Great Expectations) + - Declarative configuration + - Version control friendly + - Reusable expectation definitions + +7. **Data Freshness Checks** (Soda Core) + - Time-based validation + - Data staleness detection + +8. **Schema Evolution Tracking** (Soda Core) + - Track schema changes over time + - Detect breaking changes + +9. **Statistical Validation** (Hypothesis, PySAD) + - Distribution checks + - Statistical tests + - Anomaly detection + +10. **Streaming/Windowed Validation** (Stream DaQ) + - Window-based checks + - Dynamic thresholds + - Real-time validation + +### API Design Improvements + +1. **Consistency**: Follow patterns from Great Expectations and Pandera + - Use `expect_*` naming convention for helpers + - Support both programmatic and declarative APIs + +2. **Composability**: Learn from Voluptuous and schema + - Make expectations composable + - Support expectation chaining + +3. **Error Handling**: Improve based on Pydantic and Pandera + - Rich error objects with context + - Column-level and row-level error details + - Error aggregation and reporting + +4. **Type Safety**: Adopt patterns from Pydantic + - Type hints for expectations + - Automatic type validation + - Schema generation + +5. **Performance**: Benchmark against cuallee + - Optimize distributed execution + - Add caching for repeated validations + - Profile and optimize hot paths + +## Recommended Next Steps + +### Phase 1: Core Improvements (High Priority) +1. Add pre-built expectation helpers (`expect_column_min`, `expect_column_max`, etc.) +2. Improve error messages with column/row context +3. Add expectation suites/groups +4. Support YAML/JSON configuration + +### Phase 2: Advanced Features (Medium Priority) +1. Add schema-based validation +2. Implement data profiling +3. Add data freshness checks +4. Support SQL-based expectations + +### Phase 3: Advanced Capabilities (Low Priority) +1. Add streaming/windowed validation +2. Implement anomaly detection +3. Add statistical validation +4. Support property-based testing + +## Conclusion + +Ray Data's `expect()` API has a solid foundation with unique strengths in distributed execution and quarantine workflows. By adopting best practices from frameworks like Great Expectations, Pandera, and Pydantic, we can significantly enhance the API's usability, feature set, and developer experience while maintaining its distributed-first architecture. + diff --git a/python/ray/data/__init__.py b/python/ray/data/__init__.py index b0a7a188e053..8281e5677ad1 100644 --- a/python/ray/data/__init__.py +++ b/python/ray/data/__init__.py @@ -14,7 +14,9 @@ ) from ray.data._internal.logging import configure_logging from ray.data.context import DataContext, DatasetContext -from ray.data.dataset import Dataset, Schema, SinkMode, ClickHouseTableSettings +from ray.data.expectations import ( + expect, +) from ray.data.datasource import ( BlockBasedFileDatasink, Datasink, @@ -127,6 +129,7 @@ "Datasource", "ExecutionOptions", "ExecutionResources", + "expect", "FileShuffleConfig", "NodeIdStr", "ReadTask", diff --git a/python/ray/data/_internal/execution/interfaces/execution_options.py b/python/ray/data/_internal/execution/interfaces/execution_options.py index 3edfa2dceda5..851ecb48b59c 100644 --- a/python/ray/data/_internal/execution/interfaces/execution_options.py +++ b/python/ray/data/_internal/execution/interfaces/execution_options.py @@ -1,11 +1,16 @@ import math import os -from typing import Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from .common import NodeIdStr from ray.data._internal.execution.util import memory_string from ray.util.annotations import DeveloperAPI +if TYPE_CHECKING: + from ray.data._internal.execution.interfaces.execution_options import ( + ExecutionResources, + ) + class ExecutionResources: """Specifies resources usage or resource limits for execution. @@ -346,6 +351,37 @@ def is_resource_limits_default(self): """Returns True if resource_limits is the default value.""" return self._resource_limits == ExecutionResources.for_limits() + def copy( + self, + resource_limits: Optional[ExecutionResources] = None, + exclude_resources: Optional[ExecutionResources] = None, + locality_with_output: Optional[Union[bool, List[NodeIdStr]]] = None, + preserve_order: Optional[bool] = None, + actor_locality_enabled: Optional[bool] = None, + verbose_progress: Optional[bool] = None, + ) -> "ExecutionOptions": + """Return a copy of this object, overriding specified fields. + + Args: + resource_limits: Optional resource limits override. + exclude_resources: Optional exclude resources override. + locality_with_output: Optional locality override. + preserve_order: Optional preserve order override. + actor_locality_enabled: Optional actor locality override. + verbose_progress: Optional verbose progress override. + + Returns: + A new ExecutionOptions object with overridden fields. + """ + return ExecutionOptions( + resource_limits=resource_limits if resource_limits is not None else self.resource_limits, + exclude_resources=exclude_resources if exclude_resources is not None else self.exclude_resources, + locality_with_output=locality_with_output if locality_with_output is not None else self.locality_with_output, + preserve_order=preserve_order if preserve_order is not None else self.preserve_order, + actor_locality_enabled=actor_locality_enabled if actor_locality_enabled is not None else self.actor_locality_enabled, + verbose_progress=verbose_progress if verbose_progress is not None else self.verbose_progress, + ) + def validate(self) -> None: """Validate the options.""" for attr in ["cpu", "gpu", "object_store_memory"]: diff --git a/python/ray/data/_internal/execution/resource_manager.py b/python/ray/data/_internal/execution/resource_manager.py index 58f7b6837c36..f353449bd878 100644 --- a/python/ray/data/_internal/execution/resource_manager.py +++ b/python/ray/data/_internal/execution/resource_manager.py @@ -90,6 +90,9 @@ def __init__( self._op_resource_allocator: Optional["OpResourceAllocator"] = None + # Use reservation ratio from data context + reservation_ratio = data_context.op_resource_reservation_ratio + if data_context.op_resource_reservation_enabled: # We'll enable memory reservation if all operators have # implemented accurate memory accounting. @@ -98,10 +101,11 @@ def __init__( ) if should_enable: self._op_resource_allocator = ReservationOpResourceAllocator( - self, data_context.op_resource_reservation_ratio + self, reservation_ratio ) - self._object_store_memory_limit_fraction = ( + # Set object store memory limit fraction + base_memory_fraction = ( data_context.override_object_store_memory_limit_fraction if data_context.override_object_store_memory_limit_fraction is not None else ( @@ -110,6 +114,7 @@ def __init__( else self.DEFAULT_OBJECT_STORE_MEMORY_LIMIT_FRACTION_NO_RESERVATION ) ) + self._object_store_memory_limit_fraction = base_memory_fraction self._warn_about_object_store_memory_if_needed() @@ -135,8 +140,8 @@ def _warn_about_object_store_memory_if_needed(self): ): logger.warning( f"{WARN_PREFIX} Ray's object store is configured to use only " - f"{object_store_fraction:.1%} of available memory ({object_store_memory/GiB:.1f}GiB " - f"out of {total_memory/GiB:.1f}GiB total). For optimal Ray Data performance, " + f"{object_store_fraction:.1%} of available memory ({object_store_memory / GiB:.1f}GiB " + f"out of {total_memory / GiB:.1f}GiB total). For optimal Ray Data performance, " f"we recommend setting the object store to at least 50% of available memory. " f"You can do this by setting the 'object_store_memory' parameter when calling " f"ray.init() or by setting the RAY_DEFAULT_OBJECT_STORE_MEMORY_PROPORTION environment variable." diff --git a/python/ray/data/_internal/plan.py b/python/ray/data/_internal/plan.py index 1e496fe0d6f7..01a7f346116e 100644 --- a/python/ray/data/_internal/plan.py +++ b/python/ray/data/_internal/plan.py @@ -18,6 +18,7 @@ from ray.data.block import BlockMetadataWithSchema, _take_first_non_empty_schema from ray.data.context import DataContext from ray.data.exceptions import omit_traceback_stdout +from ray.data.expectations import ExecutionTimeExpectation from ray.util.debug import log_once if TYPE_CHECKING: @@ -88,6 +89,57 @@ def __init__( self._context = data_context + # Track execution time expectations for optimization hints + self._execution_time_expectations: List[ExecutionTimeExpectation] = [] + + def add_execution_time_expectation( + self, expectation: ExecutionTimeExpectation + ) -> None: + """Add an execution time expectation to this execution plan. + + This allows the execution plan to track execution time requirements and + inform optimization strategies. + + Args: + expectation: The execution time expectation to add. + + Raises: + TypeError: If expectation is not an ExecutionTimeExpectation instance. + """ + if not isinstance(expectation, ExecutionTimeExpectation): + raise TypeError( + f"Expected ExecutionTimeExpectation, got {type(expectation).__name__}" + ) + self._execution_time_expectations.append(expectation) + + def get_execution_time_expectations(self) -> List[ExecutionTimeExpectation]: + """Get all execution time expectations attached to this plan. + + Returns: + List of execution time expectations. + """ + return self._execution_time_expectations.copy() + + def get_max_execution_time_seconds(self) -> Optional[float]: + """Get the maximum execution time from execution time expectations. + + Returns: + Minimum max execution time from all execution time expectations, or None if no time constraints. + """ + if not self._execution_time_expectations: + return None + + max_times = [ + exp.get_max_execution_time_seconds() + for exp in self._execution_time_expectations + if exp.get_max_execution_time_seconds() is not None + ] + + if not max_times: + return None + + return min(max_times) + def get_dataset_id(self) -> str: """Unique ID of the dataset, including the dataset name, UUID, and current execution index. @@ -101,6 +153,7 @@ def create_executor(self) -> "StreamingExecutor": from ray.data._internal.execution.streaming_executor import StreamingExecutor self._run_index += 1 + executor = StreamingExecutor(self._context, self.get_dataset_id()) return executor @@ -128,7 +181,6 @@ def explain(self) -> str: sections = [] for title, convert_fn in zip(titles, convert_fns): - # 2. Convert plan to new plan plan = convert_fn(plan) @@ -359,6 +411,7 @@ def copy(self) -> "ExecutionPlan": plan_copy._snapshot_operator = self._snapshot_operator plan_copy._snapshot_stats = self._snapshot_stats plan_copy._dataset_name = self._dataset_name + plan_copy._execution_time_expectations = self._execution_time_expectations.copy() return plan_copy def deep_copy(self) -> "ExecutionPlan": @@ -379,6 +432,9 @@ def deep_copy(self) -> "ExecutionPlan": plan_copy._snapshot_operator = copy.copy(self._snapshot_operator) plan_copy._snapshot_stats = copy.copy(self._snapshot_stats) plan_copy._dataset_name = self._dataset_name + plan_copy._execution_time_expectations = copy.deepcopy( + self._execution_time_expectations + ) return plan_copy def initial_num_blocks(self) -> Optional[int]: diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index f6b630c34d09..735622c0de08 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -5,17 +5,15 @@ import logging import time import warnings +from collections.abc import Iterable, Iterator, Mapping from typing import ( TYPE_CHECKING, Any, Callable, Dict, Generic, - Iterable, - Iterator, List, Literal, - Mapping, Optional, Tuple, TypeVar, @@ -134,9 +132,16 @@ from tensorflow_metadata.proto.v0 import schema_pb2 from ray.data._internal.execution.interfaces import Executor, NodeIdStr + from ray.data.expressions import Expr from ray.data.grouped_data import GroupedData -from ray.data.expressions import Expr, StarExpr, col +from ray.data.expectations import ( + DataQualityExpectation, + ExecutionTimeExpectation, + Expectation, + ExpectationResult, + ExpectationType, +) logger = logging.getLogger(__name__) @@ -174,7 +179,8 @@ class Dataset: Datasets can be created in multiple ways: - * from external storage systems such as local disk, S3, HDFS etc. via the ``read_*()`` APIs. + * from external storage systems such as local disk, S3, HDFS etc. + via the ``read_*()`` APIs. * from existing memory data via ``from_*()`` APIs * from synthetic data via ``range_*()`` APIs @@ -210,7 +216,8 @@ class Dataset: * transformations such as :py:meth:`.map_batches()` * aggregations such as :py:meth:`.min()`/:py:meth:`.max()`/:py:meth:`.mean()`, * grouping via :py:meth:`.groupby()`, - * shuffling operations such as :py:meth:`.sort()`, :py:meth:`.random_shuffle()`, and :py:meth:`.repartition()` + * shuffling operations such as :py:meth:`.sort()`, + :py:meth:`.random_shuffle()`, and :py:meth:`.repartition()` * joining via :py:meth:`.join()` Examples: @@ -405,6 +412,7 @@ def parse_filename(row: Dict[str, Any]) -> Dict[str, Any]: ) plan = self._plan.copy() + map_op = MapRows( self._logical_plan.dag, fn, @@ -769,6 +777,7 @@ def _map_batches_without_batch_size_validation( ) plan = self._plan.copy() + map_batches_op = MapBatches( self._logical_plan.dag, fn, @@ -792,7 +801,7 @@ def _map_batches_without_batch_size_validation( def with_column( self, column_name: str, - expr: Expr, + expr: "Expr", **ray_remote_args, ) -> "Dataset": """ @@ -838,7 +847,7 @@ def with_column( from ray.data._internal.logical.operators.one_to_one_operator import Download # TODO: Once the expression API supports UDFs, we can clean up the code here. - from ray.data.expressions import DownloadExpr + from ray.data.expressions import DownloadExpr, StarExpr plan = self._plan.copy() if isinstance(expr, DownloadExpr): @@ -1187,6 +1196,8 @@ def rename_columns( "to be strings." ) + from ray.data.expressions import col + exprs = [col(prev)._rename(new) for prev, new in names.items()] elif isinstance(names, list): @@ -1212,6 +1223,8 @@ def rename_columns( f"schema names: {current_names}." ) + from ray.data.expressions import col + exprs = [col(prev)._rename(new) for prev, new in zip(current_names, names)] else: raise TypeError( @@ -1227,6 +1240,7 @@ def rename_columns( # Construct the plan and project operation from ray.data._internal.compute import TaskPoolStrategy + from ray.data.expressions import StarExpr compute = TaskPoolStrategy(size=concurrency) @@ -1383,7 +1397,7 @@ def duplicate_row(row: Dict[str, Any]) -> List[Dict[str, Any]]: def filter( self, fn: Optional[UserDefinedFunction[Dict[str, Any], bool]] = None, - expr: Optional[Union[str, Expr]] = None, + expr: Optional[Union[str, "Expr"]] = None, *, compute: Union[str, ComputeStrategy] = None, fn_args: Optional[Iterable[Any]] = None, @@ -1500,7 +1514,7 @@ def _check_fn_params_incompatible(param_type): # Initialize Filter operator arguments with proper types input_op = self._logical_plan.dag - predicate_expr: Optional[Expr] = None + predicate_expr: Optional["Expr"] = None filter_fn: Optional[UserDefinedFunction] = None filter_fn_args: Optional[Iterable[Any]] = None filter_fn_kwargs: Optional[Dict[str, Any]] = None @@ -1578,6 +1592,1054 @@ def _check_fn_params_incompatible(param_type): logical_plan = LogicalPlan(filter_op, self.context) return Dataset(plan, logical_plan) + @PublicAPI(api_group=BT_API_GROUP, stability="alpha") + def expect( + self, + expectation: Optional[Union[Expectation, "Expr", List[Expectation]]] = None, + *, + expr: Optional["Expr"] = None, + name: Optional[str] = None, + description: Optional[str] = None, + validator_fn: Optional[Callable[[Any], bool]] = None, + max_execution_time_seconds: Optional[float] = None, + error_on_failure: bool = True, + compute: Optional[ComputeStrategy] = None, + num_cpus: Optional[float] = None, + num_gpus: Optional[float] = None, + memory: Optional[float] = None, + **ray_remote_args, + ) -> Union[ + Tuple["Dataset", "Dataset", ExpectationResult], + Tuple["Dataset", "Dataset", List[ExpectationResult]], + ]: + """Apply a data quality expectation to validate this dataset. + + This method validates data quality constraints and returns datasets containing + rows that passed and failed validation. The dataset is materialized to ensure + all validation runs complete, and results are aggregated across all batches. + + .. tip:: + If you use the `expr` parameter with a predicate expression, Ray Data + optimizes your validation with native Arrow interfaces, similar to + :meth:`~Dataset.filter`. For validator functions, batches are processed + efficiently using :meth:`~Dataset.map_batches`. + + Examples: + >>> import ray + >>> from ray.data.expectations import expect + >>> from ray.data.expressions import col + >>> + >>> # Simple expression-based validation (most common) + >>> ds = ray.data.from_items([{"value": 1}, {"value": 2}, {"value": -1}]) + >>> passed_ds, failed_ds, result = ds.expect(expr=col("value") > 0) + >>> print(result.passed) + False + >>> print(failed_ds.take_all()) + [{'value': -1}] + >>> + >>> # Pass expression as positional argument with optional name + >>> passed_ds, failed_ds, result = ds.expect(col("value") > 0, name="positive_values") + >>> + >>> # Or pass expression as positional argument + >>> passed_ds, failed_ds, result = ds.expect(col("value") > 0) + >>> + >>> # Multiple expectations (pass as list) + >>> expectations = [ + ... expect(expr=col("age") >= 0), + ... expect(expr=col("email").is_not_null()) + ... ] + >>> passed_ds, failed_ds, results = ds.expect(expectations) + >>> + >>> # Quarantine workflows + >>> raw_ds = ray.data.from_items([{"user_id": 1, "score": 95}, {"user_id": 2, "score": -5}]) + >>> valid_ds, invalid_ds, result = raw_ds.expect(expr=col("score") >= 0) + >>> valid_ds.write_parquet("s3://bucket/valid/") + >>> invalid_ds.write_parquet("s3://bucket/quarantine/") + >>> + >>> # Execution time expectations: execution time constraints + >>> ds = ray.data.range(1000000) + >>> processed_ds, remaining_ds, result = ds.expect( + ... max_execution_time_seconds=60 + ... ) + + Time complexity: O(dataset size / parallelism) + + Args: + expectation: An Expectation object returned from `expect()`, an expression (Expr), + or a list of expectations. If None, `expr` or other parameters must be provided. + Mutually exclusive with `expr`. + expr: An expression that represents a predicate (boolean condition) for + validation. Uses the same expression API as :meth:`~Dataset.filter`. + Mutually exclusive with `expectation`. Can also be passed as a positional + argument: ``ds.expect(col("value") > 0, name="check")``. + name: Name for the expectation (only used when creating from expr/validator_fn/execution time). + description: Description of what this expectation checks (only used when creating + from expr/validator_fn/execution time). + validator_fn: Function for data quality validation (takes batch, returns bool). + Mutually exclusive with `expr` and `expectation`. + max_execution_time_seconds: Maximum execution time in seconds (for execution time expectations). + Mutually exclusive with `expr` and `expectation`. + error_on_failure: If True, raise exception on failure; if False, log warning. + compute: The compute strategy to use for the validation operation. + + * If ``compute`` is not specified for a function, will use ``ray.data.TaskPoolStrategy()`` to launch concurrent tasks based on the available resources and number of input blocks. + + * Use ``ray.data.TaskPoolStrategy(size=n)`` to launch at most ``n`` concurrent Ray tasks. + + * If ``compute`` is not specified for a callable class, will use ``ray.data.ActorPoolStrategy(min_size=1, max_size=None)`` to launch an autoscaling actor pool from 1 to unlimited workers. + + * Use ``ray.data.ActorPoolStrategy(size=n)`` to use a fixed size actor pool of ``n`` workers. + + * Use ``ray.data.ActorPoolStrategy(min_size=m, max_size=n)`` to use an autoscaling actor pool from ``m`` to ``n`` workers. + + * Use ``ray.data.ActorPoolStrategy(min_size=m, max_size=n, initial_size=initial)`` to use an autoscaling actor pool from ``m`` to ``n`` workers, with an initial size of ``initial``. + + num_cpus: The number of CPUs to reserve for each parallel worker. + num_gpus: The number of GPUs to reserve for each parallel worker. For + example, specify `num_gpus=1` to request 1 GPU for each parallel worker. + memory: The heap memory in bytes to reserve for each parallel worker. + **ray_remote_args: Additional resource requirements to request from + Ray for each parallel worker. See :func:`ray.remote` for details. + + Returns: + A tuple of (passed_ds, failed_ds, ExpectationResult) where: + + - passed_ds: Dataset containing rows that passed validation (for data quality) + or data processed before timeout (for execution time expectations). + - failed_ds: Dataset containing rows that failed validation (for data quality) + or remaining unprocessed data (for execution time expectations with timeout). + - ExpectationResult: The result of the expectation validation containing + pass/fail status, failure counts, execution time (for execution time expectations), and a descriptive message. + For lists of expectations, returns a list of ExpectationResult objects. + + Raises: + TypeError: If expectation is not an Expectation object or Expr. + ValueError: If neither expectation nor expr is provided, or if both are provided. + + .. seealso:: + + :meth:`~Dataset.filter` + Filter rows based on a predicate expression. + + :meth:`~Dataset.map_batches` + Transform batches of data. + """ + # Import here to avoid circular dependencies + from ray.data.expressions import Expr as _Expr + + # Validate max_execution_time_seconds + if max_execution_time_seconds is not None: + if max_execution_time_seconds <= 0: + raise ValueError( + f"max_execution_time_seconds must be positive, " + f"got {max_execution_time_seconds}" + ) + + # Validate that at least one of expectation or expr is provided + if ( + expectation is None + and expr is None + and validator_fn is None + and max_execution_time_seconds is None + ): + raise ValueError( + "Must provide at least one of: `expectation`, `expr`, `validator_fn`, " + "or `max_execution_time_seconds`. " + "Examples: ds.expect(expr=col('value') > 0) or ds.expect(max_execution_time_seconds=60)" + ) + + # Handle lists of expectations (simplified suite functionality) + if isinstance(expectation, list): + return self._expect_list(expectation) + + # Handle expression-based expectations (more Pythonic and Ray-like) + # Similar to how filter() accepts expressions + if expr is not None or ( + expectation is not None and isinstance(expectation, _Expr) + ): + if expr is not None and expectation is not None: + raise ValueError( + "Cannot specify both `expectation` and `expr`. " + "Use either an Expectation object or an expression." + ) + if validator_fn is not None: + raise ValueError( + "Cannot specify both `expr` and `validator_fn`. " + "Use either expression-based validation or a validator function." + ) + if max_execution_time_seconds is not None: + raise ValueError( + "Cannot specify both `expr` and `max_execution_time_seconds`. " + "Use `expr` for data quality validation or `max_execution_time_seconds` for execution time expectations." + ) + + # Convert expression to expectation + if isinstance(expectation, _Expr): + expr = expectation + expectation = None + + from ray.data.expectations import expect as _expect + + # Create expectation from expression with optional parameters + expectation = _expect( + expr=expr, + name=name or "Data Quality Check", + description=description or f"Validate expression: {expr}", + error_on_failure=error_on_failure, + ) + elif expectation is None: + # No expression provided, check if execution time parameters are provided + if max_execution_time_seconds is not None: + if validator_fn is not None: + raise ValueError( + "Cannot specify both `validator_fn` and `max_execution_time_seconds`. " + "Use `validator_fn` for data quality validation or `max_execution_time_seconds` for execution time expectations." + ) + from ray.data.expectations import expect as _expect + + expectation = _expect( + max_execution_time_seconds=max_execution_time_seconds, + name=name or "Execution Time Requirement", + description=description or "Execution time constraint", + error_on_failure=error_on_failure, + ) + elif validator_fn is not None: + # Validate validator_fn is callable + if not callable(validator_fn): + raise TypeError( + f"validator_fn must be callable, got {type(validator_fn).__name__}" + ) + # Validate name and description are strings if provided + if name is not None and not isinstance(name, str): + raise TypeError(f"name must be str, got {type(name).__name__}") + if description is not None and not isinstance(description, str): + raise TypeError(f"description must be str, got {type(description).__name__}") + + from ray.data.expectations import expect as _expect + + expectation = _expect( + validator_fn=validator_fn, + name=name or "Data Quality Check", + description=description or "Data quality validation", + error_on_failure=error_on_failure, + ) + else: + # This should not happen due to earlier validation, but add check for clarity + raise ValueError( + "Must provide either `expr`, `validator_fn`, `max_execution_time_seconds`, " + "or an `expectation` object." + ) + + # Validate expectation object + if not isinstance(expectation, Expectation): + raise TypeError( + f"expectation must be an Expectation object or Expr, " + f"got {type(expectation).__name__ if expectation is not None else 'None'}. " + f"If you passed a list, use ds.expect([exp1, exp2]) instead of ds.expect(exp1, exp2)." + ) + + # Handle execution time expectations (time-based execution constraints) + if isinstance(expectation, ExecutionTimeExpectation): + return self._expect_execution_time(expectation) + + # Handle DataQualityExpectation + if expectation.expectation_type != ExpectationType.DATA_QUALITY: + raise ValueError( + f"Dataset.expect() only supports DataQualityExpectation or ExecutionTimeExpectation, " + f"got {expectation.expectation_type.value}" + ) + + # Ensure it's a DataQualityExpectation (should be guaranteed by now, but double-check) + if not isinstance(expectation, DataQualityExpectation): + raise TypeError( + f"expectation must be a DataQualityExpectation, got {type(expectation).__name__}" + ) + + # For expression-based expectations, use filter() pattern for efficient validation + # This reuses Ray Data's existing expression evaluation infrastructure + if isinstance(expectation, DataQualityExpectation) and hasattr( + expectation, "_expr" + ): + # Expression-based validation: use filter() to count passing rows + expr = getattr(expectation, "_expr", None) + if expr is not None: + total_rows = self.count() + if total_rows == 0: + # Empty dataset passes validation + empty_ds = self + result = ExpectationResult( + expectation=expectation, + passed=True, + message=f"Expectation '{expectation.name}' passed: empty dataset", + total_count=0, + failure_count=0, + ) + return empty_ds, empty_ds, result + + # Optimize: Use with_column to evaluate expression once, then filter twice + # This avoids evaluating the expression twice (once for passed, once for failed) + # Add validation flag column based on expression + # Check if column already exists (shouldn't happen, but be defensive) + validation_col_name = self._get_validation_column_name() + try: + validated_ds = self.with_column(validation_col_name, expr) + except Exception as e: + raise RuntimeError( + f"Failed to evaluate expression {expr} on dataset: {e}. " + f"This may indicate an issue with the expression or dataset schema." + ) from e + + # Split into passed and failed datasets using the validation flag + # Filtering on a boolean column is fast (expression already evaluated) + from ray.data.expressions import col, lit + + # Filter passed rows: must be True (not False, not NULL) + # NULL values from expression evaluation are treated as False (failed) + passed_ds = validated_ds.filter( + expr=col(validation_col_name).is_not_null() & (col(validation_col_name) == lit(True)) + ) + # Filter failed rows: False or NULL + failed_ds = validated_ds.filter( + expr=col(validation_col_name).is_null() | (col(validation_col_name) == lit(False)) + ) + + # Remove validation flag column from both datasets + # Get original columns (exclude validation column) + schema = self.schema() + if schema and hasattr(schema, "names"): + original_cols = [ + c for c in schema.names if c != validation_col_name + ] + # Only select columns if we have original columns and they exist in validated schema + if original_cols: + try: + validated_schema = validated_ds.schema() + if validated_schema and hasattr(validated_schema, "names"): + # Ensure all original columns exist in validated schema + validated_cols = set(validated_schema.names) + existing_cols = [c for c in original_cols if c in validated_cols] + if existing_cols: + passed_ds = passed_ds.select_columns(cols=existing_cols) + failed_ds = failed_ds.select_columns(cols=existing_cols) + except Exception: + # If schema access fails, skip column removal + pass + + failed_rows = failed_ds.count() + passed = failed_rows == 0 + + if passed: + message = ( + f"Expectation '{expectation.name}' passed: " + f"all {total_rows} rows validated successfully" + ) + else: + failure_rate = ( + (failed_rows / total_rows * 100) if total_rows > 0 else 0 + ) + + # Build detailed error message with context + # Sample failed values for better debugging + # Constants for error message sampling + MAX_SAMPLE_ROWS = 5 + MAX_SAMPLE_VALUES = 3 + + sample_failed_values = [] + failed_columns = [] + try: + # Get a sample of failed rows for context + # Limit to MAX_SAMPLE_ROWS to avoid performance issues + failed_sample = failed_ds.take(MAX_SAMPLE_ROWS) + if failed_sample and isinstance(failed_sample, list) and len(failed_sample) > 0: + # Extract column values from failed rows + first_row = failed_sample[0] + if isinstance(first_row, dict): + # Try to identify which columns are involved in the expectation + if hasattr(expectation, "_expr"): + expr = getattr(expectation, "_expr", None) + if expr is not None: + # Extract column references from expression + try: + from ray.data._internal.planner.plan_expression.expression_visitors import ( + _ColumnReferenceCollector, + ) + collector = _ColumnReferenceCollector() + collector.visit(expr) + col_refs = collector.get_column_refs() + failed_columns = list(col_refs) if col_refs else [] + except (ImportError, AttributeError) as e: + # Fallback to string parsing if collector not available + logger.debug( + f"Could not use _ColumnReferenceCollector to extract column references: {e}. " + "Falling back to regex parsing." + ) + expr_str = str(expr) + import re + matches = re.findall(r"col\(['\"](.*?)['\"]\)", expr_str) + failed_columns = list(set(matches)) + except Exception as e: + # Log unexpected errors but still fall back to regex + logger.warning( + f"Unexpected error extracting column references from expression: {e}. " + "Falling back to regex parsing.", + exc_info=True + ) + expr_str = str(expr) + import re + matches = re.findall(r"col\(['\"](.*?)['\"]\)", expr_str) + failed_columns = list(set(matches)) + + # Extract sample values from first column that exists + for col_name in failed_columns: + if col_name in first_row: + # Extract up to MAX_SAMPLE_VALUES sample values + sample_failed_values = [ + str(row.get(col_name, "N/A")) + for row in failed_sample[:min(MAX_SAMPLE_VALUES, len(failed_sample))] + ] + break # Use first available column + except Exception as e: + # Log sampling failure but continue without samples + logger.debug( + f"Failed to sample failed values for error message: {e}. " + "Continuing without sample values." + ) + + message = ( + f"Expectation '{expectation.name}' failed: " + f"{failed_rows}/{total_rows} rows failed validation ({failure_rate:.1f}%)" + ) + + # Add column context if available + if failed_columns: + if len(failed_columns) == 1: + message += f". Column: {failed_columns[0]}" + else: + message += f". Columns: {', '.join(failed_columns)}" + + # Add context about failed values + if sample_failed_values: + sample_str = ", ".join(sample_failed_values) + if len(sample_failed_values) >= 3: + message += f". Sample failed values: [{sample_str}]" + else: + message += f". Failed values: [{sample_str}]" + + # Add expectation expression/description for context + if hasattr(expectation, "_expr"): + expr_str = str(getattr(expectation, "_expr", "")) + message += f". Expectation: {expr_str}" + elif expectation.description: + message += f". {expectation.description}" + + # Add suggestion for fixing + if failure_rate > 50: + message += ( + ". Suggestion: Check data source for systematic issues or " + "review expectation criteria." + ) + elif failure_rate > 0: + message += ". Suggestion: Review failed rows to identify data quality issues." + + result = ExpectationResult( + expectation=expectation, + passed=passed, + message=message, + total_count=total_rows, + failure_count=failed_rows, + ) + return passed_ds, failed_ds, result + + # For validator function-based expectations, split into passed/failed datasets + # This enables quarantine workflows and data quality checks + @ray.remote(num_cpus=0, max_restarts=-1, max_task_retries=-1) + class ValidationAggregator: + """Simple actor to aggregate validation results across distributed tasks. + + This actor is lightweight and only aggregates counts, so it uses num_cpus=0 + to avoid consuming CPU resources unnecessarily. + Fault tolerance is enabled (max_restarts=-1, max_task_retries=-1) to ensure + validation results are not lost if the actor crashes. + """ + + def __init__(self): + self.failure_count = 0 + self.total_count = 0 + + def add_result(self, passed: bool): + """Add a validation result.""" + self.total_count += 1 + if not passed: + self.failure_count += 1 + + def get_results(self): + """Get aggregated validation results.""" + # If no batches were processed, consider it passing (empty dataset) + if self.total_count == 0: + return { + "passed": True, + "failure_count": 0, + "total_count": 0, + } + return { + "passed": self.failure_count == 0, + "failure_count": self.failure_count, + "total_count": self.total_count, + } + + try: + aggregator = ValidationAggregator.remote() + except Exception as e: + raise RuntimeError( + f"Failed to create validation aggregator actor: {e}. " + "This may indicate insufficient cluster resources or Ray configuration issues." + ) from e + + def _add_validation_flag(batch: Any, flag_value: bool, is_empty: bool) -> Any: + """Helper function to add validation flag to batch regardless of format. + + Creates a copy before modifying to avoid side effects on input batch. + """ + try: + import pandas as pd + + if isinstance(batch, pd.DataFrame): + # Create a copy to avoid modifying input batch + batch = batch.copy() + batch["_validation_passed"] = flag_value + return batch + except Exception: + pass + try: + import pyarrow as pa + + if isinstance(batch, pa.Table): + if is_empty: + # For empty tables, create an empty array with the flag value + # Use pa.nulls(0) and fill with scalar, or create array directly + passed_array = pa.array([], type=pa.bool_()) + # Empty table - just add empty column with correct type + passed_col = passed_array + else: + passed_array = pa.array([flag_value] * len(batch)) + passed_col = passed_array + return batch.append_column("_validation_passed", passed_col) + except Exception: + pass + # Fallback: add validation flag to dict batch + if isinstance(batch, dict): + import numpy as np + + # Create a copy to avoid modifying input batch + batch = dict(batch) + + # Handle empty dict + if not batch: + batch["_validation_passed"] = np.array([], dtype=bool) + else: + # Get first key to determine num_rows + first_key = next(iter(batch.keys())) + first_value = batch.get(first_key, []) + if isinstance(first_value, (list, tuple, np.ndarray)): + num_rows = len(first_value) + else: + # Single value - treat as 1 row + num_rows = 1 + batch["_validation_passed"] = np.array([flag_value] * num_rows) + return batch + + def validation_fn(batch: Any) -> Dict[str, Any]: + """Wrapper function that validates the batch and marks rows.""" + from ray.data.block import BlockAccessor + + # Use BlockAccessor for consistent empty batch detection across formats + try: + block_accessor = BlockAccessor.for_block(batch) + is_empty = block_accessor.num_rows() == 0 + except Exception: + # Fallback for unsupported formats + is_empty = False + try: + if hasattr(batch, "__len__"): + is_empty = len(batch) == 0 + except Exception: + pass + + if is_empty: + # Empty batches pass validation (no data to validate) + aggregator.add_result.remote(True) + return _add_validation_flag(batch, True, is_empty=True) + + try: + passed = expectation.validate(batch) + aggregator.add_result.remote(passed) + batch = _add_validation_flag(batch, passed, is_empty=False) + + if not passed and expectation.error_on_failure: + raise ValueError( + f"Expectation '{expectation.name}' failed for batch. " + f"{expectation.description}" + ) + return batch + except Exception as e: + aggregator.add_result.remote(False) + if expectation.error_on_failure: + raise + logger.warning( + f"Expectation '{expectation.name}' failed: {e}. " + f"Batch marked as failed." + ) + # Mark batch as failed + return _add_validation_flag(batch, False, is_empty=False) + + # Apply validation using map_batches + # Always use map_batches for validator functions as they expect batch input + validated_ds = self.map_batches( + validation_fn, + batch_format="default", + compute=compute, + num_cpus=num_cpus, + num_gpus=num_gpus, + memory=memory, + **ray_remote_args, + ) + + # Materialize to ensure all validation runs complete + validated_ds.materialize() + + # Split into passed and failed datasets using the validation flag + # Use filter on the validation flag column + from ray.data.expressions import col, lit + + # Determine validation column name (might be _validation_passed or _validation_passed_2, etc.) + validation_col_name = self._get_validation_column_name_from_schema(validated_ds) + + # Filter passed rows: must be True (not False, not NULL) + # NULL values from expression evaluation are treated as False (failed) + passed_ds = validated_ds.filter( + expr=col(validation_col_name).is_not_null() & (col(validation_col_name) == lit(True)) + ) + # Filter failed rows: False or NULL + failed_ds = validated_ds.filter( + expr=col(validation_col_name).is_null() | (col(validation_col_name) == lit(False)) + ) + + # Remove validation flag column from both datasets + # Get original columns (exclude validation column) + schema = self.schema() + if schema and hasattr(schema, "names"): + original_cols = [c for c in schema.names if c != validation_col_name] + # Only select columns if we have original columns and they exist in validated schema + if original_cols: + try: + validated_schema = validated_ds.schema() + if validated_schema and hasattr(validated_schema, "names"): + # Ensure all original columns exist in validated schema + validated_cols = set(validated_schema.names) + existing_cols = [c for c in original_cols if c in validated_cols] + if existing_cols: + passed_ds = passed_ds.select_columns(cols=existing_cols) + failed_ds = failed_ds.select_columns(cols=existing_cols) + except Exception: + # If schema access fails, skip column removal + pass + + # After materialization, all map_batches tasks have completed. + # The actor calls (add_result.remote()) are processed sequentially by the actor, + # and materialize() ensures all tasks complete before returning. + # However, to ensure all actor method calls are fully processed, we wait for + # any pending actor calls to complete. Since actors process calls sequentially, + # after materialization all results should be ready, but we add a small delay + # as a safety measure to ensure actor state is fully updated. + import time + time.sleep(0.01) # Small delay to ensure actor calls are processed + + # Get aggregated validation results from the actor + # Actors process calls sequentially, so after materialization all results are ready + validation_results = ray.get(aggregator.get_results.remote()) + + # Clean up the aggregator actor to avoid resource leaks + # Note: We use ray.kill() here because the actor is only used for this validation + # and we want to free resources immediately. For long-lived actors, Ray will + # automatically clean them up when references go out of scope. + try: + ray.kill(aggregator) + except Exception as e: + # Log the failure but don't raise - actor cleanup is best effort + logger.debug( + f"Failed to kill validation aggregator actor: {e}. " + "Actor will be cleaned up automatically when reference goes out of scope." + ) + + # Create expectation result with detailed message + passed = validation_results["passed"] + failure_count = validation_results["failure_count"] + total_count = validation_results["total_count"] + + if passed: + message = ( + f"Expectation '{expectation.name}' passed: " + f"all {total_count} batches validated successfully" + ) + else: + failure_rate = (failure_count / total_count * 100) if total_count > 0 else 0 + message = ( + f"Expectation '{expectation.name}' failed: " + f"{failure_count}/{total_count} batches failed ({failure_rate:.1f}%)" + ) + if expectation.description: + message += f". {expectation.description}" + + result = ExpectationResult( + expectation=expectation, + passed=passed, + message=message, + failure_count=failure_count, + total_count=total_count, + ) + + return passed_ds, failed_ds, result + + def _get_validation_column_name(self) -> str: + """Get a validation column name that doesn't conflict with existing columns. + + Returns: + A column name like "_validation_passed", "_validation_passed_2", etc. + """ + schema = self.schema() + base_name = "_validation_passed" + validation_col_name = base_name + + if schema and hasattr(schema, "names"): + counter = 1 + while validation_col_name in schema.names: + validation_col_name = f"{base_name}_{counter}" + counter += 1 + if counter > 1: + import warnings + warnings.warn( + f"Column '{base_name}' already exists in dataset. " + f"Using '{validation_col_name}' instead.", + UserWarning, + stacklevel=3 + ) + + return validation_col_name + + def _get_validation_column_name_from_schema(self, validated_ds: "Dataset") -> str: + """Get validation column name from validated dataset schema. + + Args: + validated_ds: Dataset that has been validated (contains validation column). + + Returns: + The name of the validation column in the validated dataset. + """ + validation_col_name = "_validation_passed" + try: + validated_schema = validated_ds.schema() + if validated_schema and hasattr(validated_schema, "names"): + # Check for validation column names in order of preference + base_name = "_validation_passed" + counter = 1 + while True: + if validation_col_name in validated_schema.names: + break + validation_col_name = f"{base_name}_{counter}" + counter += 1 + # Safety check to avoid infinite loop + if counter > 100: + validation_col_name = base_name + break + except Exception: + # Fallback to default name if schema access fails + pass + + return validation_col_name + + def _count_rows_in_batch(self, batch: Any) -> int: + """Count rows in a batch regardless of format. + + Supports dict, pandas DataFrame, PyArrow Table, and other formats. + Uses BlockAccessor for consistent row counting across formats. + + Args: + batch: Batch in any supported format. + + Returns: + Number of rows in the batch, or 0 if counting fails. + """ + try: + from ray.data.block import BlockAccessor + + accessor = BlockAccessor.for_block(batch) + return accessor.num_rows() + except Exception: + # Fallback for unsupported formats + try: + if isinstance(batch, dict): + if not batch: + return 0 + first_key = next(iter(batch.keys())) + first_value = batch.get(first_key, []) + if isinstance(first_value, (list, tuple)): + return len(first_value) + # Single value or array-like + try: + return len(first_value) + except (TypeError, ValueError): + return 1 # Single value + else: + import pandas as pd # https://pandas.pydata.org/docs/ + + if isinstance(batch, pd.DataFrame): + return len(batch) + else: + import pyarrow as pa # https://arrow.apache.org/docs/python/ + + if isinstance(batch, pa.Table): + return len(batch) + except Exception: + pass + return 0 + + def _create_dataset_from_batches(self, batches: List[Any]) -> "Dataset": + """Create a dataset from a list of batches. + + Uses DelegatingBlockBuilder to efficiently combine batches into blocks. + + Args: + batches: List of batches in any supported format. + + Returns: + Dataset containing all batches, or empty dataset if batches is empty. + Empty dataset preserves schema from original dataset. + """ + if not batches: + # Return empty dataset with schema preserved from original dataset + # Use limit(0) to preserve schema instead of from_items([]) + return self.limit(0) + + from ray.data._internal.delegating_block_builder import ( + DelegatingBlockBuilder, + ) + from ray.data.read_api import from_blocks + + builder = DelegatingBlockBuilder() + for batch in batches: + builder.add_batch(batch) + processed_block = builder.build() + + return from_blocks([processed_block]) + + def _expect_execution_time( + self, expectation: ExecutionTimeExpectation + ) -> Tuple["Dataset", "Dataset", ExpectationResult]: + """Handle execution time expectations with timeout monitoring. + + This method monitors execution time and halts processing if the timeout + is exceeded, returning data processed before timeout vs remaining data. + + Args: + expectation: Execution time expectation with time constraints. + + Returns: + Tuple of (passed_ds, failed_ds, result) where: + - passed_ds: Dataset containing data processed before timeout. + - failed_ds: Dataset containing unprocessed data (empty if completed in time). + - result: ExpectationResult with execution time and pass/fail status. + """ + import time + + max_time_seconds = expectation.get_max_execution_time_seconds() + if max_time_seconds is None: + raise ValueError( + "Execution time expectation must have a valid max_execution_time_seconds" + ) + if max_time_seconds <= 0: + raise ValueError( + f"Execution time expectation max_execution_time_seconds must be positive, " + f"got {max_time_seconds}" + ) + + start_time = time.perf_counter() + + # Collect processed batches incrementally with timeout monitoring + processed_batches = [] + processed_rows = 0 + timeout_exceeded = False + + try: + # Use iter_batches to process incrementally and monitor time + # Note: iter_batches() may not respect timeout for very large datasets + # as it processes batches sequentially and timeout check happens between batches + batch_iter = self.iter_batches() + for batch in batch_iter: + elapsed_time = time.perf_counter() - start_time + + # Check if timeout exceeded before processing this batch + if elapsed_time >= max_time_seconds: + timeout_exceeded = True + # Halt execution by breaking out of iteration + # The executor will be shut down when iterator is exhausted + break + + processed_batches.append(batch) + # Count rows in batch using helper function + processed_rows += self._count_rows_in_batch(batch) + + # Check timeout again after processing batch (in case batch took a long time) + elapsed_time = time.perf_counter() - start_time + if elapsed_time >= max_time_seconds: + timeout_exceeded = True + break + + except Exception as e: + # If execution fails, don't treat as timeout - it's a different error + elapsed_time = time.perf_counter() - start_time + timeout_exceeded = False # Not a timeout, but execution failed + + if expectation.error_on_failure: + raise + logger.warning( + f"Execution time expectation '{expectation.name}' execution failed: {e}", + exc_info=True + ) + else: + # Only calculate elapsed_time here if no exception occurred + elapsed_time = time.perf_counter() - start_time + + # Calculate passed status: must not have timed out AND must be within time limit + passed = not timeout_exceeded and elapsed_time <= max_time_seconds + + # Create datasets from processed batches using helper function + passed_ds = self._create_dataset_from_batches(processed_batches) + + # Failed dataset = remaining unprocessed data + # For V1, we return empty dataset as we can't easily track unprocessed data + # This could be enhanced in future versions + # Use limit(0) to preserve schema from original dataset + failed_ds = self.limit(0) + + # Shutdown executor if timeout exceeded to halt execution + # Note: iter_batches() doesn't actually halt distributed execution, + # this is a limitation of the current implementation + if timeout_exceeded and hasattr(self, "_current_executor") and self._current_executor is not None: + try: + self._current_executor.shutdown(force=True) + self._current_executor = None + except Exception: + # Ignore errors during executor shutdown + pass + + # Create result + if passed: + message = ( + f"Execution time expectation '{expectation.name}' passed: " + f"execution completed in {elapsed_time:.2f}s " + f"(limit: {max_time_seconds:.2f}s)" + ) + else: + message = ( + f"Execution time expectation '{expectation.name}' failed: " + f"execution exceeded time limit ({elapsed_time:.2f}s > " + f"{max_time_seconds:.2f}s). Processed {processed_rows} rows " + "before timeout." + ) + if expectation.description: + message += f" {expectation.description}" + + result = ExpectationResult( + expectation=expectation, + passed=passed, + message=message, + execution_time_seconds=elapsed_time, + total_count=processed_rows, + failure_count=0, # Timeout is tracked via passed=False, not failure_count + ) + + return passed_ds, failed_ds, result + + def _expect_list( + self, expectations: List[Expectation] + ) -> Tuple["Dataset", "Dataset", List[ExpectationResult]]: + """Handle list of expectations by applying them sequentially. + + This method applies each expectation sequentially, accumulating + all failures. The passed dataset contains rows that passed ALL expectations, + while the failed dataset contains rows that failed ANY expectation. + + Args: + expectations: List of Expectation objects to apply. + + Returns: + Tuple of (passed_ds, failed_ds, results) where: + - passed_ds: Dataset containing rows that passed all expectations. + - failed_ds: Dataset containing rows that failed any expectation. + - results: List of ExpectationResult objects, one per expectation. + """ + if not expectations: + raise ValueError( + "List of expectations cannot be empty. " + "Provide at least one expectation: ds.expect([expectation1, expectation2])" + ) + + # Validate all items are Expectations + for i, exp in enumerate(expectations): + if not isinstance(exp, Expectation): + raise TypeError( + f"All items in expectations list must be Expectation objects. " + f"Item at index {i} is {type(exp).__name__}. " + f"Use expect() to create expectations: expect(expr=col('x') > 0)" + ) + + # Start with the full dataset + current_ds = self + all_failed_datasets = [] + results = [] + + # Apply each expectation sequentially + for exp in expectations: + passed_ds, failed_ds, result = current_ds.expect(exp) + results.append(result) + + # Collect failed datasets instead of rows to avoid memory issues + # Use a lazy check to avoid materializing the dataset unnecessarily + # Check if failed_ds has any rows by checking if it's different from empty dataset + # This is approximate but avoids expensive count() call + try: + # Try to get schema - if it exists and has rows, add to failed list + # This is a heuristic to avoid materialization + failed_schema = failed_ds.schema() + if failed_schema and hasattr(failed_schema, "names") and failed_schema.names: + # For now, we'll add it and let union handle empty datasets + # A more sophisticated check would peek at the first block + all_failed_datasets.append(failed_ds) + except Exception: + # If schema check fails, fall back to count (which materializes) + # This is a trade-off between performance and correctness + if failed_ds.count() > 0: + all_failed_datasets.append(failed_ds) + + # Continue with passed dataset for next expectation + current_ds = passed_ds + + # Union all failed datasets to preserve distributed data + # Note: Union may contain duplicate rows if a row fails multiple expectations + # This is expected behavior - a row that fails expectation 1 and expectation 2 + # will appear in both failed datasets, and union will include it twice + # Users can deduplicate if needed: failed_ds.distinct() + if all_failed_datasets: + failed_ds = all_failed_datasets[0] + for ds in all_failed_datasets[1:]: + failed_ds = failed_ds.union(ds) + else: + # Create empty dataset with correct schema + # Use limit(0) to preserve schema from original dataset + failed_ds = self.limit(0) + + return current_ds, failed_ds, results + @PublicAPI(api_group=SSR_API_GROUP) def repartition( self, @@ -2684,8 +3746,13 @@ def union(self, *other: List["Dataset"]) -> "Dataset": parent=[d._plan.stats() for d in datasets], ) stats.time_total_s = time.perf_counter() - start_time + + # Preserve execution time expectations from the first dataset (self) + plan = self._plan.copy() + plan._in_stats = stats + return Dataset( - ExecutionPlan(stats, self.context.copy()), + plan, logical_plan, ) @@ -5953,9 +7020,9 @@ def to_arrow_refs(self) -> List[ObjectRef["pyarrow.Table"]]: import pyarrow as pa ref_bundle: RefBundle = self._plan.execute() - block_refs: List[ - ObjectRef["pyarrow.Table"] - ] = _ref_bundles_iterator_to_block_refs_list([ref_bundle]) + block_refs: List[ObjectRef["pyarrow.Table"]] = ( + _ref_bundles_iterator_to_block_refs_list([ref_bundle]) + ) # Schema is safe to call since we have already triggered execution with # self._plan.execute(), which will cache the schema schema = self.schema(fetch_if_missing=True) @@ -6533,7 +7600,8 @@ def __setstate__(self, state): self._current_executor = None def __del__(self): - if not self._current_executor: + # Check if _current_executor exists and is not None before accessing + if not hasattr(self, "_current_executor") or self._current_executor is None: return # When Python shuts down, `ray` might evaluate to ``. @@ -6620,7 +7688,7 @@ def types(self) -> List[Union[type[object], "pyarrow.lib.DataType"]]: from ray.data.extensions import ArrowTensorType, TensorDtype def _convert_to_pa_type( - dtype: Union[np.dtype, pd.ArrowDtype, BaseMaskedDtype] + dtype: Union[np.dtype, pd.ArrowDtype, BaseMaskedDtype], ) -> pa.DataType: if isinstance(dtype, pd.ArrowDtype): return dtype.pyarrow_dtype diff --git a/python/ray/data/expectations.py b/python/ray/data/expectations.py new file mode 100644 index 000000000000..8a89e96cca2e --- /dev/null +++ b/python/ray/data/expectations.py @@ -0,0 +1,958 @@ +import datetime +from dataclasses import dataclass, field +from enum import Enum +from typing import ( + TYPE_CHECKING, + Any, + Callable, + List, + Optional, + Set, + Union, +) + +from ray.util.annotations import DeveloperAPI, PublicAPI + +if TYPE_CHECKING: + from ray.data.expressions import Expr + + +class ExpectationType(str, Enum): + """Type of expectation.""" + + DATA_QUALITY = "data_quality" + EXECUTION_TIME = "execution_time" + + +@DeveloperAPI +@dataclass +class Expectation: + """Base class for all expectations. + + Expectations can be attached to dataset operations or functions to + express data quality requirements or execution time constraints. + + Attributes: + name: Human-readable name for this expectation. + description: Detailed description of what this expectation checks. + expectation_type: Type of expectation (data quality or execution time). + error_on_failure: If True, raise an exception when expectation fails. + If False, log a warning. + """ + + name: str + description: str + expectation_type: ExpectationType + error_on_failure: bool = True + + def __post_init__(self): + if not self.name or (isinstance(self.name, str) and not self.name.strip()): + raise ValueError("Expectation name cannot be empty or whitespace-only") + if not self.description or ( + isinstance(self.description, str) and not self.description.strip() + ): + raise ValueError( + "Expectation description cannot be empty or whitespace-only" + ) + + def validate(self, *args, **kwargs) -> bool: + """Validate this expectation. Subclasses must implement this.""" + raise NotImplementedError("Subclasses must implement validate()") + + +@DeveloperAPI +@dataclass +class DataQualityExpectation(Expectation): + """Data quality expectation for validating data correctness. + + Use this to express constraints on data values, schema, completeness, + or other data quality metrics. + + Attributes: + name: Human-readable name for this expectation. + description: Detailed description of what this expectation checks. + validator_fn: Function that takes a batch (dict or pandas DataFrame) + and returns True if validation passes, False otherwise. + Can also raise exceptions for more detailed error reporting. + error_on_failure: If True, raise an exception when expectation fails. + If False, log a warning. + """ + + validator_fn: Callable[[Any], bool] + expectation_type: ExpectationType = field( + default=ExpectationType.DATA_QUALITY, init=False + ) + + def validate(self, batch: Any) -> bool: + """Validate a batch of data against this expectation. + + Args: + batch: A batch of data in any supported format (dict, pandas DataFrame, + PyArrow Table, etc.). Can be empty. + + Returns: + True if validation passes, False otherwise. + + Raises: + ValueError: If validator function returns non-boolean value. + Exception: If error_on_failure is True and validation fails. + """ + try: + # Handle empty batches gracefully using BlockAccessor + # This reuses Ray Data's standard batch format handling + from ray.data.block import BlockAccessor + + if batch is None: + return False + + # Use BlockAccessor for consistent empty batch detection across formats + try: + block_accessor = BlockAccessor.for_block(batch) + if block_accessor.num_rows() == 0: + # Empty batches pass validation (no data to validate) + return True + except Exception: + # Fallback for unsupported formats + pass + + result = self.validator_fn(batch) + if not isinstance(result, bool): + raise ValueError( + f"Validator function must return bool, got {type(result).__name__}. " + "Validator functions should return True if validation passes, " + "False otherwise." + ) + return result + except Exception as e: + if self.error_on_failure: + raise + # Log the exception for debugging when error_on_failure=False + import logging + + logger = logging.getLogger(__name__) + logger.debug( + f"Expectation '{self.name}' validation raised exception (error_on_failure=False): {e}", + exc_info=True, + ) + return False + + +@DeveloperAPI +@dataclass +class ExecutionTimeExpectation(Expectation): + """Execution time expectation for expressing timing requirements. + + Use this to express execution time constraints like "Job must finish by Y time". + + Attributes: + name: Human-readable name for this expectation. + description: Detailed description of what this execution time constraint requires. + max_execution_time_seconds: Maximum allowed execution time in seconds. + If None, no time constraint is enforced. + max_execution_time: Maximum allowed execution time as datetime.timedelta. + Alternative to max_execution_time_seconds. + target_completion_time: Target completion time as datetime.datetime. + Used for deadline-based optimization. + error_on_failure: If True, raise an exception when execution time constraint is violated. + If False, log a warning. + """ + + max_execution_time_seconds: Optional[float] = None + max_execution_time: Optional[datetime.timedelta] = None + target_completion_time: Optional[datetime.datetime] = None + expectation_type: ExpectationType = field( + default=ExpectationType.EXECUTION_TIME, init=False + ) + + def __post_init__(self): + super().__post_init__() + if ( + self.max_execution_time_seconds is not None + and self.max_execution_time is not None + ): + raise ValueError( + "Cannot specify both max_execution_time_seconds and max_execution_time" + ) + if self.max_execution_time_seconds is None and self.max_execution_time is None: + if self.target_completion_time is None: + raise ValueError( + "Must specify at least one time constraint: max_execution_time_seconds, " + "max_execution_time, or target_completion_time" + ) + + def get_max_execution_time_seconds(self) -> Optional[float]: + """Get maximum execution time in seconds.""" + if self.max_execution_time_seconds is not None: + return self.max_execution_time_seconds + if self.max_execution_time is not None: + return self.max_execution_time.total_seconds() + if self.target_completion_time is not None: + now = datetime.datetime.now() + if self.target_completion_time <= now: + # Target time is in the past - return a very small positive value + # to indicate timeout immediately, but still positive for validation + return 0.001 + return (self.target_completion_time - now).total_seconds() + return None + + def validate(self, execution_time_seconds: float) -> bool: + """Validate that execution time meets requirements.""" + max_time = self.get_max_execution_time_seconds() + if max_time is None: + return True + return execution_time_seconds <= max_time + + +@DeveloperAPI +@dataclass +class ExpectationResult: + """Result of validating an expectation. + + Attributes: + expectation: The expectation that was validated. + passed: Whether the expectation passed. + message: Human-readable message describing the result. + execution_time_seconds: Execution time in seconds (for execution time expectations). + failure_count: Number of batches/rows that failed validation (for data quality). + total_count: Total number of batches/rows validated (for data quality). + """ + + expectation: Expectation + passed: bool + message: str + execution_time_seconds: Optional[float] = None + failure_count: int = 0 + total_count: int = 0 + + def __repr__(self) -> str: + status = "PASSED" if self.passed else "FAILED" + return ( + f"ExpectationResult(expectation={self.expectation.name}, " + f"status={status}, message={self.message})" + ) + + +@PublicAPI(stability="alpha") +def expect( + *, + name: Optional[str] = None, + description: Optional[str] = None, + validator_fn: Optional[Callable[[Any], bool]] = None, + expr: Optional["Expr"] = None, + max_execution_time_seconds: Optional[float] = None, + max_execution_time: Optional[datetime.timedelta] = None, + target_completion_time: Optional[datetime.datetime] = None, + error_on_failure: bool = True, + expectation_type: Optional[ExpectationType] = None, +) -> Expectation: + """Create an expectation object for data quality or execution time requirements. + + Examples: + >>> from ray.data.expressions import col + >>> from ray.data.expectations import expect + >>> + >>> # Expression-based data quality + >>> exp = expect(expr=col("value") > 0) + >>> ds = ray.data.from_items([{"value": 1}, {"value": -1}]) + >>> passed_ds, failed_ds, result = ds.expect(exp) + >>> + >>> # Validator function + >>> exp = expect(validator_fn=lambda batch: batch["value"].min() > 0) + >>> + >>> # Execution time requirement + >>> exp = expect(max_execution_time_seconds=60.0) + + Args: + name: Name for the expectation. + description: Description of what this expectation checks. + validator_fn: Function for data quality validation (takes batch, returns bool). + Mutually exclusive with `expr`. + expr: Expression for data quality validation (e.g., col("value") > 0). + Mutually exclusive with `validator_fn`. + max_execution_time_seconds: Maximum execution time in seconds (for execution time expectations). + max_execution_time: Maximum execution time as datetime.timedelta (for execution time expectations). + target_completion_time: Target completion time as datetime.datetime (for execution time expectations). + error_on_failure: If True, raise exception on failure; if False, log warning. + expectation_type: Type of expectation (auto-detected if not specified). + + Returns: + An Expectation object that can be used with Dataset.expect(). + """ + # Validate input types + if validator_fn is not None and not callable(validator_fn): + raise TypeError( + f"validator_fn must be callable, got {type(validator_fn).__name__}" + ) + if max_execution_time is not None and not isinstance( + max_execution_time, datetime.timedelta + ): + raise TypeError( + f"max_execution_time must be datetime.timedelta, got {type(max_execution_time).__name__}" + ) + if target_completion_time is not None and not isinstance( + target_completion_time, datetime.datetime + ): + raise TypeError( + f"target_completion_time must be datetime.datetime, got {type(target_completion_time).__name__}" + ) + if name is not None and not isinstance(name, str): + raise TypeError(f"name must be str, got {type(name).__name__}") + if description is not None and not isinstance(description, str): + raise TypeError(f"description must be str, got {type(description).__name__}") + + # Handle expression-based data quality expectations + _expr = expr + + # Validate expr is an Expr object if provided + if _expr is not None: + from ray.data.expressions import Expr as _Expr + + if not isinstance(_expr, _Expr): + raise TypeError( + f"expr must be a Ray Data Expr object, got {type(_expr).__name__}. " + f"Use col('column_name') > 0 or similar expression." + ) + + # Check if expr is provided as keyword argument + if _expr is not None: + if validator_fn is not None: + raise ValueError( + "Cannot specify both `validator_fn` and `expr` for data quality expectations. " + "Use either `validator_fn` for custom validation logic or `expr` for " + "expression-based validation." + ) + if ( + max_execution_time_seconds is not None + or max_execution_time is not None + or target_completion_time is not None + ): + raise ValueError( + "Cannot specify both `expr` (data quality) and time constraints (execution time). " + "Use `expr` for data quality validation or time constraints for execution time requirements." + ) + validator_fn = _create_validator_from_expression(_expr) + expectation_type = ExpectationType.DATA_QUALITY + + # Determine expectation type if not specified + if expectation_type is None: + if validator_fn is not None or _expr is not None: + expectation_type = ExpectationType.DATA_QUALITY + elif ( + max_execution_time_seconds is not None + or max_execution_time is not None + or target_completion_time is not None + ): + expectation_type = ExpectationType.EXECUTION_TIME + else: + raise ValueError( + "Must specify either validator_fn or expr (for data quality) " + "or time constraints (for execution time). " + "Examples: expect(expr=col('x') > 0) or expect(max_execution_time_seconds=60)" + ) + + # Create expectation object + if expectation_type == ExpectationType.DATA_QUALITY: + if validator_fn is None: + raise ValueError( + "Either validator_fn or expr is required for data quality expectations. " + "This should not happen - please report this error." + ) + if name is None: + name = "Data Quality Check" + if description is None: + description = ( + f"Data quality validation: {_expr}" + if _expr is not None + else "Data quality validation" + ) + exp = DataQualityExpectation( + name=name, + description=description, + validator_fn=validator_fn, + error_on_failure=error_on_failure, + ) + # Store expression for efficient filter()-based validation + if _expr is not None: + exp._expr = _expr + else: + # Validate execution time parameters + if max_execution_time_seconds is not None and max_execution_time_seconds <= 0: + raise ValueError( + f"max_execution_time_seconds must be positive, " + f"got {max_execution_time_seconds}" + ) + if max_execution_time is not None and max_execution_time.total_seconds() <= 0: + raise ValueError( + f"max_execution_time must be positive, got {max_execution_time}" + ) + + if name is None: + name = "Execution Time Requirement" + if description is None: + description = "Execution time constraint" + exp = ExecutionTimeExpectation( + name=name, + description=description, + max_execution_time_seconds=max_execution_time_seconds, + max_execution_time=max_execution_time, + target_completion_time=target_completion_time, + error_on_failure=error_on_failure, + ) + + # Return expectation object + return exp + + +def _convert_batch_to_arrow_block(batch: Any) -> Any: + """Convert a batch to PyArrow Table format for expression evaluation. + + Uses Ray Data's BlockAccessor pattern to handle all batch formats consistently. + This is the same pattern used throughout Ray Data for batch format handling. + + Supports: + - PyArrow Tables (https://arrow.apache.org/docs/python/) + - Pandas DataFrames + - Dict[str, np.ndarray] format + - Any format supported by BlockAccessor.for_block() + + Args: + batch: Batch in any supported format. + + Returns: + PyArrow Table suitable for expression evaluation. + """ + import pyarrow as pa + + from ray.data.block import BlockAccessor + + # Use BlockAccessor - this is the standard Ray Data way to handle batches + try: + accessor = BlockAccessor.for_block(batch) + return accessor.to_arrow() + except (TypeError, AttributeError, ValueError): + # Fallback for edge cases + if isinstance(batch, pa.Table): + return batch + elif hasattr(batch, "to_arrow"): + return batch.to_arrow() + elif isinstance(batch, dict): + return pa.table(batch) + else: + # Try pandas conversion if available + # Pandas: https://pandas.pydata.org/docs/ + try: + import pandas as pd + + if isinstance(batch, pd.DataFrame): + return pa.Table.from_pandas(batch) + except Exception: + pass + # Final fallback: assume it's already a PyArrow Table or compatible + return batch + + +def _extract_boolean_result(result: Any) -> bool: + """Extract boolean result from expression evaluation. + + Handles PyArrow Arrays, scalars, and other array-like objects. + Returns True if all non-null values are True, False otherwise. + Empty batches/arrays return True (nothing to validate). + """ + import pyarrow as pa + + if isinstance(result, bool): + return result + elif isinstance(result, (pa.Array, pa.ChunkedArray)): + # PyArrow Array/ChunkedArray - check if all values are True + if len(result) == 0: + return True # Empty batch passes validation + + values = result.to_pylist() + if not values: + return True # Empty list passes validation + + # Filter out None values (nulls) and check if all remaining are True + non_null_values = [v for v in values if v is not None] + if not non_null_values: + # All values are null - for validation purposes, NULL typically means "unknown" + # and should be treated as failing validation (conservative approach) + # However, empty batches pass validation (no data to validate) + # This case (all nulls in non-empty batch) means expression evaluated to NULL + # which should be treated as False (failed validation) + return False + + # All non-null values must be True + # Convert to bool to handle truthy/falsy values correctly + return all(bool(v) is True for v in non_null_values) + elif isinstance(result, (list, tuple)): + if not result: + return True + non_null_values = [v for v in result if v is not None] + if not non_null_values: + # All values are null - treat as failed validation (consistent with PyArrow behavior) + return False + return all(v is True for v in non_null_values) + elif hasattr(result, "to_pylist"): + values = result.to_pylist() + if not values: + return True + non_null_values = [v for v in values if v is not None] + if not non_null_values: + # All values are null - treat as failed validation (consistent with PyArrow behavior) + return False + return all(v is True for v in non_null_values) + else: + # Try to convert to bool (for scalar results) + try: + return bool(result) + except (TypeError, ValueError): + # If conversion fails, assume False + return False + + +def _create_validator_from_expression(expr: "Expr") -> Callable[[Any], bool]: + """Create a validator function from a Ray Data expression. + + Uses Ray Data's existing expression evaluation infrastructure (eval_expr), + ensuring consistent behavior with filter() and other expression-based operations. + """ + + def validator_fn(batch: Any) -> bool: + """Validate that all rows in batch satisfy the expression.""" + try: + # Import here to avoid circular dependencies + from ray.data._internal.planner.plan_expression.expression_evaluator import ( + eval_expr, + ) + from ray.data.block import BlockAccessor + + # Use BlockAccessor for consistent empty batch detection + # This reuses Ray Data's standard batch format handling + try: + block_accessor = BlockAccessor.for_block(batch) + if block_accessor.num_rows() == 0: + # Empty batches pass validation (no data to validate) + return True + # Convert to Arrow format using BlockAccessor + block = block_accessor.to_arrow() + except Exception: + # Fallback to manual conversion if BlockAccessor fails + block = _convert_batch_to_arrow_block(batch) + try: + import pyarrow as pa + + if isinstance(block, pa.Table) and len(block) == 0: + return True + except Exception: + pass + + # Evaluate expression using the same path as filter(expr=...) and with_column() + result = eval_expr(expr, block) + + # Extract boolean result from various return types + return _extract_boolean_result(result) + + except Exception as e: + # If evaluation fails, preserve the original exception type but add context + # Include the expression in error message for debugging + error_msg = f"Failed to evaluate expression {expr} on batch" + if hasattr(e, "__cause__") and e.__cause__: + error_msg += f": {e.__cause__}" + elif str(e): + error_msg += f": {e}" + + # Preserve the original exception type if it's informative + if isinstance(e, (ValueError, TypeError, AttributeError, KeyError)): + raise type(e)(error_msg) from e + else: + # For other exceptions, wrap in ValueError but preserve original + raise ValueError(error_msg) from e + + return validator_fn + + +# Convenience functions for common expectations +# These follow Ray Data's expression-based API pattern + + +@PublicAPI(stability="alpha") +def expect_column_min( + column: str, + min_value: Union[int, float], + *, + name: Optional[str] = None, + description: Optional[str] = None, + error_on_failure: bool = True, +) -> Expectation: + """Create an expectation that a column's minimum value meets a threshold. + + This is a convenience function for the common pattern of checking column minimums. + + Examples: + >>> from ray.data.expectations import expect_column_min + >>> exp = expect_column_min("age", 0) + >>> ds = ray.data.from_items([{"age": 25}, {"age": 30}]) + >>> passed_ds, failed_ds, result = ds.expect(exp) + + Args: + column: Name of the column to validate. + min_value: Minimum allowed value (inclusive). + name: Optional name for the expectation. + description: Optional description of what this expectation checks. + error_on_failure: If True, raise exception on failure; if False, log warning. + + Returns: + An Expectation object that can be used with Dataset.expect(). + """ + from ray.data.expressions import col + + if name is None: + name = f"Column '{column}' minimum >= {min_value}" + if description is None: + description = ( + f"Validate that column '{column}' has minimum value >= {min_value}" + ) + + return expect( + expr=col(column) >= min_value, + name=name, + description=description, + error_on_failure=error_on_failure, + ) + + +@PublicAPI(stability="alpha") +def expect_column_max( + column: str, + max_value: Union[int, float], + *, + name: Optional[str] = None, + description: Optional[str] = None, + error_on_failure: bool = True, +) -> Expectation: + """Create an expectation that a column's maximum value meets a threshold. + + This is a convenience function for the common pattern of checking column maximums. + + Examples: + >>> from ray.data.expectations import expect_column_max + >>> exp = expect_column_max("age", 120) + >>> ds = ray.data.from_items([{"age": 25}, {"age": 30}]) + >>> passed_ds, failed_ds, result = ds.expect(exp) + + Args: + column: Name of the column to validate. + max_value: Maximum allowed value (inclusive). + name: Optional name for the expectation. + description: Optional description of what this expectation checks. + error_on_failure: If True, raise exception on failure; if False, log warning. + + Returns: + An Expectation object that can be used with Dataset.expect(). + """ + from ray.data.expressions import col + + if name is None: + name = f"Column '{column}' maximum <= {max_value}" + if description is None: + description = ( + f"Validate that column '{column}' has maximum value <= {max_value}" + ) + + return expect( + expr=col(column) <= max_value, + name=name, + description=description, + error_on_failure=error_on_failure, + ) + + +@PublicAPI(stability="alpha") +def expect_column_range( + column: str, + min_value: Union[int, float], + max_value: Union[int, float], + *, + name: Optional[str] = None, + description: Optional[str] = None, + error_on_failure: bool = True, +) -> Expectation: + """Create an expectation that a column's values are within a range. + + This is a convenience function for the common pattern of checking value ranges. + + Args: + column: Name of the column to validate. + min_value: Minimum allowed value (inclusive). Must be <= max_value. + max_value: Maximum allowed value (inclusive). Must be >= min_value. + + Examples: + >>> from ray.data.expectations import expect_column_range + >>> exp = expect_column_range("age", 0, 120) + >>> ds = ray.data.from_items([{"age": 25}, {"age": 30}]) + >>> passed_ds, failed_ds, result = ds.expect(exp) + + Args: + column: Name of the column to validate. + min_value: Minimum allowed value (inclusive). + max_value: Maximum allowed value (inclusive). + name: Optional name for the expectation. + description: Optional description of what this expectation checks. + error_on_failure: If True, raise exception on failure; if False, log warning. + + Returns: + An Expectation object that can be used with Dataset.expect(). + """ + from ray.data.expressions import col + + if min_value > max_value: + raise ValueError( + f"min_value ({min_value}) must be <= max_value ({max_value}) " + f"for expect_column_range" + ) + + if name is None: + name = f"Column '{column}' in range [{min_value}, {max_value}]" + if description is None: + description = ( + f"Validate that column '{column}' values are in range " + f"[{min_value}, {max_value}]" + ) + + return expect( + expr=(col(column) >= min_value) & (col(column) <= max_value), + name=name, + description=description, + error_on_failure=error_on_failure, + ) + + +@PublicAPI(stability="alpha") +def expect_column_not_null( + column: str, + *, + name: Optional[str] = None, + description: Optional[str] = None, + error_on_failure: bool = True, +) -> Expectation: + """Create an expectation that a column has no null values. + + This is a convenience function for the common pattern of checking for nulls. + + Examples: + >>> from ray.data.expectations import expect_column_not_null + >>> exp = expect_column_not_null("email") + >>> ds = ray.data.from_items([{"email": "test@example.com"}, {"email": None}]) + >>> passed_ds, failed_ds, result = ds.expect(exp) + + Args: + column: Name of the column to validate. + name: Optional name for the expectation. + description: Optional description of what this expectation checks. + error_on_failure: If True, raise exception on failure; if False, log warning. + + Returns: + An Expectation object that can be used with Dataset.expect(). + """ + from ray.data.expressions import col + + if name is None: + name = f"Column '{column}' is not null" + if description is None: + description = f"Validate that column '{column}' has no null values" + + return expect( + expr=col(column).is_not_null(), + name=name, + description=description, + error_on_failure=error_on_failure, + ) + + +@PublicAPI(stability="alpha") +def expect_column_in( + column: str, + values: Union[List[Any], Set[Any]], + *, + name: Optional[str] = None, + description: Optional[str] = None, + error_on_failure: bool = True, +) -> Expectation: + """Create an expectation that a column's values are in a set of allowed values. + + This is a convenience function for the common pattern of checking allowed values. + + .. note:: + If `values` is empty, no values will be allowed (validation will always fail). + This may not be the intended behavior - ensure `values` is not empty. + + Examples: + >>> from ray.data.expectations import expect_column_in + >>> exp = expect_column_in("status", ["active", "inactive", "pending"]) + >>> ds = ray.data.from_items([{"status": "active"}, {"status": "invalid"}]) + >>> passed_ds, failed_ds, result = ds.expect(exp) + + Args: + column: Name of the column to validate. + values: Set or list of allowed values. + name: Optional name for the expectation. + description: Optional description of what this expectation checks. + error_on_failure: If True, raise exception on failure; if False, log warning. + + Returns: + An Expectation object that can be used with Dataset.expect(). + """ + from ray.data.expressions import col + + if not isinstance(values, (list, set, tuple)): + raise TypeError( + f"values must be list, set, or tuple, got {type(values).__name__}" + ) + + values_set = set(values) + if not values_set: + import warnings + warnings.warn( + "expect_column_in called with empty values list. " + "This will cause all rows to fail validation. " + "Ensure this is the intended behavior.", + UserWarning, + stacklevel=2 + ) + if name is None: + name = f"Column '{column}' in allowed values" + if description is None: + description = f"Validate that column '{column}' values are in {values_set}" + + return expect( + expr=col(column).is_in(list(values_set)), + name=name, + description=description, + error_on_failure=error_on_failure, + ) + + +@PublicAPI(stability="alpha") +def expect_column_unique( + column: str, + *, + name: Optional[str] = None, + description: Optional[str] = None, + error_on_failure: bool = True, +) -> Expectation: + """Create an expectation that a column has unique values. + + .. warning:: + This function currently only checks uniqueness **within each batch**, not across + the entire dataset. A value duplicated across different batches will still pass + validation. For true dataset-wide uniqueness checking, use a different approach + (e.g., groupby().count() and check for counts > 1). + + Note: This checks that all values in the column are unique within each batch. + For large datasets, this may be expensive as it requires checking all values. + + Examples: + >>> from ray.data.expectations import expect_column_unique + >>> exp = expect_column_unique("id") + >>> ds = ray.data.from_items([{"id": 1}, {"id": 2}, {"id": 1}]) + >>> passed_ds, failed_ds, result = ds.expect(exp) + + Args: + column: Name of the column to validate. + name: Optional name for the expectation. + description: Optional description of what this expectation checks. + error_on_failure: If True, raise exception on failure; if False, log warning. + + Returns: + An Expectation object that can be used with Dataset.expect(). + """ + if name is None: + name = f"Column '{column}' is unique" + if description is None: + description = f"Validate that column '{column}' has unique values" + + def validator_fn(batch: Any) -> bool: + """Check if column values are unique in batch.""" + from ray.data.block import BlockAccessor + + try: + accessor = BlockAccessor.for_block(batch) + if accessor.num_rows() == 0: + return True + + arrow_table = accessor.to_arrow() + if arrow_table is None: + return False + + column_data = arrow_table[column] + if column_data is None: + return False + + # Convert to Python list and check uniqueness + values = column_data.to_pylist() + # Filter out None values for uniqueness check + non_null_values = [v for v in values if v is not None] + unique_values = set(non_null_values) + # All non-null values must be unique + return len(non_null_values) == len(unique_values) + except Exception: + return False + + return expect( + validator_fn=validator_fn, + name=name, + description=description, + error_on_failure=error_on_failure, + ) + + +@PublicAPI(stability="alpha") +def expect_suite( + expectations: List[Expectation], + *, + name: Optional[str] = None, + description: Optional[str] = None, +) -> List[Expectation]: + """Create a suite of expectations to apply together. + + This is a convenience function for grouping related expectations. + The suite can be passed directly to Dataset.expect(). + + Examples: + >>> from ray.data.expectations import expect_suite, expect_column_min, expect_column_not_null + >>> suite = expect_suite([ + ... expect_column_min("age", 0), + ... expect_column_not_null("email") + ... ]) + >>> ds = ray.data.from_items([{"age": 25, "email": "test@example.com"}]) + >>> passed_ds, failed_ds, results = ds.expect(suite) + + Args: + expectations: List of Expectation objects to include in the suite. + name: Optional name for the suite (for documentation purposes). + description: Optional description of what this suite validates. + + Returns: + List of Expectation objects that can be passed to Dataset.expect(). + + Raises: + ValueError: If expectations list is empty or contains non-Expectation objects. + """ + if not expectations: + raise ValueError("expectations list cannot be empty") + if not isinstance(expectations, list): + raise TypeError( + f"expectations must be a list, got {type(expectations).__name__}" + ) + + # Validate all items are Expectations + for i, exp in enumerate(expectations): + if not isinstance(exp, Expectation): + raise TypeError( + f"All items in expectations list must be Expectation objects. " + f"Item at index {i} is {type(exp).__name__}." + ) + + # Return the list as-is (Dataset.expect() handles lists directly) + # The name and description are for documentation only + return expectations diff --git a/python/ray/data/expressions.py b/python/ray/data/expressions.py index b9fa36c70dcc..0523be096cdf 100644 --- a/python/ray/data/expressions.py +++ b/python/ray/data/expressions.py @@ -4,7 +4,18 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field from enum import Enum -from typing import Any, Callable, Dict, Generic, List, TypeVar, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Generic, + List, + Optional, + Set, + TypeVar, + Union, +) import pyarrow @@ -12,6 +23,9 @@ from ray.data.datatype import DataType from ray.util.annotations import DeveloperAPI, PublicAPI +if TYPE_CHECKING: + from ray.data.expectations import Expectation + T = TypeVar("T") @@ -408,6 +422,279 @@ def alias(self, name: str) -> "Expr": def _unalias(self) -> "Expr": return self + # Expectation methods - integrate with data quality expectations + def expect( + self, + *, + name: Optional[str] = None, + description: Optional[str] = None, + error_on_failure: bool = True, + ) -> "Expectation": + """Create a data quality expectation from this expression. + + This method allows you to create expectations directly from expressions, + making the API more fluent and integrated with Ray Data's expression system. + + Examples: + >>> from ray.data.expressions import col + >>> import ray + >>> ds = ray.data.from_items([{"age": 25}, {"age": -5}]) + >>> # Create expectation directly from expression + >>> exp = (col("age") > 0).expect(name="age_positive", description="Age must be positive") + >>> passed_ds, failed_ds, result = ds.expect(exp) + + Args: + name: Optional name for the expectation. + description: Optional description of what this expectation checks. + error_on_failure: If True, raise exception on failure; if False, log warning. + + Returns: + An Expectation object that can be used with Dataset.expect(). + """ + from ray.data.expectations import expect as _expect + + return _expect( + expr=self, + name=name, + description=description, + error_on_failure=error_on_failure, + ) + + def expect_min( + self, + min_value: Union[int, float], + *, + name: Optional[str] = None, + description: Optional[str] = None, + error_on_failure: bool = True, + ) -> "Expectation": + """Create an expectation that this expression's values meet a minimum threshold. + + This is a convenience method for the common pattern of checking minimum values. + + Examples: + >>> from ray.data.expressions import col + >>> import ray + >>> ds = ray.data.from_items([{"age": 25}, {"age": 30}]) + >>> # Create expectation directly from column expression + >>> exp = col("age").expect_min(0) + >>> passed_ds, failed_ds, result = ds.expect(exp) + + Args: + min_value: Minimum allowed value (inclusive). + name: Optional name for the expectation. + description: Optional description of what this expectation checks. + error_on_failure: If True, raise exception on failure; if False, log warning. + + Returns: + An Expectation object that can be used with Dataset.expect(). + """ + from ray.data.expectations import expect as _expect + + col_name = self.name if self.name is not None else "column" + if name is None: + name = f"Column '{col_name}' minimum >= {min_value}" + if description is None: + description = ( + f"Validate that column '{col_name}' has minimum value >= {min_value}" + ) + + return _expect( + expr=self >= min_value, + name=name, + description=description, + error_on_failure=error_on_failure, + ) + + def expect_max( + self, + max_value: Union[int, float], + *, + name: Optional[str] = None, + description: Optional[str] = None, + error_on_failure: bool = True, + ) -> "Expectation": + """Create an expectation that this expression's values meet a maximum threshold. + + This is a convenience method for the common pattern of checking maximum values. + + Examples: + >>> from ray.data.expressions import col + >>> import ray + >>> ds = ray.data.from_items([{"age": 25}, {"age": 30}]) + >>> # Create expectation directly from column expression + >>> exp = col("age").expect_max(120) + >>> passed_ds, failed_ds, result = ds.expect(exp) + + Args: + max_value: Maximum allowed value (inclusive). + name: Optional name for the expectation. + description: Optional description of what this expectation checks. + error_on_failure: If True, raise exception on failure; if False, log warning. + + Returns: + An Expectation object that can be used with Dataset.expect(). + """ + from ray.data.expectations import expect as _expect + + col_name = self.name if self.name is not None else "column" + if name is None: + name = f"Column '{col_name}' maximum <= {max_value}" + if description is None: + description = ( + f"Validate that column '{col_name}' has maximum value <= {max_value}" + ) + + return _expect( + expr=self <= max_value, + name=name, + description=description, + error_on_failure=error_on_failure, + ) + + def expect_range( + self, + min_value: Union[int, float], + max_value: Union[int, float], + *, + name: Optional[str] = None, + description: Optional[str] = None, + error_on_failure: bool = True, + ) -> "Expectation": + """Create an expectation that this expression's values are within a range. + + This is a convenience method for the common pattern of checking value ranges. + + Examples: + >>> from ray.data.expressions import col + >>> import ray + >>> ds = ray.data.from_items([{"age": 25}, {"age": 30}]) + >>> # Create expectation directly from column expression + >>> exp = col("age").expect_range(0, 120) + >>> passed_ds, failed_ds, result = ds.expect(exp) + + Args: + min_value: Minimum allowed value (inclusive). + max_value: Maximum allowed value (inclusive). + name: Optional name for the expectation. + description: Optional description of what this expectation checks. + error_on_failure: If True, raise exception on failure; if False, log warning. + + Returns: + An Expectation object that can be used with Dataset.expect(). + """ + from ray.data.expectations import expect as _expect + + col_name = self.name if self.name is not None else "column" + if name is None: + name = f"Column '{col_name}' in range [{min_value}, {max_value}]" + if description is None: + description = ( + f"Validate that column '{col_name}' values are in range " + f"[{min_value}, {max_value}]" + ) + + return _expect( + expr=(self >= min_value) & (self <= max_value), + name=name, + description=description, + error_on_failure=error_on_failure, + ) + + def expect_not_null( + self, + *, + name: Optional[str] = None, + description: Optional[str] = None, + error_on_failure: bool = True, + ) -> "Expectation": + """Create an expectation that this expression has no null values. + + This is a convenience method for the common pattern of checking for nulls. + + Examples: + >>> from ray.data.expressions import col + >>> import ray + >>> ds = ray.data.from_items([{"email": "test@example.com"}, {"email": None}]) + >>> # Create expectation directly from column expression + >>> exp = col("email").expect_not_null() + >>> passed_ds, failed_ds, result = ds.expect(exp) + + Args: + name: Optional name for the expectation. + description: Optional description of what this expectation checks. + error_on_failure: If True, raise exception on failure; if False, log warning. + + Returns: + An Expectation object that can be used with Dataset.expect(). + """ + from ray.data.expectations import expect as _expect + + col_name = self.name if self.name is not None else "column" + if name is None: + name = f"Column '{col_name}' is not null" + if description is None: + description = f"Validate that column '{col_name}' has no null values" + + return _expect( + expr=self.is_not_null(), + name=name, + description=description, + error_on_failure=error_on_failure, + ) + + def expect_in( + self, + values: Union[List[Any], Set[Any]], + *, + name: Optional[str] = None, + description: Optional[str] = None, + error_on_failure: bool = True, + ) -> "Expectation": + """Create an expectation that this expression's values are in a set of allowed values. + + This is a convenience method for the common pattern of checking allowed values. + + Examples: + >>> from ray.data.expressions import col + >>> import ray + >>> ds = ray.data.from_items([{"status": "active"}, {"status": "invalid"}]) + >>> # Create expectation directly from column expression + >>> exp = col("status").expect_in(["active", "inactive", "pending"]) + >>> passed_ds, failed_ds, result = ds.expect(exp) + + Args: + values: Set or list of allowed values. + name: Optional name for the expectation. + description: Optional description of what this expectation checks. + error_on_failure: If True, raise exception on failure; if False, log warning. + + Returns: + An Expectation object that can be used with Dataset.expect(). + """ + from ray.data.expectations import expect as _expect + + if not isinstance(values, (list, set, tuple)): + raise TypeError( + f"values must be list, set, or tuple, got {type(values).__name__}" + ) + + values_set = set(values) + col_name = self.name if self.name is not None else "column" + if name is None: + name = f"Column '{col_name}' in allowed values" + if description is None: + description = ( + f"Validate that column '{col_name}' values are in {values_set}" + ) + + return _expect( + expr=self.is_in(list(values_set)), + name=name, + description=description, + error_on_failure=error_on_failure, + ) + @DeveloperAPI(stability="alpha") @dataclass(frozen=True, eq=False, repr=False) diff --git a/python/ray/data/read_api.py b/python/ray/data/read_api.py index c58408480c1f..9467c7644c38 100644 --- a/python/ray/data/read_api.py +++ b/python/ray/data/read_api.py @@ -33,7 +33,6 @@ DeltaSharingDatasource, ) from ray.data._internal.datasource.hudi_datasource import HudiDatasource -from ray.data._internal.datasource.iceberg_datasource import IcebergDatasource from ray.data._internal.datasource.image_datasource import ( ImageDatasource, ImageFileMetadataProvider, @@ -43,6 +42,10 @@ ArrowJSONDatasource, PandasJSONDatasource, ) +from ray.data._internal.datasource.kafka_datasource import ( + KafkaAuthConfig, + KafkaDatasource, +) from ray.data._internal.datasource.lance_datasource import LanceDatasource from ray.data._internal.datasource.mcap_datasource import MCAPDatasource, TimeRange from ray.data._internal.datasource.mongo_datasource import MongoDatasource @@ -3988,19 +3991,23 @@ def read_iceberg( Examples: >>> import ray - >>> from pyiceberg.expressions import EqualTo #doctest: +SKIP + >>> from ray.data.expressions import col #doctest: +SKIP + >>> # Read the table and apply filters using Ray Data expressions >>> ds = ray.data.read_iceberg( #doctest: +SKIP ... table_identifier="db_name.table_name", - ... row_filter=EqualTo("column_name", "literal_value"), ... catalog_kwargs={"name": "default", "type": "glue"} - ... ) + ... ).filter(col("column_name") == "literal_value") + >>> # Select specific columns + >>> ds = ds.select_columns(["col1", "col2"]) #doctest: +SKIP Args: table_identifier: Fully qualified table identifier (``db_name.table_name``) - row_filter: A PyIceberg :class:`~pyiceberg.expressions.BooleanExpression` - to use to filter the data *prior* to reading + row_filter: **Deprecated**. Use ``.filter()`` method on the dataset instead. + A PyIceberg :class:`~pyiceberg.expressions.BooleanExpression` + to use to filter the data *prior* to reading. parallelism: This argument is deprecated. Use ``override_num_blocks`` argument. - selected_fields: Which columns from the data to read, passed directly to + selected_fields: **Deprecated**. Use ``.select_columns()`` method on the dataset instead. + Which columns from the data to read, passed directly to PyIceberg's load functions. Should be an tuple of string column names. snapshot_id: Optional snapshot ID for the Iceberg table, by default the latest snapshot is used @@ -4027,6 +4034,27 @@ def read_iceberg( Returns: :class:`~ray.data.Dataset` with rows from the Iceberg table. """ + from ray.data._internal.datasource.iceberg_datasource import IcebergDatasource + + # Deprecation warning for row_filter parameter + if row_filter is not None: + warnings.warn( + "The 'row_filter' parameter is deprecated and will be removed in a " + "future release. Use the .filter() method on the dataset instead. " + "For example: ds = ray.data.read_iceberg(...).filter(col('column') > 5)", + DeprecationWarning, + stacklevel=2, + ) + + # Deprecation warning for selected_fields parameter + if selected_fields != ("*",): + warnings.warn( + "The 'selected_fields' parameter is deprecated and will be removed in a " + "future release. Use the .select_columns() method on the dataset instead. " + "For example: ds = ray.data.read_iceberg(...).select_columns(['col1', 'col2'])", + DeprecationWarning, + stacklevel=2, + ) # Setup the Datasource datasource = IcebergDatasource( @@ -4055,6 +4083,7 @@ def read_iceberg( def read_lance( uri: str, *, + version: Optional[Union[int, str]] = None, columns: Optional[List[str]] = None, filter: Optional[str] = None, storage_options: Optional[Dict[str, str]] = None, @@ -4081,6 +4110,9 @@ def read_lance( Args: uri: The URI of the Lance dataset to read from. Local file paths, S3, and GCS are supported. + version: Load a specific version of the Lance dataset. This can be an + integer version number or a string tag. By default, the + latest version is loaded. columns: The columns to read. By default, all columns are read. filter: Read returns only the rows matching the filter. By default, no filter is applied. @@ -4112,6 +4144,7 @@ def read_lance( """ # noqa: E501 datasource = LanceDatasource( uri=uri, + version=version, columns=columns, filter=filter, storage_options=storage_options, @@ -4282,6 +4315,7 @@ def read_unity_catalog( @PublicAPI(stability="alpha") def read_delta( path: Union[str, List[str]], + version: Optional[int] = None, *, filesystem: Optional["pyarrow.fs.FileSystem"] = None, columns: Optional[List[str]] = None, @@ -4309,6 +4343,7 @@ def read_delta( Args: path: A single file path for a Delta Lake table. Multiple tables are not yet supported. + version: The version of the Delta Lake table to read. If not specified, the latest version is read. filesystem: The PyArrow filesystem implementation to read from. These filesystems are specified in the `pyarrow docs Dataset: + """Read data from Kafka topics. + + This function supports bounded reads from Kafka topics, reading messages + between a start and end offset. Only the "once" trigger is + supported for now, which performs a single bounded read. Currently we only + have one read task for each partition. + + Examples: + + .. testcode:: + :skipif: True + + import ray + + # Read from a single topic with offset range + ds = ray.data.read_kafka( + topics="my-topic", + bootstrap_servers="localhost:9092", + start_offset=0, + end_offset=1000, + ) + + + Args: + topics: Kafka topic name(s) to read from. Can be a single topic name + or a list of topic names. + bootstrap_servers: Kafka broker addresses. Can be a single string or + a list of strings. + trigger: Trigger mode for reading. Only "once" is supported, which + performs a single bounded read. + start_offset: Starting position for reading. Can be: + - int: Offset number + - str: "earliest" + end_offset: Ending position for reading (exclusive). Can be: + - int: Offset number + - str: "latest" + kafka_auth_config: Authentication configuration. See KafkaAuthConfig for details. + num_cpus: The number of CPUs to reserve for each parallel read worker. + num_gpus: The number of GPUs to reserve for each parallel read worker. + memory: The heap memory in bytes to reserve for each parallel read worker. + ray_remote_args: kwargs passed to :func:`ray.remote` in the read tasks. + override_num_blocks: Override the number of output blocks from all read tasks. + By default, the number of output blocks is dynamically decided based on + input data size and available resources. You shouldn't manually set this + value in most cases. + timeout_ms: Timeout in milliseconds for every read task to poll until reaching end_offset (default 10000ms). + If the read task does not reach end_offset within the timeout, it will stop polling and return the messages + it has read so far. + + Returns: + A :class:`~ray.data.Dataset` containing Kafka messages with the following schema: + - offset: int64 - Message offset within partition + - key: binary - Message key as raw bytes + - value: binary - Message value as raw bytes + - topic: string - Topic name + - partition: int32 - Partition ID + - timestamp: int64 - Message timestamp in milliseconds + - timestamp_type: int32 - 0=CreateTime, 1=LogAppendTime + - headers: map - Message headers (keys as strings, values as bytes) + + Raises: + ValueError: If invalid parameters are provided. + ImportError: If kafka-python is not installed. + """ # noqa: E501 + if trigger != "once": + raise ValueError(f"Only trigger='once' is supported. Got trigger={trigger!r}") + + return ray.data.read_datasource( + KafkaDatasource( + topics=topics, + bootstrap_servers=bootstrap_servers, + start_offset=start_offset, + end_offset=end_offset, + kafka_auth_config=kafka_auth_config, + timeout_ms=timeout_ms, + ), + parallelism=-1, + num_cpus=num_cpus, + num_gpus=num_gpus, + memory=memory, + ray_remote_args=ray_remote_args, + override_num_blocks=override_num_blocks, + ) + + def _get_datasource_or_legacy_reader( ds: Datasource, ctx: DataContext, diff --git a/python/ray/data/tests/test_expectations.py b/python/ray/data/tests/test_expectations.py new file mode 100644 index 000000000000..9dfebdc3a8d8 --- /dev/null +++ b/python/ray/data/tests/test_expectations.py @@ -0,0 +1,627 @@ +"""Tests for Ray Data Expectations API.""" + +import datetime +import pickle + +import numpy as np +import pandas as pd +import pytest + +import ray +from ray.data._internal.plan import ExecutionPlan +from ray.data._internal.stats import DatasetStats +from ray.data.context import DataContext +from ray.data.expectations import ( + DataQualityExpectation, + ExecutionTimeExpectation, + Expectation, + ExpectationResult, + ExpectationType, + expect, +) +from ray.data.expressions import col +from ray.tests.conftest import * # noqa + + +def test_expectation_creation(): + """Test creating basic expectation objects.""" + exp = Expectation( + name="test", + description="Test expectation", + expectation_type=ExpectationType.DATA_QUALITY, + ) + assert exp.name == "test" + assert exp.description == "Test expectation" + assert exp.expectation_type == ExpectationType.DATA_QUALITY + assert exp.error_on_failure is True + + +def test_expectation_validation_raises(): + """Test that base expectation.validate() raises NotImplementedError.""" + exp = Expectation( + name="test", + description="Test", + expectation_type=ExpectationType.DATA_QUALITY, + ) + with pytest.raises(NotImplementedError): + exp.validate(None) + + +def test_expectation_empty_name(): + """Test that empty name raises ValueError.""" + with pytest.raises(ValueError, match="name cannot be empty"): + Expectation( + name="", + description="Test", + expectation_type=ExpectationType.DATA_QUALITY, + ) + + +def test_expectation_empty_description(): + """Test that empty description raises ValueError.""" + with pytest.raises(ValueError, match="description cannot be empty"): + Expectation( + name="test", + description="", + expectation_type=ExpectationType.DATA_QUALITY, + ) + + +def test_data_quality_creation(): + """Test creating data quality expectations.""" + + def validator(batch): + return batch["value"].min() > 0 + + exp = DataQualityExpectation( + name="positive_values", + description="Values must be positive", + validator_fn=validator, + ) + + assert exp.name == "positive_values" + assert exp.expectation_type == ExpectationType.DATA_QUALITY + assert exp.validator_fn == validator + + +def test_data_quality_validation_passes(): + """Test data quality validation when it passes.""" + + def validator(batch): + return batch["value"].min() > 0 + + exp = DataQualityExpectation( + name="positive_values", + description="Values must be positive", + validator_fn=validator, + ) + + batch = {"value": np.array([1, 2, 3, 4])} + assert exp.validate(batch) is True + + +def test_data_quality_validation_fails(): + """Test data quality validation when it fails.""" + + def validator(batch): + return batch["value"].min() > 0 + + exp = DataQualityExpectation( + name="positive_values", + description="Values must be positive", + validator_fn=validator, + ) + + batch = {"value": np.array([-1, 2, 3, 4])} + assert exp.validate(batch) is False + + +def test_data_quality_with_pandas(): + """Test data quality validation with pandas DataFrames.""" + + def validator(batch): + return batch["value"].min() > 0 + + exp = DataQualityExpectation( + name="positive_values", + description="Values must be positive", + validator_fn=validator, + ) + + batch = pd.DataFrame({"value": [1, 2, 3, 4]}) + assert exp.validate(batch) is True + + +def test_execution_time_creation_with_seconds(): + """Test creating execution time expectations with seconds.""" + exp = ExecutionTimeExpectation( + name="fast_job", + description="Job must finish in 60 seconds", + max_execution_time_seconds=60.0, + ) + + assert exp.name == "fast_job" + assert exp.max_execution_time_seconds == 60.0 + + +def test_execution_time_creation_with_timedelta(): + """Test creating execution time expectations with timedelta.""" + exp = ExecutionTimeExpectation( + name="fast_job", + description="Job must finish in 60 seconds", + max_execution_time=datetime.timedelta(seconds=60), + ) + assert exp.max_execution_time_seconds == 60.0 + + +def test_execution_time_creation_with_target_time(): + """Test creating execution time expectations with target completion time.""" + target = datetime.datetime.now() + datetime.timedelta(seconds=60) + + exp = ExecutionTimeExpectation( + name="fast_job", + description="Job must finish by target time", + target_completion_time=target, + ) + + assert exp.max_execution_time_seconds is not None + assert exp.max_execution_time_seconds > 0 + + +def test_execution_time_validation_passes(): + """Test execution time validation when it passes.""" + exp = ExecutionTimeExpectation( + name="fast_job", + description="Job must finish in 60 seconds", + max_execution_time_seconds=60.0, + ) + + assert exp.validate(30.0) is True + + +def test_execution_time_validation_fails(): + """Test execution time validation when it fails.""" + exp = ExecutionTimeExpectation( + name="fast_job", + description="Job must finish in 60 seconds", + max_execution_time_seconds=60.0, + ) + assert exp.validate(90.0) is False + + +def test_execution_time_no_time_constraints_error(): + """Test that execution time expectation without time constraints raises error.""" + with pytest.raises(ValueError, match="at least one time constraint"): + ExecutionTimeExpectation( + name="fast_job", + description="Job must finish quickly", + ) + + +def test_expect_with_expression(ray_start_regular_shared): + """Test dataset.expect() with expression.""" + ds = ray.data.from_items([{"value": 1}, {"value": -1}]) + passed_ds, failed_ds, result = ds.expect(expr=col("value") > 0) + + assert result.passed is False + assert failed_ds.count() == 1 + + +def test_expect_complex_expression(ray_start_regular_shared): + """Test expression with AND/OR conditions.""" + ds = ray.data.from_items([{"value": 50}, {"value": 150}]) + passed_ds, failed_ds, result = ds.expect( + expr=(col("value") >= 0) & (col("value") <= 100) + ) + assert result.passed is False + assert failed_ds.count() == 1 + + +def test_expect_expression_with_null_handling(ray_start_regular_shared): + """Test expression with null value handling.""" + ds = ray.data.from_items([{"value": 1}, {"value": None}, {"value": 3}]) + + passed_ds, failed_ds, result = ds.expect(expr=col("value").is_not_null()) + assert result.passed is False + + +def test_expect_cannot_specify_both_validator_fn_and_expr(): + """Test that specifying both validator_fn and expr raises error.""" + with pytest.raises(ValueError, match="Cannot specify both"): + expect( + validator_fn=lambda batch: True, + expr=col("value") > 0, + ) + + +def test_dataset_expect_basic(ray_start_regular_shared): + """Test basic dataset.expect() usage.""" + ds = ray.data.range(10) + + def validator(batch): + return batch["id"] >= 0 + + exp = DataQualityExpectation( + name="non_negative", + description="IDs must be non-negative", + validator_fn=validator, + ) + + passed_ds, failed_ds, result = ds.expect(exp) + + assert result.passed is True + assert passed_ds.count() == 10 + assert failed_ds.count() == 0 + + +def test_dataset_expect_validation_fails(ray_start_regular_shared): + """Test dataset.expect() when validation fails.""" + ds = ray.data.from_items([{"value": -1}, {"value": 2}]) + + def validator(batch): + return batch["value"] > 0 + + exp = DataQualityExpectation( + name="positive", + description="Values must be positive", + validator_fn=validator, + error_on_failure=False, + ) + + passed_ds, failed_ds, result = ds.expect(exp) + + assert result.passed is False + assert failed_ds.count() > 0 + + +def test_dataset_expect_with_expression(ray_start_regular_shared): + """Test dataset.expect() with expression.""" + ds = ray.data.range(10) + + passed_ds, failed_ds, result = ds.expect(expr=col("id") >= 0) + + assert result.passed is True + assert passed_ds.count() == 10 + assert failed_ds.count() == 0 + + +def test_dataset_expect_execution_time_expectation(ray_start_regular_shared): + """Test dataset.expect() with execution time expectation.""" + ds = ray.data.range(10) + + execution_time_exp = ExecutionTimeExpectation( + name="fast", + description="Must be fast", + max_execution_time_seconds=60.0, + ) + + passed_ds, failed_ds, result = ds.expect(execution_time_exp) + assert isinstance(result, ExpectationResult) + + +def test_dataset_expect_data_unchanged(ray_start_regular_shared): + """Test that dataset.expect() doesn't modify data.""" + original_data = [{"value": i} for i in range(10)] + ds = ray.data.from_items(original_data) + + def validator(batch): + return batch["value"] >= 0 + + exp = DataQualityExpectation( + name="non_negative", + description="Values must be non-negative", + validator_fn=validator, + ) + + passed_ds, failed_ds, result = ds.expect(exp) + + assert passed_ds.take_all() == original_data + + +def test_execution_plan_add_execution_time_expectation(): + """Test adding execution time expectations to execution plan.""" + context = DataContext() + stats = DatasetStats(metadata={}, parent=None) + plan = ExecutionPlan(stats, run_by_consumer=False, dataset_context=context) + + exp = ExecutionTimeExpectation( + name="fast_job", + description="Job must finish in 60 seconds", + max_execution_time_seconds=60.0, + ) + + plan.add_execution_time_expectation(exp) + + assert plan.get_max_execution_time_seconds() == 60.0 + assert len(plan.get_execution_time_expectations()) == 1 + + +def test_execution_plan_multiple_execution_time_expectations(): + """Test multiple execution time expectations with minimum time.""" + context = DataContext() + stats = DatasetStats(metadata={}, parent=None) + plan = ExecutionPlan(stats, run_by_consumer=False, dataset_context=context) + + plan.add_execution_time_expectation( + ExecutionTimeExpectation( + name="job1", + description="Job 1", + max_execution_time_seconds=120.0, + ) + ) + + plan.add_execution_time_expectation( + ExecutionTimeExpectation( + name="job2", + description="Job 2", + max_execution_time_seconds=60.0, + ) + ) + + assert plan.get_max_execution_time_seconds() == 60.0 + assert len(plan.get_execution_time_expectations()) == 2 + + +def test_execution_plan_copy_preserves_execution_time_expectations(): + """Test that copying execution plan preserves execution time expectations.""" + context = DataContext() + stats = DatasetStats(metadata={}, parent=None) + plan = ExecutionPlan(stats, run_by_consumer=False, dataset_context=context) + + exp = ExecutionTimeExpectation( + name="fast_job", + description="Job must finish in 60 seconds", + max_execution_time_seconds=60.0, + ) + plan.add_execution_time_expectation(exp) + + copied_plan = plan.copy() + + assert copied_plan.get_max_execution_time_seconds() == 60.0 + + +def test_expectations_persist_through_filter(ray_start_regular_shared): + """Test execution time expectations persist through filter operation.""" + exp = expect(max_execution_time_seconds=60.0) + ds = ray.data.range(100) + ds._plan.add_execution_time_expectation(exp) + filtered_ds = ds.filter(lambda row: row["id"] % 2 == 0) + + assert filtered_ds._plan.get_max_execution_time_seconds() == 60.0 + + +def test_expectations_persist_through_union(ray_start_regular_shared): + """Test execution time expectations persist through union operation.""" + exp = expect(max_execution_time_seconds=60.0) + ds1 = ray.data.range(50) + ds1._plan.add_execution_time_expectation(exp) + ds2 = ray.data.range(50, 100) + + unioned_ds = ds1.union(ds2) + + assert unioned_ds._plan.get_max_execution_time_seconds() == 60.0 + + +def test_expectations_persist_through_groupby(ray_start_regular_shared): + """Test execution time expectations persist through groupby operation.""" + exp = expect(max_execution_time_seconds=60.0) + ds = ray.data.range(100) + ds._plan.add_execution_time_expectation(exp) + grouped_ds = ds.groupby("id").count() + + assert grouped_ds._plan.get_max_execution_time_seconds() == 60.0 + + +def test_expectations_persist_through_repartition(ray_start_regular_shared): + """Test execution time expectations persist through repartition.""" + exp = expect(max_execution_time_seconds=60.0) + ds = ray.data.range(100) + ds._plan.add_execution_time_expectation(exp) + repartitioned_ds = ds.repartition(5) + + assert repartitioned_ds._plan.get_max_execution_time_seconds() == 60.0 + + +def test_expectations_persist_through_select_columns(ray_start_regular_shared): + """Test execution time expectations persist through select_columns.""" + exp = expect(max_execution_time_seconds=60.0) + ds = ray.data.range(100) + ds._plan.add_execution_time_expectation(exp) + selected_ds = ds.select_columns(cols=["id"]) + + assert selected_ds._plan.get_max_execution_time_seconds() == 60.0 + + +def test_expect_with_empty_dataset(ray_start_regular_shared): + """Test expectations with empty dataset.""" + ds = ray.data.range(0) + + def validator(batch): + return True + + exp = DataQualityExpectation( + name="always_true", + description="Always passes", + validator_fn=validator, + ) + + passed_ds, failed_ds, result = ds.expect(exp) + + assert result.passed is True + assert passed_ds.count() == 0 + assert failed_ds.count() == 0 + + +def test_expression_with_all_nulls(ray_start_regular_shared): + """Test expression validation with all null values.""" + ds = ray.data.from_items([{"value": None}, {"value": None}]) + + passed_ds, failed_ds, result = ds.expect( + expr=col("value").is_null(), error_on_failure=False + ) + + assert result.passed is True + + +def test_validator_function_raises_exception(ray_start_regular_shared): + """Test that validator exceptions are handled.""" + + def bad_validator(batch): + raise RuntimeError("Validation error") + + ds = ray.data.range(10) + exp = DataQualityExpectation( + name="bad_validator", + description="Validator that raises", + validator_fn=bad_validator, + error_on_failure=False, + ) + + passed_ds, failed_ds, result = ds.expect(exp) + assert result.passed is False + + +def test_execution_time_with_zero_time(): + """Test execution time expectation with zero execution time.""" + with pytest.raises(ValueError, match="must be positive"): + ExecutionTimeExpectation( + name="instant_job", + description="Job must finish instantly", + max_execution_time_seconds=0.0, + ) + + +def test_execution_time_with_negative_time(): + """Test execution time expectation with negative execution time.""" + with pytest.raises(ValueError, match="must be positive"): + ExecutionTimeExpectation( + name="time_travel_job", + description="Job finished yesterday", + max_execution_time_seconds=-10.0, + ) + + +def test_expectation_serialization(): + """Test that expectations can be pickled for distributed execution.""" + exp = DataQualityExpectation( + name="positive_values", + description="Values must be positive", + validator_fn=lambda batch: batch["value"] > 0, + ) + + pickled = pickle.dumps(exp) + unpickled = pickle.loads(pickled) + + assert unpickled.name == exp.name + assert unpickled.description == exp.description + + +def test_execution_time_expectation_serialization(): + """Test that execution time expectations can be pickled.""" + exp = ExecutionTimeExpectation( + name="fast_job", + description="Job must finish in 60 seconds", + max_execution_time_seconds=60.0, + ) + + pickled = pickle.dumps(exp) + unpickled = pickle.loads(pickled) + + assert unpickled.name == exp.name + assert unpickled.max_execution_time_seconds == 60.0 + + +def test_expectation_in_pipeline(ray_start_regular_shared): + """Test expectations in multi-stage pipeline.""" + ds = ray.data.range(100) + + passed_ds, failed_ds, result1 = ds.expect(expr=col("id") >= 0) + assert result1.passed is True + + processed_ds, remaining_ds, result2 = passed_ds.expect( + max_execution_time_seconds=60.0 + ) + assert isinstance(result2, ExpectationResult) + + assert processed_ds.count() == 100 + + +def test_expectation_chaining(ray_start_regular_shared): + """Test chaining multiple dataset.expect() calls.""" + ds = ray.data.from_items([{"value": i} for i in range(1, 11)]) + + passed_ds1, failed_ds1, result1 = ds.expect(expr=col("value") > 0) + passed_ds2, failed_ds2, result2 = passed_ds1.expect(expr=col("value") < 100) + + assert result1.passed is True + assert result2.passed is True + assert passed_ds2.count() == 10 + + +def test_both_execution_time_and_data_quality_expectations(ray_start_regular_shared): + """Test using both execution time and data quality expectations together.""" + ds = ray.data.from_items([{"value": i} for i in range(1, 11)]) + + passed_ds, failed_ds, dq_result = ds.expect(expr=col("value") > 0) + assert dq_result.passed is True + + processed_ds, remaining_ds, execution_time_result = passed_ds.expect( + max_execution_time_seconds=60.0 + ) + assert isinstance(execution_time_result, ExpectationResult) + + +def test_expectation_with_large_dataset(ray_start_regular_shared): + """Test expectations work with larger datasets.""" + ds = ray.data.range(10000) + + passed_ds, failed_ds, result = ds.expect(expr=col("id") >= 0) + + assert result.passed is True + assert passed_ds.count() == 10000 + + +def test_expectation_result_with_all_fields(ray_start_regular_shared): + """Test ExpectationResult with all fields.""" + from ray.data.expressions import col + + exp = expect(expr=col("value") > 0, name="test_expectation") + result = ExpectationResult( + expectation=exp, + passed=True, + total_count=1000, + failure_count=0, + execution_time_seconds=5.2, + message="All checks passed for test_expectation", + ) + + assert result.passed is True + assert result.total_count == 1000 + assert result.failure_count == 0 + assert result.execution_time_seconds == 5.2 + assert "test_expectation" in result.message + + +def test_expectations_import(): + """Test importing expectations from ray.data.""" + from ray.data import ( + DataQualityExpectation, + ExecutionTimeExpectation, + Expectation, + ExpectationResult, + ExpectationType, + expect, + ) + + assert Expectation is not None + assert DataQualityExpectation is not None + assert ExecutionTimeExpectation is not None + assert ExpectationResult is not None + assert ExpectationType is not None + assert expect is not None + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main(["-v", __file__]))