diff --git a/src/pyodbcmodule.cpp b/src/pyodbcmodule.cpp index 771fe257..f30dfd18 100644 --- a/src/pyodbcmodule.cpp +++ b/src/pyodbcmodule.cpp @@ -1327,7 +1327,6 @@ BOOL WINAPI DllMain( } #endif - static PyObject* MakeConnectionString(PyObject* existing, PyObject* parts) { // Creates a connection string from an optional existing connection string plus a dictionary of keyword value @@ -1344,13 +1343,13 @@ static PyObject* MakeConnectionString(PyObject* existing, PyObject* parts) I(PyUnicode_Check(existing)); - Py_ssize_t length = 0; // length in *characters* - if (existing) - length = Text_Size(existing) + 1; // + 1 to add a trailing semicolon - Py_ssize_t pos = 0; PyObject* key = 0; PyObject* value = 0; + Py_ssize_t length = 0; // length in *characters* +#if PY_MAJOR_VERSION < 3 + if (existing) + length = Text_Size(existing) + 1; // + 1 to add a trailing semicolon while (PyDict_Next(parts, &pos, &key, &value)) { @@ -1379,7 +1378,66 @@ static PyObject* MakeConnectionString(PyObject* existing, PyObject* parts) offset += TextCopyToUnicode(&buffer[offset], value); buffer[offset++] = (Py_UNICODE)';'; } +#else // >= Python 3.3 + int result_kind = PyUnicode_1BYTE_KIND; + if (existing) { + length = PyUnicode_GET_LENGTH(existing) + 1; // + 1 to add a trailing semicolon + int kind = PyUnicode_KIND(existing); + if (result_kind < kind) + result_kind = kind; + } + + while (PyDict_Next(parts, &pos, &key, &value)) + { + // key=value; + length += PyUnicode_GET_LENGTH(key) + 1; + length += PyUnicode_GET_LENGTH(value) + 1; + int kind = PyUnicode_KIND(key); + if (result_kind < kind) + result_kind = kind; + kind = PyUnicode_KIND(value); + if (result_kind < kind) + result_kind = kind; + } + + Py_UCS4 maxchar = 0x10ffff; + if (result_kind == PyUnicode_2BYTE_KIND) + maxchar = 0xffff; + else if (result_kind == PyUnicode_1BYTE_KIND) + maxchar = 0xff; + PyObject* result = PyUnicode_New(length, maxchar); + if (!result) + return 0; + + Py_ssize_t offset = 0; + if (existing) + { + Py_ssize_t count = PyUnicode_GET_LENGTH(existing); + Py_ssize_t n = PyUnicode_CopyCharacters(result, offset, existing, 0, + count); + if (n < 0) + return 0; + offset += count; + PyUnicode_WriteChar(result, offset++, (Py_UCS4)';'); + } + pos = 0; + while (PyDict_Next(parts, &pos, &key, &value)) + { + Py_ssize_t count = PyUnicode_GET_LENGTH(key); + Py_ssize_t n = PyUnicode_CopyCharacters(result, offset, key, 0, count); + if (n < 0) + return 0; + offset += count; + PyUnicode_WriteChar(result, offset++, (Py_UCS4)'='); + count = PyUnicode_GET_LENGTH(value); + n = PyUnicode_CopyCharacters(result, offset, value, 0, count); + if (n < 0) + return 0; + offset += count; + PyUnicode_WriteChar(result, offset++, (Py_UCS4)';'); + } +#endif I(offset == length); return result; diff --git a/src/row.cpp b/src/row.cpp index 657da814..476ce744 100644 --- a/src/row.cpp +++ b/src/row.cpp @@ -256,6 +256,7 @@ static int Row_setattro(PyObject* o, PyObject *name, PyObject* v) } +#if PY_MAJOR_VERSION < 3 static PyObject* Row_repr(PyObject* o) { Row* self = (Row*)o; @@ -310,7 +311,74 @@ static PyObject* Row_repr(PyObject* o) return result; } +#else // >= Python 3.3 +static PyObject* Row_repr(PyObject* o) +{ + Row* self = (Row*)o; + + if (self->cValues == 0) + return PyUnicode_FromString("()"); + + Object pieces(PyTuple_New(self->cValues)); + if (!pieces) + return 0; + Py_ssize_t length = 2 + (2 * (self->cValues-1)); // parens + ', ' separators + int result_kind = PyUnicode_1BYTE_KIND; + + for (Py_ssize_t i = 0; i < self->cValues; i++) + { + PyObject* piece = PyObject_Repr(self->apValues[i]); + if (!piece) + return 0; + + length += PyUnicode_GET_LENGTH(piece); + int kind = PyUnicode_KIND(piece); + if (result_kind < kind) + result_kind = kind; + + PyTuple_SET_ITEM(pieces.Get(), i, piece); + } + + if (self->cValues == 1) + { + // Need a trailing comma: (value,) + length += 2; + } + Py_UCS4 maxchar = 0x10ffff; + if (result_kind == PyUnicode_2BYTE_KIND) + maxchar = 0xffff; + else if (result_kind == PyUnicode_1BYTE_KIND) + maxchar = 0xff; + PyObject* result = PyUnicode_New(length, maxchar); + if (!result) + return 0; + Py_ssize_t offset = 0; + PyUnicode_WriteChar(result, offset++, (Py_UCS4)'('); + for (Py_ssize_t i = 0; i < self->cValues; i++) + { + PyObject* item = PyTuple_GET_ITEM(pieces.Get(), i); + Py_ssize_t count = PyUnicode_GET_LENGTH(item); + Py_ssize_t n = PyUnicode_CopyCharacters(result, offset, item, 0, count); + if (n < 0) + return 0; + offset += count; + + if (i != self->cValues-1 || self->cValues == 1) + { + PyUnicode_WriteChar(result, offset++, (Py_UCS4)','); + PyUnicode_WriteChar(result, offset++, (Py_UCS4)' '); + } + } + PyUnicode_WriteChar(result, offset++, (Py_UCS4)')'); + if (PyUnicode_READY(result) < 0) + return 0; + + I(offset == length); + + return result; +} +#endif static PyObject* Row_richcompare(PyObject* olhs, PyObject* orhs, int op) { diff --git a/tests3/issue998.py b/tests3/issue998.py new file mode 100644 index 00000000..24e58e2a --- /dev/null +++ b/tests3/issue998.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 +""" +Verify that no warning is emitted for `PyUnicode_FromUnicode(NULL, size)`. + +See https://github.com/mkleehammer/pyodbc/issues/998. +See also https://bugs.python.org/issue36346. +""" + +import io +import os +import sys +import unittest + +# pylint: disable-next=import-error +from tests3.testutils import add_to_path, load_setup_connection_string + +add_to_path() +import pyodbc # pylint: disable=wrong-import-position + +KB = 1024 +MB = KB * 1024 + +CONNECTION_STRING = None + +CONNECTION_STRING_ERROR_MESSAGE = ( + "Please create tmp/setup.cfg file or " + "set a valid value to CONNECTION_STRING." +) +NO_ERROR = None + + +class SQLPutDataUnicodeToBytesMemoryLeakTestCase(unittest.TestCase): + """Test case for issue998 bug fix.""" + + driver = pyodbc + + @classmethod + def setUpClass(cls): + """Set the connection string.""" + + filename = os.path.splitext(os.path.basename(__file__))[0] + cls.connection_string = ( + load_setup_connection_string(filename) or CONNECTION_STRING + ) + + if cls.connection_string: + return NO_ERROR + return ValueError(CONNECTION_STRING_ERROR_MESSAGE) + + def test_use_correct_unicode_factory_function(self): + """Verify that the obsolete function call has been replaced.""" + + # Create a results set. + with pyodbc.connect(self.connection_string, autocommit=True) as cnxn: + cursor = cnxn.cursor() + cursor.execute("SELECT 1 AS a, 2 AS b") + rows = cursor.fetchall() + + # Redirect stderr so we can detect the warning. + sys.stderr = redirected_stderr = io.StringIO() + + # Convert the results object to a string. + self.assertGreater(len(str(rows)), 0) + + # Restore stderr to the original stream. + sys.stderr = sys.__stderr__ + + # If the bug has been fixed, nothing will have been written to stderr. + self.assertEqual(len(redirected_stderr.getvalue()), 0) + + +def main(): + """Top-level driver for the test.""" + unittest.main() + + +if __name__ == "__main__": + main()