alts: TsiFrameHandler doesn't throw exception when flush after closed (#5180)

also, error / log messages will contain state of FrameHandler
This commit is contained in:
Jihun Cho 2018-12-20 10:12:37 -08:00 committed by GitHub
parent 87cf40437c
commit 9eeceab597
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 196 additions and 8 deletions

View File

@ -47,6 +47,16 @@ public final class TsiFrameHandler extends ByteToMessageDecoder implements Chann
private TsiFrameProtector protector; private TsiFrameProtector protector;
private PendingWriteQueue pendingUnprotectedWrites; private PendingWriteQueue pendingUnprotectedWrites;
private State state = State.HANDSHAKE_NOT_FINISHED;
private boolean closeInitiated = false;
@VisibleForTesting
enum State {
HANDSHAKE_NOT_FINISHED,
PROTECTED,
CLOSED,
HANDSHAKE_FAILED
}
public TsiFrameHandler() {} public TsiFrameHandler() {}
@ -67,6 +77,8 @@ public final class TsiFrameHandler extends ByteToMessageDecoder implements Chann
TsiHandshakeCompletionEvent tsiEvent = (TsiHandshakeCompletionEvent) event; TsiHandshakeCompletionEvent tsiEvent = (TsiHandshakeCompletionEvent) event;
if (tsiEvent.isSuccess()) { if (tsiEvent.isSuccess()) {
setProtector(tsiEvent.protector()); setProtector(tsiEvent.protector());
} else {
state = State.HANDSHAKE_FAILED;
} }
// Ignore errors. Another handler in the pipeline must handle TSI Errors. // Ignore errors. Another handler in the pipeline must handle TSI Errors.
} }
@ -79,18 +91,23 @@ public final class TsiFrameHandler extends ByteToMessageDecoder implements Chann
logger.finest("TsiFrameHandler protector set"); logger.finest("TsiFrameHandler protector set");
checkState(this.protector == null); checkState(this.protector == null);
this.protector = checkNotNull(protector); this.protector = checkNotNull(protector);
this.state = State.PROTECTED;
} }
@Override @Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception { protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
checkState(protector != null, "Cannot read frames while the TSI handshake is in progress"); checkState(
state == State.PROTECTED,
"Cannot read frames while the TSI handshake is %s", state);
protector.unprotect(in, out, ctx.alloc()); protector.unprotect(in, out, ctx.alloc());
} }
@Override @Override
public void write(ChannelHandlerContext ctx, Object message, ChannelPromise promise) public void write(ChannelHandlerContext ctx, Object message, ChannelPromise promise)
throws Exception { throws Exception {
checkState(protector != null, "Cannot write frames while the TSI handshake is in progress"); checkState(
state == State.PROTECTED,
"Cannot write frames while the TSI handshake state is %s", state);
ByteBuf msg = (ByteBuf) message; ByteBuf msg = (ByteBuf) message;
if (!msg.isReadable()) { if (!msg.isReadable()) {
// Nothing to encode. // Nothing to encode.
@ -104,8 +121,7 @@ public final class TsiFrameHandler extends ByteToMessageDecoder implements Chann
} }
@Override @Override
public void handlerRemoved0(ChannelHandlerContext ctx) { public void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
logger.finest("TsiFrameHandler removed");
if (!pendingUnprotectedWrites.isEmpty()) { if (!pendingUnprotectedWrites.isEmpty()) {
pendingUnprotectedWrites.removeAndFailAll( pendingUnprotectedWrites.removeAndFailAll(
new ChannelException("Pending write on removal of TSI handler")); new ChannelException("Pending write on removal of TSI handler"));
@ -134,19 +150,37 @@ public final class TsiFrameHandler extends ByteToMessageDecoder implements Chann
@Override @Override
public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) { public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) {
release(); doClose(ctx);
ctx.disconnect(promise); ctx.disconnect(promise);
} }
private void doClose(ChannelHandlerContext ctx) {
if (closeInitiated) {
return;
}
closeInitiated = true;
try {
// flush any remaining writes before close
if (!pendingUnprotectedWrites.isEmpty()) {
flush(ctx);
}
} catch (GeneralSecurityException e) {
logger.log(Level.FINE, "Ignoring error on flush before close", e);
} finally {
state = State.CLOSED;
release();
}
}
@Override @Override
public void close(ChannelHandlerContext ctx, ChannelPromise promise) { public void close(ChannelHandlerContext ctx, ChannelPromise promise) {
release(); doClose(ctx);
ctx.close(promise); ctx.close(promise);
} }
@Override @Override
public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) { public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) {
release(); doClose(ctx);
ctx.deregister(promise); ctx.deregister(promise);
} }
@ -157,7 +191,14 @@ public final class TsiFrameHandler extends ByteToMessageDecoder implements Chann
@Override @Override
public void flush(final ChannelHandlerContext ctx) throws GeneralSecurityException { public void flush(final ChannelHandlerContext ctx) throws GeneralSecurityException {
checkState(protector != null, "Cannot write frames while the TSI handshake is in progress"); if (state == State.CLOSED || state == State.HANDSHAKE_FAILED) {
logger.fine(
String.format("FrameHandler is inactive(%s), channel id: %s",
state, ctx.channel().id().asShortText()));
return;
}
checkState(
state == State.PROTECTED, "Cannot write frames while the TSI handshake state is %s", state);
final ProtectedPromise aggregatePromise = final ProtectedPromise aggregatePromise =
new ProtectedPromise(ctx.channel(), ctx.executor(), pendingUnprotectedWrites.size()); new ProtectedPromise(ctx.channel(), ctx.executor(), pendingUnprotectedWrites.size());

View File

@ -0,0 +1,147 @@
/*
* Copyright 2018 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts.internal;
import static com.google.common.truth.Truth.assertThat;
import static com.google.common.truth.Truth.assertWithMessage;
import static org.junit.Assert.fail;
import io.grpc.alts.internal.TsiFrameHandler.State;
import io.grpc.alts.internal.TsiHandshakeHandler.TsiHandshakeCompletionEvent;
import io.grpc.alts.internal.TsiPeer.Property;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.Unpooled;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.util.CharsetUtil;
import java.security.GeneralSecurityException;
import java.util.ArrayList;
import java.util.List;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.DisableOnDebug;
import org.junit.rules.TestRule;
import org.junit.rules.Timeout;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Unit tests for {@link TsiFrameHandler}. */
@RunWith(JUnit4.class)
public class TsiFrameHandlerTest {
@Rule
public final TestRule globalTimeout = new DisableOnDebug(Timeout.seconds(5));
private final TsiFrameHandler tsiFrameHandler = new TsiFrameHandler();
private final EmbeddedChannel channel = new EmbeddedChannel(tsiFrameHandler);
@Test
public void writeAndFlush_beforeHandshakeEventShouldBeIgnored() {
ByteBuf msg = Unpooled.copiedBuffer("message before handshake finished", CharsetUtil.UTF_8);
channel.writeAndFlush(msg);
assertThat(channel.outboundMessages()).isEmpty();
try {
channel.checkException();
fail();
} catch (IllegalStateException e) {
assertThat(e).hasMessageThat().contains(State.HANDSHAKE_NOT_FINISHED.name());
}
}
@Test
public void writeAndFlush_handshakeSucceed() throws InterruptedException {
channel.pipeline().fireUserEventTriggered(getHandshakeSuccessEvent());
ByteBuf msg = Unpooled.copiedBuffer("message after handshake finished", CharsetUtil.UTF_8);
channel.writeAndFlush(msg);
assertThat(channel.readOutbound()).isEqualTo(msg);
channel.close().sync();
channel.checkException();
}
@Test
public void writeAndFlush_shouldBeIgnoredAfterClose() throws InterruptedException {
channel.close().sync();
ByteBuf msg = Unpooled.copiedBuffer("message after closed", CharsetUtil.UTF_8);
channel.writeAndFlush(msg);
assertThat(channel.outboundMessages()).isEmpty();
try {
channel.checkException();
} catch (Exception e) {
throw new AssertionError("Any attempt after close should be ignored without out exception");
}
}
@Test
public void writeAndFlush_handshakeFailed() throws InterruptedException {
channel.pipeline().fireUserEventTriggered(new TsiHandshakeCompletionEvent(new Exception()));
ByteBuf msg = Unpooled.copiedBuffer("message after handshake failed", CharsetUtil.UTF_8);
channel.writeAndFlush(msg);
assertThat(channel.outboundMessages()).isEmpty();
channel.close().sync();
channel.checkException();
}
@Test
public void close_shouldFlushRemainingMessage() throws InterruptedException {
channel.pipeline().fireUserEventTriggered(getHandshakeSuccessEvent());
ByteBuf msg = Unpooled.copiedBuffer("message after handshake failed", CharsetUtil.UTF_8);
channel.write(msg);
assertThat(channel.outboundMessages()).isEmpty();
channel.close().sync();
assertWithMessage("pending write should be flushed on close")
.that(channel.readOutbound()).isEqualTo(msg);
channel.checkException();
}
private TsiHandshakeCompletionEvent getHandshakeSuccessEvent() {
TsiFrameProtector protector = new IdentityFrameProtector();
TsiPeer peer = new TsiPeer(new ArrayList<Property<?>>());
return new TsiHandshakeCompletionEvent(protector, peer, new Object());
}
private static final class IdentityFrameProtector implements TsiFrameProtector {
@Override
public void protectFlush(List<ByteBuf> unprotectedBufs, Consumer<ByteBuf> ctxWrite,
ByteBufAllocator alloc) throws GeneralSecurityException {
for (ByteBuf unprotectedBuf : unprotectedBufs) {
ctxWrite.accept(unprotectedBuf);
}
}
@Override
public void unprotect(ByteBuf in, List<Object> out, ByteBufAllocator alloc)
throws GeneralSecurityException {
out.add(in.toString(CharsetUtil.UTF_8));
}
@Override
public void destroy() {}
}
}