1- from .compat import BaseModel , validator , root_validator , field_validator
2- from typing import Dict , Optional , Union , Any , Literal , List
1+ from __future__ import annotations
2+
3+ import logging
4+ from typing import Any , Literal , Union
5+
36import pandas as pd
4- from IPython .display import display , HTML
5-
6- DbldatagenBasicType = Literal [
7- "string" ,
8- "int" ,
9- "long" ,
10- "float" ,
11- "double" ,
12- "decimal" ,
13- "boolean" ,
14- "date" ,
15- "timestamp" ,
16- "short" ,
17- "byte" ,
18- "binary" ,
19- "integer" ,
20- "bigint" ,
21- "tinyint" ,
22- ]
23-
24- class ColumnDefinition (BaseModel ):
25- name : str
26- type : Optional [DbldatagenBasicType ] = None
27- primary : bool = False
28- options : Optional [Dict [str , Any ]] = {}
29- nullable : Optional [bool ] = False
30- omit : Optional [bool ] = False
31- baseColumn : Optional [str ] = "id"
32- baseColumnType : Optional [str ] = "auto"
33-
34- @root_validator (skip_on_failure = True )
35- def check_model_constraints (cls , values : Dict [str , Any ]) -> Dict [str , Any ]:
36- """
37- Validates constraints across the entire model after individual fields are processed.
38- """
39- is_primary = values .get ("primary" )
40- options = values .get ("options" , {})
41- name = values .get ("name" )
42- is_nullable = values .get ("nullable" )
43- column_type = values .get ("type" )
7+ from IPython .display import HTML , display
448
45- if is_primary :
46- if "min" in options or "max" in options :
47- raise ValueError (f"Primary column '{ name } ' cannot have min/max options." )
9+ from dbldatagen .spec .column_spec import ColumnDefinition
4810
49- if is_nullable :
50- raise ValueError (f"Primary column '{ name } ' cannot be nullable." )
11+ from .compat import BaseModel , validator
5112
52- if column_type is None :
53- raise ValueError (f"Primary column '{ name } ' must have a type defined." )
54- return values
5513
14+ logger = logging .getLogger (__name__ )
5615
5716class UCSchemaTarget (BaseModel ):
5817 catalog : str
5918 schema_ : str
6019 output_format : str = "delta" # Default to delta for UC Schema
6120
62- @field_validator ("catalog" , "schema_" , mode = "after " )
63- def validate_identifiers (cls , v ): # noqa: N805, pylint: disable=no-self-argument
21+ @validator ("catalog" , "schema_" )
22+ def validate_identifiers (cls , v : str ) -> str :
6423 if not v .strip ():
6524 raise ValueError ("Identifier must be non-empty." )
6625 if not v .isidentifier ():
6726 logger .warning (
6827 f"'{ v } ' is not a basic Python identifier. Ensure validity for Unity Catalog." )
6928 return v .strip ()
7029
71- def __str__ (self ):
30+ def __str__ (self ) -> str :
7231 return f"{ self .catalog } .{ self .schema_ } (Format: { self .output_format } , Type: UC Table)"
7332
7433
7534class FilePathTarget (BaseModel ):
7635 base_path : str
7736 output_format : Literal ["csv" , "parquet" ] # No default, must be specified
7837
79- @field_validator ("base_path" , mode = "after " )
80- def validate_base_path (cls , v ): # noqa: N805, pylint: disable=no-self-argument
38+ @validator ("base_path" )
39+ def validate_base_path (cls , v : str ) -> str :
8140 if not v .strip ():
8241 raise ValueError ("base_path must be non-empty." )
8342 return v .strip ()
8443
85- def __str__ (self ):
44+ def __str__ (self ) -> str :
8645 return f"{ self .base_path } (Format: { self .output_format } , Type: File Path)"
8746
8847
8948class TableDefinition (BaseModel ):
9049 number_of_rows : int
91- partitions : Optional [ int ] = None
92- columns : List [ColumnDefinition ]
50+ partitions : int | None = None
51+ columns : list [ColumnDefinition ]
9352
9453
9554class ValidationResult :
9655 """Container for validation results with errors and warnings."""
9756
9857 def __init__ (self ) -> None :
99- self .errors : List [str ] = []
100- self .warnings : List [str ] = []
58+ self .errors : list [str ] = []
59+ self .warnings : list [str ] = []
10160
10261 def add_error (self , message : str ) -> None :
10362 """Add an error message."""
@@ -132,16 +91,16 @@ def __str__(self) -> str:
13291 return "\n " .join (lines )
13392
13493class DatagenSpec (BaseModel ):
135- tables : Dict [str , TableDefinition ]
136- output_destination : Optional [ Union [UCSchemaTarget , FilePathTarget ]] = None # there is a abstraction, may be we can use that? talk to Greg
137- generator_options : Optional [ Dict [ str , Any ]] = {}
138- intended_for_databricks : Optional [ bool ] = None # May be infered.
94+ tables : dict [str , TableDefinition ]
95+ output_destination : Union [UCSchemaTarget , FilePathTarget ] | None = None # there is a abstraction, may be we can use that? talk to Greg
96+ generator_options : dict [ str , Any ] | None = None
97+ intended_for_databricks : bool | None = None # May be infered.
13998
14099 def _check_circular_dependencies (
141100 self ,
142101 table_name : str ,
143- columns : List [ColumnDefinition ]
144- ) -> List [str ]:
102+ columns : list [ColumnDefinition ]
103+ ) -> list [str ]:
145104 """
146105 Check for circular dependencies in baseColumn references.
147106 Returns a list of error messages if circular dependencies are found.
@@ -152,13 +111,13 @@ def _check_circular_dependencies(
152111 for col in columns :
153112 if col .baseColumn and col .baseColumn != "id" :
154113 # Track the dependency chain
155- visited = set ()
114+ visited : set [ str ] = set ()
156115 current = col .name
157116
158117 while current :
159118 if current in visited :
160119 # Found a cycle
161- cycle_path = " -> " .join (list (visited ) + [ current ])
120+ cycle_path = " -> " .join ([ * list (visited ), current ])
162121 errors .append (
163122 f"Table '{ table_name } ': Circular dependency detected in column '{ col .name } ': { cycle_path } "
164123 )
@@ -182,7 +141,7 @@ def _check_circular_dependencies(
182141
183142 return errors
184143
185- def validate (self , strict : bool = True ) -> ValidationResult :
144+ def validate (self , strict : bool = True ) -> ValidationResult : # type: ignore[override]
186145 """
187146 Validates the entire DatagenSpec configuration.
188147 Always runs all validation checks and collects all errors and warnings.
@@ -284,17 +243,15 @@ def validate(self, strict: bool = True) -> ValidationResult:
284243 "random" , "randomSeed" , "randomSeedMethod" , "verbose" ,
285244 "debug" , "seedColumnName"
286245 ]
287- for key in self .generator_options . keys () :
246+ for key in self .generator_options :
288247 if key not in known_options :
289248 result .add_warning (
290249 f"Unknown generator option: '{ key } '. "
291250 "This may be ignored during generation."
292251 )
293252
294253 # Now that all validations are complete, decide whether to raise
295- if strict and (result .errors or result .warnings ):
296- raise ValueError (str (result ))
297- elif not strict and result .errors :
254+ if (strict and (result .errors or result .warnings )) or (not strict and result .errors ):
298255 raise ValueError (str (result ))
299256
300257 return result
0 commit comments