opentelemetry-python-contrib/instrumentation/opentelemetry-instrumentati.../tests/test_starlette_instrumentat...

731 lines
28 KiB
Python

# Copyright The OpenTelemetry Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from timeit import default_timer
from unittest.mock import patch
from starlette import applications
from starlette.responses import PlainTextResponse
from starlette.routing import Route
from starlette.testclient import TestClient
from starlette.websockets import WebSocket
import opentelemetry.instrumentation.starlette as otel_starlette
from opentelemetry.sdk.metrics.export import (
HistogramDataPoint,
NumberDataPoint,
)
from opentelemetry.sdk.resources import Resource
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.test.globals_test import reset_trace_globals
from opentelemetry.test.test_base import TestBase
from opentelemetry.trace import (
NoOpTracerProvider,
SpanKind,
get_tracer,
set_tracer_provider,
)
from opentelemetry.util.http import (
OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SANITIZE_FIELDS,
OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SERVER_REQUEST,
OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SERVER_RESPONSE,
_active_requests_count_attrs,
_duration_attrs,
get_excluded_urls,
)
_expected_metric_names = [
"http.server.active_requests",
"http.server.duration",
"http.server.response.size",
]
_recommended_attrs = {
"http.server.active_requests": _active_requests_count_attrs,
"http.server.duration": _duration_attrs,
"http.server.response.size": _duration_attrs,
}
class TestStarletteManualInstrumentation(TestBase):
def _create_app(self):
app = self._create_starlette_app()
self._instrumentor.instrument_app(
app=app,
server_request_hook=getattr(self, "server_request_hook", None),
client_request_hook=getattr(self, "client_request_hook", None),
client_response_hook=getattr(self, "client_response_hook", None),
)
return app
def setUp(self):
super().setUp()
self.env_patch = patch.dict(
"os.environ",
{"OTEL_PYTHON_STARLETTE_EXCLUDED_URLS": "/exclude/123,healthzz"},
)
self.env_patch.start()
self.exclude_patch = patch(
"opentelemetry.instrumentation.starlette._excluded_urls",
get_excluded_urls("STARLETTE"),
)
self.exclude_patch.start()
self._instrumentor = otel_starlette.StarletteInstrumentor()
self._app = self._create_app()
self._client = TestClient(self._app)
def tearDown(self):
super().tearDown()
self.env_patch.stop()
self.exclude_patch.stop()
def test_basic_starlette_call(self):
self._client.get("/foobar")
spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 3)
for span in spans:
self.assertIn("GET /foobar", span.name)
def test_starlette_route_attribute_added(self):
"""Ensure that starlette routes are used as the span name."""
self._client.get("/user/123")
spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 3)
for span in spans:
self.assertIn("GET /user/{username}", span.name)
self.assertEqual(
spans[-1].attributes[SpanAttributes.HTTP_ROUTE], "/user/{username}"
)
# ensure that at least one attribute that is populated by
# the asgi instrumentation is successfully feeding though.
self.assertEqual(
spans[-1].attributes[SpanAttributes.HTTP_FLAVOR], "1.1"
)
def test_starlette_excluded_urls(self):
"""Ensure that given starlette routes are excluded."""
self._client.get("/healthzz")
spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 0)
def test_starlette_metrics(self):
self._client.get("/foobar")
self._client.get("/foobar")
self._client.get("/foobar")
metrics_list = self.memory_metrics_reader.get_metrics_data()
number_data_point_seen = False
histogram_data_point_seen = False
self.assertTrue(len(metrics_list.resource_metrics) == 1)
for resource_metric in metrics_list.resource_metrics:
self.assertTrue(len(resource_metric.scope_metrics) == 1)
for scope_metric in resource_metric.scope_metrics:
self.assertTrue(len(scope_metric.metrics) == 3)
for metric in scope_metric.metrics:
self.assertIn(metric.name, _expected_metric_names)
data_points = list(metric.data.data_points)
self.assertEqual(len(data_points), 1)
for point in data_points:
if isinstance(point, HistogramDataPoint):
self.assertEqual(point.count, 3)
histogram_data_point_seen = True
if isinstance(point, NumberDataPoint):
number_data_point_seen = True
for attr in point.attributes:
self.assertIn(
attr, _recommended_attrs[metric.name]
)
self.assertTrue(number_data_point_seen and histogram_data_point_seen)
def test_basic_post_request_metric_success(self):
start = default_timer()
expected_duration_attributes = {
"http.flavor": "1.1",
"http.host": "testserver",
"http.method": "POST",
"http.scheme": "http",
"http.server_name": "testserver",
"http.status_code": 405,
"net.host.port": 80,
}
expected_requests_count_attributes = {
"http.flavor": "1.1",
"http.host": "testserver",
"http.method": "POST",
"http.scheme": "http",
"http.server_name": "testserver",
}
self._client.post("/foobar")
duration = max(round((default_timer() - start) * 1000), 0)
metrics_list = self.memory_metrics_reader.get_metrics_data()
for metric in (
metrics_list.resource_metrics[0].scope_metrics[0].metrics
):
for point in list(metric.data.data_points):
if isinstance(point, HistogramDataPoint):
self.assertEqual(point.count, 1)
self.assertAlmostEqual(duration, point.sum, delta=30)
self.assertDictEqual(
dict(point.attributes), expected_duration_attributes
)
if isinstance(point, NumberDataPoint):
self.assertDictEqual(
expected_requests_count_attributes,
dict(point.attributes),
)
self.assertEqual(point.value, 0)
def test_metric_for_uninstrment_app_method(self):
self._client.get("/foobar")
# uninstrumenting the existing client app
self._instrumentor.uninstrument_app(self._app)
self._client.get("/foobar")
self._client.get("/foobar")
metrics_list = self.memory_metrics_reader.get_metrics_data()
for metric in (
metrics_list.resource_metrics[0].scope_metrics[0].metrics
):
for point in list(metric.data.data_points):
if isinstance(point, HistogramDataPoint):
self.assertEqual(point.count, 1)
if isinstance(point, NumberDataPoint):
self.assertEqual(point.value, 0)
def test_metric_uninstrument_inherited_by_base(self):
# instrumenting class and creating app to send request
self._instrumentor.instrument()
app = self._create_starlette_app()
client = TestClient(app)
client.get("/foobar")
# calling uninstrument and checking for telemetry data
self._instrumentor.uninstrument()
client.get("/foobar")
client.get("/foobar")
client.get("/foobar")
metrics_list = self.memory_metrics_reader.get_metrics_data()
for metric in (
metrics_list.resource_metrics[0].scope_metrics[0].metrics
):
for point in list(metric.data.data_points):
if isinstance(point, HistogramDataPoint):
self.assertEqual(point.count, 1)
if isinstance(point, NumberDataPoint):
self.assertEqual(point.value, 0)
@staticmethod
def _create_starlette_app():
def home(_):
return PlainTextResponse("hi")
def health(_):
return PlainTextResponse("ok")
app = applications.Starlette(
routes=[
Route("/foobar", home),
Route("/user/{username}", home),
Route("/healthzz", health),
]
)
return app
class TestStarletteManualInstrumentationHooks(
TestStarletteManualInstrumentation
):
_server_request_hook = None
_client_request_hook = None
_client_response_hook = None
def server_request_hook(self, span, scope):
if self._server_request_hook is not None:
self._server_request_hook(span, scope)
def client_request_hook(self, receive_span, request):
if self._client_request_hook is not None:
self._client_request_hook(receive_span, request)
def client_response_hook(self, send_span, response):
if self._client_response_hook is not None:
self._client_response_hook(send_span, response)
def test_hooks(self):
def server_request_hook(span, scope):
span.update_name("name from server hook")
def client_request_hook(receive_span, request):
receive_span.update_name("name from client hook")
receive_span.set_attribute("attr-from-request-hook", "set")
def client_response_hook(send_span, response):
send_span.update_name("name from response hook")
send_span.set_attribute("attr-from-response-hook", "value")
self._server_request_hook = server_request_hook
self._client_request_hook = client_request_hook
self._client_response_hook = client_response_hook
self._client.get("/foobar")
spans = self.sorted_spans(self.memory_exporter.get_finished_spans())
self.assertEqual(
len(spans), 3
) # 1 server span and 2 response spans (response start and body)
server_span = spans[2]
self.assertEqual(server_span.name, "name from server hook")
response_spans = spans[:2]
for span in response_spans:
self.assertEqual(span.name, "name from response hook")
self.assertSpanHasAttributes(
span, {"attr-from-response-hook": "value"}
)
class TestAutoInstrumentation(TestStarletteManualInstrumentation):
"""Test the auto-instrumented variant
Extending the manual instrumentation as most test cases apply
to both.
"""
def _create_app(self):
# instrumentation is handled by the instrument call
resource = Resource.create({"key1": "value1", "key2": "value2"})
result = self.create_tracer_provider(resource=resource)
tracer_provider, exporter = result
self.memory_exporter = exporter
self._instrumentor.instrument(tracer_provider=tracer_provider)
return self._create_starlette_app()
def tearDown(self):
self._instrumentor.uninstrument()
super().tearDown()
def test_request(self):
self._client.get("/foobar")
spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 3)
for span in spans:
self.assertEqual(span.resource.attributes["key1"], "value1")
self.assertEqual(span.resource.attributes["key2"], "value2")
def test_uninstrument(self):
self._client.get("/foobar")
spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 3)
self.memory_exporter.clear()
self._instrumentor.uninstrument()
self._client.get("/foobar")
spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 0)
class TestAutoInstrumentationHooks(TestStarletteManualInstrumentationHooks):
"""
Test the auto-instrumented variant for request and response hooks
"""
def _create_app(self):
# instrumentation is handled by the instrument call
self._instrumentor.instrument(
server_request_hook=getattr(self, "server_request_hook", None),
client_request_hook=getattr(self, "client_request_hook", None),
client_response_hook=getattr(self, "client_response_hook", None),
)
return self._create_starlette_app()
def tearDown(self):
self._instrumentor.uninstrument()
super().tearDown()
class TestAutoInstrumentationLogic(unittest.TestCase):
def test_instrumentation(self):
"""Verify that instrumentation methods are instrumenting and
removing as expected.
"""
instrumentor = otel_starlette.StarletteInstrumentor()
original = applications.Starlette
instrumentor.instrument()
try:
instrumented = applications.Starlette
self.assertIsNot(original, instrumented)
finally:
instrumentor.uninstrument()
should_be_original = applications.Starlette
self.assertIs(original, should_be_original)
class TestConditonalServerSpanCreation(TestStarletteManualInstrumentation):
def test_mark_span_internal_in_presence_of_another_span(self):
tracer = get_tracer(__name__)
with tracer.start_as_current_span(
"test", kind=SpanKind.SERVER
) as parent_span:
self._client.get("/foobar")
spans = self.sorted_spans(
self.memory_exporter.get_finished_spans()
)
starlette_span = spans[2]
self.assertEqual(SpanKind.INTERNAL, starlette_span.kind)
self.assertEqual(SpanKind.SERVER, parent_span.kind)
self.assertEqual(
parent_span.context.span_id, starlette_span.parent.span_id
)
class TestBaseWithCustomHeaders(TestBase):
def create_app(self):
app = self.create_starlette_app()
self._instrumentor.instrument_app(app=app)
return app
def setUp(self):
super().setUp()
self._instrumentor = otel_starlette.StarletteInstrumentor()
self._app = self.create_app()
self._client = TestClient(self._app)
def tearDown(self) -> None:
super().tearDown()
with self.disable_logging():
self._instrumentor.uninstrument()
@staticmethod
def create_starlette_app():
app = applications.Starlette()
@app.route("/foobar")
def _(request):
return PlainTextResponse(
content="hi",
headers={
"custom-test-header-1": "test-header-value-1",
"custom-test-header-2": "test-header-value-2",
"my-custom-regex-header-1": "my-custom-regex-value-1,my-custom-regex-value-2",
"My-Custom-Regex-Header-2": "my-custom-regex-value-3,my-custom-regex-value-4",
"my-secret-header": "my-secret-value",
},
)
@app.websocket_route("/foobar_web")
async def _(websocket: WebSocket) -> None:
message = await websocket.receive()
if message.get("type") == "websocket.connect":
await websocket.send(
{
"type": "websocket.accept",
"headers": [
(b"custom-test-header-1", b"test-header-value-1"),
(b"custom-test-header-2", b"test-header-value-2"),
(
b"my-custom-regex-header-1",
b"my-custom-regex-value-1,my-custom-regex-value-2",
),
(
b"My-Custom-Regex-Header-2",
b"my-custom-regex-value-3,my-custom-regex-value-4",
),
(b"my-secret-header", b"my-secret-value"),
],
}
)
await websocket.send_json({"message": "hello world"})
await websocket.close()
if message.get("type") == "websocket.disconnect":
pass
return app
@patch.dict(
"os.environ",
{
OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SANITIZE_FIELDS: ".*my-secret.*",
OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SERVER_REQUEST: "Custom-Test-Header-1,Custom-Test-Header-2,Custom-Test-Header-3,Regex-Test-Header-.*,Regex-Invalid-Test-Header-.*,.*my-secret.*",
OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SERVER_RESPONSE: "Custom-Test-Header-1,Custom-Test-Header-2,Custom-Test-Header-3,my-custom-regex-header-.*,invalid-regex-header-.*,.*my-secret.*",
},
)
class TestHTTPAppWithCustomHeaders(TestBaseWithCustomHeaders):
def test_custom_request_headers_in_span_attributes(self):
expected = {
"http.request.header.custom_test_header_1": (
"test-header-value-1",
),
"http.request.header.custom_test_header_2": (
"test-header-value-2",
),
"http.request.header.regex_test_header_1": ("Regex Test Value 1",),
"http.request.header.regex_test_header_2": (
"RegexTestValue2,RegexTestValue3",
),
"http.request.header.my_secret_header": ("[REDACTED]",),
}
resp = self._client.get(
"/foobar",
headers={
"custom-test-header-1": "test-header-value-1",
"custom-test-header-2": "test-header-value-2",
"Regex-Test-Header-1": "Regex Test Value 1",
"regex-test-header-2": "RegexTestValue2,RegexTestValue3",
"My-Secret-Header": "My Secret Value",
},
)
self.assertEqual(200, resp.status_code)
span_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(span_list), 3)
server_span = [
span for span in span_list if span.kind == SpanKind.SERVER
][0]
self.assertSpanHasAttributes(server_span, expected)
@patch.dict(
"os.environ",
{
OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SANITIZE_FIELDS: ".*my-secret.*",
OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SERVER_REQUEST: "Custom-Test-Header-1,Custom-Test-Header-2,Custom-Test-Header-3,Regex-Test-Header-.*,Regex-Invalid-Test-Header-.*,.*my-secret.*",
},
)
def test_custom_request_headers_not_in_span_attributes(self):
not_expected = {
"http.request.header.custom_test_header_3": (
"test-header-value-3",
),
}
resp = self._client.get(
"/foobar",
headers={
"custom-test-header-1": "test-header-value-1",
"custom-test-header-2": "test-header-value-2",
"Regex-Test-Header-1": "Regex Test Value 1",
"regex-test-header-2": "RegexTestValue2,RegexTestValue3",
"My-Secret-Header": "My Secret Value",
},
)
self.assertEqual(200, resp.status_code)
span_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(span_list), 3)
server_span = [
span for span in span_list if span.kind == SpanKind.SERVER
][0]
for key in not_expected:
self.assertNotIn(key, server_span.attributes)
def test_custom_response_headers_in_span_attributes(self):
expected = {
"http.response.header.custom_test_header_1": (
"test-header-value-1",
),
"http.response.header.custom_test_header_2": (
"test-header-value-2",
),
"http.response.header.my_custom_regex_header_1": (
"my-custom-regex-value-1,my-custom-regex-value-2",
),
"http.response.header.my_custom_regex_header_2": (
"my-custom-regex-value-3,my-custom-regex-value-4",
),
"http.response.header.my_secret_header": ("[REDACTED]",),
}
resp = self._client.get("/foobar")
self.assertEqual(200, resp.status_code)
span_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(span_list), 3)
server_span = [
span for span in span_list if span.kind == SpanKind.SERVER
][0]
self.assertSpanHasAttributes(server_span, expected)
def test_custom_response_headers_not_in_span_attributes(self):
not_expected = {
"http.response.header.custom_test_header_3": (
"test-header-value-3",
),
}
resp = self._client.get("/foobar")
self.assertEqual(200, resp.status_code)
span_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(span_list), 3)
server_span = [
span for span in span_list if span.kind == SpanKind.SERVER
][0]
for key in not_expected:
self.assertNotIn(key, server_span.attributes)
@patch.dict(
"os.environ",
{
OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SANITIZE_FIELDS: ".*my-secret.*",
OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SERVER_REQUEST: "Custom-Test-Header-1,Custom-Test-Header-2,Custom-Test-Header-3,Regex-Test-Header-.*,Regex-Invalid-Test-Header-.*,.*my-secret.*",
OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SERVER_RESPONSE: "Custom-Test-Header-1,Custom-Test-Header-2,Custom-Test-Header-3,my-custom-regex-header-.*,invalid-regex-header-.*,.*my-secret.*",
},
)
class TestWebSocketAppWithCustomHeaders(TestBaseWithCustomHeaders):
def test_custom_request_headers_in_span_attributes(self):
expected = {
"http.request.header.custom_test_header_1": (
"test-header-value-1",
),
"http.request.header.custom_test_header_2": (
"test-header-value-2",
),
"http.request.header.regex_test_header_1": ("Regex Test Value 1",),
"http.request.header.regex_test_header_2": (
"RegexTestValue2,RegexTestValue3",
),
"http.request.header.my_secret_header": ("[REDACTED]",),
}
with self._client.websocket_connect(
"/foobar_web",
headers={
"custom-test-header-1": "test-header-value-1",
"custom-test-header-2": "test-header-value-2",
"Regex-Test-Header-1": "Regex Test Value 1",
"regex-test-header-2": "RegexTestValue2,RegexTestValue3",
"My-Secret-Header": "My Secret Value",
},
) as websocket:
data = websocket.receive_json()
self.assertEqual(data, {"message": "hello world"})
span_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(span_list), 5)
server_span = [
span for span in span_list if span.kind == SpanKind.SERVER
][0]
self.assertSpanHasAttributes(server_span, expected)
def test_custom_request_headers_not_in_span_attributes(self):
not_expected = {
"http.request.header.custom_test_header_3": (
"test-header-value-3",
),
}
with self._client.websocket_connect(
"/foobar_web",
headers={
"custom-test-header-1": "test-header-value-1",
"custom-test-header-2": "test-header-value-2",
"Regex-Test-Header-1": "Regex Test Value 1",
"regex-test-header-2": "RegexTestValue2,RegexTestValue3",
"My-Secret-Header": "My Secret Value",
},
) as websocket:
data = websocket.receive_json()
self.assertEqual(data, {"message": "hello world"})
span_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(span_list), 5)
server_span = [
span for span in span_list if span.kind == SpanKind.SERVER
][0]
for key, _ in not_expected.items():
self.assertNotIn(key, server_span.attributes)
def test_custom_response_headers_in_span_attributes(self):
expected = {
"http.response.header.custom_test_header_1": (
"test-header-value-1",
),
"http.response.header.custom_test_header_2": (
"test-header-value-2",
),
"http.response.header.my_custom_regex_header_1": (
"my-custom-regex-value-1,my-custom-regex-value-2",
),
"http.response.header.my_custom_regex_header_2": (
"my-custom-regex-value-3,my-custom-regex-value-4",
),
"http.response.header.my_secret_header": ("[REDACTED]",),
}
with self._client.websocket_connect("/foobar_web") as websocket:
data = websocket.receive_json()
self.assertEqual(data, {"message": "hello world"})
span_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(span_list), 5)
server_span = [
span for span in span_list if span.kind == SpanKind.SERVER
][0]
self.assertSpanHasAttributes(server_span, expected)
def test_custom_response_headers_not_in_span_attributes(self):
not_expected = {
"http.response.header.custom_test_header_3": (
"test-header-value-3",
),
}
with self._client.websocket_connect("/foobar_web") as websocket:
data = websocket.receive_json()
self.assertEqual(data, {"message": "hello world"})
span_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(span_list), 5)
server_span = [
span for span in span_list if span.kind == SpanKind.SERVER
][0]
for key, _ in not_expected.items():
self.assertNotIn(key, server_span.attributes)
@patch.dict(
"os.environ",
{
OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SANITIZE_FIELDS: ".*my-secret.*",
OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SERVER_REQUEST: "Custom-Test-Header-1,Custom-Test-Header-2,Custom-Test-Header-3,Regex-Test-Header-.*,Regex-Invalid-Test-Header-.*,.*my-secret.*",
OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SERVER_RESPONSE: "Custom-Test-Header-1,Custom-Test-Header-2,Custom-Test-Header-3,my-custom-regex-header-.*,invalid-regex-header-.*,.*my-secret.*",
},
)
class TestNonRecordingSpanWithCustomHeaders(TestBaseWithCustomHeaders):
def setUp(self):
super().setUp()
reset_trace_globals()
set_tracer_provider(tracer_provider=NoOpTracerProvider())
self._app = self.create_app()
self._client = TestClient(self._app)
def test_custom_header_not_present_in_non_recording_span(self):
resp = self._client.get(
"/foobar",
headers={
"custom-test-header-1": "test-header-value-1",
},
)
self.assertEqual(200, resp.status_code)
span_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(span_list), 0)