diff --git a/deepnote_toolkit/sql/sql_execution.py b/deepnote_toolkit/sql/sql_execution.py index 7f51e3e..5b70790 100644 --- a/deepnote_toolkit/sql/sql_execution.py +++ b/deepnote_toolkit/sql/sql_execution.py @@ -107,6 +107,12 @@ class ExecuteSqlError(Exception): params = sql_alchemy_dict.get("params") sql_alchemy_dict["params"] = _build_params_for_bigquery_oauth(params) + requires_trino_oauth = sql_alchemy_dict["url"].startswith("trino://") and dnenv.get_env("TRINO_OAUTH_TOKEN") + + if requires_trino_oauth: + params = sql_alchemy_dict.get("params", {}) + sql_alchemy_dict["params"] = _build_params_for_trino_oauth(params) + # When using key-pair authentication with Snowflake, the private key will be # passed as a base64 encoded string as 'snowflake_private_key'. # @@ -457,6 +463,23 @@ class BigQueryCredentialsError(Exception): return {"connect_args": {"client": client}} +def _build_params_for_trino_oauth(params): + import trino.auth + + oauth_token = dnenv.get_env("TRINO_OAUTH_TOKEN") + if not oauth_token: + raise Exception("TRINO_OAUTH_TOKEN environment variable is not set") + + auth = trino.auth.JWTAuthentication(oauth_token) + + result_params = params.copy() + connect_args = result_params.get("connect_args", {}) + connect_args["auth"] = auth + result_params["connect_args"] = connect_args + + return result_params + + def _sanitize_dataframe_for_parquet(dataframe): """Sanitizes the dataframe so that we can safely call .to_parquet on it"""