mirror of https://github.com/grpc/grpc-java.git
alts: TsiFrameHandler doesn't throw exception when flush after closed (#5180)
also, error / log messages will contain state of FrameHandler
This commit is contained in:
parent
87cf40437c
commit
9eeceab597
|
|
@ -47,6 +47,16 @@ public final class TsiFrameHandler extends ByteToMessageDecoder implements Chann
|
|||
|
||||
private TsiFrameProtector protector;
|
||||
private PendingWriteQueue pendingUnprotectedWrites;
|
||||
private State state = State.HANDSHAKE_NOT_FINISHED;
|
||||
private boolean closeInitiated = false;
|
||||
|
||||
@VisibleForTesting
|
||||
enum State {
|
||||
HANDSHAKE_NOT_FINISHED,
|
||||
PROTECTED,
|
||||
CLOSED,
|
||||
HANDSHAKE_FAILED
|
||||
}
|
||||
|
||||
public TsiFrameHandler() {}
|
||||
|
||||
|
|
@ -67,6 +77,8 @@ public final class TsiFrameHandler extends ByteToMessageDecoder implements Chann
|
|||
TsiHandshakeCompletionEvent tsiEvent = (TsiHandshakeCompletionEvent) event;
|
||||
if (tsiEvent.isSuccess()) {
|
||||
setProtector(tsiEvent.protector());
|
||||
} else {
|
||||
state = State.HANDSHAKE_FAILED;
|
||||
}
|
||||
// Ignore errors. Another handler in the pipeline must handle TSI Errors.
|
||||
}
|
||||
|
|
@ -79,18 +91,23 @@ public final class TsiFrameHandler extends ByteToMessageDecoder implements Chann
|
|||
logger.finest("TsiFrameHandler protector set");
|
||||
checkState(this.protector == null);
|
||||
this.protector = checkNotNull(protector);
|
||||
this.state = State.PROTECTED;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
|
||||
checkState(protector != null, "Cannot read frames while the TSI handshake is in progress");
|
||||
checkState(
|
||||
state == State.PROTECTED,
|
||||
"Cannot read frames while the TSI handshake is %s", state);
|
||||
protector.unprotect(in, out, ctx.alloc());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void write(ChannelHandlerContext ctx, Object message, ChannelPromise promise)
|
||||
throws Exception {
|
||||
checkState(protector != null, "Cannot write frames while the TSI handshake is in progress");
|
||||
checkState(
|
||||
state == State.PROTECTED,
|
||||
"Cannot write frames while the TSI handshake state is %s", state);
|
||||
ByteBuf msg = (ByteBuf) message;
|
||||
if (!msg.isReadable()) {
|
||||
// Nothing to encode.
|
||||
|
|
@ -104,8 +121,7 @@ public final class TsiFrameHandler extends ByteToMessageDecoder implements Chann
|
|||
}
|
||||
|
||||
@Override
|
||||
public void handlerRemoved0(ChannelHandlerContext ctx) {
|
||||
logger.finest("TsiFrameHandler removed");
|
||||
public void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
|
||||
if (!pendingUnprotectedWrites.isEmpty()) {
|
||||
pendingUnprotectedWrites.removeAndFailAll(
|
||||
new ChannelException("Pending write on removal of TSI handler"));
|
||||
|
|
@ -134,19 +150,37 @@ public final class TsiFrameHandler extends ByteToMessageDecoder implements Chann
|
|||
|
||||
@Override
|
||||
public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) {
|
||||
release();
|
||||
doClose(ctx);
|
||||
ctx.disconnect(promise);
|
||||
}
|
||||
|
||||
private void doClose(ChannelHandlerContext ctx) {
|
||||
if (closeInitiated) {
|
||||
return;
|
||||
}
|
||||
closeInitiated = true;
|
||||
try {
|
||||
// flush any remaining writes before close
|
||||
if (!pendingUnprotectedWrites.isEmpty()) {
|
||||
flush(ctx);
|
||||
}
|
||||
} catch (GeneralSecurityException e) {
|
||||
logger.log(Level.FINE, "Ignoring error on flush before close", e);
|
||||
} finally {
|
||||
state = State.CLOSED;
|
||||
release();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close(ChannelHandlerContext ctx, ChannelPromise promise) {
|
||||
release();
|
||||
doClose(ctx);
|
||||
ctx.close(promise);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) {
|
||||
release();
|
||||
doClose(ctx);
|
||||
ctx.deregister(promise);
|
||||
}
|
||||
|
||||
|
|
@ -157,7 +191,14 @@ public final class TsiFrameHandler extends ByteToMessageDecoder implements Chann
|
|||
|
||||
@Override
|
||||
public void flush(final ChannelHandlerContext ctx) throws GeneralSecurityException {
|
||||
checkState(protector != null, "Cannot write frames while the TSI handshake is in progress");
|
||||
if (state == State.CLOSED || state == State.HANDSHAKE_FAILED) {
|
||||
logger.fine(
|
||||
String.format("FrameHandler is inactive(%s), channel id: %s",
|
||||
state, ctx.channel().id().asShortText()));
|
||||
return;
|
||||
}
|
||||
checkState(
|
||||
state == State.PROTECTED, "Cannot write frames while the TSI handshake state is %s", state);
|
||||
final ProtectedPromise aggregatePromise =
|
||||
new ProtectedPromise(ctx.channel(), ctx.executor(), pendingUnprotectedWrites.size());
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,147 @@
|
|||
/*
|
||||
* Copyright 2018 The gRPC Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.alts.internal;
|
||||
|
||||
import static com.google.common.truth.Truth.assertThat;
|
||||
import static com.google.common.truth.Truth.assertWithMessage;
|
||||
import static org.junit.Assert.fail;
|
||||
|
||||
import io.grpc.alts.internal.TsiFrameHandler.State;
|
||||
import io.grpc.alts.internal.TsiHandshakeHandler.TsiHandshakeCompletionEvent;
|
||||
import io.grpc.alts.internal.TsiPeer.Property;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.ByteBufAllocator;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.embedded.EmbeddedChannel;
|
||||
import io.netty.util.CharsetUtil;
|
||||
import java.security.GeneralSecurityException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.DisableOnDebug;
|
||||
import org.junit.rules.TestRule;
|
||||
import org.junit.rules.Timeout;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.JUnit4;
|
||||
|
||||
/** Unit tests for {@link TsiFrameHandler}. */
|
||||
@RunWith(JUnit4.class)
|
||||
public class TsiFrameHandlerTest {
|
||||
|
||||
@Rule
|
||||
public final TestRule globalTimeout = new DisableOnDebug(Timeout.seconds(5));
|
||||
|
||||
private final TsiFrameHandler tsiFrameHandler = new TsiFrameHandler();
|
||||
private final EmbeddedChannel channel = new EmbeddedChannel(tsiFrameHandler);
|
||||
|
||||
@Test
|
||||
public void writeAndFlush_beforeHandshakeEventShouldBeIgnored() {
|
||||
ByteBuf msg = Unpooled.copiedBuffer("message before handshake finished", CharsetUtil.UTF_8);
|
||||
|
||||
channel.writeAndFlush(msg);
|
||||
|
||||
assertThat(channel.outboundMessages()).isEmpty();
|
||||
try {
|
||||
channel.checkException();
|
||||
fail();
|
||||
} catch (IllegalStateException e) {
|
||||
assertThat(e).hasMessageThat().contains(State.HANDSHAKE_NOT_FINISHED.name());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void writeAndFlush_handshakeSucceed() throws InterruptedException {
|
||||
channel.pipeline().fireUserEventTriggered(getHandshakeSuccessEvent());
|
||||
ByteBuf msg = Unpooled.copiedBuffer("message after handshake finished", CharsetUtil.UTF_8);
|
||||
|
||||
channel.writeAndFlush(msg);
|
||||
|
||||
assertThat(channel.readOutbound()).isEqualTo(msg);
|
||||
channel.close().sync();
|
||||
channel.checkException();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void writeAndFlush_shouldBeIgnoredAfterClose() throws InterruptedException {
|
||||
channel.close().sync();
|
||||
ByteBuf msg = Unpooled.copiedBuffer("message after closed", CharsetUtil.UTF_8);
|
||||
|
||||
channel.writeAndFlush(msg);
|
||||
|
||||
assertThat(channel.outboundMessages()).isEmpty();
|
||||
try {
|
||||
channel.checkException();
|
||||
} catch (Exception e) {
|
||||
throw new AssertionError("Any attempt after close should be ignored without out exception");
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void writeAndFlush_handshakeFailed() throws InterruptedException {
|
||||
channel.pipeline().fireUserEventTriggered(new TsiHandshakeCompletionEvent(new Exception()));
|
||||
ByteBuf msg = Unpooled.copiedBuffer("message after handshake failed", CharsetUtil.UTF_8);
|
||||
|
||||
channel.writeAndFlush(msg);
|
||||
|
||||
assertThat(channel.outboundMessages()).isEmpty();
|
||||
channel.close().sync();
|
||||
channel.checkException();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void close_shouldFlushRemainingMessage() throws InterruptedException {
|
||||
channel.pipeline().fireUserEventTriggered(getHandshakeSuccessEvent());
|
||||
|
||||
ByteBuf msg = Unpooled.copiedBuffer("message after handshake failed", CharsetUtil.UTF_8);
|
||||
channel.write(msg);
|
||||
|
||||
assertThat(channel.outboundMessages()).isEmpty();
|
||||
|
||||
channel.close().sync();
|
||||
|
||||
assertWithMessage("pending write should be flushed on close")
|
||||
.that(channel.readOutbound()).isEqualTo(msg);
|
||||
channel.checkException();
|
||||
}
|
||||
|
||||
private TsiHandshakeCompletionEvent getHandshakeSuccessEvent() {
|
||||
TsiFrameProtector protector = new IdentityFrameProtector();
|
||||
TsiPeer peer = new TsiPeer(new ArrayList<Property<?>>());
|
||||
return new TsiHandshakeCompletionEvent(protector, peer, new Object());
|
||||
}
|
||||
|
||||
private static final class IdentityFrameProtector implements TsiFrameProtector {
|
||||
|
||||
@Override
|
||||
public void protectFlush(List<ByteBuf> unprotectedBufs, Consumer<ByteBuf> ctxWrite,
|
||||
ByteBufAllocator alloc) throws GeneralSecurityException {
|
||||
for (ByteBuf unprotectedBuf : unprotectedBufs) {
|
||||
ctxWrite.accept(unprotectedBuf);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void unprotect(ByteBuf in, List<Object> out, ByteBufAllocator alloc)
|
||||
throws GeneralSecurityException {
|
||||
out.add(in.toString(CharsetUtil.UTF_8));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void destroy() {}
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue