mirror of https://github.com/grpc/grpc-java.git
alts: TsiFrameHandler doesn't throw exception when flush after closed (#5180)
also, error / log messages will contain state of FrameHandler
This commit is contained in:
parent
87cf40437c
commit
9eeceab597
|
|
@ -47,6 +47,16 @@ public final class TsiFrameHandler extends ByteToMessageDecoder implements Chann
|
||||||
|
|
||||||
private TsiFrameProtector protector;
|
private TsiFrameProtector protector;
|
||||||
private PendingWriteQueue pendingUnprotectedWrites;
|
private PendingWriteQueue pendingUnprotectedWrites;
|
||||||
|
private State state = State.HANDSHAKE_NOT_FINISHED;
|
||||||
|
private boolean closeInitiated = false;
|
||||||
|
|
||||||
|
@VisibleForTesting
|
||||||
|
enum State {
|
||||||
|
HANDSHAKE_NOT_FINISHED,
|
||||||
|
PROTECTED,
|
||||||
|
CLOSED,
|
||||||
|
HANDSHAKE_FAILED
|
||||||
|
}
|
||||||
|
|
||||||
public TsiFrameHandler() {}
|
public TsiFrameHandler() {}
|
||||||
|
|
||||||
|
|
@ -67,6 +77,8 @@ public final class TsiFrameHandler extends ByteToMessageDecoder implements Chann
|
||||||
TsiHandshakeCompletionEvent tsiEvent = (TsiHandshakeCompletionEvent) event;
|
TsiHandshakeCompletionEvent tsiEvent = (TsiHandshakeCompletionEvent) event;
|
||||||
if (tsiEvent.isSuccess()) {
|
if (tsiEvent.isSuccess()) {
|
||||||
setProtector(tsiEvent.protector());
|
setProtector(tsiEvent.protector());
|
||||||
|
} else {
|
||||||
|
state = State.HANDSHAKE_FAILED;
|
||||||
}
|
}
|
||||||
// Ignore errors. Another handler in the pipeline must handle TSI Errors.
|
// Ignore errors. Another handler in the pipeline must handle TSI Errors.
|
||||||
}
|
}
|
||||||
|
|
@ -79,18 +91,23 @@ public final class TsiFrameHandler extends ByteToMessageDecoder implements Chann
|
||||||
logger.finest("TsiFrameHandler protector set");
|
logger.finest("TsiFrameHandler protector set");
|
||||||
checkState(this.protector == null);
|
checkState(this.protector == null);
|
||||||
this.protector = checkNotNull(protector);
|
this.protector = checkNotNull(protector);
|
||||||
|
this.state = State.PROTECTED;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
|
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
|
||||||
checkState(protector != null, "Cannot read frames while the TSI handshake is in progress");
|
checkState(
|
||||||
|
state == State.PROTECTED,
|
||||||
|
"Cannot read frames while the TSI handshake is %s", state);
|
||||||
protector.unprotect(in, out, ctx.alloc());
|
protector.unprotect(in, out, ctx.alloc());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void write(ChannelHandlerContext ctx, Object message, ChannelPromise promise)
|
public void write(ChannelHandlerContext ctx, Object message, ChannelPromise promise)
|
||||||
throws Exception {
|
throws Exception {
|
||||||
checkState(protector != null, "Cannot write frames while the TSI handshake is in progress");
|
checkState(
|
||||||
|
state == State.PROTECTED,
|
||||||
|
"Cannot write frames while the TSI handshake state is %s", state);
|
||||||
ByteBuf msg = (ByteBuf) message;
|
ByteBuf msg = (ByteBuf) message;
|
||||||
if (!msg.isReadable()) {
|
if (!msg.isReadable()) {
|
||||||
// Nothing to encode.
|
// Nothing to encode.
|
||||||
|
|
@ -104,8 +121,7 @@ public final class TsiFrameHandler extends ByteToMessageDecoder implements Chann
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void handlerRemoved0(ChannelHandlerContext ctx) {
|
public void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
|
||||||
logger.finest("TsiFrameHandler removed");
|
|
||||||
if (!pendingUnprotectedWrites.isEmpty()) {
|
if (!pendingUnprotectedWrites.isEmpty()) {
|
||||||
pendingUnprotectedWrites.removeAndFailAll(
|
pendingUnprotectedWrites.removeAndFailAll(
|
||||||
new ChannelException("Pending write on removal of TSI handler"));
|
new ChannelException("Pending write on removal of TSI handler"));
|
||||||
|
|
@ -134,19 +150,37 @@ public final class TsiFrameHandler extends ByteToMessageDecoder implements Chann
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) {
|
public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) {
|
||||||
release();
|
doClose(ctx);
|
||||||
ctx.disconnect(promise);
|
ctx.disconnect(promise);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private void doClose(ChannelHandlerContext ctx) {
|
||||||
|
if (closeInitiated) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
closeInitiated = true;
|
||||||
|
try {
|
||||||
|
// flush any remaining writes before close
|
||||||
|
if (!pendingUnprotectedWrites.isEmpty()) {
|
||||||
|
flush(ctx);
|
||||||
|
}
|
||||||
|
} catch (GeneralSecurityException e) {
|
||||||
|
logger.log(Level.FINE, "Ignoring error on flush before close", e);
|
||||||
|
} finally {
|
||||||
|
state = State.CLOSED;
|
||||||
|
release();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void close(ChannelHandlerContext ctx, ChannelPromise promise) {
|
public void close(ChannelHandlerContext ctx, ChannelPromise promise) {
|
||||||
release();
|
doClose(ctx);
|
||||||
ctx.close(promise);
|
ctx.close(promise);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) {
|
public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) {
|
||||||
release();
|
doClose(ctx);
|
||||||
ctx.deregister(promise);
|
ctx.deregister(promise);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -157,7 +191,14 @@ public final class TsiFrameHandler extends ByteToMessageDecoder implements Chann
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void flush(final ChannelHandlerContext ctx) throws GeneralSecurityException {
|
public void flush(final ChannelHandlerContext ctx) throws GeneralSecurityException {
|
||||||
checkState(protector != null, "Cannot write frames while the TSI handshake is in progress");
|
if (state == State.CLOSED || state == State.HANDSHAKE_FAILED) {
|
||||||
|
logger.fine(
|
||||||
|
String.format("FrameHandler is inactive(%s), channel id: %s",
|
||||||
|
state, ctx.channel().id().asShortText()));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
checkState(
|
||||||
|
state == State.PROTECTED, "Cannot write frames while the TSI handshake state is %s", state);
|
||||||
final ProtectedPromise aggregatePromise =
|
final ProtectedPromise aggregatePromise =
|
||||||
new ProtectedPromise(ctx.channel(), ctx.executor(), pendingUnprotectedWrites.size());
|
new ProtectedPromise(ctx.channel(), ctx.executor(), pendingUnprotectedWrites.size());
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,147 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2018 The gRPC Authors
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package io.grpc.alts.internal;
|
||||||
|
|
||||||
|
import static com.google.common.truth.Truth.assertThat;
|
||||||
|
import static com.google.common.truth.Truth.assertWithMessage;
|
||||||
|
import static org.junit.Assert.fail;
|
||||||
|
|
||||||
|
import io.grpc.alts.internal.TsiFrameHandler.State;
|
||||||
|
import io.grpc.alts.internal.TsiHandshakeHandler.TsiHandshakeCompletionEvent;
|
||||||
|
import io.grpc.alts.internal.TsiPeer.Property;
|
||||||
|
import io.netty.buffer.ByteBuf;
|
||||||
|
import io.netty.buffer.ByteBufAllocator;
|
||||||
|
import io.netty.buffer.Unpooled;
|
||||||
|
import io.netty.channel.embedded.EmbeddedChannel;
|
||||||
|
import io.netty.util.CharsetUtil;
|
||||||
|
import java.security.GeneralSecurityException;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import org.junit.Rule;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.junit.rules.DisableOnDebug;
|
||||||
|
import org.junit.rules.TestRule;
|
||||||
|
import org.junit.rules.Timeout;
|
||||||
|
import org.junit.runner.RunWith;
|
||||||
|
import org.junit.runners.JUnit4;
|
||||||
|
|
||||||
|
/** Unit tests for {@link TsiFrameHandler}. */
|
||||||
|
@RunWith(JUnit4.class)
|
||||||
|
public class TsiFrameHandlerTest {
|
||||||
|
|
||||||
|
@Rule
|
||||||
|
public final TestRule globalTimeout = new DisableOnDebug(Timeout.seconds(5));
|
||||||
|
|
||||||
|
private final TsiFrameHandler tsiFrameHandler = new TsiFrameHandler();
|
||||||
|
private final EmbeddedChannel channel = new EmbeddedChannel(tsiFrameHandler);
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void writeAndFlush_beforeHandshakeEventShouldBeIgnored() {
|
||||||
|
ByteBuf msg = Unpooled.copiedBuffer("message before handshake finished", CharsetUtil.UTF_8);
|
||||||
|
|
||||||
|
channel.writeAndFlush(msg);
|
||||||
|
|
||||||
|
assertThat(channel.outboundMessages()).isEmpty();
|
||||||
|
try {
|
||||||
|
channel.checkException();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalStateException e) {
|
||||||
|
assertThat(e).hasMessageThat().contains(State.HANDSHAKE_NOT_FINISHED.name());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void writeAndFlush_handshakeSucceed() throws InterruptedException {
|
||||||
|
channel.pipeline().fireUserEventTriggered(getHandshakeSuccessEvent());
|
||||||
|
ByteBuf msg = Unpooled.copiedBuffer("message after handshake finished", CharsetUtil.UTF_8);
|
||||||
|
|
||||||
|
channel.writeAndFlush(msg);
|
||||||
|
|
||||||
|
assertThat(channel.readOutbound()).isEqualTo(msg);
|
||||||
|
channel.close().sync();
|
||||||
|
channel.checkException();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void writeAndFlush_shouldBeIgnoredAfterClose() throws InterruptedException {
|
||||||
|
channel.close().sync();
|
||||||
|
ByteBuf msg = Unpooled.copiedBuffer("message after closed", CharsetUtil.UTF_8);
|
||||||
|
|
||||||
|
channel.writeAndFlush(msg);
|
||||||
|
|
||||||
|
assertThat(channel.outboundMessages()).isEmpty();
|
||||||
|
try {
|
||||||
|
channel.checkException();
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new AssertionError("Any attempt after close should be ignored without out exception");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void writeAndFlush_handshakeFailed() throws InterruptedException {
|
||||||
|
channel.pipeline().fireUserEventTriggered(new TsiHandshakeCompletionEvent(new Exception()));
|
||||||
|
ByteBuf msg = Unpooled.copiedBuffer("message after handshake failed", CharsetUtil.UTF_8);
|
||||||
|
|
||||||
|
channel.writeAndFlush(msg);
|
||||||
|
|
||||||
|
assertThat(channel.outboundMessages()).isEmpty();
|
||||||
|
channel.close().sync();
|
||||||
|
channel.checkException();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void close_shouldFlushRemainingMessage() throws InterruptedException {
|
||||||
|
channel.pipeline().fireUserEventTriggered(getHandshakeSuccessEvent());
|
||||||
|
|
||||||
|
ByteBuf msg = Unpooled.copiedBuffer("message after handshake failed", CharsetUtil.UTF_8);
|
||||||
|
channel.write(msg);
|
||||||
|
|
||||||
|
assertThat(channel.outboundMessages()).isEmpty();
|
||||||
|
|
||||||
|
channel.close().sync();
|
||||||
|
|
||||||
|
assertWithMessage("pending write should be flushed on close")
|
||||||
|
.that(channel.readOutbound()).isEqualTo(msg);
|
||||||
|
channel.checkException();
|
||||||
|
}
|
||||||
|
|
||||||
|
private TsiHandshakeCompletionEvent getHandshakeSuccessEvent() {
|
||||||
|
TsiFrameProtector protector = new IdentityFrameProtector();
|
||||||
|
TsiPeer peer = new TsiPeer(new ArrayList<Property<?>>());
|
||||||
|
return new TsiHandshakeCompletionEvent(protector, peer, new Object());
|
||||||
|
}
|
||||||
|
|
||||||
|
private static final class IdentityFrameProtector implements TsiFrameProtector {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void protectFlush(List<ByteBuf> unprotectedBufs, Consumer<ByteBuf> ctxWrite,
|
||||||
|
ByteBufAllocator alloc) throws GeneralSecurityException {
|
||||||
|
for (ByteBuf unprotectedBuf : unprotectedBufs) {
|
||||||
|
ctxWrite.accept(unprotectedBuf);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void unprotect(ByteBuf in, List<Object> out, ByteBufAllocator alloc)
|
||||||
|
throws GeneralSecurityException {
|
||||||
|
out.add(in.toString(CharsetUtil.UTF_8));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void destroy() {}
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue