Refactor TracedConnectionProxy (#1097)
* Refactor TracedConnectionProxy Fixes #1077 * Remove unecessary methods
This commit is contained in:
		
							parent
							
								
									8afbce7533
								
							
						
					
					
						commit
						6876ad857f
					
				|  | @ -214,8 +214,8 @@ def uninstrument_connection(connection): | |||
|     Returns: | ||||
|         An uninstrumented connection. | ||||
|     """ | ||||
|     if isinstance(connection, wrapt.ObjectProxy): | ||||
|         return connection.__wrapped__ | ||||
|     if isinstance(connection, _TracedConnectionProxy): | ||||
|         return connection._connection | ||||
| 
 | ||||
|     _logger.warning("Connection is not instrumented") | ||||
|     return connection | ||||
|  | @ -300,28 +300,35 @@ class DatabaseApiIntegration: | |||
|             self.span_attributes[SpanAttributes.NET_PEER_PORT] = port | ||||
| 
 | ||||
| 
 | ||||
| class _TracedConnectionProxy: | ||||
|     pass | ||||
| 
 | ||||
| 
 | ||||
| def get_traced_connection_proxy( | ||||
|     connection, db_api_integration, *args, **kwargs | ||||
| ): | ||||
|     # pylint: disable=abstract-method | ||||
|     class TracedConnectionProxy(wrapt.ObjectProxy): | ||||
|         # pylint: disable=unused-argument | ||||
|         def __init__(self, connection, *args, **kwargs): | ||||
|             wrapt.ObjectProxy.__init__(self, connection) | ||||
|     class TracedConnectionProxy(type(connection), _TracedConnectionProxy): | ||||
|         def __init__(self, connection): | ||||
|             self._connection = connection | ||||
| 
 | ||||
|         def __getattr__(self, name): | ||||
|             return object.__getattribute__( | ||||
|                 object.__getattribute__(self, "_connection"), name | ||||
|             ) | ||||
| 
 | ||||
|         def cursor(self, *args, **kwargs): | ||||
|             return get_traced_cursor_proxy( | ||||
|                 self.__wrapped__.cursor(*args, **kwargs), db_api_integration | ||||
|                 self._connection.cursor(*args, **kwargs), db_api_integration | ||||
|             ) | ||||
| 
 | ||||
|         def __enter__(self): | ||||
|             self.__wrapped__.__enter__() | ||||
|             return self | ||||
|         # For some reason this is necessary as trying to access the close | ||||
|         # method of self._connection via __getattr__ leads to unexplained | ||||
|         # errors. | ||||
|         def close(self): | ||||
|             self._connection.close() | ||||
| 
 | ||||
|         def __exit__(self, *args, **kwargs): | ||||
|             self.__wrapped__.__exit__(*args, **kwargs) | ||||
| 
 | ||||
|     return TracedConnectionProxy(connection, *args, **kwargs) | ||||
|     return TracedConnectionProxy(connection) | ||||
| 
 | ||||
| 
 | ||||
| class CursorTracer: | ||||
|  |  | |||
|  | @ -262,14 +262,14 @@ class TestDBApiIntegration(TestBase): | |||
| 
 | ||||
|     @mock.patch("opentelemetry.instrumentation.dbapi") | ||||
|     def test_wrap_connect(self, mock_dbapi): | ||||
|         dbapi.wrap_connect(self.tracer, mock_dbapi, "connect", "-") | ||||
|         dbapi.wrap_connect(self.tracer, MockConnectionEmpty(), "connect", "-") | ||||
|         connection = mock_dbapi.connect() | ||||
|         self.assertEqual(mock_dbapi.connect.call_count, 1) | ||||
|         self.assertIsInstance(connection.__wrapped__, mock.Mock) | ||||
|         self.assertIsInstance(connection._connection, mock.Mock) | ||||
| 
 | ||||
|     @mock.patch("opentelemetry.instrumentation.dbapi") | ||||
|     def test_unwrap_connect(self, mock_dbapi): | ||||
|         dbapi.wrap_connect(self.tracer, mock_dbapi, "connect", "-") | ||||
|         dbapi.wrap_connect(self.tracer, MockConnectionEmpty(), "connect", "-") | ||||
|         connection = mock_dbapi.connect() | ||||
|         self.assertEqual(mock_dbapi.connect.call_count, 1) | ||||
| 
 | ||||
|  | @ -279,19 +279,21 @@ class TestDBApiIntegration(TestBase): | |||
|         self.assertIsInstance(connection, mock.Mock) | ||||
| 
 | ||||
|     def test_instrument_connection(self): | ||||
|         connection = mock.Mock() | ||||
|         connection = MockConnectionEmpty() | ||||
|         # Avoid get_attributes failing because can't concatenate mock | ||||
|         # pylint: disable=attribute-defined-outside-init | ||||
|         connection.database = "-" | ||||
|         connection2 = dbapi.instrument_connection(self.tracer, connection, "-") | ||||
|         self.assertIs(connection2.__wrapped__, connection) | ||||
|         self.assertIs(connection2._connection, connection) | ||||
| 
 | ||||
|     def test_uninstrument_connection(self): | ||||
|         connection = mock.Mock() | ||||
|         connection = MockConnectionEmpty() | ||||
|         # Set connection.database to avoid a failure because mock can't | ||||
|         # be concatenated | ||||
|         # pylint: disable=attribute-defined-outside-init | ||||
|         connection.database = "-" | ||||
|         connection2 = dbapi.instrument_connection(self.tracer, connection, "-") | ||||
|         self.assertIs(connection2.__wrapped__, connection) | ||||
|         self.assertIs(connection2._connection, connection) | ||||
| 
 | ||||
|         connection3 = dbapi.uninstrument_connection(connection2) | ||||
|         self.assertIs(connection3, connection) | ||||
|  | @ -307,10 +309,12 @@ def mock_connect(*args, **kwargs): | |||
|     server_host = kwargs.get("server_host") | ||||
|     server_port = kwargs.get("server_port") | ||||
|     user = kwargs.get("user") | ||||
|     return MockConnection(database, server_port, server_host, user) | ||||
|     return MockConnectionWithAttributes( | ||||
|         database, server_port, server_host, user | ||||
|     ) | ||||
| 
 | ||||
| 
 | ||||
| class MockConnection: | ||||
| class MockConnectionWithAttributes: | ||||
|     def __init__(self, database, server_port, server_host, user): | ||||
|         self.database = database | ||||
|         self.server_port = server_port | ||||
|  | @ -343,3 +347,7 @@ class MockCursor: | |||
|     def callproc(self, query, params=None, throw_exception=False): | ||||
|         if throw_exception: | ||||
|             raise Exception("Test Exception") | ||||
| 
 | ||||
| 
 | ||||
| class MockConnectionEmpty: | ||||
|     pass | ||||
|  |  | |||
|  | @ -12,7 +12,7 @@ | |||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| from unittest import mock | ||||
| from unittest.mock import Mock, patch | ||||
| 
 | ||||
| import mysql.connector | ||||
| 
 | ||||
|  | @ -22,15 +22,24 @@ from opentelemetry.sdk import resources | |||
| from opentelemetry.test.test_base import TestBase | ||||
| 
 | ||||
| 
 | ||||
| def mock_connect(*args, **kwargs): | ||||
|     class MockConnection: | ||||
|         def cursor(self): | ||||
|             # pylint: disable=no-self-use | ||||
|             return Mock() | ||||
| 
 | ||||
|     return MockConnection() | ||||
| 
 | ||||
| 
 | ||||
| class TestMysqlIntegration(TestBase): | ||||
|     def tearDown(self): | ||||
|         super().tearDown() | ||||
|         with self.disable_logging(): | ||||
|             MySQLInstrumentor().uninstrument() | ||||
| 
 | ||||
|     @mock.patch("mysql.connector.connect") | ||||
|     @patch("mysql.connector.connect", new=mock_connect) | ||||
|     # pylint: disable=unused-argument | ||||
|     def test_instrumentor(self, mock_connect): | ||||
|     def test_instrumentor(self): | ||||
|         MySQLInstrumentor().instrument() | ||||
| 
 | ||||
|         cnx = mysql.connector.connect(database="test") | ||||
|  | @ -58,9 +67,8 @@ class TestMysqlIntegration(TestBase): | |||
|         spans_list = self.memory_exporter.get_finished_spans() | ||||
|         self.assertEqual(len(spans_list), 1) | ||||
| 
 | ||||
|     @mock.patch("mysql.connector.connect") | ||||
|     # pylint: disable=unused-argument | ||||
|     def test_custom_tracer_provider(self, mock_connect): | ||||
|     @patch("mysql.connector.connect", new=mock_connect) | ||||
|     def test_custom_tracer_provider(self): | ||||
|         resource = resources.Resource.create({}) | ||||
|         result = self.create_tracer_provider(resource=resource) | ||||
|         tracer_provider, exporter = result | ||||
|  | @ -77,9 +85,9 @@ class TestMysqlIntegration(TestBase): | |||
| 
 | ||||
|         self.assertIs(span.resource, resource) | ||||
| 
 | ||||
|     @mock.patch("mysql.connector.connect") | ||||
|     @patch("mysql.connector.connect", new=mock_connect) | ||||
|     # pylint: disable=unused-argument | ||||
|     def test_instrument_connection(self, mock_connect): | ||||
|     def test_instrument_connection(self): | ||||
|         cnx = mysql.connector.connect(database="test") | ||||
|         query = "SELECT * FROM test" | ||||
|         cursor = cnx.cursor() | ||||
|  | @ -95,9 +103,9 @@ class TestMysqlIntegration(TestBase): | |||
|         spans_list = self.memory_exporter.get_finished_spans() | ||||
|         self.assertEqual(len(spans_list), 1) | ||||
| 
 | ||||
|     @mock.patch("mysql.connector.connect") | ||||
|     @patch("mysql.connector.connect", new=mock_connect) | ||||
|     # pylint: disable=unused-argument | ||||
|     def test_uninstrument_connection(self, mock_connect): | ||||
|     def test_uninstrument_connection(self): | ||||
|         MySQLInstrumentor().instrument() | ||||
|         cnx = mysql.connector.connect(database="test") | ||||
|         query = "SELECT * FROM test" | ||||
|  |  | |||
|  | @ -12,7 +12,7 @@ | |||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| from unittest import mock | ||||
| from unittest.mock import Mock, patch | ||||
| 
 | ||||
| import pymysql | ||||
| 
 | ||||
|  | @ -22,15 +22,24 @@ from opentelemetry.sdk import resources | |||
| from opentelemetry.test.test_base import TestBase | ||||
| 
 | ||||
| 
 | ||||
| def mock_connect(*args, **kwargs): | ||||
|     class MockConnection: | ||||
|         def cursor(self): | ||||
|             # pylint: disable=no-self-use | ||||
|             return Mock() | ||||
| 
 | ||||
|     return MockConnection() | ||||
| 
 | ||||
| 
 | ||||
| class TestPyMysqlIntegration(TestBase): | ||||
|     def tearDown(self): | ||||
|         super().tearDown() | ||||
|         with self.disable_logging(): | ||||
|             PyMySQLInstrumentor().uninstrument() | ||||
| 
 | ||||
|     @mock.patch("pymysql.connect") | ||||
|     @patch("pymysql.connect", new=mock_connect) | ||||
|     # pylint: disable=unused-argument | ||||
|     def test_instrumentor(self, mock_connect): | ||||
|     def test_instrumentor(self): | ||||
|         PyMySQLInstrumentor().instrument() | ||||
| 
 | ||||
|         cnx = pymysql.connect(database="test") | ||||
|  | @ -58,9 +67,9 @@ class TestPyMysqlIntegration(TestBase): | |||
|         spans_list = self.memory_exporter.get_finished_spans() | ||||
|         self.assertEqual(len(spans_list), 1) | ||||
| 
 | ||||
|     @mock.patch("pymysql.connect") | ||||
|     @patch("pymysql.connect", new=mock_connect) | ||||
|     # pylint: disable=unused-argument | ||||
|     def test_custom_tracer_provider(self, mock_connect): | ||||
|     def test_custom_tracer_provider(self): | ||||
|         resource = resources.Resource.create({}) | ||||
|         result = self.create_tracer_provider(resource=resource) | ||||
|         tracer_provider, exporter = result | ||||
|  | @ -78,9 +87,9 @@ class TestPyMysqlIntegration(TestBase): | |||
| 
 | ||||
|         self.assertIs(span.resource, resource) | ||||
| 
 | ||||
|     @mock.patch("pymysql.connect") | ||||
|     @patch("pymysql.connect", new=mock_connect) | ||||
|     # pylint: disable=unused-argument | ||||
|     def test_instrument_connection(self, mock_connect): | ||||
|     def test_instrument_connection(self): | ||||
|         cnx = pymysql.connect(database="test") | ||||
|         query = "SELECT * FROM test" | ||||
|         cursor = cnx.cursor() | ||||
|  | @ -96,9 +105,9 @@ class TestPyMysqlIntegration(TestBase): | |||
|         spans_list = self.memory_exporter.get_finished_spans() | ||||
|         self.assertEqual(len(spans_list), 1) | ||||
| 
 | ||||
|     @mock.patch("pymysql.connect") | ||||
|     @patch("pymysql.connect", new=mock_connect) | ||||
|     # pylint: disable=unused-argument | ||||
|     def test_uninstrument_connection(self, mock_connect): | ||||
|     def test_uninstrument_connection(self): | ||||
|         PyMySQLInstrumentor().instrument() | ||||
|         cnx = pymysql.connect(database="test") | ||||
|         query = "SELECT * FROM test" | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue