Call Cipher APIs with non-direct ByteBuffers and perform copies in the ALTS code. (cl/308901367)

This commit is contained in:
Esun Kim 2020-04-30 16:33:40 -07:00 committed by Eric Anderson
parent d71161432a
commit a7bca23053
1 changed files with 51 additions and 40 deletions

View File

@ -17,10 +17,10 @@
package io.grpc.alts.internal; package io.grpc.alts.internal;
import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Verify.verify;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.security.GeneralSecurityException; import java.security.GeneralSecurityException;
import java.util.List; import java.util.List;
@ -56,61 +56,72 @@ final class AltsChannelCrypter implements ChannelCrypterNetty {
@Override @Override
public void encrypt(ByteBuf outBuf, List<ByteBuf> plainBufs) throws GeneralSecurityException { public void encrypt(ByteBuf outBuf, List<ByteBuf> plainBufs) throws GeneralSecurityException {
checkArgument(outBuf.nioBufferCount() == 1); byte[] tempArr = new byte[outBuf.writableBytes()];
// Copy plaintext buffers into outBuf for in-place encryption on single direct buffer.
ByteBuf plainBuf = outBuf.slice(outBuf.writerIndex(), outBuf.writableBytes()); // Copy plaintext into tempArr.
plainBuf.writerIndex(0); {
for (ByteBuf inBuf : plainBufs) { ByteBuf tempBuf = Unpooled.wrappedBuffer(tempArr, 0, tempArr.length - TAG_LENGTH);
plainBuf.writeBytes(inBuf); tempBuf.resetWriterIndex();
for (ByteBuf plainBuf : plainBufs) {
tempBuf.writeBytes(plainBuf);
}
} }
verify(outBuf.writableBytes() == plainBuf.readableBytes() + TAG_LENGTH); // Encrypt into tempArr.
ByteBuffer out = outBuf.internalNioBuffer(outBuf.writerIndex(), outBuf.writableBytes()); {
ByteBuffer plain = out.duplicate(); ByteBuffer out = ByteBuffer.wrap(tempArr);
plain.limit(out.limit() - TAG_LENGTH); ByteBuffer plain = ByteBuffer.wrap(tempArr, 0, tempArr.length - TAG_LENGTH);
byte[] counter = incrementOutCounter(); byte[] counter = incrementOutCounter();
int outPosition = out.position(); aeadCrypter.encrypt(out, plain, counter);
aeadCrypter.encrypt(out, plain, counter); }
int bytesWritten = out.position() - outPosition; outBuf.writeBytes(tempArr);
outBuf.writerIndex(outBuf.writerIndex() + bytesWritten);
verify(!outBuf.isWritable());
} }
@Override @Override
public void decrypt(ByteBuf out, ByteBuf tag, List<ByteBuf> ciphertextBufs) public void decrypt(ByteBuf outBuf, ByteBuf tagBuf, List<ByteBuf> ciphertextBufs)
throws GeneralSecurityException { throws GeneralSecurityException {
// There is enough space for the ciphertext including the tag in outBuf.
byte[] tempArr = new byte[outBuf.writableBytes()];
ByteBuf cipherTextAndTag = out.slice(out.writerIndex(), out.writableBytes()); // Copy ciphertext and tag into tempArr.
cipherTextAndTag.writerIndex(0); {
ByteBuf tempBuf = Unpooled.wrappedBuffer(tempArr);
for (ByteBuf inBuf : ciphertextBufs) { tempBuf.resetWriterIndex();
cipherTextAndTag.writeBytes(inBuf); for (ByteBuf ciphertextBuf : ciphertextBufs) {
tempBuf.writeBytes(ciphertextBuf);
}
tempBuf.writeBytes(tagBuf);
} }
cipherTextAndTag.writeBytes(tag);
decrypt(out, cipherTextAndTag); decryptInternal(outBuf, tempArr);
} }
@Override @Override
public void decrypt(ByteBuf out, ByteBuf ciphertextAndTag) throws GeneralSecurityException { public void decrypt(
int bytesRead = ciphertextAndTag.readableBytes(); ByteBuf outBuf, ByteBuf ciphertextAndTagDirect) throws GeneralSecurityException {
checkArgument(bytesRead == out.writableBytes()); byte[] tempArr = new byte[ciphertextAndTagDirect.readableBytes()];
checkArgument(out.nioBufferCount() == 1); // Copy ciphertext and tag into tempArr.
ByteBuffer outBuffer = out.internalNioBuffer(out.writerIndex(), out.writableBytes()); {
ByteBuf tempBuf = Unpooled.wrappedBuffer(tempArr);
tempBuf.resetWriterIndex();
tempBuf.writeBytes(ciphertextAndTagDirect);
}
checkArgument(ciphertextAndTag.nioBufferCount() == 1); decryptInternal(outBuf, tempArr);
ByteBuffer ciphertextAndTagBuffer = }
ciphertextAndTag.nioBuffer(ciphertextAndTag.readerIndex(), bytesRead);
byte[] counter = incrementInCounter(); private void decryptInternal(ByteBuf outBuf, byte[] tempArr) throws GeneralSecurityException {
int outPosition = outBuffer.position(); // Perform in-place decryption on tempArr.
aeadCrypter.decrypt(outBuffer, ciphertextAndTagBuffer, counter); {
int bytesWritten = outBuffer.position() - outPosition; ByteBuffer ciphertextAndTag = ByteBuffer.wrap(tempArr);
out.writerIndex(out.writerIndex() + bytesWritten); ByteBuffer out = ByteBuffer.wrap(tempArr);
ciphertextAndTag.readerIndex(out.readerIndex() + bytesRead); byte[] counter = incrementInCounter();
verify(out.writableBytes() == TAG_LENGTH); aeadCrypter.decrypt(out, ciphertextAndTag, counter);
}
outBuf.writeBytes(tempArr, 0, tempArr.length - TAG_LENGTH);
} }
@Override @Override