Skip to content

Commit d37de68

Browse files
committed
fixing tests
1 parent 4e9cba5 commit d37de68

File tree

6 files changed

+104
-90
lines changed

6 files changed

+104
-90
lines changed

dbldatagen/spec/column_spec.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
from __future__ import annotations
2+
3+
from typing import Any, Literal
4+
5+
from .compat import BaseModel, root_validator
6+
7+
8+
DbldatagenBasicType = Literal[
9+
"string",
10+
"int",
11+
"long",
12+
"float",
13+
"double",
14+
"decimal",
15+
"boolean",
16+
"date",
17+
"timestamp",
18+
"short",
19+
"byte",
20+
"binary",
21+
"integer",
22+
"bigint",
23+
"tinyint",
24+
]
25+
class ColumnDefinition(BaseModel):
26+
name: str
27+
type: DbldatagenBasicType | None = None
28+
primary: bool = False
29+
options: dict[str, Any] | None = None
30+
nullable: bool | None = False
31+
omit: bool | None = False
32+
baseColumn: str | None = "id"
33+
baseColumnType: str | None = "auto"
34+
35+
@root_validator()
36+
def check_model_constraints(cls, values: dict[str, Any]) -> dict[str, Any]:
37+
"""
38+
Validates constraints across the entire model after individual fields are processed.
39+
"""
40+
is_primary = values.get("primary")
41+
options = values.get("options") or {} # Handle None case
42+
name = values.get("name")
43+
is_nullable = values.get("nullable")
44+
column_type = values.get("type")
45+
46+
if is_primary:
47+
if "min" in options or "max" in options:
48+
raise ValueError(f"Primary column '{name}' cannot have min/max options.")
49+
50+
if is_nullable:
51+
raise ValueError(f"Primary column '{name}' cannot be nullable.")
52+
53+
if column_type is None:
54+
raise ValueError(f"Primary column '{name}' must have a type defined.")
55+
return values

dbldatagen/spec/compat.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,12 @@
22

33
try:
44
# This will succeed on environments with Pydantic V2.x
5-
# It imports the V1 API that is bundled within V2.
6-
from pydantic.v1 import BaseModel, Field, validator, constr
7-
5+
from pydantic.v1 import BaseModel, Field, constr, root_validator, validator
86
except ImportError:
97
# This will be executed on environments with only Pydantic V1.x
10-
from pydantic import BaseModel, Field, validator, constr, root_validator, field_validator
8+
from pydantic import BaseModel, Field, constr, root_validator, validator # type: ignore[assignment,no-redef]
119

10+
__all__ = ["BaseModel", "Field", "constr", "root_validator", "validator"]
1211
# In your application code, do this:
1312
# from .compat import BaseModel
1413
# NOT this:
@@ -28,4 +27,4 @@
2827
2928
Future-Ready: When you eventually decide to migrate fully to the Pydantic V2 API (to take advantage of its speed and features),
3029
you only need to change your application code and your compat.py import statements, making the transition much clearer.
31-
"""
30+
"""

dbldatagen/spec/generator_spec.py

Lines changed: 30 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,103 +1,62 @@
1-
from .compat import BaseModel, validator, root_validator, field_validator
2-
from typing import Dict, Optional, Union, Any, Literal, List
1+
from __future__ import annotations
2+
3+
import logging
4+
from typing import Any, Literal, Union
5+
36
import pandas as pd
4-
from IPython.display import display, HTML
5-
6-
DbldatagenBasicType = Literal[
7-
"string",
8-
"int",
9-
"long",
10-
"float",
11-
"double",
12-
"decimal",
13-
"boolean",
14-
"date",
15-
"timestamp",
16-
"short",
17-
"byte",
18-
"binary",
19-
"integer",
20-
"bigint",
21-
"tinyint",
22-
]
23-
24-
class ColumnDefinition(BaseModel):
25-
name: str
26-
type: Optional[DbldatagenBasicType] = None
27-
primary: bool = False
28-
options: Optional[Dict[str, Any]] = {}
29-
nullable: Optional[bool] = False
30-
omit: Optional[bool] = False
31-
baseColumn: Optional[str] = "id"
32-
baseColumnType: Optional[str] = "auto"
33-
34-
@root_validator(skip_on_failure=True)
35-
def check_model_constraints(cls, values: Dict[str, Any]) -> Dict[str, Any]:
36-
"""
37-
Validates constraints across the entire model after individual fields are processed.
38-
"""
39-
is_primary = values.get("primary")
40-
options = values.get("options", {})
41-
name = values.get("name")
42-
is_nullable = values.get("nullable")
43-
column_type = values.get("type")
7+
from IPython.display import HTML, display
448

45-
if is_primary:
46-
if "min" in options or "max" in options:
47-
raise ValueError(f"Primary column '{name}' cannot have min/max options.")
9+
from dbldatagen.spec.column_spec import ColumnDefinition
4810

49-
if is_nullable:
50-
raise ValueError(f"Primary column '{name}' cannot be nullable.")
11+
from .compat import BaseModel, validator
5112

52-
if column_type is None:
53-
raise ValueError(f"Primary column '{name}' must have a type defined.")
54-
return values
5513

14+
logger = logging.getLogger(__name__)
5615

5716
class UCSchemaTarget(BaseModel):
5817
catalog: str
5918
schema_: str
6019
output_format: str = "delta" # Default to delta for UC Schema
6120

62-
@field_validator("catalog", "schema_", mode="after")
63-
def validate_identifiers(cls, v): # noqa: N805, pylint: disable=no-self-argument
21+
@validator("catalog", "schema_")
22+
def validate_identifiers(cls, v: str) -> str:
6423
if not v.strip():
6524
raise ValueError("Identifier must be non-empty.")
6625
if not v.isidentifier():
6726
logger.warning(
6827
f"'{v}' is not a basic Python identifier. Ensure validity for Unity Catalog.")
6928
return v.strip()
7029

71-
def __str__(self):
30+
def __str__(self) -> str:
7231
return f"{self.catalog}.{self.schema_} (Format: {self.output_format}, Type: UC Table)"
7332

7433

7534
class FilePathTarget(BaseModel):
7635
base_path: str
7736
output_format: Literal["csv", "parquet"] # No default, must be specified
7837

79-
@field_validator("base_path", mode="after")
80-
def validate_base_path(cls, v): # noqa: N805, pylint: disable=no-self-argument
38+
@validator("base_path")
39+
def validate_base_path(cls, v: str) -> str:
8140
if not v.strip():
8241
raise ValueError("base_path must be non-empty.")
8342
return v.strip()
8443

85-
def __str__(self):
44+
def __str__(self) -> str:
8645
return f"{self.base_path} (Format: {self.output_format}, Type: File Path)"
8746

8847

8948
class TableDefinition(BaseModel):
9049
number_of_rows: int
91-
partitions: Optional[int] = None
92-
columns: List[ColumnDefinition]
50+
partitions: int | None = None
51+
columns: list[ColumnDefinition]
9352

9453

9554
class ValidationResult:
9655
"""Container for validation results with errors and warnings."""
9756

9857
def __init__(self) -> None:
99-
self.errors: List[str] = []
100-
self.warnings: List[str] = []
58+
self.errors: list[str] = []
59+
self.warnings: list[str] = []
10160

10261
def add_error(self, message: str) -> None:
10362
"""Add an error message."""
@@ -132,16 +91,16 @@ def __str__(self) -> str:
13291
return "\n".join(lines)
13392

13493
class DatagenSpec(BaseModel):
135-
tables: Dict[str, TableDefinition]
136-
output_destination: Optional[Union[UCSchemaTarget, FilePathTarget]] = None # there is a abstraction, may be we can use that? talk to Greg
137-
generator_options: Optional[Dict[str, Any]] = {}
138-
intended_for_databricks: Optional[bool] = None # May be infered.
94+
tables: dict[str, TableDefinition]
95+
output_destination: Union[UCSchemaTarget, FilePathTarget] | None = None # there is a abstraction, may be we can use that? talk to Greg
96+
generator_options: dict[str, Any] | None = None
97+
intended_for_databricks: bool | None = None # May be infered.
13998

14099
def _check_circular_dependencies(
141100
self,
142101
table_name: str,
143-
columns: List[ColumnDefinition]
144-
) -> List[str]:
102+
columns: list[ColumnDefinition]
103+
) -> list[str]:
145104
"""
146105
Check for circular dependencies in baseColumn references.
147106
Returns a list of error messages if circular dependencies are found.
@@ -152,13 +111,13 @@ def _check_circular_dependencies(
152111
for col in columns:
153112
if col.baseColumn and col.baseColumn != "id":
154113
# Track the dependency chain
155-
visited = set()
114+
visited: set[str] = set()
156115
current = col.name
157116

158117
while current:
159118
if current in visited:
160119
# Found a cycle
161-
cycle_path = " -> ".join(list(visited) + [current])
120+
cycle_path = " -> ".join([*list(visited), current])
162121
errors.append(
163122
f"Table '{table_name}': Circular dependency detected in column '{col.name}': {cycle_path}"
164123
)
@@ -182,7 +141,7 @@ def _check_circular_dependencies(
182141

183142
return errors
184143

185-
def validate(self, strict: bool = True) -> ValidationResult:
144+
def validate(self, strict: bool = True) -> ValidationResult: # type: ignore[override]
186145
"""
187146
Validates the entire DatagenSpec configuration.
188147
Always runs all validation checks and collects all errors and warnings.
@@ -284,17 +243,15 @@ def validate(self, strict: bool = True) -> ValidationResult:
284243
"random", "randomSeed", "randomSeedMethod", "verbose",
285244
"debug", "seedColumnName"
286245
]
287-
for key in self.generator_options.keys():
246+
for key in self.generator_options:
288247
if key not in known_options:
289248
result.add_warning(
290249
f"Unknown generator option: '{key}'. "
291250
"This may be ignored during generation."
292251
)
293252

294253
# Now that all validations are complete, decide whether to raise
295-
if strict and (result.errors or result.warnings):
296-
raise ValueError(str(result))
297-
elif not strict and result.errors:
254+
if (strict and (result.errors or result.warnings)) or (not strict and result.errors):
298255
raise ValueError(str(result))
299256

300257
return result

dbldatagen/spec/generator_spec_impl.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import logging
2-
from typing import Dict, Union
32
import posixpath
3+
from typing import Any, Union
44

5-
from dbldatagen.spec.generator_spec import TableDefinition
65
from pyspark.sql import SparkSession
6+
77
import dbldatagen as dg
8-
from .generator_spec import DatagenSpec, UCSchemaTarget, FilePathTarget, ColumnDefinition
8+
from dbldatagen.spec.generator_spec import TableDefinition
9+
10+
from .generator_spec import ColumnDefinition, DatagenSpec, FilePathTarget, UCSchemaTarget
911

1012

1113
logging.basicConfig(
@@ -41,7 +43,7 @@ def __init__(self, spark: SparkSession, app_name: str = "DataGen_ClassBased") ->
4143
self.app_name = app_name
4244
logger.info("Generator initialized with SparkSession")
4345

44-
def _columnspec_to_datagen_columnspec(self, col_def: ColumnDefinition) -> Dict[str, str]:
46+
def _columnspec_to_datagen_columnspec(self, col_def: ColumnDefinition) -> dict[str, Any]:
4547
"""
4648
Convert a ColumnDefinition to dbldatagen column specification.
4749
Args:
@@ -95,7 +97,7 @@ def _prepare_data_generators(
9597
self,
9698
config: DatagenSpec,
9799
config_source_name: str = "PydanticConfig"
98-
) -> Dict[str, dg.DataGenerator]:
100+
) -> dict[str, dg.DataGenerator]:
99101
"""
100102
Prepare DataGenerator specifications for each table based on the configuration.
101103
Args:
@@ -117,10 +119,10 @@ def _prepare_data_generators(
117119
raise RuntimeError(
118120
"SparkSession is not available. Cannot prepare data generators")
119121

120-
tables_config: Dict[str, TableDefinition] = config.tables
122+
tables_config: dict[str, TableDefinition] = config.tables
121123
global_gen_options = config.generator_options if config.generator_options else {}
122124

123-
prepared_generators: Dict[str, dg.DataGenerator] = {}
125+
prepared_generators: dict[str, dg.DataGenerator] = {}
124126
generation_order = list(tables_config.keys()) # This becomes impotant when we get into multitable
125127

126128
for table_name in generation_order:
@@ -156,7 +158,7 @@ def _prepare_data_generators(
156158

157159
def write_prepared_data(
158160
self,
159-
prepared_generators: Dict[str, dg.DataGenerator],
161+
prepared_generators: dict[str, dg.DataGenerator],
160162
output_destination: Union[UCSchemaTarget, FilePathTarget, None],
161163
config_source_name: str = "PydanticConfig",
162164
) -> None:
@@ -188,7 +190,7 @@ def write_prepared_data(
188190
logger.info(
189191
f"Built DataFrame for '{table_name}': {actual_row_count} rows (requested: {requested_rows})")
190192

191-
if actual_row_count == 0 and requested_rows > 0:
193+
if actual_row_count == 0 and requested_rows is not None and requested_rows > 0:
192194
logger.warning(f"Table '{table_name}': Requested {requested_rows} rows but built 0")
193195

194196
# Write data based on destination type
@@ -251,4 +253,4 @@ def generate_and_write_data(
251253
logger.error(
252254
f"Error during combined data generation and writing: {e}")
253255
raise RuntimeError(
254-
f"Error during combined data generation and writing: {e}") from e
256+
f"Error during combined data generation and writing: {e}") from e

makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ clean:
88

99
.venv/bin/python:
1010
pip install hatch
11-
hatch env create test-pydantic.pydantic==1.10.6-v1
11+
hatch env create
1212

1313
dev: .venv/bin/python
1414
@hatch run which python

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ dependencies = [
103103
"jmespath>=0.10.0",
104104
"py4j>=0.10.9",
105105
"pickleshare>=0.7.5",
106+
"ipython>=7.32.0",
106107
]
107108

108109
python="3.10"
@@ -431,7 +432,7 @@ check_untyped_defs = true
431432
disallow_untyped_decorators = false
432433
no_implicit_optional = true
433434
warn_redundant_casts = true
434-
warn_unused_ignores = true
435+
warn_unused_ignores = false
435436
warn_no_return = true
436437
warn_unreachable = true
437438
strict_equality = true

0 commit comments

Comments
 (0)