diff --git a/instrumentation/opentelemetry-instrumentation-asgi/src/opentelemetry/instrumentation/asgi/__init__.py b/instrumentation/opentelemetry-instrumentation-asgi/src/opentelemetry/instrumentation/asgi/__init__.py index 1a0bb47a6..1c442889a 100644 --- a/instrumentation/opentelemetry-instrumentation-asgi/src/opentelemetry/instrumentation/asgi/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-asgi/src/opentelemetry/instrumentation/asgi/__init__.py @@ -29,21 +29,31 @@ from asgiref.compatibility import guarantee_single_callable from opentelemetry import context, propagators, trace from opentelemetry.instrumentation.asgi.version import __version__ # noqa from opentelemetry.instrumentation.utils import http_status_to_status_code +from opentelemetry.trace.propagation.textmap import DictGetter from opentelemetry.trace.status import Status, StatusCode -def get_header_from_scope(scope: dict, header_name: str) -> typing.List[str]: - """Retrieve a HTTP header value from the ASGI scope. +class CarrierGetter(DictGetter): + def get(self, carrier: dict, key: str) -> typing.List[str]: + """Getter implementation to retrieve a HTTP header value from the ASGI + scope. - Returns: - A list with a single string with the header value if it exists, else an empty list. - """ - headers = scope.get("headers") - return [ - value.decode("utf8") - for (key, value) in headers - if key.decode("utf8") == header_name - ] + Args: + carrier: ASGI scope object + key: header name in scope + Returns: + A list with a single string with the header value if it exists, + else an empty list. + """ + headers = carrier.get("headers") + return [ + _value.decode("utf8") + for (_key, _value) in headers + if _key.decode("utf8") == key + ] + + +carrier_getter = CarrierGetter() def collect_request_attributes(scope): @@ -72,10 +82,10 @@ def collect_request_attributes(scope): http_method = scope.get("method") if http_method: result["http.method"] = http_method - http_host_value = ",".join(get_header_from_scope(scope, "host")) + http_host_value = ",".join(carrier_getter.get(scope, "host")) if http_host_value: result["http.server_name"] = http_host_value - http_user_agent = get_header_from_scope(scope, "user-agent") + http_user_agent = carrier_getter.get(scope, "user-agent") if len(http_user_agent) > 0: result["http.user_agent"] = http_user_agent[0] @@ -154,9 +164,7 @@ class OpenTelemetryMiddleware: if scope["type"] not in ("http", "websocket"): return await self.app(scope, receive, send) - token = context.attach( - propagators.extract(get_header_from_scope, scope) - ) + token = context.attach(propagators.extract(carrier_getter, scope)) span_name, additional_attributes = self.span_details_callback(scope) try: