Skip to content

Commit 3aca7f3

Browse files
authored
Linting datasets module & upgrading ruff (#351)
* linting datasets module & upgrading ruff
1 parent 7833022 commit 3aca7f3

13 files changed

+222
-198
lines changed

dbldatagen/datasets/__init__.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,32 @@
1-
from .dataset_provider import DatasetProvider, dataset_definition
21
from .basic_geometries import BasicGeometriesProvider
32
from .basic_process_historian import BasicProcessHistorianProvider
43
from .basic_stock_ticker import BasicStockTickerProvider
54
from .basic_telematics import BasicTelematicsProvider
65
from .basic_user import BasicUserProvider
76
from .benchmark_groupby import BenchmarkGroupByProvider
7+
from .dataset_provider import DatasetProvider, dataset_definition
88
from .multi_table_sales_order_provider import MultiTableSalesOrderProvider
99
from .multi_table_telephony_provider import MultiTableTelephonyProvider
1010

11-
__all__ = ["dataset_provider",
11+
12+
__all__ = [
13+
"BasicGeometriesProvider",
14+
"BasicProcessHistorianProvider",
15+
"BasicStockTickerProvider",
16+
"BasicTelematicsProvider",
17+
"BasicUserProvider",
18+
"BenchmarkGroupByProvider",
19+
"DatasetProvider",
20+
"MultiTableSalesOrderProvider",
21+
"MultiTableTelephonyProvider",
1222
"basic_geometries",
1323
"basic_process_historian",
1424
"basic_stock_ticker",
1525
"basic_telematics",
1626
"basic_user",
1727
"benchmark_groupby",
28+
"dataset_definition",
29+
"dataset_provider",
1830
"multi_table_sales_order_provider",
1931
"multi_table_telephony_provider"
2032
]

dbldatagen/datasets/basic_geometries.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,11 @@
1-
from .dataset_provider import DatasetProvider, dataset_definition
1+
import warnings as w
2+
from typing import Any, ClassVar
3+
4+
from pyspark.sql import SparkSession
5+
6+
import dbldatagen as dg
7+
from dbldatagen.data_generator import DataGenerator
8+
from dbldatagen.datasets.dataset_provider import DatasetProvider, dataset_definition
29

310

411
@dataset_definition(name="basic/geometries",
@@ -34,7 +41,7 @@ class BasicGeometriesProvider(DatasetProvider.NoAssociatedDatasetsMixin, Dataset
3441
DEFAULT_MIN_LON = -180.0
3542
DEFAULT_MAX_LON = 180.0
3643
COLUMN_COUNT = 2
37-
ALLOWED_OPTIONS = [
44+
ALLOWED_OPTIONS: ClassVar[list[str]] = [
3845
"geometryType",
3946
"maxVertices",
4047
"minLatitude",
@@ -45,11 +52,7 @@ class BasicGeometriesProvider(DatasetProvider.NoAssociatedDatasetsMixin, Dataset
4552
]
4653

4754
@DatasetProvider.allowed_options(options=ALLOWED_OPTIONS)
48-
def getTableGenerator(self, sparkSession, *, tableName=None, rows=-1, partitions=-1,
49-
**options):
50-
import dbldatagen as dg
51-
import warnings as w
52-
55+
def getTableGenerator(self, sparkSession: SparkSession, *, tableName: str|None=None, rows: int=-1, partitions: int=-1, **options: dict[str, Any]) -> DataGenerator:
5356
generateRandom = options.get("random", False)
5457
geometryType = options.get("geometryType", "point")
5558
maxVertices = options.get("maxVertices", 1 if geometryType == "point" else 3)
@@ -72,7 +75,7 @@ def getTableGenerator(self, sparkSession, *, tableName=None, rows=-1, partitions
7275
)
7376
if geometryType == "point":
7477
if maxVertices > 1:
75-
w.warn('Ignoring property maxVertices for point geometries')
78+
w.warn("Ignoring property maxVertices for point geometries", stacklevel=2)
7679
df_spec = (
7780
df_spec.withColumn("lat", "float", minValue=minLatitude, maxValue=maxLatitude,
7881
step=1e-5, random=generateRandom, omit=True)
@@ -83,7 +86,7 @@ def getTableGenerator(self, sparkSession, *, tableName=None, rows=-1, partitions
8386
elif geometryType == "lineString":
8487
if maxVertices < 2:
8588
maxVertices = 2
86-
w.warn("Parameter maxVertices must be >=2 for 'lineString' geometries; Setting to 2")
89+
w.warn("Parameter maxVertices must be >=2 for 'lineString' geometries; Setting to 2", stacklevel=2)
8790
j = 0
8891
while j < maxVertices:
8992
df_spec = (
@@ -101,7 +104,7 @@ def getTableGenerator(self, sparkSession, *, tableName=None, rows=-1, partitions
101104
elif geometryType == "polygon":
102105
if maxVertices < 3:
103106
maxVertices = 3
104-
w.warn("Parameter maxVertices must be >=3 for 'polygon' geometries; Setting to 3")
107+
w.warn("Parameter maxVertices must be >=3 for 'polygon' geometries; Setting to 3", stacklevel=2)
105108
j = 0
106109
while j < maxVertices:
107110
df_spec = (
@@ -111,7 +114,7 @@ def getTableGenerator(self, sparkSession, *, tableName=None, rows=-1, partitions
111114
step=1e-5, random=generateRandom, omit=True)
112115
)
113116
j = j + 1
114-
vertexIndices = list(range(maxVertices)) + [0]
117+
vertexIndices = [*list(range(maxVertices)), 0]
115118
concatCoordinatesExpr = [f"concat(lon_{j}, ' ', lat_{j}, ', ')" for j in vertexIndices]
116119
concatPairsExpr = f"replace(concat('POLYGON(', {', '.join(concatCoordinatesExpr)}, ')'), ', )', ')')"
117120
df_spec = (

dbldatagen/datasets/basic_process_historian.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,11 @@
1-
from .dataset_provider import DatasetProvider, dataset_definition
1+
from typing import Any, ClassVar
2+
3+
import numpy as np
4+
from pyspark.sql import SparkSession
5+
6+
import dbldatagen as dg
7+
from dbldatagen.data_generator import DataGenerator
8+
from dbldatagen.datasets.dataset_provider import DatasetProvider, dataset_definition
29

310

411
@dataset_definition(name="basic/process_historian",
@@ -43,28 +50,27 @@ class BasicProcessHistorianProvider(DatasetProvider.NoAssociatedDatasetsMixin, D
4350
DEFAULT_START_TIMESTAMP = "2024-01-01 00:00:00"
4451
DEFAULT_END_TIMESTAMP = "2024-02-01 00:00:00"
4552
COLUMN_COUNT = 10
46-
ALLOWED_OPTIONS = [
53+
ALLOWED_OPTIONS: ClassVar[list[str]] = [
4754
"numDevices",
4855
"numPlants",
49-
"numTags",
50-
"startTimestamp",
51-
"endTimestamp",
56+
"numTags",
57+
"startTimestamp",
58+
"endTimestamp",
5259
"dataQualityRatios",
5360
"random"
5461
]
5562

5663
@DatasetProvider.allowed_options(options=ALLOWED_OPTIONS)
57-
def getTableGenerator(self, sparkSession, *, tableName=None, rows=-1, partitions=-1, **options):
58-
import dbldatagen as dg # import locally to avoid circular imports
59-
import numpy as np
64+
def getTableGenerator(self, sparkSession: SparkSession, *, tableName: str|None=None, rows: int=-1, partitions: int=-1, **options: dict[str, Any]) -> DataGenerator:
65+
6066

6167
generateRandom = options.get("random", False)
6268
numDevices = options.get("numDevices", self.DEFAULT_NUM_DEVICES)
6369
numPlants = options.get("numPlants", self.DEFAULT_NUM_PLANTS)
6470
numTags = options.get("numTags", self.DEFAULT_NUM_TAGS)
6571
startTimestamp = options.get("startTimestamp", self.DEFAULT_START_TIMESTAMP)
6672
endTimestamp = options.get("endTimestamp", self.DEFAULT_END_TIMESTAMP)
67-
dataQualityRatios = options.get("dataQualityRatios", None)
73+
dataQualityRatios = options.get("dataQualityRatios")
6874

6975
assert tableName is None or tableName == DatasetProvider.DEFAULT_TABLE_NAME, "Invalid table name"
7076
if rows is None or rows < 0:
@@ -83,7 +89,7 @@ def getTableGenerator(self, sparkSession, *, tableName=None, rows=-1, partitions
8389
.withColumn("device_id", "string", format="0x%09x", baseColumn="internal_device_id")
8490
.withColumn("plant_id", "string", values=plant_ids, baseColumn="internal_device_id")
8591
.withColumn("tag_name", "string", values=tag_names, baseColumn="internal_device_id")
86-
.withColumn("ts", "timestamp", begin=startTimestamp, end=endTimestamp,
92+
.withColumn("ts", "timestamp", begin=startTimestamp, end=endTimestamp,
8793
interval="1 second", random=generateRandom)
8894
.withColumn("value", "float", minValue=self.MIN_PROPERTY_VALUE, maxValue=self.MAX_PROPERTY_VALUE,
8995
step=1e-3, random=generateRandom)

dbldatagen/datasets/basic_stock_ticker.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
from random import random
2+
from typing import ClassVar
23

3-
from .dataset_provider import DatasetProvider, dataset_definition
4+
from pyspark.sql import SparkSession
5+
6+
import dbldatagen as dg
7+
from dbldatagen.data_generator import DataGenerator
8+
from dbldatagen.datasets.dataset_provider import DatasetProvider, dataset_definition
49

510

611
@dataset_definition(name="basic/stock_ticker",
@@ -21,7 +26,6 @@ class BasicStockTickerProvider(DatasetProvider.NoAssociatedDatasetsMixin, Datase
2126
- numSymbols: number of unique stock ticker symbols
2227
- startDate: first date for stock ticker data
2328
- endDate: last date for stock ticker data
24-
2529
As the data specification is a DataGenerator object, you can add further columns to the data set and
2630
add constraints (when the feature is available)
2731
@@ -32,14 +36,13 @@ class BasicStockTickerProvider(DatasetProvider.NoAssociatedDatasetsMixin, Datase
3236
DEFAULT_NUM_SYMBOLS = 100
3337
DEFAULT_START_DATE = "2024-10-01"
3438
COLUMN_COUNT = 8
35-
ALLOWED_OPTIONS = [
39+
ALLOWED_OPTIONS: ClassVar[list[str]] = [
3640
"numSymbols",
3741
"startDate"
3842
]
3943

4044
@DatasetProvider.allowed_options(options=ALLOWED_OPTIONS)
41-
def getTableGenerator(self, sparkSession, *, tableName=None, rows=-1, partitions=-1, **options):
42-
import dbldatagen as dg
45+
def getTableGenerator(self, sparkSession: SparkSession, *, tableName: str|None=None, rows: int=-1, partitions: int=-1, **options: object) -> DataGenerator:
4346

4447
numSymbols = options.get("numSymbols", self.DEFAULT_NUM_SYMBOLS)
4548
startDate = options.get("startDate", self.DEFAULT_START_DATE)
@@ -59,7 +62,7 @@ def getTableGenerator(self, sparkSession, *, tableName=None, rows=-1, partitions
5962
.withColumn("rand_value", "float", minValue=0.0, maxValue=1.0, step=0.1,
6063
baseColumn="symbol_id", omit=True)
6164
.withColumn("symbol", "string",
62-
expr="""concat_ws('', transform(split(conv(symbol_id, 10, 26), ''),
65+
expr="""concat_ws('', transform(split(conv(symbol_id, 10, 26), ''),
6366
x -> case when ascii(x) < 10 then char(ascii(x) - 48 + 65) else char(ascii(x) + 10) end))""")
6467
.withColumn("days_from_start_date", "int", expr=f"floor(try_divide(id, {numSymbols}))", omit=True)
6568
.withColumn("post_date", "date", expr=f"date_add(cast('{startDate}' as date), days_from_start_date)")
@@ -76,13 +79,13 @@ def getTableGenerator(self, sparkSession, *, tableName=None, rows=-1, partitions
7679
expr="case when sin(id % 17) > 0 then -1.0 else 1.0 end",
7780
omit=True)
7881
.withColumn("open_base", "decimal(11,2)",
79-
expr=f"""start_value
80-
+ (volatility * prev_modifier_sign * start_value * sin((id - {numSymbols}) % 17))
82+
expr=f"""start_value
83+
+ (volatility * prev_modifier_sign * start_value * sin((id - {numSymbols}) % 17))
8184
+ (growth_rate * start_value * try_divide(days_from_start_date - 1, 365))""",
8285
omit=True)
8386
.withColumn("close_base", "decimal(11,2)",
84-
expr="""start_value
85-
+ (volatility * start_value * sin(id % 17))
87+
expr="""start_value
88+
+ (volatility * start_value * sin(id % 17))
8689
+ (growth_rate * start_value * try_divide(days_from_start_date, 365))""",
8790
omit=True)
8891
.withColumn("high_base", "decimal(11,2)",

dbldatagen/datasets/basic_telematics.py

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,11 @@
1-
from .dataset_provider import DatasetProvider, dataset_definition
1+
import warnings as w
2+
from typing import Any, ClassVar
3+
4+
from pyspark.sql import SparkSession
5+
6+
import dbldatagen as dg
7+
from dbldatagen.data_generator import DataGenerator
8+
from dbldatagen.datasets.dataset_provider import DatasetProvider, dataset_definition
29

310

411
@dataset_definition(name="basic/telematics",
@@ -24,7 +31,7 @@ class BasicTelematicsProvider(DatasetProvider.NoAssociatedDatasetsMixin, Dataset
2431
- minLon: minimum longitude
2532
- maxLon: maximum longitude
2633
- generateWKT: if `True`, generates the well-known text representation of the location
27-
34+
2835
As the data specification is a DataGenerator object, you can add further columns to the data set and
2936
add constraints (when the feature is available)
3037
@@ -42,7 +49,7 @@ class BasicTelematicsProvider(DatasetProvider.NoAssociatedDatasetsMixin, Dataset
4249
DEFAULT_MIN_LON = -180.0
4350
DEFAULT_MAX_LON = 180.0
4451
COLUMN_COUNT = 6
45-
ALLOWED_OPTIONS = [
52+
ALLOWED_OPTIONS: ClassVar[list[str]] = [
4653
"numDevices",
4754
"startTimestamp",
4855
"endTimestamp",
@@ -55,10 +62,7 @@ class BasicTelematicsProvider(DatasetProvider.NoAssociatedDatasetsMixin, Dataset
5562
]
5663

5764
@DatasetProvider.allowed_options(options=ALLOWED_OPTIONS)
58-
def getTableGenerator(self, sparkSession, *, tableName=None, rows=-1, partitions=-1,
59-
**options):
60-
import dbldatagen as dg
61-
import warnings as w
65+
def getTableGenerator(self, sparkSession: SparkSession, *, tableName: str|None=None, rows: int=-1, partitions: int=-1, **options: dict[str, Any]) -> DataGenerator:
6266

6367
generateRandom = options.get("random", False)
6468
numDevices = options.get("numDevices", self.DEFAULT_NUM_DEVICES)
@@ -77,52 +81,52 @@ def getTableGenerator(self, sparkSession, *, tableName=None, rows=-1, partitions
7781
partitions = self.autoComputePartitions(rows, self.COLUMN_COUNT)
7882
if minLat < -90.0:
7983
minLat = -90.0
80-
w.warn("Received an invalid minLat value; Setting to -90.0")
84+
w.warn("Received an invalid minLat value; Setting to -90.0", stacklevel=2)
8185
if minLat > 90.0:
8286
minLat = 89.0
83-
w.warn("Recieved an invalid minLat value; Setting to 89.0")
87+
w.warn("Recieved an invalid minLat value; Setting to 89.0", stacklevel=2)
8488
if maxLat < -90:
8589
maxLat = -89.0
86-
w.warn("Recieved an invalid maxLat value; Setting to -89.0")
90+
w.warn("Recieved an invalid maxLat value; Setting to -89.0", stacklevel=2)
8791
if maxLat > 90.0:
8892
maxLat = 90.0
89-
w.warn("Received an invalid maxLat value; Setting to 90.0")
93+
w.warn("Received an invalid maxLat value; Setting to 90.0", stacklevel=2)
9094
if minLon < -180.0:
9195
minLon = -180.0
92-
w.warn("Received an invalid minLon value; Setting to -180.0")
96+
w.warn("Received an invalid minLon value; Setting to -180.0", stacklevel=2)
9397
if minLon > 180.0:
9498
minLon = 179.0
95-
w.warn("Received an invalid minLon value; Setting to 179.0")
99+
w.warn("Received an invalid minLon value; Setting to 179.0", stacklevel=2)
96100
if maxLon < -180.0:
97101
maxLon = -179.0
98-
w.warn("Received an invalid maxLon value; Setting to -179.0")
102+
w.warn("Received an invalid maxLon value; Setting to -179.0", stacklevel=2)
99103
if maxLon > 180.0:
100104
maxLon = 180.0
101-
w.warn("Received an invalid maxLon value; Setting to 180.0")
105+
w.warn("Received an invalid maxLon value; Setting to 180.0", stacklevel=2)
102106
if minLon > maxLon:
103107
(minLon, maxLon) = (maxLon, minLon)
104-
w.warn("Received minLon > maxLon; Swapping values")
108+
w.warn("Received minLon > maxLon; Swapping values", stacklevel=2)
105109
if minLat > maxLat:
106110
(minLat, maxLat) = (maxLat, minLat)
107-
w.warn("Received minLat > maxLat; Swapping values")
111+
w.warn("Received minLat > maxLat; Swapping values", stacklevel=2)
108112
df_spec = (
109113
dg.DataGenerator(sparkSession=sparkSession, rows=rows,
110114
partitions=partitions, randomSeedMethod="hash_fieldname")
111-
.withColumn("device_id", "long", minValue=self.MIN_DEVICE_ID, maxValue=self.MAX_DEVICE_ID,
115+
.withColumn("device_id", "long", minValue=self.MIN_DEVICE_ID, maxValue=self.MAX_DEVICE_ID,
112116
uniqueValues=numDevices, random=generateRandom)
113-
.withColumn("ts", "timestamp", begin=startTimestamp, end=endTimestamp,
117+
.withColumn("ts", "timestamp", begin=startTimestamp, end=endTimestamp,
114118
interval="1 second", random=generateRandom)
115119
.withColumn("base_lat", "float", minValue=minLat, maxValue=maxLat, step=0.5,
116-
baseColumn='device_id', omit=True)
120+
baseColumn="device_id", omit=True)
117121
.withColumn("base_lon", "float", minValue=minLon, maxValue=maxLon, step=0.5,
118-
baseColumn='device_id', omit=True)
122+
baseColumn="device_id", omit=True)
119123
.withColumn("unv_lat", "float", expr="base_lat + (0.5-format_number(rand(), 3))*1e-3", omit=True)
120124
.withColumn("unv_lon", "float", expr="base_lon + (0.5-format_number(rand(), 3))*1e-3", omit=True)
121-
.withColumn("lat", "float", expr=f"""CASE WHEN unv_lat > {maxLat} THEN {maxLat}
122-
ELSE CASE WHEN unv_lat < {minLat} THEN {minLat}
125+
.withColumn("lat", "float", expr=f"""CASE WHEN unv_lat > {maxLat} THEN {maxLat}
126+
ELSE CASE WHEN unv_lat < {minLat} THEN {minLat}
123127
ELSE unv_lat END END""")
124-
.withColumn("lon", "float", expr=f"""CASE WHEN unv_lon > {maxLon} THEN {maxLon}
125-
ELSE CASE WHEN unv_lon < {minLon} THEN {minLon}
128+
.withColumn("lon", "float", expr=f"""CASE WHEN unv_lon > {maxLon} THEN {maxLon}
129+
ELSE CASE WHEN unv_lon < {minLon} THEN {minLon}
126130
ELSE unv_lon END END""")
127131
.withColumn("heading", "integer", minValue=0, maxValue=359, step=1, random=generateRandom)
128132
.withColumn("wkt", "string", expr="concat('POINT(', lon, ' ', lat, ')')", omit=not generateWkt)

dbldatagen/datasets/basic_user.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
1-
from .dataset_provider import DatasetProvider, dataset_definition
1+
from typing import Any
2+
3+
from pyspark.sql import SparkSession
4+
5+
import dbldatagen as dg
6+
from dbldatagen.data_generator import DataGenerator
7+
from dbldatagen.datasets.dataset_provider import DatasetProvider, dataset_definition
28

39

410
@dataset_definition(name="basic/user", summary="Basic User Data Set", autoRegister=True, supportsStreaming=True)
@@ -27,10 +33,7 @@ class BasicUserProvider(DatasetProvider.NoAssociatedDatasetsMixin, DatasetProvid
2733
COLUMN_COUNT = 5
2834

2935
@DatasetProvider.allowed_options(options=["random", "dummyValues"])
30-
def getTableGenerator(self, sparkSession, *, tableName=None, rows=-1, partitions=-1,
31-
**options):
32-
import dbldatagen as dg
33-
36+
def getTableGenerator(self, sparkSession: SparkSession, *, tableName: str|None=None, rows: int=-1, partitions: int=-1, **options: dict[str, Any]) -> DataGenerator:
3437
generateRandom = options.get("random", False)
3538
dummyValues = options.get("dummyValues", 0)
3639

@@ -47,13 +50,13 @@ def getTableGenerator(self, sparkSession, *, tableName=None, rows=-1, partitions
4750
randomSeedMethod="hash_fieldname")
4851
.withColumn("customer_id", "long", minValue=1000000, maxValue=self.MAX_LONG, random=generateRandom)
4952
.withColumn("name", "string",
50-
template=r'\w \w|\w \w \w', random=generateRandom)
53+
template=r"\w \w|\w \w \w", random=generateRandom)
5154
.withColumn("email", "string",
52-
template=r'\w.\w@\w.com|\w@\w.co.u\k', random=generateRandom)
55+
template=r"\w.\w@\w.com|\w@\w.co.u\k", random=generateRandom)
5356
.withColumn("ip_addr", "string",
54-
template=r'\n.\n.\n.\n', random=generateRandom)
57+
template=r"\n.\n.\n.\n", random=generateRandom)
5558
.withColumn("phone", "string",
56-
template=r'(ddd)-ddd-dddd|1(ddd) ddd-dddd|ddd ddddddd',
59+
template=r"(ddd)-ddd-dddd|1(ddd) ddd-dddd|ddd ddddddd",
5760
random=generateRandom)
5861
)
5962

0 commit comments

Comments
 (0)