diff --git a/alts/src/main/java/io/grpc/alts/internal/TsiFrameHandler.java b/alts/src/main/java/io/grpc/alts/internal/TsiFrameHandler.java index a2c1a1d7c6..264541223b 100644 --- a/alts/src/main/java/io/grpc/alts/internal/TsiFrameHandler.java +++ b/alts/src/main/java/io/grpc/alts/internal/TsiFrameHandler.java @@ -47,6 +47,16 @@ public final class TsiFrameHandler extends ByteToMessageDecoder implements Chann private TsiFrameProtector protector; private PendingWriteQueue pendingUnprotectedWrites; + private State state = State.HANDSHAKE_NOT_FINISHED; + private boolean closeInitiated = false; + + @VisibleForTesting + enum State { + HANDSHAKE_NOT_FINISHED, + PROTECTED, + CLOSED, + HANDSHAKE_FAILED + } public TsiFrameHandler() {} @@ -67,6 +77,8 @@ public final class TsiFrameHandler extends ByteToMessageDecoder implements Chann TsiHandshakeCompletionEvent tsiEvent = (TsiHandshakeCompletionEvent) event; if (tsiEvent.isSuccess()) { setProtector(tsiEvent.protector()); + } else { + state = State.HANDSHAKE_FAILED; } // Ignore errors. Another handler in the pipeline must handle TSI Errors. } @@ -79,18 +91,23 @@ public final class TsiFrameHandler extends ByteToMessageDecoder implements Chann logger.finest("TsiFrameHandler protector set"); checkState(this.protector == null); this.protector = checkNotNull(protector); + this.state = State.PROTECTED; } @Override protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { - checkState(protector != null, "Cannot read frames while the TSI handshake is in progress"); + checkState( + state == State.PROTECTED, + "Cannot read frames while the TSI handshake is %s", state); protector.unprotect(in, out, ctx.alloc()); } @Override public void write(ChannelHandlerContext ctx, Object message, ChannelPromise promise) throws Exception { - checkState(protector != null, "Cannot write frames while the TSI handshake is in progress"); + checkState( + state == State.PROTECTED, + "Cannot write frames while the TSI handshake state is %s", state); ByteBuf msg = (ByteBuf) message; if (!msg.isReadable()) { // Nothing to encode. @@ -104,8 +121,7 @@ public final class TsiFrameHandler extends ByteToMessageDecoder implements Chann } @Override - public void handlerRemoved0(ChannelHandlerContext ctx) { - logger.finest("TsiFrameHandler removed"); + public void handlerRemoved0(ChannelHandlerContext ctx) throws Exception { if (!pendingUnprotectedWrites.isEmpty()) { pendingUnprotectedWrites.removeAndFailAll( new ChannelException("Pending write on removal of TSI handler")); @@ -134,19 +150,37 @@ public final class TsiFrameHandler extends ByteToMessageDecoder implements Chann @Override public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) { - release(); + doClose(ctx); ctx.disconnect(promise); } + private void doClose(ChannelHandlerContext ctx) { + if (closeInitiated) { + return; + } + closeInitiated = true; + try { + // flush any remaining writes before close + if (!pendingUnprotectedWrites.isEmpty()) { + flush(ctx); + } + } catch (GeneralSecurityException e) { + logger.log(Level.FINE, "Ignoring error on flush before close", e); + } finally { + state = State.CLOSED; + release(); + } + } + @Override public void close(ChannelHandlerContext ctx, ChannelPromise promise) { - release(); + doClose(ctx); ctx.close(promise); } @Override public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) { - release(); + doClose(ctx); ctx.deregister(promise); } @@ -157,7 +191,14 @@ public final class TsiFrameHandler extends ByteToMessageDecoder implements Chann @Override public void flush(final ChannelHandlerContext ctx) throws GeneralSecurityException { - checkState(protector != null, "Cannot write frames while the TSI handshake is in progress"); + if (state == State.CLOSED || state == State.HANDSHAKE_FAILED) { + logger.fine( + String.format("FrameHandler is inactive(%s), channel id: %s", + state, ctx.channel().id().asShortText())); + return; + } + checkState( + state == State.PROTECTED, "Cannot write frames while the TSI handshake state is %s", state); final ProtectedPromise aggregatePromise = new ProtectedPromise(ctx.channel(), ctx.executor(), pendingUnprotectedWrites.size()); diff --git a/alts/src/test/java/io/grpc/alts/internal/TsiFrameHandlerTest.java b/alts/src/test/java/io/grpc/alts/internal/TsiFrameHandlerTest.java new file mode 100644 index 0000000000..df33f07946 --- /dev/null +++ b/alts/src/test/java/io/grpc/alts/internal/TsiFrameHandlerTest.java @@ -0,0 +1,147 @@ +/* + * Copyright 2018 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.alts.internal; + +import static com.google.common.truth.Truth.assertThat; +import static com.google.common.truth.Truth.assertWithMessage; +import static org.junit.Assert.fail; + +import io.grpc.alts.internal.TsiFrameHandler.State; +import io.grpc.alts.internal.TsiHandshakeHandler.TsiHandshakeCompletionEvent; +import io.grpc.alts.internal.TsiPeer.Property; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.util.CharsetUtil; +import java.security.GeneralSecurityException; +import java.util.ArrayList; +import java.util.List; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.DisableOnDebug; +import org.junit.rules.TestRule; +import org.junit.rules.Timeout; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link TsiFrameHandler}. */ +@RunWith(JUnit4.class) +public class TsiFrameHandlerTest { + + @Rule + public final TestRule globalTimeout = new DisableOnDebug(Timeout.seconds(5)); + + private final TsiFrameHandler tsiFrameHandler = new TsiFrameHandler(); + private final EmbeddedChannel channel = new EmbeddedChannel(tsiFrameHandler); + + @Test + public void writeAndFlush_beforeHandshakeEventShouldBeIgnored() { + ByteBuf msg = Unpooled.copiedBuffer("message before handshake finished", CharsetUtil.UTF_8); + + channel.writeAndFlush(msg); + + assertThat(channel.outboundMessages()).isEmpty(); + try { + channel.checkException(); + fail(); + } catch (IllegalStateException e) { + assertThat(e).hasMessageThat().contains(State.HANDSHAKE_NOT_FINISHED.name()); + } + } + + @Test + public void writeAndFlush_handshakeSucceed() throws InterruptedException { + channel.pipeline().fireUserEventTriggered(getHandshakeSuccessEvent()); + ByteBuf msg = Unpooled.copiedBuffer("message after handshake finished", CharsetUtil.UTF_8); + + channel.writeAndFlush(msg); + + assertThat(channel.readOutbound()).isEqualTo(msg); + channel.close().sync(); + channel.checkException(); + } + + @Test + public void writeAndFlush_shouldBeIgnoredAfterClose() throws InterruptedException { + channel.close().sync(); + ByteBuf msg = Unpooled.copiedBuffer("message after closed", CharsetUtil.UTF_8); + + channel.writeAndFlush(msg); + + assertThat(channel.outboundMessages()).isEmpty(); + try { + channel.checkException(); + } catch (Exception e) { + throw new AssertionError("Any attempt after close should be ignored without out exception"); + } + } + + @Test + public void writeAndFlush_handshakeFailed() throws InterruptedException { + channel.pipeline().fireUserEventTriggered(new TsiHandshakeCompletionEvent(new Exception())); + ByteBuf msg = Unpooled.copiedBuffer("message after handshake failed", CharsetUtil.UTF_8); + + channel.writeAndFlush(msg); + + assertThat(channel.outboundMessages()).isEmpty(); + channel.close().sync(); + channel.checkException(); + } + + @Test + public void close_shouldFlushRemainingMessage() throws InterruptedException { + channel.pipeline().fireUserEventTriggered(getHandshakeSuccessEvent()); + + ByteBuf msg = Unpooled.copiedBuffer("message after handshake failed", CharsetUtil.UTF_8); + channel.write(msg); + + assertThat(channel.outboundMessages()).isEmpty(); + + channel.close().sync(); + + assertWithMessage("pending write should be flushed on close") + .that(channel.readOutbound()).isEqualTo(msg); + channel.checkException(); + } + + private TsiHandshakeCompletionEvent getHandshakeSuccessEvent() { + TsiFrameProtector protector = new IdentityFrameProtector(); + TsiPeer peer = new TsiPeer(new ArrayList>()); + return new TsiHandshakeCompletionEvent(protector, peer, new Object()); + } + + private static final class IdentityFrameProtector implements TsiFrameProtector { + + @Override + public void protectFlush(List unprotectedBufs, Consumer ctxWrite, + ByteBufAllocator alloc) throws GeneralSecurityException { + for (ByteBuf unprotectedBuf : unprotectedBufs) { + ctxWrite.accept(unprotectedBuf); + } + } + + @Override + public void unprotect(ByteBuf in, List out, ByteBufAllocator alloc) + throws GeneralSecurityException { + out.add(in.toString(CharsetUtil.UTF_8)); + } + + @Override + public void destroy() {} + } +}