Skip to content

Commit 60e50de

Browse files
committed
fix: Update thrift_backend.py to use host_url instead of session_id_hex
Changes: 1. Added self._host attribute to store server_hostname 2. Updated all error raises to use host_url=self._host 3. Changed method signatures from session_id_hex to host_url: - _check_response_for_error - _hive_schema_to_arrow_schema - _col_to_description - _hive_schema_to_description - _check_direct_results_for_error 4. Updated all method calls to pass self._host instead of self._session_id_hex This completes the migration from session-level to host-level error reporting.
1 parent b11a461 commit 60e50de

File tree

1 file changed

+33
-32
lines changed

1 file changed

+33
-32
lines changed

src/databricks/sql/backend/thrift_backend.py

Lines changed: 33 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ def __init__(
163163
else:
164164
raise ValueError("No valid connection settings.")
165165

166+
self._host = server_hostname
166167
self._initialize_retry_args(kwargs)
167168
self._use_arrow_native_complex_types = kwargs.get(
168169
"_use_arrow_native_complex_types", True
@@ -279,14 +280,14 @@ def _initialize_retry_args(self, kwargs):
279280
)
280281

281282
@staticmethod
282-
def _check_response_for_error(response, session_id_hex=None):
283+
def _check_response_for_error(response, host_url=None):
283284
if response.status and response.status.statusCode in [
284285
ttypes.TStatusCode.ERROR_STATUS,
285286
ttypes.TStatusCode.INVALID_HANDLE_STATUS,
286287
]:
287288
raise DatabaseError(
288289
response.status.errorMessage,
289-
session_id_hex=session_id_hex,
290+
host_url=host_url,
290291
)
291292

292293
@staticmethod
@@ -340,7 +341,7 @@ def _handle_request_error(self, error_info, attempt, elapsed):
340341
network_request_error = RequestError(
341342
user_friendly_error_message,
342343
full_error_info_context,
343-
self._session_id_hex,
344+
self._host,
344345
error_info.error,
345346
)
346347
logger.info(network_request_error.message_with_context())
@@ -517,7 +518,7 @@ def attempt_request(attempt):
517518
# log nothing here, presume that main request logging covers
518519
response = response_or_error_info
519520
ThriftDatabricksClient._check_response_for_error(
520-
response, self._session_id_hex
521+
response, self._host
521522
)
522523
return response
523524

@@ -533,7 +534,7 @@ def _check_protocol_version(self, t_open_session_resp):
533534
"Error: expected server to use a protocol version >= "
534535
"SPARK_CLI_SERVICE_PROTOCOL_V2, "
535536
"instead got: {}".format(protocol_version),
536-
session_id_hex=self._session_id_hex,
537+
host_url=self._host,
537538
)
538539

539540
def _check_initial_namespace(self, catalog, schema, response):
@@ -547,15 +548,15 @@ def _check_initial_namespace(self, catalog, schema, response):
547548
raise InvalidServerResponseError(
548549
"Setting initial namespace not supported by the DBR version, "
549550
"Please use a Databricks SQL endpoint or a cluster with DBR >= 9.0.",
550-
session_id_hex=self._session_id_hex,
551+
host_url=self._host,
551552
)
552553

553554
if catalog:
554555
if not response.canUseMultipleCatalogs:
555556
raise InvalidServerResponseError(
556557
"Unexpected response from server: Trying to set initial catalog to {}, "
557558
+ "but server does not support multiple catalogs.".format(catalog), # type: ignore
558-
session_id_hex=self._session_id_hex,
559+
host_url=self._host,
559560
)
560561

561562
def _check_session_configuration(self, session_configuration):
@@ -570,7 +571,7 @@ def _check_session_configuration(self, session_configuration):
570571
TIMESTAMP_AS_STRING_CONFIG,
571572
session_configuration[TIMESTAMP_AS_STRING_CONFIG],
572573
),
573-
session_id_hex=self._session_id_hex,
574+
host_url=self._host,
574575
)
575576

576577
def open_session(self, session_configuration, catalog, schema) -> SessionId:
@@ -639,7 +640,7 @@ def _check_command_not_in_error_or_closed_state(
639640
and guid_to_hex_id(op_handle.operationId.guid),
640641
"diagnostic-info": get_operations_resp.diagnosticInfo,
641642
},
642-
session_id_hex=self._session_id_hex,
643+
host_url=self._host,
643644
)
644645
else:
645646
raise ServerOperationError(
@@ -649,7 +650,7 @@ def _check_command_not_in_error_or_closed_state(
649650
and guid_to_hex_id(op_handle.operationId.guid),
650651
"diagnostic-info": None,
651652
},
652-
session_id_hex=self._session_id_hex,
653+
host_url=self._host,
653654
)
654655
elif get_operations_resp.operationState == ttypes.TOperationState.CLOSED_STATE:
655656
raise DatabaseError(
@@ -660,7 +661,7 @@ def _check_command_not_in_error_or_closed_state(
660661
"operation-id": op_handle
661662
and guid_to_hex_id(op_handle.operationId.guid)
662663
},
663-
session_id_hex=self._session_id_hex,
664+
host_url=self._host,
664665
)
665666

666667
def _poll_for_status(self, op_handle):
@@ -683,7 +684,7 @@ def _create_arrow_table(self, t_row_set, lz4_compressed, schema_bytes, descripti
683684
else:
684685
raise OperationalError(
685686
"Unsupported TRowSet instance {}".format(t_row_set),
686-
session_id_hex=self._session_id_hex,
687+
host_url=self._host,
687688
)
688689
return convert_decimals_in_arrow_table(arrow_table, description), num_rows
689690

@@ -692,7 +693,7 @@ def _get_metadata_resp(self, op_handle):
692693
return self.make_request(self._client.GetResultSetMetadata, req)
693694

694695
@staticmethod
695-
def _hive_schema_to_arrow_schema(t_table_schema, session_id_hex=None):
696+
def _hive_schema_to_arrow_schema(t_table_schema, host_url=None):
696697
def map_type(t_type_entry):
697698
if t_type_entry.primitiveEntry:
698699
return {
@@ -724,7 +725,7 @@ def map_type(t_type_entry):
724725
# even for complex types
725726
raise OperationalError(
726727
"Thrift protocol error: t_type_entry not a primitiveEntry",
727-
session_id_hex=session_id_hex,
728+
host_url=host_url,
728729
)
729730

730731
def convert_col(t_column_desc):
@@ -735,7 +736,7 @@ def convert_col(t_column_desc):
735736
return pyarrow.schema([convert_col(col) for col in t_table_schema.columns])
736737

737738
@staticmethod
738-
def _col_to_description(col, field=None, session_id_hex=None):
739+
def _col_to_description(col, field=None, host_url=None):
739740
type_entry = col.typeDesc.types[0]
740741

741742
if type_entry.primitiveEntry:
@@ -745,7 +746,7 @@ def _col_to_description(col, field=None, session_id_hex=None):
745746
else:
746747
raise OperationalError(
747748
"Thrift protocol error: t_type_entry not a primitiveEntry",
748-
session_id_hex=session_id_hex,
749+
host_url=host_url,
749750
)
750751

751752
if type_entry.primitiveEntry.type == ttypes.TTypeId.DECIMAL_TYPE:
@@ -759,7 +760,7 @@ def _col_to_description(col, field=None, session_id_hex=None):
759760
raise OperationalError(
760761
"Decimal type did not provide typeQualifier precision, scale in "
761762
"primitiveEntry {}".format(type_entry.primitiveEntry),
762-
session_id_hex=session_id_hex,
763+
host_url=host_url,
763764
)
764765
else:
765766
precision, scale = None, None
@@ -779,7 +780,7 @@ def _col_to_description(col, field=None, session_id_hex=None):
779780

780781
@staticmethod
781782
def _hive_schema_to_description(
782-
t_table_schema, schema_bytes=None, session_id_hex=None
783+
t_table_schema, schema_bytes=None, host_url=None
783784
):
784785
field_dict = {}
785786
if pyarrow and schema_bytes:
@@ -795,7 +796,7 @@ def _hive_schema_to_description(
795796
ThriftDatabricksClient._col_to_description(
796797
col,
797798
field_dict.get(col.columnName) if field_dict else None,
798-
session_id_hex,
799+
host_url,
799800
)
800801
for col in t_table_schema.columns
801802
]
@@ -818,7 +819,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
818819
t_result_set_metadata_resp.resultFormat
819820
]
820821
),
821-
session_id_hex=self._session_id_hex,
822+
host_url=self._host,
822823
)
823824
direct_results = resp.directResults
824825
has_been_closed_server_side = direct_results and direct_results.closeOperation
@@ -833,7 +834,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
833834
schema_bytes = (
834835
t_result_set_metadata_resp.arrowSchema
835836
or self._hive_schema_to_arrow_schema(
836-
t_result_set_metadata_resp.schema, self._session_id_hex
837+
t_result_set_metadata_resp.schema, self._host
837838
)
838839
.serialize()
839840
.to_pybytes()
@@ -844,7 +845,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
844845
description = self._hive_schema_to_description(
845846
t_result_set_metadata_resp.schema,
846847
schema_bytes,
847-
self._session_id_hex,
848+
self._host,
848849
)
849850

850851
lz4_compressed = t_result_set_metadata_resp.lz4Compressed
@@ -895,7 +896,7 @@ def get_execution_result(
895896
schema_bytes = (
896897
t_result_set_metadata_resp.arrowSchema
897898
or self._hive_schema_to_arrow_schema(
898-
t_result_set_metadata_resp.schema, self._session_id_hex
899+
t_result_set_metadata_resp.schema, self._host
899900
)
900901
.serialize()
901902
.to_pybytes()
@@ -906,7 +907,7 @@ def get_execution_result(
906907
description = self._hive_schema_to_description(
907908
t_result_set_metadata_resp.schema,
908909
schema_bytes,
909-
self._session_id_hex,
910+
self._host,
910911
)
911912

912913
lz4_compressed = t_result_set_metadata_resp.lz4Compressed
@@ -971,27 +972,27 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
971972
return state
972973

973974
@staticmethod
974-
def _check_direct_results_for_error(t_spark_direct_results, session_id_hex=None):
975+
def _check_direct_results_for_error(t_spark_direct_results, host_url=None):
975976
if t_spark_direct_results:
976977
if t_spark_direct_results.operationStatus:
977978
ThriftDatabricksClient._check_response_for_error(
978979
t_spark_direct_results.operationStatus,
979-
session_id_hex,
980+
host_url,
980981
)
981982
if t_spark_direct_results.resultSetMetadata:
982983
ThriftDatabricksClient._check_response_for_error(
983984
t_spark_direct_results.resultSetMetadata,
984-
session_id_hex,
985+
host_url,
985986
)
986987
if t_spark_direct_results.resultSet:
987988
ThriftDatabricksClient._check_response_for_error(
988989
t_spark_direct_results.resultSet,
989-
session_id_hex,
990+
host_url,
990991
)
991992
if t_spark_direct_results.closeOperation:
992993
ThriftDatabricksClient._check_response_for_error(
993994
t_spark_direct_results.closeOperation,
994-
session_id_hex,
995+
host_url,
995996
)
996997

997998
def execute_command(
@@ -1260,7 +1261,7 @@ def _handle_execute_response(self, resp, cursor):
12601261
raise ValueError(f"Invalid Thrift handle: {resp.operationHandle}")
12611262

12621263
cursor.active_command_id = command_id
1263-
self._check_direct_results_for_error(resp.directResults, self._session_id_hex)
1264+
self._check_direct_results_for_error(resp.directResults, self._host)
12641265

12651266
final_operation_state = self._wait_until_command_done(
12661267
resp.operationHandle,
@@ -1275,7 +1276,7 @@ def _handle_execute_response_async(self, resp, cursor):
12751276
raise ValueError(f"Invalid Thrift handle: {resp.operationHandle}")
12761277

12771278
cursor.active_command_id = command_id
1278-
self._check_direct_results_for_error(resp.directResults, self._session_id_hex)
1279+
self._check_direct_results_for_error(resp.directResults, self._host)
12791280

12801281
def fetch_results(
12811282
self,
@@ -1313,7 +1314,7 @@ def fetch_results(
13131314
"fetch_results failed due to inconsistency in the state between the client and the server. Expected results to start from {} but they instead start at {}, some result batches must have been skipped".format(
13141315
expected_row_start_offset, resp.results.startRowOffset
13151316
),
1316-
session_id_hex=self._session_id_hex,
1317+
host_url=self._host,
13171318
)
13181319

13191320
queue = ThriftResultSetQueueFactory.build_queue(

0 commit comments

Comments
 (0)