Skip to content

Commit dd13fe1

Browse files
committed
Resolving comments
1 parent de8abbd commit dd13fe1

File tree

2 files changed

+71
-19
lines changed

2 files changed

+71
-19
lines changed

mssql_python/pybind/ddbc_bindings.cpp

Lines changed: 66 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,10 @@ struct NumericData {
7272
if (valueBytes.size() > SQL_MAX_NUMERIC_LEN) {
7373
throw std::runtime_error("NumericData valueBytes size exceeds SQL_MAX_NUMERIC_LEN (16)");
7474
}
75-
// Copy binary data to buffer, remaining bytes stay zero-padded
76-
std::memcpy(&val[0], valueBytes.data(), valueBytes.size());
75+
// Secure copy: bounds already validated, but using std::copy_n for safety
76+
if (valueBytes.size() > 0) {
77+
std::copy_n(valueBytes.data(), valueBytes.size(), &val[0]);
78+
}
7779
}
7880
};
7981

@@ -768,8 +770,9 @@ SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params,
768770
// Convert the integer decimalParam.val to char array
769771
std::memset(static_cast<void*>(decimalPtr->val), 0, sizeof(decimalPtr->val));
770772
size_t copyLen = std::min(decimalParam.val.size(), sizeof(decimalPtr->val));
773+
// Secure copy: bounds already validated with std::min
771774
if (copyLen > 0) {
772-
std::memcpy(decimalPtr->val, decimalParam.val.data(), copyLen);
775+
std::copy_n(decimalParam.val.data(), copyLen, decimalPtr->val);
773776
}
774777
dataPtr = static_cast<void*>(decimalPtr);
775778
break;
@@ -796,7 +799,8 @@ SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params,
796799
guid_data_ptr->Data3 =
797800
(static_cast<uint16_t>(uuid_data[7]) << 8) |
798801
(static_cast<uint16_t>(uuid_data[6]));
799-
std::memcpy(guid_data_ptr->Data4, &uuid_data[8], 8);
802+
// Secure copy: Fixed 8-byte copy for GUID Data4 field
803+
std::copy_n(&uuid_data[8], 8, guid_data_ptr->Data4);
800804
dataPtr = static_cast<void*>(guid_data_ptr);
801805
bufferLength = sizeof(SQLGUID);
802806
strLenOrIndPtr = AllocateParamBuffer<SQLLEN>(paramBuffers);
@@ -1992,15 +1996,34 @@ SQLRETURN BindParameterArray(SQLHANDLE hStmt,
19921996
ThrowStdException("Input string UTF-16 length exceeds allowed column size at parameter index " + std::to_string(paramIndex) +
19931997
". UTF-16 length: " + std::to_string(utf16Buf.size() - 1) + ", Column size: " + std::to_string(info.columnSize));
19941998
}
1995-
// If we reach here, the UTF-16 string fits - copy it completely
1996-
std::memcpy(wcharArray + i * (info.columnSize + 1), utf16Buf.data(), utf16Buf.size() * sizeof(SQLWCHAR));
1999+
// Secure copy: use validated bounds for defense-in-depth
2000+
size_t copyBytes = utf16Buf.size() * sizeof(SQLWCHAR);
2001+
size_t bufferBytes = (info.columnSize + 1) * sizeof(SQLWCHAR);
2002+
SQLWCHAR* destPtr = wcharArray + i * (info.columnSize + 1);
2003+
2004+
if (copyBytes > bufferBytes) {
2005+
ThrowStdException("Buffer overflow prevented in WCHAR array binding at parameter index " + std::to_string(paramIndex) +
2006+
", array element " + std::to_string(i));
2007+
}
2008+
if (copyBytes > 0) {
2009+
std::copy_n(reinterpret_cast<const char*>(utf16Buf.data()), copyBytes, reinterpret_cast<char*>(destPtr));
2010+
}
19972011
#else
19982012
// On Windows, wchar_t is already UTF-16, so the original check is sufficient
19992013
if (wstr.length() > info.columnSize) {
20002014
std::string offending = WideToUTF8(wstr);
20012015
ThrowStdException("Input string exceeds allowed column size at parameter index " + std::to_string(paramIndex));
20022016
}
2003-
std::memcpy(wcharArray + i * (info.columnSize + 1), wstr.c_str(), (wstr.length() + 1) * sizeof(SQLWCHAR));
2017+
// Secure copy with bounds checking
2018+
size_t copyBytes = (wstr.length() + 1) * sizeof(SQLWCHAR);
2019+
size_t bufferBytes = (info.columnSize + 1) * sizeof(SQLWCHAR);
2020+
SQLWCHAR* destPtr = wcharArray + i * (info.columnSize + 1);
2021+
2022+
errno_t err = memcpy_s(destPtr, bufferBytes, wstr.c_str(), copyBytes);
2023+
if (err != 0) {
2024+
ThrowStdException("Secure memory copy failed in WCHAR array binding at parameter index " + std::to_string(paramIndex) +
2025+
", array element " + std::to_string(i) + ", error code: " + std::to_string(err));
2026+
}
20042027
#endif
20052028
strLenOrIndArray[i] = SQL_NTS;
20062029
}
@@ -2097,8 +2120,30 @@ SQLRETURN BindParameterArray(SQLHANDLE hStmt,
20972120
ThrowStdException("Input exceeds column size at index " + std::to_string(i));
20982121
}
20992122

2100-
std::memcpy(charArray + i * (info.columnSize + 1), str.c_str(), str.size());
2101-
strLenOrIndArray[i] = static_cast<SQLLEN>(str.size());
2123+
// SECURITY: Use secure copy with bounds checking
2124+
size_t destOffset = i * (info.columnSize + 1);
2125+
size_t destBufferSize = info.columnSize + 1;
2126+
size_t copyLength = str.size();
2127+
2128+
// Validate bounds to prevent buffer overflow
2129+
if (copyLength >= destBufferSize) {
2130+
ThrowStdException("Buffer overflow prevented at parameter array index " + std::to_string(i));
2131+
}
2132+
2133+
#ifdef _WIN32
2134+
// Windows: Use memcpy_s for secure copy
2135+
errno_t err = memcpy_s(charArray + destOffset, destBufferSize, str.data(), copyLength);
2136+
if (err != 0) {
2137+
ThrowStdException("Secure memory copy failed with error code " + std::to_string(err) + " at array index " + std::to_string(i));
2138+
}
2139+
#else
2140+
// POSIX: Use std::copy_n with explicit bounds checking
2141+
if (copyLength > 0) {
2142+
std::copy_n(str.data(), copyLength, charArray + destOffset);
2143+
}
2144+
#endif
2145+
2146+
strLenOrIndArray[i] = static_cast<SQLLEN>(copyLength);
21022147
}
21032148
}
21042149
dataPtr = charArray;
@@ -2303,8 +2348,9 @@ SQLRETURN BindParameterArray(SQLHANDLE hStmt,
23032348
target.scale = decimalParam.scale;
23042349
target.sign = decimalParam.sign;
23052350
size_t copyLen = std::min(decimalParam.val.size(), sizeof(target.val));
2351+
// Secure copy: bounds already validated with std::min
23062352
if (copyLen > 0) {
2307-
std::memcpy(target.val, decimalParam.val.data(), copyLen);
2353+
std::copy_n(decimalParam.val.data(), copyLen, target.val);
23082354
}
23092355
strLenOrIndArray[i] = sizeof(SQL_NUMERIC_STRUCT);
23102356
}
@@ -2333,11 +2379,13 @@ SQLRETURN BindParameterArray(SQLHANDLE hStmt,
23332379
if (PyBytes_GET_SIZE(b.ptr()) != 16) {
23342380
ThrowStdException("UUID binary data must be exactly 16 bytes long.");
23352381
}
2336-
std::memcpy(uuid_bytes.data(), PyBytes_AS_STRING(b.ptr()), 16);
2382+
// Secure copy: Fixed 16-byte copy, size validated above
2383+
std::copy_n(reinterpret_cast<const unsigned char*>(PyBytes_AS_STRING(b.ptr())), 16, uuid_bytes.data());
23372384
}
23382385
else if (py::isinstance(element, uuid_class)) {
23392386
py::bytes b = element.attr("bytes_le").cast<py::bytes>();
2340-
std::memcpy(uuid_bytes.data(), PyBytes_AS_STRING(b.ptr()), 16);
2387+
// Secure copy: Fixed 16-byte copy from UUID bytes_le attribute
2388+
std::copy_n(reinterpret_cast<const unsigned char*>(PyBytes_AS_STRING(b.ptr())), 16, uuid_bytes.data());
23412389
}
23422390
else {
23432391
ThrowStdException(MakeParamMismatchErrorStr(info.paramCType, paramIndex));
@@ -2350,7 +2398,8 @@ SQLRETURN BindParameterArray(SQLHANDLE hStmt,
23502398
(static_cast<uint16_t>(uuid_bytes[4]));
23512399
guidArray[i].Data3 = (static_cast<uint16_t>(uuid_bytes[7]) << 8) |
23522400
(static_cast<uint16_t>(uuid_bytes[6]));
2353-
std::memcpy(guidArray[i].Data4, uuid_bytes.data() + 8, 8);
2401+
// Secure copy: Fixed 8-byte copy for GUID Data4 field
2402+
std::copy_n(uuid_bytes.data() + 8, 8, guidArray[i].Data4);
23542403
strLenOrIndArray[i] = sizeof(SQLGUID);
23552404
}
23562405
dataPtr = guidArray;
@@ -3181,7 +3230,8 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p
31813230
guid_bytes[5] = ((char*)&guidValue.Data2)[0];
31823231
guid_bytes[6] = ((char*)&guidValue.Data3)[1];
31833232
guid_bytes[7] = ((char*)&guidValue.Data3)[0];
3184-
std::memcpy(&guid_bytes[8], guidValue.Data4, sizeof(guidValue.Data4));
3233+
// Secure copy: Fixed 8-byte copy for GUID Data4 field
3234+
std::copy_n(guidValue.Data4, sizeof(guidValue.Data4), &guid_bytes[8]);
31853235

31863236
py::bytes py_guid_bytes(guid_bytes.data(), guid_bytes.size());
31873237
py::object uuid_module = py::module_::import("uuid");
@@ -3655,7 +3705,8 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum
36553705
reordered[5] = ((char*)&guidValue->Data2)[0];
36563706
reordered[6] = ((char*)&guidValue->Data3)[1];
36573707
reordered[7] = ((char*)&guidValue->Data3)[0];
3658-
std::memcpy(reordered + 8, guidValue->Data4, 8);
3708+
// Secure copy: Fixed 8-byte copy for GUID Data4 field
3709+
std::copy_n(guidValue->Data4, 8, reordered + 8);
36593710

36603711
py::bytes py_guid_bytes(reinterpret_cast<char*>(reordered), 16);
36613712
py::dict kwargs;

tests/test_011_encoding_decoding.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2961,9 +2961,10 @@ def safe_display(text, max_len=50):
29612961
if text is None:
29622962
return "NULL"
29632963
try:
2964+
# Use ascii() to ensure CP1252 console compatibility on Windows
29642965
display = text[:max_len] if len(text) > max_len else text
2965-
return display.encode('ascii', 'replace').decode('ascii')
2966-
except (UnicodeError, AttributeError):
2966+
return ascii(display)
2967+
except (AttributeError, TypeError):
29672968
return repr(text)[:max_len]
29682969

29692970

@@ -3505,8 +3506,8 @@ def test_utf16_unicode_preservation(db_connection):
35053506
cursor.execute("SELECT data FROM #test_utf16 WHERE id = 1")
35063507
result = cursor.fetchone()
35073508
match = "PASS" if result[0] == text else "FAIL"
3508-
# Use repr() to avoid console encoding issues on Windows
3509-
print(f" {match} {description:10} | {text!r} -> {result[0]!r}")
3509+
# Use ascii() to force ASCII-safe output on Windows CP1252 console
3510+
print(f" {match} {description:10} | {ascii(text)} -> {ascii(result[0])}")
35103511
assert result[0] == text, f"UTF-16 should preserve {description}"
35113512

35123513
print("="*60)

0 commit comments

Comments
 (0)