Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions mssql_python/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,12 @@ def setencoding(
),
)

# Enforce UTF-16 encoding restriction for SQL_WCHAR
if ctype == ConstantsDDBC.SQL_WCHAR.value and encoding not in UTF16_ENCODINGS:
log('warning', "SQL_WCHAR only supports UTF-16 encodings. Attempted encoding '%s' is not allowed. Using default 'utf-16le' instead.",
sanitize_user_input(encoding))
encoding = 'utf-16le'

# Store the encoding settings
self._encoding_settings = {"encoding": encoding, "ctype": ctype}

Expand Down Expand Up @@ -543,13 +549,26 @@ def setdecoding(
# Normalize encoding to lowercase for consistency
encoding = encoding.lower()

# Enforce UTF-16 encoding restriction for SQL_WCHAR and SQL_WMETADATA
if (sqltype == ConstantsDDBC.SQL_WCHAR.value or sqltype == SQL_WMETADATA) and encoding not in UTF16_ENCODINGS:
sqltype_name = "SQL_WCHAR" if sqltype == ConstantsDDBC.SQL_WCHAR.value else "SQL_WMETADATA"
log('warning', "%s only supports UTF-16 encodings. Attempted encoding '%s' is not allowed. Using default 'utf-16le' instead.",
sqltype_name, sanitize_user_input(encoding))
encoding = 'utf-16le'

# Set default ctype based on encoding if not provided
if ctype is None:
if encoding in UTF16_ENCODINGS:
ctype = ConstantsDDBC.SQL_WCHAR.value
else:
ctype = ConstantsDDBC.SQL_CHAR.value

# Additional validation: if user explicitly sets ctype to SQL_WCHAR but encoding is not UTF-16
if ctype == ConstantsDDBC.SQL_WCHAR.value and encoding not in UTF16_ENCODINGS:
log('warning', "SQL_WCHAR ctype only supports UTF-16 encodings. Attempted encoding '%s' is not compatible. Using default 'utf-16le' instead.",
sanitize_user_input(encoding))
encoding = 'utf-16le'

# Validate ctype
valid_ctypes = [ConstantsDDBC.SQL_CHAR.value, ConstantsDDBC.SQL_WCHAR.value]
if ctype not in valid_ctypes:
Expand Down
82 changes: 73 additions & 9 deletions mssql_python/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from mssql_python.constants import ConstantsDDBC as ddbc_sql_const, SQLTypes
from mssql_python.helpers import check_error, log
from mssql_python import ddbc_bindings
from mssql_python.exceptions import InterfaceError, NotSupportedError, ProgrammingError
from mssql_python.exceptions import InterfaceError, NotSupportedError, ProgrammingError, OperationalError, DatabaseError
from mssql_python.row import Row
from mssql_python import get_settings

Expand Down Expand Up @@ -287,6 +287,51 @@ def _get_numeric_data(self, param: decimal.Decimal) -> Any:

numeric_data.val = bytes(byte_array)
return numeric_data

def _get_encoding_settings(self):
"""
Get the encoding settings from the connection.

Returns:
dict: A dictionary with 'encoding' and 'ctype' keys, or default settings if not available
"""
if hasattr(self._connection, 'getencoding'):
try:
return self._connection.getencoding()
except (OperationalError, DatabaseError) as db_error:
# Only catch database-related errors, not programming errors
log('warning', f"Failed to get encoding settings from connection due to database error: {db_error}")
return {
'encoding': 'utf-16le',
'ctype': ddbc_sql_const.SQL_WCHAR.value
}

# Return default encoding settings if getencoding is not available
return {
'encoding': 'utf-16le',
'ctype': ddbc_sql_const.SQL_WCHAR.value
}

def _get_decoding_settings(self, sql_type):
"""
Get decoding settings for a specific SQL type.

Args:
sql_type: SQL type constant (SQL_CHAR, SQL_WCHAR, etc.)

Returns:
Dictionary containing the decoding settings.
"""
try:
# Get decoding settings from connection for this SQL type
return self._connection.getdecoding(sql_type)
except (OperationalError, DatabaseError) as db_error:
# Only handle expected database-related errors
log('warning', f"Failed to get decoding settings for SQL type {sql_type} due to database error: {db_error}")
if sql_type == ddbc_sql_const.SQL_WCHAR.value:
return {'encoding': 'utf-16le', 'ctype': ddbc_sql_const.SQL_WCHAR.value}
else:
return {'encoding': 'utf-8', 'ctype': ddbc_sql_const.SQL_CHAR.value}

def _map_sql_type( # pylint: disable=too-many-arguments,too-many-positional-arguments,too-many-locals,too-many-return-statements,too-many-branches
self,
Expand Down Expand Up @@ -1028,6 +1073,9 @@ def execute( # pylint: disable=too-many-locals,too-many-branches,too-many-state
# Clear any previous messages
self.messages = []

# Getting encoding setting
encoding_settings = self._get_encoding_settings()

# Apply timeout if set (non-zero)
if self._timeout > 0:
try:
Expand Down Expand Up @@ -1100,6 +1148,7 @@ def execute( # pylint: disable=too-many-locals,too-many-branches,too-many-state
parameters_type,
self.is_stmt_prepared,
use_prepare,
encoding_settings
)
# Check return code
try:
Expand Down Expand Up @@ -1897,9 +1946,10 @@ def executemany( # pylint: disable=too-many-locals,too-many-branches,too-many-s
processed_parameters.append(processed_row)

# Now transpose the processed parameters
columnwise_params, row_count = self._transpose_rowwise_to_columnwise(
processed_parameters
)
columnwise_params, row_count = self._transpose_rowwise_to_columnwise(processed_parameters)

# Get encoding settings
encoding_settings = self._get_encoding_settings()

# Add debug logging
log(
Expand All @@ -1913,7 +1963,12 @@ def executemany( # pylint: disable=too-many-locals,too-many-branches,too-many-s
)

ret = ddbc_bindings.SQLExecuteMany(
self.hstmt, operation, columnwise_params, parameters_type, row_count
self.hstmt,
operation,
columnwise_params,
parameters_type,
row_count,
encoding_settings
)

# Capture any diagnostic messages after execution
Expand Down Expand Up @@ -1945,11 +2000,14 @@ def fetchone(self) -> Union[None, Row]:
"""
self._check_closed() # Check if the cursor is closed

char_decoding = self._get_decoding_settings(ddbc_sql_const.SQL_CHAR.value)
wchar_decoding = self._get_decoding_settings(ddbc_sql_const.SQL_WCHAR.value)

# Fetch raw data
row_data = []
try:
ret = ddbc_bindings.DDBCSQLFetchOne(self.hstmt, row_data)

ret = ddbc_bindings.DDBCSQLFetchOne(self.hstmt, row_data, char_decoding.get('encoding', 'utf-8'), wchar_decoding.get('encoding', 'utf-16le'))
if self.hstmt:
self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt))

Expand Down Expand Up @@ -1997,10 +2055,13 @@ def fetchmany(self, size: Optional[int] = None) -> List[Row]:
if size <= 0:
return []

char_decoding = self._get_decoding_settings(ddbc_sql_const.SQL_CHAR.value)
wchar_decoding = self._get_decoding_settings(ddbc_sql_const.SQL_WCHAR.value)

# Fetch raw data
rows_data = []
try:
_ = ddbc_bindings.DDBCSQLFetchMany(self.hstmt, rows_data, size)
ret = ddbc_bindings.DDBCSQLFetchMany(self.hstmt, rows_data, size, char_decoding.get('encoding', 'utf-8'), wchar_decoding.get('encoding', 'utf-16le'))

if self.hstmt:
self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt))
Expand Down Expand Up @@ -2039,10 +2100,13 @@ def fetchall(self) -> List[Row]:
if not self._has_result_set and self.description:
self._reset_rownumber()

char_decoding = self._get_decoding_settings(ddbc_sql_const.SQL_CHAR.value)
wchar_decoding = self._get_decoding_settings(ddbc_sql_const.SQL_WCHAR.value)

# Fetch raw data
rows_data = []
try:
_ = ddbc_bindings.DDBCSQLFetchAll(self.hstmt, rows_data)
ret = ddbc_bindings.DDBCSQLFetchAll(self.hstmt, rows_data, char_decoding.get('encoding', 'utf-8'), wchar_decoding.get('encoding', 'utf-16le'))

if self.hstmt:
self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt))
Expand Down
Loading
Loading