Only set aws header attribute name once. (#2509)

This commit is contained in:
Anuraag Agrawal 2021-03-06 04:00:07 +09:00 committed by GitHub
parent 3a206db032
commit 96606711ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 47 additions and 2 deletions

View File

@ -10,6 +10,8 @@ import static java.lang.invoke.MethodType.methodType;
import com.amazonaws.AmazonWebServiceRequest;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.util.Collections;
import java.util.List;
import javax.annotation.Nullable;
/**
@ -24,6 +26,7 @@ import javax.annotation.Nullable;
final class SqsReceiveMessageRequestAccess {
@Nullable private static final MethodHandle WITH_ATTRIBUTE_NAMES;
@Nullable private static final MethodHandle GET_ATTRIBUTE_NAMES;
static {
Class<?> receiveMessageRequestClass = null;
@ -46,8 +49,19 @@ final class SqsReceiveMessageRequestAccess {
// Ignore
}
WITH_ATTRIBUTE_NAMES = withAttributeNames;
MethodHandle getAttributeNames = null;
try {
getAttributeNames =
lookup.findVirtual(
receiveMessageRequestClass, "getAttributeNames", methodType(List.class));
} catch (NoSuchMethodException | IllegalAccessException e) {
// Ignore
}
GET_ATTRIBUTE_NAMES = getAttributeNames;
} else {
WITH_ATTRIBUTE_NAMES = null;
GET_ATTRIBUTE_NAMES = null;
}
}
@ -69,5 +83,16 @@ final class SqsReceiveMessageRequestAccess {
}
}
static List<String> getAttributeNames(AmazonWebServiceRequest request) {
if (GET_ATTRIBUTE_NAMES == null) {
return Collections.emptyList();
}
try {
return (List<String>) GET_ATTRIBUTE_NAMES.invoke(request);
} catch (Throwable t) {
return Collections.emptyList();
}
}
private SqsReceiveMessageRequestAccess() {}
}

View File

@ -53,8 +53,11 @@ public class TracingRequestHandler extends RequestHandler2 {
@Override
public AmazonWebServiceRequest beforeMarshalling(AmazonWebServiceRequest request) {
if (SqsReceiveMessageRequestAccess.isInstance(request)) {
SqsReceiveMessageRequestAccess.withAttributeNames(
request, SqsParentContext.AWS_TRACE_SYSTEM_ATTRIBUTE);
if (!SqsReceiveMessageRequestAccess.getAttributeNames(request)
.contains(SqsParentContext.AWS_TRACE_SYSTEM_ATTRIBUTE)) {
SqsReceiveMessageRequestAccess.withAttributeNames(
request, SqsParentContext.AWS_TRACE_SYSTEM_ATTRIBUTE);
}
}
return request;
}

View File

@ -11,6 +11,7 @@ import com.amazonaws.auth.AWSStaticCredentialsProvider
import com.amazonaws.auth.BasicAWSCredentials
import com.amazonaws.client.builder.AwsClientBuilder
import com.amazonaws.services.sqs.AmazonSQSAsyncClient
import com.amazonaws.services.sqs.model.ReceiveMessageRequest
import com.amazonaws.services.sqs.model.SendMessageRequest
import io.opentelemetry.instrumentation.test.AgentInstrumentationSpecification
import io.opentelemetry.instrumentation.test.utils.PortUtils
@ -144,4 +145,20 @@ class SqsTracingTest extends AgentInstrumentationSpecification {
}
}
}
def "only adds attribute name once when request reused"() {
setup:
client.createQueue("testSdkSqs2")
when:
SendMessageRequest send = new SendMessageRequest("http://localhost:$sqsPort/000000000000/testSdkSqs2", "{\"type\": \"hello\"}")
client.sendMessage(send)
ReceiveMessageRequest receive = new ReceiveMessageRequest("http://localhost:$sqsPort/000000000000/testSdkSqs2")
client.receiveMessage(receive)
client.sendMessage(send)
client.receiveMessage(receive)
then:
receive.getAttributeNames() == ["AWSTraceHeader"]
}
}