diff --git a/alts/src/main/java/io/grpc/alts/internal/AltsChannelCrypter.java b/alts/src/main/java/io/grpc/alts/internal/AltsChannelCrypter.java index 5e4c6fec30..e47433ff03 100644 --- a/alts/src/main/java/io/grpc/alts/internal/AltsChannelCrypter.java +++ b/alts/src/main/java/io/grpc/alts/internal/AltsChannelCrypter.java @@ -17,10 +17,10 @@ package io.grpc.alts.internal; import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Verify.verify; import com.google.common.annotations.VisibleForTesting; import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; import java.nio.ByteBuffer; import java.security.GeneralSecurityException; import java.util.List; @@ -56,61 +56,72 @@ final class AltsChannelCrypter implements ChannelCrypterNetty { @Override public void encrypt(ByteBuf outBuf, List plainBufs) throws GeneralSecurityException { - checkArgument(outBuf.nioBufferCount() == 1); - // Copy plaintext buffers into outBuf for in-place encryption on single direct buffer. - ByteBuf plainBuf = outBuf.slice(outBuf.writerIndex(), outBuf.writableBytes()); - plainBuf.writerIndex(0); - for (ByteBuf inBuf : plainBufs) { - plainBuf.writeBytes(inBuf); + byte[] tempArr = new byte[outBuf.writableBytes()]; + + // Copy plaintext into tempArr. + { + ByteBuf tempBuf = Unpooled.wrappedBuffer(tempArr, 0, tempArr.length - TAG_LENGTH); + tempBuf.resetWriterIndex(); + for (ByteBuf plainBuf : plainBufs) { + tempBuf.writeBytes(plainBuf); + } } - verify(outBuf.writableBytes() == plainBuf.readableBytes() + TAG_LENGTH); - ByteBuffer out = outBuf.internalNioBuffer(outBuf.writerIndex(), outBuf.writableBytes()); - ByteBuffer plain = out.duplicate(); - plain.limit(out.limit() - TAG_LENGTH); + // Encrypt into tempArr. + { + ByteBuffer out = ByteBuffer.wrap(tempArr); + ByteBuffer plain = ByteBuffer.wrap(tempArr, 0, tempArr.length - TAG_LENGTH); - byte[] counter = incrementOutCounter(); - int outPosition = out.position(); - aeadCrypter.encrypt(out, plain, counter); - int bytesWritten = out.position() - outPosition; - outBuf.writerIndex(outBuf.writerIndex() + bytesWritten); - verify(!outBuf.isWritable()); + byte[] counter = incrementOutCounter(); + aeadCrypter.encrypt(out, plain, counter); + } + outBuf.writeBytes(tempArr); } @Override - public void decrypt(ByteBuf out, ByteBuf tag, List ciphertextBufs) + public void decrypt(ByteBuf outBuf, ByteBuf tagBuf, List ciphertextBufs) throws GeneralSecurityException { + // There is enough space for the ciphertext including the tag in outBuf. + byte[] tempArr = new byte[outBuf.writableBytes()]; - ByteBuf cipherTextAndTag = out.slice(out.writerIndex(), out.writableBytes()); - cipherTextAndTag.writerIndex(0); - - for (ByteBuf inBuf : ciphertextBufs) { - cipherTextAndTag.writeBytes(inBuf); + // Copy ciphertext and tag into tempArr. + { + ByteBuf tempBuf = Unpooled.wrappedBuffer(tempArr); + tempBuf.resetWriterIndex(); + for (ByteBuf ciphertextBuf : ciphertextBufs) { + tempBuf.writeBytes(ciphertextBuf); + } + tempBuf.writeBytes(tagBuf); } - cipherTextAndTag.writeBytes(tag); - decrypt(out, cipherTextAndTag); + decryptInternal(outBuf, tempArr); } @Override - public void decrypt(ByteBuf out, ByteBuf ciphertextAndTag) throws GeneralSecurityException { - int bytesRead = ciphertextAndTag.readableBytes(); - checkArgument(bytesRead == out.writableBytes()); + public void decrypt( + ByteBuf outBuf, ByteBuf ciphertextAndTagDirect) throws GeneralSecurityException { + byte[] tempArr = new byte[ciphertextAndTagDirect.readableBytes()]; - checkArgument(out.nioBufferCount() == 1); - ByteBuffer outBuffer = out.internalNioBuffer(out.writerIndex(), out.writableBytes()); + // Copy ciphertext and tag into tempArr. + { + ByteBuf tempBuf = Unpooled.wrappedBuffer(tempArr); + tempBuf.resetWriterIndex(); + tempBuf.writeBytes(ciphertextAndTagDirect); + } - checkArgument(ciphertextAndTag.nioBufferCount() == 1); - ByteBuffer ciphertextAndTagBuffer = - ciphertextAndTag.nioBuffer(ciphertextAndTag.readerIndex(), bytesRead); + decryptInternal(outBuf, tempArr); + } - byte[] counter = incrementInCounter(); - int outPosition = outBuffer.position(); - aeadCrypter.decrypt(outBuffer, ciphertextAndTagBuffer, counter); - int bytesWritten = outBuffer.position() - outPosition; - out.writerIndex(out.writerIndex() + bytesWritten); - ciphertextAndTag.readerIndex(out.readerIndex() + bytesRead); - verify(out.writableBytes() == TAG_LENGTH); + private void decryptInternal(ByteBuf outBuf, byte[] tempArr) throws GeneralSecurityException { + // Perform in-place decryption on tempArr. + { + ByteBuffer ciphertextAndTag = ByteBuffer.wrap(tempArr); + ByteBuffer out = ByteBuffer.wrap(tempArr); + byte[] counter = incrementInCounter(); + aeadCrypter.decrypt(out, ciphertextAndTag, counter); + } + + outBuf.writeBytes(tempArr, 0, tempArr.length - TAG_LENGTH); } @Override