Skip to content

Commit 43572b3

Browse files
Add Literal typing for mode and projection_types (#2191)
1 parent 8c7f50f commit 43572b3

File tree

13 files changed

+93
-50
lines changed

13 files changed

+93
-50
lines changed

awswrangler/catalog/_create.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""AWS Glue Catalog Module."""
22

33
import logging
4-
from typing import TYPE_CHECKING, Any, Dict, Optional
4+
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional
55

66
import boto3
77

@@ -679,7 +679,7 @@ def create_parquet_table(
679679
description: Optional[str] = None,
680680
parameters: Optional[Dict[str, str]] = None,
681681
columns_comments: Optional[Dict[str, str]] = None,
682-
mode: str = "overwrite",
682+
mode: Literal["overwrite", "append"] = "overwrite",
683683
catalog_versioning: bool = False,
684684
transaction_id: Optional[str] = None,
685685
athena_partition_projection_settings: Optional[typing.AthenaPartitionProjectionSettings] = None,
@@ -840,7 +840,7 @@ def create_csv_table( # pylint: disable=too-many-arguments,too-many-locals
840840
description: Optional[str] = None,
841841
parameters: Optional[Dict[str, str]] = None,
842842
columns_comments: Optional[Dict[str, str]] = None,
843-
mode: str = "overwrite",
843+
mode: Literal["overwrite", "append"] = "overwrite",
844844
catalog_versioning: bool = False,
845845
schema_evolution: bool = False,
846846
sep: str = ",",
@@ -1033,7 +1033,7 @@ def create_json_table( # pylint: disable=too-many-arguments
10331033
description: Optional[str] = None,
10341034
parameters: Optional[Dict[str, str]] = None,
10351035
columns_comments: Optional[Dict[str, str]] = None,
1036-
mode: str = "overwrite",
1036+
mode: Literal["overwrite", "append"] = "overwrite",
10371037
catalog_versioning: bool = False,
10381038
schema_evolution: bool = False,
10391039
serde_library: Optional[str] = None,

awswrangler/data_quality/_create.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import logging
44
import pprint
55
import uuid
6-
from typing import Any, Dict, List, Optional, Union, cast
6+
from typing import Any, Dict, List, Literal, Optional, Union, cast
77

88
import boto3
99

@@ -121,7 +121,7 @@ def create_ruleset(
121121
@apply_configs
122122
def update_ruleset(
123123
name: str,
124-
mode: str = "overwrite",
124+
mode: Literal["overwrite", "upsert"] = "overwrite",
125125
df_rules: Optional[pd.DataFrame] = None,
126126
dqdl_rules: Optional[str] = None,
127127
description: str = "",

awswrangler/emr.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import logging
55
import pprint
6-
from typing import Any, Dict, List, Optional, Union, cast
6+
from typing import Any, Dict, List, Literal, Optional, Union, cast
77

88
import boto3
99

@@ -12,6 +12,9 @@
1212
_logger: logging.Logger = logging.getLogger(__name__)
1313

1414

15+
_ActionOnFailureLiteral = Literal["TERMINATE_JOB_FLOW", "TERMINATE_CLUSTER", "CANCEL_AND_WAIT", "CONTINUE"]
16+
17+
1518
def _get_ecr_credentials_refresh_content(region: str) -> str:
1619
return f"""
1720
import subprocess
@@ -815,7 +818,7 @@ def submit_step(
815818
cluster_id: str,
816819
command: str,
817820
name: str = "my-step",
818-
action_on_failure: str = "CONTINUE",
821+
action_on_failure: _ActionOnFailureLiteral = "CONTINUE",
819822
script: bool = False,
820823
boto3_session: Optional[boto3.Session] = None,
821824
) -> str:
@@ -865,7 +868,7 @@ def submit_step(
865868
def build_step(
866869
command: str,
867870
name: str = "my-step",
868-
action_on_failure: str = "CONTINUE",
871+
action_on_failure: _ActionOnFailureLiteral = "CONTINUE",
869872
script: bool = False,
870873
region: Optional[str] = None,
871874
boto3_session: Optional[boto3.Session] = None,
@@ -951,7 +954,10 @@ def get_step_state(cluster_id: str, step_id: str, boto3_session: Optional[boto3.
951954

952955

953956
def submit_ecr_credentials_refresh(
954-
cluster_id: str, path: str, action_on_failure: str = "CONTINUE", boto3_session: Optional[boto3.Session] = None
957+
cluster_id: str,
958+
path: str,
959+
action_on_failure: _ActionOnFailureLiteral = "CONTINUE",
960+
boto3_session: Optional[boto3.Session] = None,
955961
) -> str:
956962
"""Update internal ECR credentials.
957963
@@ -999,10 +1005,10 @@ def submit_ecr_credentials_refresh(
9991005
def build_spark_step(
10001006
path: str,
10011007
args: Optional[List[str]] = None,
1002-
deploy_mode: str = "cluster",
1008+
deploy_mode: Literal["cluster", "client"] = "cluster",
10031009
docker_image: Optional[str] = None,
10041010
name: str = "my-step",
1005-
action_on_failure: str = "CONTINUE",
1011+
action_on_failure: _ActionOnFailureLiteral = "CONTINUE",
10061012
region: Optional[str] = None,
10071013
boto3_session: Optional[boto3.Session] = None,
10081014
) -> Dict[str, Any]:
@@ -1074,10 +1080,10 @@ def submit_spark_step(
10741080
cluster_id: str,
10751081
path: str,
10761082
args: Optional[List[str]] = None,
1077-
deploy_mode: str = "cluster",
1083+
deploy_mode: Literal["cluster", "client"] = "cluster",
10781084
docker_image: Optional[str] = None,
10791085
name: str = "my-step",
1080-
action_on_failure: str = "CONTINUE",
1086+
action_on_failure: _ActionOnFailureLiteral = "CONTINUE",
10811087
region: Optional[str] = None,
10821088
boto3_session: Optional[boto3.Session] = None,
10831089
) -> str:

awswrangler/mysql.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import logging
55
import uuid
6-
from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union, overload
6+
from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Type, Union, cast, overload
77

88
import boto3
99
import pyarrow as pa
@@ -402,14 +402,19 @@ def read_sql_table(
402402
)
403403

404404

405+
_ToSqlModeLiteral = Literal[
406+
"append", "overwrite", "upsert_replace_into", "upsert_duplicate_key", "upsert_distinct", "ignore"
407+
]
408+
409+
405410
@_utils.check_optional_dependency(pymysql, "pymysql")
406411
@apply_configs
407412
def to_sql(
408413
df: pd.DataFrame,
409414
con: "pymysql.connections.Connection[Any]",
410415
table: str,
411416
schema: str,
412-
mode: str = "append",
417+
mode: _ToSqlModeLiteral = "append",
413418
index: bool = False,
414419
dtype: Optional[Dict[str, str]] = None,
415420
varchar_lengths: Optional[Dict[str, int]] = None,
@@ -430,7 +435,7 @@ def to_sql(
430435
schema : str
431436
Schema name
432437
mode : str
433-
Append, overwrite, upsert_duplicate_key, upsert_replace_into, upsert_distinct, ignore.
438+
append, overwrite, upsert_duplicate_key, upsert_replace_into, upsert_distinct, ignore.
434439
append: Inserts new records into table.
435440
overwrite: Drops table and recreates.
436441
upsert_duplicate_key: Performs an upsert using `ON DUPLICATE KEY` clause. Requires table schema to have
@@ -484,7 +489,7 @@ def to_sql(
484489
if df.empty is True:
485490
raise exceptions.EmptyDataFrame("DataFrame cannot be empty.")
486491

487-
mode = mode.strip().lower()
492+
mode = cast(_ToSqlModeLiteral, mode.strip().lower())
488493
allowed_modes = [
489494
"append",
490495
"overwrite",

awswrangler/oracle.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,19 @@
33

44
import logging
55
from decimal import Decimal
6-
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, TypeVar, Union, overload
6+
from typing import (
7+
Any,
8+
Callable,
9+
Dict,
10+
Iterator,
11+
List,
12+
Literal,
13+
Optional,
14+
Tuple,
15+
TypeVar,
16+
Union,
17+
overload,
18+
)
719

820
import boto3
921
import pyarrow as pa
@@ -404,7 +416,7 @@ def to_sql(
404416
con: "oracledb.Connection",
405417
table: str,
406418
schema: str,
407-
mode: str = "append",
419+
mode: Literal["append", "overwrite"] = "append",
408420
index: bool = False,
409421
dtype: Optional[Dict[str, str]] = None,
410422
varchar_lengths: Optional[Dict[str, int]] = None,

awswrangler/postgresql.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import logging
55
from ssl import SSLContext
6-
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, overload
6+
from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union, cast, overload
77

88
import boto3
99
import pyarrow as pa
@@ -393,14 +393,17 @@ def read_sql_table(
393393
)
394394

395395

396+
_ToSqlModeLiteral = Literal["append", "overwrite", "upsert"]
397+
398+
396399
@_utils.check_optional_dependency(pg8000, "pg8000")
397400
@apply_configs
398401
def to_sql(
399402
df: pd.DataFrame,
400403
con: "pg8000.Connection",
401404
table: str,
402405
schema: str,
403-
mode: str = "append",
406+
mode: _ToSqlModeLiteral = "append",
404407
index: bool = False,
405408
dtype: Optional[Dict[str, str]] = None,
406409
varchar_lengths: Optional[Dict[str, int]] = None,
@@ -472,7 +475,7 @@ def to_sql(
472475
if df.empty is True:
473476
raise exceptions.EmptyDataFrame("DataFrame cannot be empty.")
474477

475-
mode = mode.strip().lower()
478+
mode = cast(_ToSqlModeLiteral, mode.strip().lower())
476479
allowed_modes = ["append", "overwrite", "upsert"]
477480
_db_utils.validate_mode(mode=mode, allowed_modes=allowed_modes)
478481
if mode == "upsert" and not upsert_conflict_columns:

awswrangler/quicksight/_create.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import logging
44
import uuid
5-
from typing import Any, Dict, List, Optional, Union, cast
5+
from typing import Any, Dict, List, Literal, Optional, Union, cast
66

77
import boto3
88

@@ -202,7 +202,7 @@ def create_athena_dataset(
202202
sql_name: Optional[str] = None,
203203
data_source_name: Optional[str] = None,
204204
data_source_arn: Optional[str] = None,
205-
import_mode: str = "DIRECT_QUERY",
205+
import_mode: Literal["SPICE", "DIRECT_QUERY"] = "DIRECT_QUERY",
206206
allowed_to_use: Optional[List[str]] = None,
207207
allowed_to_manage: Optional[List[str]] = None,
208208
logical_table_alias: str = "LogicalTable",

awswrangler/redshift.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import json
66
import logging
77
import uuid
8-
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, overload
8+
from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union, overload
99

1010
import boto3
1111
import botocore
@@ -24,6 +24,11 @@
2424
_RS_DISTSTYLES: List[str] = ["AUTO", "EVEN", "ALL", "KEY"]
2525
_RS_SORTSTYLES: List[str] = ["COMPOUND", "INTERLEAVED"]
2626

27+
_ToSqlModeLiteral = Literal["append", "overwrite", "upsert"]
28+
_ToSqlOverwriteModeLiteral = Literal["drop", "cascade", "truncate", "delete"]
29+
_ToSqlDistStyleLiteral = Literal["AUTO", "EVEN", "ALL", "KEY"]
30+
_ToSqlSortStyleLiteral = Literal["COMPOUND", "INTERLEAVED"]
31+
2732

2833
def _validate_connection(con: "redshift_connector.Connection") -> None:
2934
if not isinstance(con, redshift_connector.Connection):
@@ -909,13 +914,13 @@ def to_sql( # pylint: disable=too-many-locals
909914
con: "redshift_connector.Connection",
910915
table: str,
911916
schema: str,
912-
mode: str = "append",
913-
overwrite_method: str = "drop",
917+
mode: _ToSqlModeLiteral = "append",
918+
overwrite_method: _ToSqlOverwriteModeLiteral = "drop",
914919
index: bool = False,
915920
dtype: Optional[Dict[str, str]] = None,
916-
diststyle: str = "AUTO",
921+
diststyle: _ToSqlDistStyleLiteral = "AUTO",
917922
distkey: Optional[str] = None,
918-
sortstyle: str = "COMPOUND",
923+
sortstyle: _ToSqlSortStyleLiteral = "COMPOUND",
919924
sortkey: Optional[List[str]] = None,
920925
primary_keys: Optional[List[str]] = None,
921926
varchar_lengths_default: int = 256,
@@ -1084,7 +1089,7 @@ def unload_to_files(
10841089
aws_secret_access_key: Optional[str] = None,
10851090
aws_session_token: Optional[str] = None,
10861091
region: Optional[str] = None,
1087-
unload_format: Optional[str] = None,
1092+
unload_format: Optional[Literal["CSV", "PARQUET"]] = None,
10881093
max_file_size: Optional[float] = None,
10891094
kms_key_id: Optional[str] = None,
10901095
manifest: bool = False,
@@ -1380,11 +1385,11 @@ def copy_from_files( # pylint: disable=too-many-locals,too-many-arguments
13801385
aws_secret_access_key: Optional[str] = None,
13811386
aws_session_token: Optional[str] = None,
13821387
parquet_infer_sampling: float = 1.0,
1383-
mode: str = "append",
1384-
overwrite_method: str = "drop",
1385-
diststyle: str = "AUTO",
1388+
mode: _ToSqlModeLiteral = "append",
1389+
overwrite_method: _ToSqlOverwriteModeLiteral = "drop",
1390+
diststyle: _ToSqlDistStyleLiteral = "AUTO",
13861391
distkey: Optional[str] = None,
1387-
sortstyle: str = "COMPOUND",
1392+
sortstyle: _ToSqlSortStyleLiteral = "COMPOUND",
13881393
sortkey: Optional[List[str]] = None,
13891394
primary_keys: Optional[List[str]] = None,
13901395
varchar_lengths_default: int = 256,
@@ -1608,11 +1613,11 @@ def copy( # pylint: disable=too-many-arguments,too-many-locals
16081613
aws_session_token: Optional[str] = None,
16091614
index: bool = False,
16101615
dtype: Optional[Dict[str, str]] = None,
1611-
mode: str = "append",
1612-
overwrite_method: str = "drop",
1613-
diststyle: str = "AUTO",
1616+
mode: _ToSqlModeLiteral = "append",
1617+
overwrite_method: _ToSqlOverwriteModeLiteral = "drop",
1618+
diststyle: _ToSqlDistStyleLiteral = "AUTO",
16141619
distkey: Optional[str] = None,
1615-
sortstyle: str = "COMPOUND",
1620+
sortstyle: _ToSqlSortStyleLiteral = "COMPOUND",
16161621
sortkey: Optional[List[str]] = None,
16171622
primary_keys: Optional[List[str]] = None,
16181623
varchar_lengths_default: int = 256,

awswrangler/s3/_copy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import itertools
44
import logging
5-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
5+
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union
66

77
import boto3
88
from boto3.s3.transfer import TransferConfig
@@ -73,7 +73,7 @@ def _copy(
7373
def merge_datasets(
7474
source_path: str,
7575
target_path: str,
76-
mode: str = "append",
76+
mode: Literal["append", "overwrite", "overwrite_partitions"] = "append",
7777
ignore_empty: bool = False,
7878
use_threads: Union[bool, int] = True,
7979
boto3_session: Optional[boto3.Session] = None,

awswrangler/s3/_write_parquet.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import math
55
import uuid
66
from contextlib import contextmanager
7-
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Tuple, Union, cast
7+
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Literal, Optional, Tuple, Union, cast
88

99
import boto3
1010
import pandas as pd
@@ -241,7 +241,7 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals,too-many-b
241241
partition_cols: Optional[List[str]] = None,
242242
bucketing_info: Optional[BucketingInfoTuple] = None,
243243
concurrent_partitioning: bool = False,
244-
mode: Optional[str] = None,
244+
mode: Optional[Literal["append", "overwrite", "overwrite_partitions"]] = None,
245245
catalog_versioning: bool = False,
246246
schema_evolution: bool = True,
247247
database: Optional[str] = None,
@@ -821,7 +821,7 @@ def store_parquet_metadata( # pylint: disable=too-many-arguments,too-many-local
821821
parameters: Optional[Dict[str, str]] = None,
822822
columns_comments: Optional[Dict[str, str]] = None,
823823
compression: Optional[str] = None,
824-
mode: str = "overwrite",
824+
mode: Literal["append", "overwrite"] = "overwrite",
825825
catalog_versioning: bool = False,
826826
regular_partitions: bool = True,
827827
athena_partition_projection_settings: Optional[typing.AthenaPartitionProjectionSettings] = None,

0 commit comments

Comments
 (0)