diff --git a/instrumentation/opentelemetry-instrumentation-celery/src/opentelemetry/instrumentation/celery/__init__.py b/instrumentation/opentelemetry-instrumentation-celery/src/opentelemetry/instrumentation/celery/__init__.py index a5897f8b6..d225e6bd0 100644 --- a/instrumentation/opentelemetry-instrumentation-celery/src/opentelemetry/instrumentation/celery/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-celery/src/opentelemetry/instrumentation/celery/__init__.py @@ -67,7 +67,7 @@ from opentelemetry import propagators, trace from opentelemetry.instrumentation.celery import utils from opentelemetry.instrumentation.celery.version import __version__ from opentelemetry.instrumentation.instrumentor import BaseInstrumentor -from opentelemetry.trace.propagation import get_current_span +from opentelemetry.trace.propagation.textmap import DictGetter from opentelemetry.trace.status import Status, StatusCode logger = logging.getLogger(__name__) @@ -84,6 +84,20 @@ _TASK_NAME_KEY = "celery.task_name" _MESSAGE_ID_ATTRIBUTE_NAME = "messaging.message_id" +class CarrierGetter(DictGetter): + def get(self, carrier, key): + value = getattr(carrier, key, []) + if isinstance(value, str) or not isinstance(value, Iterable): + value = (value,) + return value + + def keys(self, carrier): + return [] + + +carrier_getter = CarrierGetter() + + class CeleryInstrumentor(BaseInstrumentor): def _instrument(self, **kwargs): tracer_provider = kwargs.get("tracer_provider") @@ -118,7 +132,7 @@ class CeleryInstrumentor(BaseInstrumentor): return request = task.request - tracectx = propagators.extract(carrier_extractor, request) or None + tracectx = propagators.extract(carrier_getter, request) or None logger.debug("prerun signal start task_id=%s", task_id) @@ -246,10 +260,3 @@ class CeleryInstrumentor(BaseInstrumentor): # Use `str(reason)` instead of `reason.message` in case we get # something that isn't an `Exception` span.set_attribute(_TASK_RETRY_REASON_KEY, str(reason)) - - -def carrier_extractor(carrier, key): - value = getattr(carrier, key, []) - if isinstance(value, str) or not isinstance(value, Iterable): - value = (value,) - return value