grpc-js: Add max message size enforcement

This commit is contained in:
Michael Lumish 2020-04-08 14:37:03 -07:00
parent 7eef1efd65
commit 9221fdea24
8 changed files with 288 additions and 7 deletions

View File

@ -1,6 +1,6 @@
{ {
"name": "@grpc/grpc-js", "name": "@grpc/grpc-js",
"version": "0.7.7", "version": "0.8.0",
"description": "gRPC Library for Node - pure JS implementation", "description": "gRPC Library for Node - pure JS implementation",
"homepage": "https://grpc.io/", "homepage": "https://grpc.io/",
"repository": "https://github.com/grpc/grpc-node/tree/master/packages/grpc-js", "repository": "https://github.com/grpc/grpc-node/tree/master/packages/grpc-js",

View File

@ -30,6 +30,8 @@ export interface ChannelOptions {
'grpc.initial_reconnect_backoff_ms'?: number; 'grpc.initial_reconnect_backoff_ms'?: number;
'grpc.max_reconnect_backoff_ms'?: number; 'grpc.max_reconnect_backoff_ms'?: number;
'grpc.use_local_subchannel_pool'?: number; 'grpc.use_local_subchannel_pool'?: number;
'grpc.max_send_message_length'?: number;
'grpc.max_receive_message_length'?: number;
[key: string]: string | number | undefined; [key: string]: string | number | undefined;
} }
@ -49,6 +51,8 @@ export const recognizedOptions = {
'grpc.initial_reconnect_backoff_ms': true, 'grpc.initial_reconnect_backoff_ms': true,
'grpc.max_reconnect_backoff_ms': true, 'grpc.max_reconnect_backoff_ms': true,
'grpc.use_local_subchannel_pool': true, 'grpc.use_local_subchannel_pool': true,
'grpc.max_send_message_length': true,
'grpc.max_receive_message_length': true,
}; };
export function channelOptionsEqual( export function channelOptionsEqual(

View File

@ -38,6 +38,7 @@ import { LoadBalancingConfig } from './load-balancing-config';
import { ServiceConfig, validateServiceConfig } from './service-config'; import { ServiceConfig, validateServiceConfig } from './service-config';
import { trace, log } from './logging'; import { trace, log } from './logging';
import { SubchannelAddress } from './subchannel'; import { SubchannelAddress } from './subchannel';
import { MaxMessageSizeFilterFactory } from './max-message-size-filter';
export enum ConnectivityState { export enum ConnectivityState {
CONNECTING, CONNECTING,
@ -202,6 +203,7 @@ export class ChannelImplementation implements Channel {
this.filterStackFactory = new FilterStackFactory([ this.filterStackFactory = new FilterStackFactory([
new CallCredentialsFilterFactory(this), new CallCredentialsFilterFactory(this),
new DeadlineFilterFactory(this), new DeadlineFilterFactory(this),
new MaxMessageSizeFilterFactory(this.options),
new CompressionFilterFactory(this), new CompressionFilterFactory(this),
]); ]);
// TODO(murgatroid99): Add more centralized handling of channel options // TODO(murgatroid99): Add more centralized handling of channel options

View File

@ -0,0 +1,81 @@
/*
* Copyright 2020 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
import { BaseFilter, Filter, FilterFactory } from "./filter";
import { Call, WriteObject } from "./call-stream";
import { Status } from "./constants";
import { ChannelOptions } from "./channel-options";
// The default max message size for sending or receiving is 4 MB
const DEFAULT_MAX_MESSAGE_SIZE = 4 * 1024 * 1024;
export class MaxMessageSizeFilter extends BaseFilter implements Filter {
private maxSendMessageSize: number = DEFAULT_MAX_MESSAGE_SIZE;
private maxReceiveMessageSize: number = DEFAULT_MAX_MESSAGE_SIZE;
constructor(
private readonly options: ChannelOptions,
private readonly callStream: Call
) {
super();
if ('grpc.max_send_message_length' in options) {
this.maxSendMessageSize = options['grpc.max_send_message_length']!;
}
if ('grpc.max_receive_message_length' in options) {
this.maxReceiveMessageSize = options['grpc.max_receive_message_length']!;
}
}
async sendMessage(message: Promise<WriteObject>): Promise<WriteObject> {
/* A configured size of -1 means that there is no limit, so skip the check
* entirely */
if (this.maxSendMessageSize === -1) {
return message;
} else {
const concreteMessage = await message;
if (concreteMessage.message.length > this.maxSendMessageSize) {
this.callStream.cancelWithStatus(Status.RESOURCE_EXHAUSTED, `Failed to send message of size ${concreteMessage.message.length} > max size ${this.maxSendMessageSize}`);
return Promise.reject<WriteObject>('Message too large');
} else {
return concreteMessage;
}
}
}
async receiveMessage(message: Promise<Buffer>): Promise<Buffer> {
/* A configured size of -1 means that there is no limit, so skip the check
* entirely */
if (this.maxReceiveMessageSize === -1) {
return message;
} else {
const concreteMessage = await message;
if (concreteMessage.length > this.maxReceiveMessageSize) {
this.callStream.cancelWithStatus(Status.RESOURCE_EXHAUSTED, `Received message of size ${concreteMessage.length} > max size ${this.maxReceiveMessageSize}`);
return Promise.reject<Buffer>('Message too large');
} else {
return concreteMessage;
}
}
}
}
export class MaxMessageSizeFilterFactory implements FilterFactory<MaxMessageSizeFilter> {
constructor(private readonly options: ChannelOptions) {}
createFilter(callStream: Call): MaxMessageSizeFilter {
return new MaxMessageSizeFilter(this.options, callStream);
}
}

View File

@ -25,6 +25,7 @@ import { Deserialize, Serialize } from './make-client';
import { Metadata } from './metadata'; import { Metadata } from './metadata';
import { StreamDecoder } from './stream-decoder'; import { StreamDecoder } from './stream-decoder';
import { ObjectReadable, ObjectWritable } from './object-stream'; import { ObjectReadable, ObjectWritable } from './object-stream';
import { ChannelOptions } from './channel-options';
interface DeadlineUnitIndexSignature { interface DeadlineUnitIndexSignature {
[name: string]: number; [name: string]: number;
@ -325,6 +326,9 @@ export type HandlerType = 'bidi' | 'clientStream' | 'serverStream' | 'unary';
const noopTimer: NodeJS.Timer = setTimeout(() => {}, 0); const noopTimer: NodeJS.Timer = setTimeout(() => {}, 0);
// The default max message size for sending or receiving is 4 MB
const DEFAULT_MAX_MESSAGE_SIZE = 4 * 1024 * 1024;
// Internal class that wraps the HTTP2 request. // Internal class that wraps the HTTP2 request.
export class Http2ServerCallStream< export class Http2ServerCallStream<
RequestType, RequestType,
@ -338,10 +342,13 @@ export class Http2ServerCallStream<
private isPushPending = false; private isPushPending = false;
private bufferedMessages: Array<Buffer | null> = []; private bufferedMessages: Array<Buffer | null> = [];
private messagesToPush: Array<RequestType | null> = []; private messagesToPush: Array<RequestType | null> = [];
private maxSendMessageSize: number = DEFAULT_MAX_MESSAGE_SIZE;
private maxReceiveMessageSize: number = DEFAULT_MAX_MESSAGE_SIZE;
constructor( constructor(
private stream: http2.ServerHttp2Stream, private stream: http2.ServerHttp2Stream,
private handler: Handler<RequestType, ResponseType> private handler: Handler<RequestType, ResponseType>,
private options: ChannelOptions
) { ) {
super(); super();
@ -361,6 +368,13 @@ export class Http2ServerCallStream<
this.stream.on('drain', () => { this.stream.on('drain', () => {
this.emit('drain'); this.emit('drain');
}); });
if ('grpc.max_send_message_length' in options) {
this.maxSendMessageSize = options['grpc.max_send_message_length']!;
}
if ('grpc.max_receive_message_length' in options) {
this.maxReceiveMessageSize = options['grpc.max_receive_message_length']!;
}
} }
private checkCancelled(): boolean { private checkCancelled(): boolean {
@ -435,6 +449,9 @@ export class Http2ServerCallStream<
stream.once('end', async () => { stream.once('end', async () => {
try { try {
const requestBytes = Buffer.concat(chunks, totalLength); const requestBytes = Buffer.concat(chunks, totalLength);
if (this.maxReceiveMessageSize !== -1 && requestBytes.length > this.maxReceiveMessageSize) {
this.cancelWithStatus(Status.RESOURCE_EXHAUSTED, `Server received message of size ${requestBytes.length} > max size ${this.maxReceiveMessageSize}`);
}
resolve(await this.deserializeMessage(requestBytes)); resolve(await this.deserializeMessage(requestBytes));
} catch (err) { } catch (err) {
@ -550,11 +567,20 @@ export class Http2ServerCallStream<
this.sendStatus(status); this.sendStatus(status);
} }
cancelWithStatus(code: Status, details: string) {
this.cancelled = true;
this.sendStatus({code, details, metadata: new Metadata()});
}
write(chunk: Buffer) { write(chunk: Buffer) {
if (this.checkCancelled()) { if (this.checkCancelled()) {
return; return;
} }
if (this.maxSendMessageSize !== -1 && chunk.length > this.maxSendMessageSize) {
this.cancelWithStatus(Status.RESOURCE_EXHAUSTED, `Server failed to send message of size ${chunk.length} > max size ${this.maxSendMessageSize}`);
}
this.sendMetadata(); this.sendMetadata();
return this.stream.write(chunk); return this.stream.write(chunk);
} }
@ -581,6 +607,9 @@ export class Http2ServerCallStream<
const messages = decoder.write(data); const messages = decoder.write(data);
for (const message of messages) { for (const message of messages) {
if (this.maxReceiveMessageSize !== -1 && message.length > this.maxReceiveMessageSize) {
this.cancelWithStatus(Status.RESOURCE_EXHAUSTED, `Server received message of size ${message.length} > max size ${this.maxReceiveMessageSize}`);
}
this.pushOrBufferMessage(readable, message); this.pushOrBufferMessage(readable, message);
} }
}); });

View File

@ -485,7 +485,7 @@ export class Server {
throw getUnimplementedStatusResponse(path); throw getUnimplementedStatusResponse(path);
} }
const call = new Http2ServerCallStream(stream, handler); const call = new Http2ServerCallStream(stream, handler, this.options);
const metadata: Metadata = call.receiveMetadata(headers) as Metadata; const metadata: Metadata = call.receiveMetadata(headers) as Metadata;
switch (handler.type) { switch (handler.type) {
case 'unary': case 'unary':
@ -516,7 +516,7 @@ export class Server {
throw new Error(`Unknown handler type: ${handler.type}`); throw new Error(`Unknown handler type: ${handler.type}`);
} }
} catch (err) { } catch (err) {
const call = new Http2ServerCallStream(stream, null!); const call = new Http2ServerCallStream(stream, null!, this.options);
if (err.code === undefined) { if (err.code === undefined) {
err.code = Status.INTERNAL; err.code = Status.INTERNAL;

View File

@ -55,12 +55,16 @@ describe(`${anyGrpc.clientName} client -> ${anyGrpc.serverName} server`, functio
describe('Interop-adjacent tests', function() { describe('Interop-adjacent tests', function() {
let server; let server;
let client; let client;
let port;
before(function(done) { before(function(done) {
/* The default server has no limits on max message size to make those
* tests easier to write */
interopServer.getServer(0, true, (err, serverObj) => { interopServer.getServer(0, true, (err, serverObj) => {
if (err) { if (err) {
done(err); done(err);
} else { } else {
server = serverObj.server; server = serverObj.server;
port = serverObj.port;
server.start(); server.start();
const ca_path = path.join(__dirname, '../data/ca.pem'); const ca_path = path.join(__dirname, '../data/ca.pem');
const ca_data = fs.readFileSync(ca_path); const ca_data = fs.readFileSync(ca_path);
@ -69,9 +73,12 @@ describe(`${anyGrpc.clientName} client -> ${anyGrpc.serverName} server`, functio
'grpc.ssl_target_name_override': 'foo.test.google.fr', 'grpc.ssl_target_name_override': 'foo.test.google.fr',
'grpc.default_authority': 'foo.test.google.fr' 'grpc.default_authority': 'foo.test.google.fr'
}; };
client = new testProto.TestService(`localhost:${serverObj.port}`, creds, options); client = new testProto.TestService(`localhost:${port}`, creds, options);
done(); done();
} }
}, {
'grpc.max_send_message_length': -1,
'grpc.max_receive_message_length': -1
}); });
}); });
after(function() { after(function() {
@ -133,5 +140,159 @@ describe(`${anyGrpc.clientName} client -> ${anyGrpc.serverName} server`, functio
done(); done();
}); });
}); });
describe.only('max message size', function() {
// Note: the main server has these checks disabled
// A size that is larger than the default limit
const largeMessageSize = 32 * 1024 * 1024;
const largeMessage = Buffer.alloc(largeMessageSize);
it('should get an error when sending a large message', function(done) {
done = multiDone(done, 2);
const unaryMessage = {payload: {body: largeMessage}}
console.log(client.unaryCall.requestSerialize(unaryMessage).length);
client.unaryCall(unaryMessage, (error, result) => {
assert(error);
assert.strictEqual(error.code, grpc.status.RESOURCE_EXHAUSTED);
done();
});
const stream = client.fullDuplexCall();
stream.write({payload: {body: largeMessage}});
stream.end();
stream.on('data', () => {});
stream.on('status', (status) => {
assert.strictEqual(status.code, grpc.status.RESOURCE_EXHAUSTED);
done();
});
stream.on('error', (error) => {
});
});
it('should get an error when receiving a large message', function(done) {
done = multiDone(done, 2);
client.unaryCall({response_size: largeMessageSize}, (error, result) => {
assert(error);
assert.strictEqual(error.code, grpc.status.RESOURCE_EXHAUSTED);
done();
});
const stream = client.fullDuplexCall();
stream.write({response_parameters: [{size: largeMessageSize}]});
stream.end();
stream.on('data', () => {});
stream.on('status', (status) => {
assert.strictEqual(status.code, grpc.status.RESOURCE_EXHAUSTED);
done();
});
stream.on('error', (error) => {
});
});
describe('with a client with no message size limits', function() {
let unrestrictedClient;
before(function() {
const ca_path = path.join(__dirname, '../data/ca.pem');
const ca_data = fs.readFileSync(ca_path);
const creds = grpc.credentials.createSsl(ca_data);
const options = {
'grpc.ssl_target_name_override': 'foo.test.google.fr',
'grpc.default_authority': 'foo.test.google.fr',
'grpc.max_send_message_length': -1,
'grpc.max_receive_message_length': -1
};
unrestrictedClient = new testProto.TestService(`localhost:${port}`, creds, options);
});
it('should not get an error when sending or receiving a large message', function(done) {
done = multiDone(done, 2);
const unaryRequestMessage = {
response_size: largeMessageSize,
payload: {
body: largeMessage
}
};
unrestrictedClient.unaryCall(unaryRequestMessage, (error, result) => {
assert.ifError(error);
assert.strictEqual(result.payload.body.length, largeMessageSize);
done();
});
const streamingRequestMessage = {
response_parameters: [{size: largeMessageSize}],
payload: {body: largeMessage}
};
const stream = unrestrictedClient.fullDuplexCall();
stream.write(streamingRequestMessage);
stream.end();
stream.on('data', (result) => {
assert.strictEqual(result.payload.body.length, largeMessageSize);
});
stream.on('status', () => {
done();
});
stream.on('error', (error) => {
assert.ifError(error);
});
});
});
describe('with a server with message size limits', function() {
let restrictedServer;
let restrictedServerClient;
before(function(done) {
interopServer.getServer(0, true, (err, serverObj) => {
if (err) {
done(err);
} else {
restrictedServer = serverObj.server;
restrictedServer.start();
const ca_path = path.join(__dirname, '../data/ca.pem');
const ca_data = fs.readFileSync(ca_path);
const creds = grpc.credentials.createSsl(ca_data);
const options = {
'grpc.ssl_target_name_override': 'foo.test.google.fr',
'grpc.default_authority': 'foo.test.google.fr',
'grpc.max_send_message_length': -1,
'grpc.max_receive_message_length': -1
};
restrictedServerClient = new testProto.TestService(`localhost:${serverObj.port}`, creds, options);
done();
}
});
});
after(function() {
restrictedServer.forceShutdown();
});
it('should get an error when sending a large message', function(done) {
done = multiDone(done, 2);
restrictedServerClient.unaryCall({payload: {body: largeMessage}}, (error, result) => {
assert(error);
assert.strictEqual(error.code, grpc.status.RESOURCE_EXHAUSTED);
done();
});
const stream = restrictedServerClient.fullDuplexCall();
stream.write({payload: {body: largeMessage}});
stream.end();
stream.on('data', () => {});
stream.on('status', (status) => {
assert.strictEqual(status.code, grpc.status.RESOURCE_EXHAUSTED);
done();
});
stream.on('error', (error) => {
});
});
it('should get an error when requesting a large message', function(done) {
done = multiDone(done, 2);
restrictedServerClient.unaryCall({response_size: largeMessageSize}, (error, result) => {
console.log(result.payload.body.length);
assert(error);
assert.strictEqual(error.code, grpc.status.RESOURCE_EXHAUSTED);
done();
});
const stream = restrictedServerClient.fullDuplexCall();
stream.write({response_parameters: [{size: largeMessageSize}]});
stream.end();
stream.on('data', () => {});
stream.on('status', (status) => {
assert.strictEqual(status.code, grpc.status.RESOURCE_EXHAUSTED);
done();
});
stream.on('error', (error) => {
});
});
});
});
}); });
}); });

View File

@ -202,10 +202,14 @@ function handleHalfDuplex(call) {
* @param {boolean} tls Indicates that the bound port should use TLS * @param {boolean} tls Indicates that the bound port should use TLS
* @param {function(Error, {{server: Server, port: number}})} callback Callback * @param {function(Error, {{server: Server, port: number}})} callback Callback
* to call with result or error * to call with result or error
* @param {object?} options Optional additional options to use when
* constructing the server
*/ */
function getServer(port, tls, callback) { function getServer(port, tls, callback, options) {
// TODO(mlumish): enable TLS functionality // TODO(mlumish): enable TLS functionality
var options = {}; if (!options) {
options = {};
}
var server_creds; var server_creds;
if (tls) { if (tls) {
var key_path = path.join(__dirname, '../data/server1.key'); var key_path = path.join(__dirname, '../data/server1.key');