Skip to content

Commit

Permalink
psycopg: fix running async tests
Browse files Browse the repository at this point in the history
Tests were not run because the test case did not await on async tests.
Split tests in two testcases: one for sync tests and one for async tests.
  • Loading branch information
xrmx committed May 22, 2024
1 parent 74f8a00 commit ae4bc8e
Showing 1 changed file with 109 additions and 104 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import types
from unittest import mock
from unittest import IsolatedAsyncioTestCase, mock

import psycopg

Expand Down Expand Up @@ -124,7 +123,7 @@ async def __aexit__(self, *args):
return mock.MagicMock(spec=types.MethodType)


class TestPostgresqlIntegration(TestBase):
class PostgresqlIntegrationTestMixin:
def setUp(self):
super().setUp()
self.cursor_mock = mock.patch(
Expand Down Expand Up @@ -159,6 +158,8 @@ def tearDown(self):
with self.disable_logging():
PsycopgInstrumentor().uninstrument()


class TestPostgresqlIntegration(PostgresqlIntegrationTestMixin, TestBase):
# pylint: disable=unused-argument
def test_instrumentor(self):
PsycopgInstrumentor().instrument()
Expand Down Expand Up @@ -221,60 +222,6 @@ def test_instrumentor_with_connection_class(self):
spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)

async def test_wrap_async_connection_class_with_cursor(self):
PsycopgInstrumentor().instrument()

async def test_async_connection():
acnx = await psycopg.AsyncConnection.connect(database="test")
async with acnx as cnx:
async with cnx.cursor() as cursor:
await cursor.execute("SELECT * FROM test")

asyncio.run(test_async_connection())
spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)
span = spans_list[0]

# Check version and name in span's instrumentation info
self.assertEqualSpanInstrumentationInfo(
span, opentelemetry.instrumentation.psycopg
)

# check that no spans are generated after uninstrument
PsycopgInstrumentor().uninstrument()

asyncio.run(test_async_connection())

spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)

# pylint: disable=unused-argument
async def test_instrumentor_with_async_connection_class(self):
PsycopgInstrumentor().instrument()

async def test_async_connection():
acnx = await psycopg.AsyncConnection.connect(database="test")
async with acnx as cnx:
await cnx.execute("SELECT * FROM test")

asyncio.run(test_async_connection())

spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)
span = spans_list[0]

# Check version and name in span's instrumentation info
self.assertEqualSpanInstrumentationInfo(
span, opentelemetry.instrumentation.psycopg
)

# check that no spans are generated after uninstrument
PsycopgInstrumentor().uninstrument()
asyncio.run(test_async_connection())

spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)

def test_span_name(self):
PsycopgInstrumentor().instrument()

Expand All @@ -301,33 +248,6 @@ def test_span_name(self):
self.assertEqual(spans_list[4].name, "query")
self.assertEqual(spans_list[5].name, "query")

async def test_span_name_async(self):
PsycopgInstrumentor().instrument()

cnx = psycopg.AsyncConnection.connect(database="test")
async with cnx.cursor() as cursor:
await cursor.execute("Test query", ("param1Value", False))
await cursor.execute(
"""multi
line
query"""
)
await cursor.execute("tab\tseparated query")
await cursor.execute("/* leading comment */ query")
await cursor.execute(
"/* leading comment */ query /* trailing comment */"
)
await cursor.execute("query /* trailing comment */")

spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 6)
self.assertEqual(spans_list[0].name, "Test")
self.assertEqual(spans_list[1].name, "multi")
self.assertEqual(spans_list[2].name, "tab")
self.assertEqual(spans_list[3].name, "query")
self.assertEqual(spans_list[4].name, "query")
self.assertEqual(spans_list[5].name, "query")

# pylint: disable=unused-argument
def test_not_recording(self):
mock_tracer = mock.Mock()
Expand All @@ -348,26 +268,6 @@ def test_not_recording(self):

PsycopgInstrumentor().uninstrument()

# pylint: disable=unused-argument
async def test_not_recording_async(self):
mock_tracer = mock.Mock()
mock_span = mock.Mock()
mock_span.is_recording.return_value = False
mock_tracer.start_span.return_value = mock_span
PsycopgInstrumentor().instrument()
with mock.patch("opentelemetry.trace.get_tracer") as tracer:
tracer.return_value = mock_tracer
cnx = psycopg.AsyncConnection.connect(database="test")
async with cnx.cursor() as cursor:
query = "SELECT * FROM test"
cursor.execute(query)
self.assertFalse(mock_span.is_recording())
self.assertTrue(mock_span.is_recording.called)
self.assertFalse(mock_span.set_attribute.called)
self.assertFalse(mock_span.set_status.called)

PsycopgInstrumentor().uninstrument()

# pylint: disable=unused-argument
def test_custom_tracer_provider(self):
resource = resources.Resource.create({})
Expand Down Expand Up @@ -477,3 +377,108 @@ def test_sqlcommenter_disabled(self, event_mocked):
cursor.execute(query)
kwargs = event_mocked.call_args[1]
self.assertEqual(kwargs["enable_commenter"], False)


class TestPostgresqlIntegrationAsync(
PostgresqlIntegrationTestMixin, TestBase, IsolatedAsyncioTestCase
):
async def test_wrap_async_connection_class_with_cursor(self):
PsycopgInstrumentor().instrument()

async def test_async_connection():
acnx = await psycopg.AsyncConnection.connect("test")
async with acnx as cnx:
async with cnx.cursor() as cursor:
await cursor.execute("SELECT * FROM test")

await test_async_connection()
spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)
span = spans_list[0]

# Check version and name in span's instrumentation info
self.assertEqualSpanInstrumentationInfo(
span, opentelemetry.instrumentation.psycopg
)

# check that no spans are generated after uninstrument
PsycopgInstrumentor().uninstrument()

await test_async_connection()

spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)

# pylint: disable=unused-argument
async def test_instrumentor_with_async_connection_class(self):
PsycopgInstrumentor().instrument()

async def test_async_connection():
acnx = await psycopg.AsyncConnection.connect("test")
async with acnx as cnx:
await cnx.execute("SELECT * FROM test")

await test_async_connection()

spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)
span = spans_list[0]

# Check version and name in span's instrumentation info
self.assertEqualSpanInstrumentationInfo(
span, opentelemetry.instrumentation.psycopg
)

# check that no spans are generated after uninstrument
PsycopgInstrumentor().uninstrument()
await test_async_connection()

spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)

async def test_span_name_async(self):
PsycopgInstrumentor().instrument()

cnx = await psycopg.AsyncConnection.connect("test")
async with cnx.cursor() as cursor:
await cursor.execute("Test query", ("param1Value", False))
await cursor.execute(
"""multi
line
query"""
)
await cursor.execute("tab\tseparated query")
await cursor.execute("/* leading comment */ query")
await cursor.execute(
"/* leading comment */ query /* trailing comment */"
)
await cursor.execute("query /* trailing comment */")

spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 6)
self.assertEqual(spans_list[0].name, "Test")
self.assertEqual(spans_list[1].name, "multi")
self.assertEqual(spans_list[2].name, "tab")
self.assertEqual(spans_list[3].name, "query")
self.assertEqual(spans_list[4].name, "query")
self.assertEqual(spans_list[5].name, "query")

# pylint: disable=unused-argument
async def test_not_recording_async(self):
mock_tracer = mock.Mock()
mock_span = mock.Mock()
mock_span.is_recording.return_value = False
mock_tracer.start_span.return_value = mock_span
PsycopgInstrumentor().instrument()
with mock.patch("opentelemetry.trace.get_tracer") as tracer:
tracer.return_value = mock_tracer
cnx = psycopg.AsyncConnection.connect("test")
async with cnx.cursor() as cursor:
query = "SELECT * FROM test"
cursor.execute(query)
self.assertFalse(mock_span.is_recording())
self.assertTrue(mock_span.is_recording.called)
self.assertFalse(mock_span.set_attribute.called)
self.assertFalse(mock_span.set_status.called)

PsycopgInstrumentor().uninstrument()

0 comments on commit ae4bc8e

Please sign in to comment.