Only set aws header attribute name once. (#2509)
This commit is contained in:
parent
3a206db032
commit
96606711ac
|
@ -10,6 +10,8 @@ import static java.lang.invoke.MethodType.methodType;
|
|||
import com.amazonaws.AmazonWebServiceRequest;
|
||||
import java.lang.invoke.MethodHandle;
|
||||
import java.lang.invoke.MethodHandles;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import javax.annotation.Nullable;
|
||||
|
||||
/**
|
||||
|
@ -24,6 +26,7 @@ import javax.annotation.Nullable;
|
|||
final class SqsReceiveMessageRequestAccess {
|
||||
|
||||
@Nullable private static final MethodHandle WITH_ATTRIBUTE_NAMES;
|
||||
@Nullable private static final MethodHandle GET_ATTRIBUTE_NAMES;
|
||||
|
||||
static {
|
||||
Class<?> receiveMessageRequestClass = null;
|
||||
|
@ -46,8 +49,19 @@ final class SqsReceiveMessageRequestAccess {
|
|||
// Ignore
|
||||
}
|
||||
WITH_ATTRIBUTE_NAMES = withAttributeNames;
|
||||
|
||||
MethodHandle getAttributeNames = null;
|
||||
try {
|
||||
getAttributeNames =
|
||||
lookup.findVirtual(
|
||||
receiveMessageRequestClass, "getAttributeNames", methodType(List.class));
|
||||
} catch (NoSuchMethodException | IllegalAccessException e) {
|
||||
// Ignore
|
||||
}
|
||||
GET_ATTRIBUTE_NAMES = getAttributeNames;
|
||||
} else {
|
||||
WITH_ATTRIBUTE_NAMES = null;
|
||||
GET_ATTRIBUTE_NAMES = null;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -69,5 +83,16 @@ final class SqsReceiveMessageRequestAccess {
|
|||
}
|
||||
}
|
||||
|
||||
static List<String> getAttributeNames(AmazonWebServiceRequest request) {
|
||||
if (GET_ATTRIBUTE_NAMES == null) {
|
||||
return Collections.emptyList();
|
||||
}
|
||||
try {
|
||||
return (List<String>) GET_ATTRIBUTE_NAMES.invoke(request);
|
||||
} catch (Throwable t) {
|
||||
return Collections.emptyList();
|
||||
}
|
||||
}
|
||||
|
||||
private SqsReceiveMessageRequestAccess() {}
|
||||
}
|
||||
|
|
|
@ -53,8 +53,11 @@ public class TracingRequestHandler extends RequestHandler2 {
|
|||
@Override
|
||||
public AmazonWebServiceRequest beforeMarshalling(AmazonWebServiceRequest request) {
|
||||
if (SqsReceiveMessageRequestAccess.isInstance(request)) {
|
||||
SqsReceiveMessageRequestAccess.withAttributeNames(
|
||||
request, SqsParentContext.AWS_TRACE_SYSTEM_ATTRIBUTE);
|
||||
if (!SqsReceiveMessageRequestAccess.getAttributeNames(request)
|
||||
.contains(SqsParentContext.AWS_TRACE_SYSTEM_ATTRIBUTE)) {
|
||||
SqsReceiveMessageRequestAccess.withAttributeNames(
|
||||
request, SqsParentContext.AWS_TRACE_SYSTEM_ATTRIBUTE);
|
||||
}
|
||||
}
|
||||
return request;
|
||||
}
|
||||
|
|
|
@ -11,6 +11,7 @@ import com.amazonaws.auth.AWSStaticCredentialsProvider
|
|||
import com.amazonaws.auth.BasicAWSCredentials
|
||||
import com.amazonaws.client.builder.AwsClientBuilder
|
||||
import com.amazonaws.services.sqs.AmazonSQSAsyncClient
|
||||
import com.amazonaws.services.sqs.model.ReceiveMessageRequest
|
||||
import com.amazonaws.services.sqs.model.SendMessageRequest
|
||||
import io.opentelemetry.instrumentation.test.AgentInstrumentationSpecification
|
||||
import io.opentelemetry.instrumentation.test.utils.PortUtils
|
||||
|
@ -144,4 +145,20 @@ class SqsTracingTest extends AgentInstrumentationSpecification {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
def "only adds attribute name once when request reused"() {
|
||||
setup:
|
||||
client.createQueue("testSdkSqs2")
|
||||
|
||||
when:
|
||||
SendMessageRequest send = new SendMessageRequest("http://localhost:$sqsPort/000000000000/testSdkSqs2", "{\"type\": \"hello\"}")
|
||||
client.sendMessage(send)
|
||||
ReceiveMessageRequest receive = new ReceiveMessageRequest("http://localhost:$sqsPort/000000000000/testSdkSqs2")
|
||||
client.receiveMessage(receive)
|
||||
client.sendMessage(send)
|
||||
client.receiveMessage(receive)
|
||||
|
||||
then:
|
||||
receive.getAttributeNames() == ["AWSTraceHeader"]
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue