Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions mssql_python/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,12 @@ def setencoding(self, encoding=None, ctype=None):
ddbc_error=f"ctype must be SQL_CHAR ({ConstantsDDBC.SQL_CHAR.value}) or SQL_WCHAR ({ConstantsDDBC.SQL_WCHAR.value})",
)

# Enforce UTF-16 encoding restriction for SQL_WCHAR
if ctype == ConstantsDDBC.SQL_WCHAR.value and encoding not in UTF16_ENCODINGS:
log('warning', "SQL_WCHAR only supports UTF-16 encodings. Attempted encoding '%s' is not allowed. Using default 'utf-16le' instead.",
sanitize_user_input(encoding))
encoding = 'utf-16le'

# Store the encoding settings
self._encoding_settings = {
'encoding': encoding,
Expand Down Expand Up @@ -489,13 +495,26 @@ def setdecoding(self, sqltype, encoding=None, ctype=None):
# Normalize encoding to lowercase for consistency
encoding = encoding.lower()

# Enforce UTF-16 encoding restriction for SQL_WCHAR and SQL_WMETADATA
if (sqltype == ConstantsDDBC.SQL_WCHAR.value or sqltype == SQL_WMETADATA) and encoding not in UTF16_ENCODINGS:
sqltype_name = "SQL_WCHAR" if sqltype == ConstantsDDBC.SQL_WCHAR.value else "SQL_WMETADATA"
log('warning', "%s only supports UTF-16 encodings. Attempted encoding '%s' is not allowed. Using default 'utf-16le' instead.",
sqltype_name, sanitize_user_input(encoding))
encoding = 'utf-16le'

# Set default ctype based on encoding if not provided
if ctype is None:
if encoding in UTF16_ENCODINGS:
ctype = ConstantsDDBC.SQL_WCHAR.value
else:
ctype = ConstantsDDBC.SQL_CHAR.value

# Additional validation: if user explicitly sets ctype to SQL_WCHAR but encoding is not UTF-16
if ctype == ConstantsDDBC.SQL_WCHAR.value and encoding not in UTF16_ENCODINGS:
log('warning', "SQL_WCHAR ctype only supports UTF-16 encodings. Attempted encoding '%s' is not compatible. Using default 'utf-16le' instead.",
sanitize_user_input(encoding))
encoding = 'utf-16le'

# Validate ctype
valid_ctypes = [ConstantsDDBC.SQL_CHAR.value, ConstantsDDBC.SQL_WCHAR.value]
if ctype not in valid_ctypes:
Expand Down
Loading