1- from .dataset_provider import DatasetProvider , dataset_definition
1+ import warnings as w
2+ from typing import Any , ClassVar
3+
4+ from pyspark .sql import SparkSession
5+
6+ import dbldatagen as dg
7+ from dbldatagen .data_generator import DataGenerator
8+ from dbldatagen .datasets .dataset_provider import DatasetProvider , dataset_definition
29
310
411@dataset_definition (name = "basic/telematics" ,
@@ -24,7 +31,7 @@ class BasicTelematicsProvider(DatasetProvider.NoAssociatedDatasetsMixin, Dataset
2431 - minLon: minimum longitude
2532 - maxLon: maximum longitude
2633 - generateWKT: if `True`, generates the well-known text representation of the location
27-
34+
2835 As the data specification is a DataGenerator object, you can add further columns to the data set and
2936 add constraints (when the feature is available)
3037
@@ -42,7 +49,7 @@ class BasicTelematicsProvider(DatasetProvider.NoAssociatedDatasetsMixin, Dataset
4249 DEFAULT_MIN_LON = - 180.0
4350 DEFAULT_MAX_LON = 180.0
4451 COLUMN_COUNT = 6
45- ALLOWED_OPTIONS = [
52+ ALLOWED_OPTIONS : ClassVar [ list [ str ]] = [
4653 "numDevices" ,
4754 "startTimestamp" ,
4855 "endTimestamp" ,
@@ -55,10 +62,7 @@ class BasicTelematicsProvider(DatasetProvider.NoAssociatedDatasetsMixin, Dataset
5562 ]
5663
5764 @DatasetProvider .allowed_options (options = ALLOWED_OPTIONS )
58- def getTableGenerator (self , sparkSession , * , tableName = None , rows = - 1 , partitions = - 1 ,
59- ** options ):
60- import dbldatagen as dg
61- import warnings as w
65+ def getTableGenerator (self , sparkSession : SparkSession , * , tableName : str | None = None , rows : int = - 1 , partitions : int = - 1 , ** options : dict [str , Any ]) -> DataGenerator :
6266
6367 generateRandom = options .get ("random" , False )
6468 numDevices = options .get ("numDevices" , self .DEFAULT_NUM_DEVICES )
@@ -77,52 +81,52 @@ def getTableGenerator(self, sparkSession, *, tableName=None, rows=-1, partitions
7781 partitions = self .autoComputePartitions (rows , self .COLUMN_COUNT )
7882 if minLat < - 90.0 :
7983 minLat = - 90.0
80- w .warn ("Received an invalid minLat value; Setting to -90.0" )
84+ w .warn ("Received an invalid minLat value; Setting to -90.0" , stacklevel = 2 )
8185 if minLat > 90.0 :
8286 minLat = 89.0
83- w .warn ("Recieved an invalid minLat value; Setting to 89.0" )
87+ w .warn ("Recieved an invalid minLat value; Setting to 89.0" , stacklevel = 2 )
8488 if maxLat < - 90 :
8589 maxLat = - 89.0
86- w .warn ("Recieved an invalid maxLat value; Setting to -89.0" )
90+ w .warn ("Recieved an invalid maxLat value; Setting to -89.0" , stacklevel = 2 )
8791 if maxLat > 90.0 :
8892 maxLat = 90.0
89- w .warn ("Received an invalid maxLat value; Setting to 90.0" )
93+ w .warn ("Received an invalid maxLat value; Setting to 90.0" , stacklevel = 2 )
9094 if minLon < - 180.0 :
9195 minLon = - 180.0
92- w .warn ("Received an invalid minLon value; Setting to -180.0" )
96+ w .warn ("Received an invalid minLon value; Setting to -180.0" , stacklevel = 2 )
9397 if minLon > 180.0 :
9498 minLon = 179.0
95- w .warn ("Received an invalid minLon value; Setting to 179.0" )
99+ w .warn ("Received an invalid minLon value; Setting to 179.0" , stacklevel = 2 )
96100 if maxLon < - 180.0 :
97101 maxLon = - 179.0
98- w .warn ("Received an invalid maxLon value; Setting to -179.0" )
102+ w .warn ("Received an invalid maxLon value; Setting to -179.0" , stacklevel = 2 )
99103 if maxLon > 180.0 :
100104 maxLon = 180.0
101- w .warn ("Received an invalid maxLon value; Setting to 180.0" )
105+ w .warn ("Received an invalid maxLon value; Setting to 180.0" , stacklevel = 2 )
102106 if minLon > maxLon :
103107 (minLon , maxLon ) = (maxLon , minLon )
104- w .warn ("Received minLon > maxLon; Swapping values" )
108+ w .warn ("Received minLon > maxLon; Swapping values" , stacklevel = 2 )
105109 if minLat > maxLat :
106110 (minLat , maxLat ) = (maxLat , minLat )
107- w .warn ("Received minLat > maxLat; Swapping values" )
111+ w .warn ("Received minLat > maxLat; Swapping values" , stacklevel = 2 )
108112 df_spec = (
109113 dg .DataGenerator (sparkSession = sparkSession , rows = rows ,
110114 partitions = partitions , randomSeedMethod = "hash_fieldname" )
111- .withColumn ("device_id" , "long" , minValue = self .MIN_DEVICE_ID , maxValue = self .MAX_DEVICE_ID ,
115+ .withColumn ("device_id" , "long" , minValue = self .MIN_DEVICE_ID , maxValue = self .MAX_DEVICE_ID ,
112116 uniqueValues = numDevices , random = generateRandom )
113- .withColumn ("ts" , "timestamp" , begin = startTimestamp , end = endTimestamp ,
117+ .withColumn ("ts" , "timestamp" , begin = startTimestamp , end = endTimestamp ,
114118 interval = "1 second" , random = generateRandom )
115119 .withColumn ("base_lat" , "float" , minValue = minLat , maxValue = maxLat , step = 0.5 ,
116- baseColumn = ' device_id' , omit = True )
120+ baseColumn = " device_id" , omit = True )
117121 .withColumn ("base_lon" , "float" , minValue = minLon , maxValue = maxLon , step = 0.5 ,
118- baseColumn = ' device_id' , omit = True )
122+ baseColumn = " device_id" , omit = True )
119123 .withColumn ("unv_lat" , "float" , expr = "base_lat + (0.5-format_number(rand(), 3))*1e-3" , omit = True )
120124 .withColumn ("unv_lon" , "float" , expr = "base_lon + (0.5-format_number(rand(), 3))*1e-3" , omit = True )
121- .withColumn ("lat" , "float" , expr = f"""CASE WHEN unv_lat > { maxLat } THEN { maxLat }
122- ELSE CASE WHEN unv_lat < { minLat } THEN { minLat }
125+ .withColumn ("lat" , "float" , expr = f"""CASE WHEN unv_lat > { maxLat } THEN { maxLat }
126+ ELSE CASE WHEN unv_lat < { minLat } THEN { minLat }
123127 ELSE unv_lat END END""" )
124- .withColumn ("lon" , "float" , expr = f"""CASE WHEN unv_lon > { maxLon } THEN { maxLon }
125- ELSE CASE WHEN unv_lon < { minLon } THEN { minLon }
128+ .withColumn ("lon" , "float" , expr = f"""CASE WHEN unv_lon > { maxLon } THEN { maxLon }
129+ ELSE CASE WHEN unv_lon < { minLon } THEN { minLon }
126130 ELSE unv_lon END END""" )
127131 .withColumn ("heading" , "integer" , minValue = 0 , maxValue = 359 , step = 1 , random = generateRandom )
128132 .withColumn ("wkt" , "string" , expr = "concat('POINT(', lon, ' ', lat, ')')" , omit = not generateWkt )
0 commit comments