diff --git a/binder/src/androidTest/java/io/grpc/binder/internal/BinderClientTransportTest.java b/binder/src/androidTest/java/io/grpc/binder/internal/BinderClientTransportTest.java index b99114bb50..dbacf35178 100644 --- a/binder/src/androidTest/java/io/grpc/binder/internal/BinderClientTransportTest.java +++ b/binder/src/androidTest/java/io/grpc/binder/internal/BinderClientTransportTest.java @@ -18,42 +18,39 @@ package io.grpc.binder.internal; import static com.google.common.truth.Truth.assertThat; -import android.app.Service; import android.content.Context; -import android.content.Intent; -import android.os.IBinder; import android.os.Parcel; import androidx.core.content.ContextCompat; import androidx.test.core.app.ApplicationProvider; import androidx.test.ext.junit.runners.AndroidJUnit4; -import com.google.common.util.concurrent.testing.TestingExecutors; import com.google.protobuf.Empty; import io.grpc.Attributes; import io.grpc.CallOptions; import io.grpc.ClientStreamTracer; import io.grpc.Metadata; import io.grpc.MethodDescriptor; -import io.grpc.Server; import io.grpc.ServerCallHandler; import io.grpc.ServerServiceDefinition; import io.grpc.Status; +import io.grpc.Status.Code; import io.grpc.binder.AndroidComponentAddress; -import io.grpc.binder.BinderServerBuilder; import io.grpc.binder.BindServiceFlags; +import io.grpc.binder.BinderServerBuilder; import io.grpc.binder.HostServices; -import io.grpc.binder.IBinderReceiver; import io.grpc.binder.InboundParcelablePolicy; import io.grpc.binder.SecurityPolicies; +import io.grpc.binder.SecurityPolicy; import io.grpc.internal.ClientStream; import io.grpc.internal.ClientStreamListener; -import io.grpc.internal.ClientStreamListener.RpcProgress; import io.grpc.internal.FixedObjectPool; import io.grpc.internal.ManagedClientTransport; import io.grpc.internal.ObjectPool; import io.grpc.internal.StreamListener; import io.grpc.protobuf.lite.ProtoLiteUtils; import io.grpc.stub.ServerCalls; -import java.io.IOException; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.Executors; +import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.ScheduledExecutorService; import javax.annotation.Nullable; import org.junit.After; @@ -94,7 +91,7 @@ public final class BinderClientTransportTest { BinderTransport.BinderClientTransport transport; private final ObjectPool executorServicePool = - new FixedObjectPool<>(TestingExecutors.sameThreadScheduledExecutor()); + new FixedObjectPool<>(Executors.newScheduledThreadPool(1)); private final TestTransportListener transportListener = new TestTransportListener(); private final TestStreamListener streamListener = new TestStreamListener(); @@ -127,39 +124,50 @@ public final class BinderClientTransportTest { .build(); serverAddress = HostServices.allocateService(appContext); - HostServices.configureService(serverAddress, + HostServices.configureService( + serverAddress, HostServices.serviceParamsBuilder() - .setServerFactory((service, receiver) -> - BinderServerBuilder.forAddress(serverAddress, receiver) - .addService(serviceDef) - .build()) - .build()); + .setServerFactory( + (service, receiver) -> + BinderServerBuilder.forAddress(serverAddress, receiver) + .addService(serviceDef) + .build()) + .build()); + } - transport = - new BinderTransport.BinderClientTransport( - appContext, - serverAddress, - BindServiceFlags.DEFAULTS, - ContextCompat.getMainExecutor(appContext), - executorServicePool, - executorServicePool, - SecurityPolicies.internalOnly(), - InboundParcelablePolicy.DEFAULT, - Attributes.EMPTY); + private class BinderClientTransportBuilder { + private SecurityPolicy securityPolicy = SecurityPolicies.internalOnly(); - Runnable r = transport.start(transportListener); - r.run(); - transportListener.awaitReady(); + public BinderClientTransportBuilder setSecurityPolicy(SecurityPolicy securityPolicy) { + this.securityPolicy = securityPolicy; + return this; + } + + public BinderTransport.BinderClientTransport build() { + return new BinderTransport.BinderClientTransport( + appContext, + serverAddress, + BindServiceFlags.DEFAULTS, + ContextCompat.getMainExecutor(appContext), + executorServicePool, + executorServicePool, + securityPolicy, + InboundParcelablePolicy.DEFAULT, + Attributes.EMPTY); + } } @After public void tearDown() throws Exception { transport.shutdownNow(Status.OK); HostServices.awaitServiceShutdown(); + executorServicePool.getObject().shutdownNow(); } @Test public void testShutdownBeforeStreamStart_b153326034() throws Exception { + transport = new BinderClientTransportBuilder().build(); + startAndAwaitReady(transport, transportListener); ClientStream stream = transport.newStream( methodDesc, new Metadata(), CallOptions.DEFAULT, tracers); transport.shutdownNow(Status.UNKNOWN.withDescription("reasons")); @@ -170,6 +178,8 @@ public final class BinderClientTransportTest { @Test public void testRequestWhileStreamIsWaitingOnCall_b154088869() throws Exception { + transport = new BinderClientTransportBuilder().build(); + startAndAwaitReady(transport, transportListener); ClientStream stream = transport.newStream(streamingMethodDesc, new Metadata(), CallOptions.DEFAULT, tracers); @@ -188,6 +198,8 @@ public final class BinderClientTransportTest { @Test public void testTransactionForDiscardedCall_b155244043() throws Exception { + transport = new BinderClientTransportBuilder().build(); + startAndAwaitReady(transport, transportListener); ClientStream stream = transport.newStream(streamingMethodDesc, new Metadata(), CallOptions.DEFAULT, tracers); @@ -206,6 +218,8 @@ public final class BinderClientTransportTest { @Test public void testBadTransactionStreamThroughput_b163053382() throws Exception { + transport = new BinderClientTransportBuilder().build(); + startAndAwaitReady(transport, transportListener); ClientStream stream = transport.newStream(streamingMethodDesc, new Metadata(), CallOptions.DEFAULT, tracers); @@ -225,6 +239,8 @@ public final class BinderClientTransportTest { @Test public void testMessageProducerClosedAfterStream_b169313545() { + transport = new BinderClientTransportBuilder().build(); + startAndAwaitReady(transport, transportListener); ClientStream stream = transport.newStream(methodDesc, new Metadata(), CallOptions.DEFAULT, tracers); @@ -243,6 +259,22 @@ public final class BinderClientTransportTest { streamListener.drainMessages(); } + @Test + public void testNewStreamBeforeTransportReadyFails() throws InterruptedException { + // Use a special SecurityPolicy that lets us act before the transport is setup/ready. + BlockingSecurityPolicy bsp = new BlockingSecurityPolicy(); + transport = new BinderClientTransportBuilder().setSecurityPolicy(bsp).build(); + transport.start(transportListener).run(); + ClientStream stream = + transport.newStream(streamingMethodDesc, new Metadata(), CallOptions.DEFAULT, tracers); + stream.start(streamListener); + assertThat(streamListener.awaitClose().getCode()).isEqualTo(Code.INTERNAL); + + // Unblock the SETUP_TRANSPORT handshake and make sure it becomes ready in the usual way. + bsp.provideNextCheckAuthorizationResult(Status.OK); + transportListener.awaitReady(); + } + private synchronized void awaitServerCallsCompleted(int calls) { while (serverCallsCompleted < calls) { try { @@ -253,6 +285,12 @@ public final class BinderClientTransportTest { } } + private static void startAndAwaitReady( + BinderTransport.BinderClientTransport transport, TestTransportListener transportListener) { + transport.start(transportListener).run(); + transportListener.awaitReady(); + } + private static final class TestTransportListener implements ManagedClientTransport.Listener { public boolean ready; public boolean inUse; @@ -313,6 +351,17 @@ public final class BinderClientTransportTest { } } + public synchronized Status awaitClose() { + while (closedStatus == null) { + try { + wait(100); + } catch (InterruptedException inte) { + throw new AssertionError("Interrupted waiting for close"); + } + } + return closedStatus; + } + public int drainMessages() { int n = 0; while (messageProducer.next() != null) { @@ -336,4 +385,24 @@ public final class BinderClientTransportTest { this.closedStatus = status; } } + + /** + * A SecurityPolicy that blocks the transport authorization check until a test sets the outcome. + */ + static class BlockingSecurityPolicy extends SecurityPolicy { + private final BlockingQueue results = new LinkedBlockingQueue<>(); + + public void provideNextCheckAuthorizationResult(Status status) { + results.add(status); + } + + @Override + public Status checkAuthorization(int uid) { + try { + return results.take(); + } catch (InterruptedException e) { + return Status.fromThrowable(e); + } + } + } } diff --git a/binder/src/main/java/io/grpc/binder/internal/BinderTransport.java b/binder/src/main/java/io/grpc/binder/internal/BinderTransport.java index 219651a8b6..84a5e17fa5 100644 --- a/binder/src/main/java/io/grpc/binder/internal/BinderTransport.java +++ b/binder/src/main/java/io/grpc/binder/internal/BinderTransport.java @@ -635,33 +635,39 @@ public abstract class BinderTransport final Metadata headers, final CallOptions callOptions, ClientStreamTracer[] tracers) { - if (isShutdown()) { - return newFailingClientStream(shutdownStatus, attributes, headers, tracers); + if (!inState(TransportState.READY)) { + return newFailingClientStream( + isShutdown() + ? shutdownStatus + : Status.INTERNAL.withDescription("newStream() before transportReady()"), + attributes, + headers, + tracers); + } + + int callId = latestCallId++; + if (latestCallId == LAST_CALL_ID) { + latestCallId = FIRST_CALL_ID; + } + StatsTraceContext statsTraceContext = + StatsTraceContext.newClientContext(tracers, attributes, headers); + Inbound.ClientInbound inbound = + new Inbound.ClientInbound( + this, attributes, callId, GrpcUtil.shouldBeCountedForInUse(callOptions)); + if (ongoingCalls.putIfAbsent(callId, inbound) != null) { + Status failure = Status.INTERNAL.withDescription("Clashing call IDs"); + shutdownInternal(failure, true); + return newFailingClientStream(failure, attributes, headers, tracers); } else { - int callId = latestCallId++; - if (latestCallId == LAST_CALL_ID) { - latestCallId = FIRST_CALL_ID; + if (inbound.countsForInUse() && numInUseStreams.getAndIncrement() == 0) { + clientTransportListener.transportInUse(true); } - StatsTraceContext statsTraceContext = - StatsTraceContext.newClientContext(tracers, attributes, headers); - Inbound.ClientInbound inbound = - new Inbound.ClientInbound( - this, attributes, callId, GrpcUtil.shouldBeCountedForInUse(callOptions)); - if (ongoingCalls.putIfAbsent(callId, inbound) != null) { - Status failure = Status.INTERNAL.withDescription("Clashing call IDs"); - shutdownInternal(failure, true); - return newFailingClientStream(failure, attributes, headers, tracers); + Outbound.ClientOutbound outbound = + new Outbound.ClientOutbound(this, callId, method, headers, statsTraceContext); + if (method.getType().clientSendsOneMessage()) { + return new SingleMessageClientStream(inbound, outbound, attributes); } else { - if (inbound.countsForInUse() && numInUseStreams.getAndIncrement() == 0) { - clientTransportListener.transportInUse(true); - } - Outbound.ClientOutbound outbound = - new Outbound.ClientOutbound(this, callId, method, headers, statsTraceContext); - if (method.getType().clientSendsOneMessage()) { - return new SingleMessageClientStream(inbound, outbound, attributes); - } else { - return new MultiMessageClientStream(inbound, outbound, attributes); - } + return new MultiMessageClientStream(inbound, outbound, attributes); } } } diff --git a/core/src/main/java/io/grpc/internal/ManagedClientTransport.java b/core/src/main/java/io/grpc/internal/ManagedClientTransport.java index 47cf53a725..d38721af78 100644 --- a/core/src/main/java/io/grpc/internal/ManagedClientTransport.java +++ b/core/src/main/java/io/grpc/internal/ManagedClientTransport.java @@ -94,6 +94,8 @@ public interface ManagedClientTransport extends ClientTransport { /** * The transport is ready to accept traffic, because the connection is established. This is * called at most once. + * + *

Streams created before this milestone are not guaranteed to function. */ void transportReady(); diff --git a/core/src/test/java/io/grpc/internal/AbstractTransportTest.java b/core/src/test/java/io/grpc/internal/AbstractTransportTest.java index cd52218131..06dfef5daa 100644 --- a/core/src/test/java/io/grpc/internal/AbstractTransportTest.java +++ b/core/src/test/java/io/grpc/internal/AbstractTransportTest.java @@ -1145,7 +1145,7 @@ public abstract class AbstractTransportTest { public void earlyServerClose_serverFailure_withClientCancelOnListenerClosed() throws Exception { server.start(serverListener); client = newClientTransport(server); - runIfNotNull(client.start(mockClientTransportListener)); + startTransport(client, mockClientTransportListener); MockServerTransportListener serverTransportListener = serverListener.takeListenerOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); serverTransport = serverTransportListener.transport;