grpc-js: make client interceptors tests pass mostly unmodified

This commit is contained in:
murgatroid99 2019-11-14 15:02:24 -08:00
parent 17126e4640
commit 33875dce4a
7 changed files with 1880 additions and 1704 deletions

View File

@ -152,6 +152,7 @@ export class Http2CallStream implements Call {
filterStack: Filter; filterStack: Filter;
private http2Stream: http2.ClientHttp2Stream | null = null; private http2Stream: http2.ClientHttp2Stream | null = null;
private pendingRead = false; private pendingRead = false;
private isWriteFilterPending = false;
private pendingWrite: Buffer | null = null; private pendingWrite: Buffer | null = null;
private pendingWriteCallback: WriteCallback | null = null; private pendingWriteCallback: WriteCallback | null = null;
private writesClosed = false; private writesClosed = false;
@ -160,12 +161,16 @@ export class Http2CallStream implements Call {
private isReadFilterPending = false; private isReadFilterPending = false;
private canPush = false; private canPush = false;
/**
* Indicates that an 'end' event has come from the http2 stream, so there
* will be no more data events.
*/
private readsClosed = false; private readsClosed = false;
private statusOutput = false; private statusOutput = false;
private unpushedReadMessages: Array<Buffer | null> = []; private unpushedReadMessages: Buffer[] = [];
private unfilteredReadMessages: Array<Buffer | null> = []; private unfilteredReadMessages: Buffer[] = [];
// Status code mapped from :status. To be used if grpc-status is not received // Status code mapped from :status. To be used if grpc-status is not received
private mappedStatusCode: Status = Status.UNKNOWN; private mappedStatusCode: Status = Status.UNKNOWN;
@ -200,16 +205,7 @@ export class Http2CallStream implements Call {
/* Precondition: this.finalStatus !== null */ /* Precondition: this.finalStatus !== null */
if (!this.statusOutput) { if (!this.statusOutput) {
this.statusOutput = true; this.statusOutput = true;
/* We do this asynchronously to ensure that no async function is in the this.listener!.onReceiveStatus(this.finalStatus!);
* call stack when we return control to the application. If an async
* function is in the call stack, any exception thrown by the application
* (or our tests) will bubble up and turn into promise rejection, which
* will result in an UnhandledPromiseRejectionWarning. Because that is
* a warning, the error will be effectively swallowed and execution will
* continue */
process.nextTick(() => {
this.listener!.onReceiveStatus(this.finalStatus!);
});
if (this.subchannel) { if (this.subchannel) {
this.subchannel.callUnref(); this.subchannel.callUnref();
this.subchannel.removeDisconnectListener(this.disconnectListener); this.subchannel.removeDisconnectListener(this.disconnectListener);
@ -227,30 +223,24 @@ export class Http2CallStream implements Call {
* deserialization failure), that new status takes priority */ * deserialization failure), that new status takes priority */
if (this.finalStatus === null || this.finalStatus.code === Status.OK) { if (this.finalStatus === null || this.finalStatus.code === Status.OK) {
this.finalStatus = status; this.finalStatus = status;
/* Then, if an incoming message is still being handled or the status code this.maybeOutputStatus();
* is OK, hold off on emitting the status until that is done */ }
if (this.readsClosed || this.finalStatus.code !== Status.OK) { }
private maybeOutputStatus() {
if (this.finalStatus !== null) {
/* The combination check of readsClosed and that the two message buffer
* arrays are empty checks that there all incoming data has been fully
* processed */
if (this.finalStatus.code !== Status.OK || (this.readsClosed && this.unpushedReadMessages.length === 0 && this.unfilteredReadMessages.length === 0 && !this.isReadFilterPending)) {
this.outputStatus(); this.outputStatus();
} }
} }
} }
private push(message: Buffer | null): void { private push(message: Buffer): void {
if (message === null) { this.listener!.onReceiveMessage(message);
this.readsClosed = true; this.maybeOutputStatus();
if (this.finalStatus) {
this.outputStatus();
}
} else {
this.listener!.onReceiveMessage(message);
/* Don't wait for the upper layer to ask for a read before pushing null
* to close out the call, because pushing null doesn't actually push
* another message up to the upper layer */
if (this.unpushedReadMessages.length > 0 && this.unpushedReadMessages[0] === null) {
this.unpushedReadMessages.shift();
this.push(null);
}
}
} }
private handleFilterError(error: Error) { private handleFilterError(error: Error) {
@ -261,7 +251,7 @@ export class Http2CallStream implements Call {
/* If we the call has already ended, we don't want to do anything with /* If we the call has already ended, we don't want to do anything with
* this message. Dropping it on the floor is correct behavior */ * this message. Dropping it on the floor is correct behavior */
if (this.finalStatus !== null) { if (this.finalStatus !== null) {
this.push(null); this.maybeOutputStatus();
return; return;
} }
this.isReadFilterPending = false; this.isReadFilterPending = false;
@ -275,24 +265,16 @@ export class Http2CallStream implements Call {
if (this.unfilteredReadMessages.length > 0) { if (this.unfilteredReadMessages.length > 0) {
/* nextMessage is guaranteed not to be undefined because /* nextMessage is guaranteed not to be undefined because
unfilteredReadMessages is non-empty */ unfilteredReadMessages is non-empty */
const nextMessage = this.unfilteredReadMessages.shift() as Buffer | null; const nextMessage = this.unfilteredReadMessages.shift()!;
this.filterReceivedMessage(nextMessage); this.filterReceivedMessage(nextMessage);
} }
} }
private filterReceivedMessage(framedMessage: Buffer | null) { private filterReceivedMessage(framedMessage: Buffer) {
/* If we the call has already ended, we don't want to do anything with /* If we the call has already ended, we don't want to do anything with
* this message. Dropping it on the floor is correct behavior */ * this message. Dropping it on the floor is correct behavior */
if (this.finalStatus !== null) { if (this.finalStatus !== null) {
this.push(null); this.maybeOutputStatus();
return;
}
if (framedMessage === null) {
if (this.canPush) {
this.push(null);
} else {
this.unpushedReadMessages.push(null);
}
return; return;
} }
this.isReadFilterPending = true; this.isReadFilterPending = true;
@ -304,7 +286,7 @@ export class Http2CallStream implements Call {
); );
} }
private tryPush(messageBytes: Buffer | null): void { private tryPush(messageBytes: Buffer): void {
if (this.isReadFilterPending) { if (this.isReadFilterPending) {
this.unfilteredReadMessages.push(messageBytes); this.unfilteredReadMessages.push(messageBytes);
} else { } else {
@ -411,12 +393,23 @@ export class Http2CallStream implements Call {
} }
}); });
stream.on('end', () => { stream.on('end', () => {
this.tryPush(null); this.readsClosed = true;
this.maybeOutputStatus();
}); });
stream.on('close', async () => { stream.on('close', () => {
let code: Status; let code: Status;
let details = ''; let details = '';
switch (stream.rstCode) { switch (stream.rstCode) {
case http2.constants.NGHTTP2_NO_ERROR:
/* If we get a NO_ERROR code and we already have a status, the
* stream completed properly and we just haven't fully processed
* it yet */
if (this.finalStatus !== null) {
return;
}
code = Status.INTERNAL;
details = `Received RST_STREAM with code ${stream.rstCode}`;
break;
case http2.constants.NGHTTP2_REFUSED_STREAM: case http2.constants.NGHTTP2_REFUSED_STREAM:
code = Status.UNAVAILABLE; code = Status.UNAVAILABLE;
details = 'Stream refused by server'; details = 'Stream refused by server';
@ -435,6 +428,7 @@ export class Http2CallStream implements Call {
break; break;
default: default:
code = Status.INTERNAL; code = Status.INTERNAL;
details = `Received RST_STREAM with code ${stream.rstCode}`;
} }
// This is a no-op if trailers were received at all. // This is a no-op if trailers were received at all.
// This is OK, because status codes emitted here correspond to more // This is OK, because status codes emitted here correspond to more
@ -456,9 +450,7 @@ export class Http2CallStream implements Call {
} }
stream.write(this.pendingWrite, this.pendingWriteCallback); stream.write(this.pendingWrite, this.pendingWriteCallback);
} }
if (this.writesClosed) { this.maybeCloseWrites();
stream.end();
}
} }
} }
@ -514,7 +506,7 @@ export class Http2CallStream implements Call {
/* If we have already emitted a status, we should not emit any more /* If we have already emitted a status, we should not emit any more
* messages and we should communicate that the stream has ended */ * messages and we should communicate that the stream has ended */
if (this.finalStatus !== null) { if (this.finalStatus !== null) {
this.push(null); this.maybeOutputStatus();
return; return;
} }
this.canPush = true; this.canPush = true;
@ -522,7 +514,7 @@ export class Http2CallStream implements Call {
this.pendingRead = true; this.pendingRead = true;
} else { } else {
if (this.unpushedReadMessages.length > 0) { if (this.unpushedReadMessages.length > 0) {
const nextMessage: Buffer | null = this.unpushedReadMessages.shift() as Buffer | null; const nextMessage: Buffer = this.unpushedReadMessages.shift()!;
this.push(nextMessage); this.push(nextMessage);
this.canPush = false; this.canPush = false;
return; return;
@ -534,26 +526,33 @@ export class Http2CallStream implements Call {
} }
} }
private maybeCloseWrites() {
if (this.writesClosed && !this.isWriteFilterPending && this.http2Stream !== null) {
this.http2Stream.end();
}
}
sendMessageWithContext(context: MessageContext, message: Buffer) { sendMessageWithContext(context: MessageContext, message: Buffer) {
const writeObj: WriteObject = { const writeObj: WriteObject = {
message: message, message: message,
flags: context.flags flags: context.flags
}; };
const cb: WriteCallback = context.callback || (() => {}); const cb: WriteCallback = context.callback || (() => {});
this.isWriteFilterPending = true;
this.filterStack.sendMessage(Promise.resolve(writeObj)).then(message => { this.filterStack.sendMessage(Promise.resolve(writeObj)).then(message => {
this.isWriteFilterPending = false;
if (this.http2Stream === null) { if (this.http2Stream === null) {
this.pendingWrite = message.message; this.pendingWrite = message.message;
this.pendingWriteCallback = cb; this.pendingWriteCallback = cb;
} else { } else {
this.http2Stream.write(message.message, cb); this.http2Stream.write(message.message, cb);
this.maybeCloseWrites();
} }
}, this.handleFilterError.bind(this)); }, this.handleFilterError.bind(this));
} }
halfClose() { halfClose() {
this.writesClosed = true; this.writesClosed = true;
if (this.http2Stream !== null) { this.maybeCloseWrites();
this.http2Stream.end();
}
} }
} }

View File

@ -18,11 +18,12 @@
import { EventEmitter } from 'events'; import { EventEmitter } from 'events';
import { Duplex, Readable, Writable } from 'stream'; import { Duplex, Readable, Writable } from 'stream';
import { Call, StatusObject, WriteObject } from './call-stream'; import { StatusObject, MessageContext } from './call-stream';
import { Status } from './constants'; import { Status } from './constants';
import { EmitterAugmentation1 } from './events'; import { EmitterAugmentation1 } from './events';
import { Metadata } from './metadata'; import { Metadata } from './metadata';
import { ObjectReadable, ObjectWritable, WriteCallback } from './object-stream'; import { ObjectReadable, ObjectWritable, WriteCallback } from './object-stream';
import { InterceptingCallInterface } from './client-interceptors';
/** /**
* A type extending the built-in Error object with additional fields. * A type extending the built-in Error object with additional fields.
@ -81,7 +82,7 @@ export function callErrorFromStatus(status: StatusObject): ServiceError {
export class ClientUnaryCallImpl extends EventEmitter export class ClientUnaryCallImpl extends EventEmitter
implements ClientUnaryCall { implements ClientUnaryCall {
constructor(private readonly call: Call) { constructor(private readonly call: InterceptingCallInterface) {
super(); super();
} }
@ -97,7 +98,7 @@ export class ClientUnaryCallImpl extends EventEmitter
export class ClientReadableStreamImpl<ResponseType> extends Readable export class ClientReadableStreamImpl<ResponseType> extends Readable
implements ClientReadableStream<ResponseType> { implements ClientReadableStream<ResponseType> {
constructor( constructor(
private readonly call: Call, private readonly call: InterceptingCallInterface,
readonly deserialize: (chunk: Buffer) => ResponseType readonly deserialize: (chunk: Buffer) => ResponseType
) { ) {
super({ objectMode: true }); super({ objectMode: true });
@ -116,33 +117,10 @@ export class ClientReadableStreamImpl<ResponseType> extends Readable
} }
} }
function tryWrite<RequestType>(
call: Call,
serialize: (value: RequestType) => Buffer,
chunk: RequestType,
encoding: string,
cb: WriteCallback
) {
let message: Buffer;
const flags: number = Number(encoding);
try {
message = serialize(chunk);
} catch (e) {
call.cancelWithStatus(Status.INTERNAL, 'Serialization failure');
cb(e);
return;
}
const writeObj: WriteObject = { message };
if (!Number.isNaN(flags)) {
writeObj.flags = flags;
}
call.write(writeObj, cb);
}
export class ClientWritableStreamImpl<RequestType> extends Writable export class ClientWritableStreamImpl<RequestType> extends Writable
implements ClientWritableStream<RequestType> { implements ClientWritableStream<RequestType> {
constructor( constructor(
private readonly call: Call, private readonly call: InterceptingCallInterface,
readonly serialize: (value: RequestType) => Buffer readonly serialize: (value: RequestType) => Buffer
) { ) {
super({ objectMode: true }); super({ objectMode: true });
@ -157,12 +135,14 @@ export class ClientWritableStreamImpl<RequestType> extends Writable
} }
_write(chunk: RequestType, encoding: string, cb: WriteCallback) { _write(chunk: RequestType, encoding: string, cb: WriteCallback) {
const writeObj: WriteObject = { message: chunk }; const context: MessageContext = {
callback: cb
}
const flags: number = Number(encoding); const flags: number = Number(encoding);
if (!Number.isNaN(flags)) { if (!Number.isNaN(flags)) {
writeObj.flags = flags; context.flags = flags;
} }
this.call.write(writeObj, cb); this.call.sendMessageWithContext(context, chunk);
} }
_final(cb: Function) { _final(cb: Function) {
@ -174,7 +154,7 @@ export class ClientWritableStreamImpl<RequestType> extends Writable
export class ClientDuplexStreamImpl<RequestType, ResponseType> extends Duplex export class ClientDuplexStreamImpl<RequestType, ResponseType> extends Duplex
implements ClientDuplexStream<RequestType, ResponseType> { implements ClientDuplexStream<RequestType, ResponseType> {
constructor( constructor(
private readonly call: Call, private readonly call: InterceptingCallInterface,
readonly serialize: (value: RequestType) => Buffer, readonly serialize: (value: RequestType) => Buffer,
readonly deserialize: (chunk: Buffer) => ResponseType readonly deserialize: (chunk: Buffer) => ResponseType
) { ) {
@ -194,12 +174,14 @@ export class ClientDuplexStreamImpl<RequestType, ResponseType> extends Duplex
} }
_write(chunk: RequestType, encoding: string, cb: WriteCallback) { _write(chunk: RequestType, encoding: string, cb: WriteCallback) {
const writeObj: WriteObject = { message: chunk }; const context: MessageContext = {
callback: cb
}
const flags: number = Number(encoding); const flags: number = Number(encoding);
if (!Number.isNaN(flags)) { if (!Number.isNaN(flags)) {
writeObj.flags = flags; context.flags = flags;
} }
this.call.write(writeObj, cb); this.call.sendMessageWithContext(context, chunk);
} }
_final(cb: Function) { _final(cb: Function) {

View File

@ -155,8 +155,9 @@ export interface InterceptorOptions extends CallOptions {
export interface InterceptingCallInterface { export interface InterceptingCallInterface {
cancelWithStatus(status: Status, details: string): void; cancelWithStatus(status: Status, details: string): void;
getPeer(): string; getPeer(): string;
start(metadata: Metadata, listener: InterceptingListener): void; start(metadata: Metadata, listener?: Partial<InterceptingListener>): void;
sendMessageWithContext(context: MessageContext, message: any): void; sendMessageWithContext(context: MessageContext, message: any): void;
sendMessage(message: any): void;
startRead(): void; startRead(): void;
halfClose(): void; halfClose(): void;
@ -194,18 +195,23 @@ export class InterceptingCall implements InterceptingCallInterface {
getPeer() { getPeer() {
return this.nextCall.getPeer(); return this.nextCall.getPeer();
} }
start(metadata: Metadata, interceptingListener: InterceptingListener): void { start(metadata: Metadata, interceptingListener?: Partial<InterceptingListener>): void {
this.requester.start(metadata, interceptingListener, (md, listener) => { const fullInterceptingListener: InterceptingListener = {
onReceiveMetadata: interceptingListener?.onReceiveMetadata?.bind(interceptingListener) ?? (metadata => {}),
onReceiveMessage: interceptingListener?.onReceiveMessage?.bind(interceptingListener) ?? (message => {}),
onReceiveStatus: interceptingListener?.onReceiveStatus?.bind(interceptingListener) ?? (status => {})
}
this.requester.start(metadata, fullInterceptingListener, (md, listener) => {
let finalInterceptingListener: InterceptingListener; let finalInterceptingListener: InterceptingListener;
if (isInterceptingListener(listener)) { if (isInterceptingListener(listener)) {
finalInterceptingListener = listener; finalInterceptingListener = listener;
} else { } else {
const fullListener: FullListener = { const fullListener: FullListener = {
onReceiveMetadata: listener.onReceiveMetadata || defaultListener.onReceiveMetadata, onReceiveMetadata: listener.onReceiveMetadata ?? defaultListener.onReceiveMetadata,
onReceiveMessage: listener.onReceiveMessage || defaultListener.onReceiveMessage, onReceiveMessage: listener.onReceiveMessage ?? defaultListener.onReceiveMessage,
onReceiveStatus: listener.onReceiveStatus || defaultListener.onReceiveStatus onReceiveStatus: listener.onReceiveStatus ?? defaultListener.onReceiveStatus
}; };
finalInterceptingListener = new InterceptingListenerImpl(fullListener, interceptingListener); finalInterceptingListener = new InterceptingListenerImpl(fullListener, fullInterceptingListener);
} }
this.nextCall.start(md, finalInterceptingListener); this.nextCall.start(md, finalInterceptingListener);
}); });
@ -218,7 +224,7 @@ export class InterceptingCall implements InterceptingCallInterface {
if (this.pendingHalfClose) { if (this.pendingHalfClose) {
this.nextCall.halfClose(); this.nextCall.halfClose();
} }
}) });
} }
sendMessage(message: any): void { sendMessage(message: any): void {
this.sendMessageWithContext({}, message); this.sendMessageWithContext({}, message);
@ -308,17 +314,20 @@ class BaseInterceptingCall implements InterceptingCallInterface {
this.call.cancelWithStatus(Status.INTERNAL, 'Serialization failure'); this.call.cancelWithStatus(Status.INTERNAL, 'Serialization failure');
} }
} }
start(metadata: Metadata, listener: InterceptingListener): void { sendMessage(message: any) {
this.sendMessageWithContext({}, message);
}
start(metadata: Metadata, interceptingListener?: Partial<InterceptingListener>): void {
let readError: StatusObject | null = null; let readError: StatusObject | null = null;
this.call.start(metadata, { this.call.start(metadata, {
onReceiveMetadata: (metadata) => { onReceiveMetadata: (metadata) => {
listener.onReceiveMetadata(metadata); interceptingListener?.onReceiveMetadata?.(metadata);
}, },
onReceiveMessage: (message) => { onReceiveMessage: (message) => {
let deserialized: any; let deserialized: any;
try { try {
deserialized = this.methodDefinition.responseDeserialize(message); deserialized = this.methodDefinition.responseDeserialize(message);
listener.onReceiveMessage(deserialized); interceptingListener?.onReceiveMessage?.(deserialized);
} catch (e) { } catch (e) {
readError = {code: Status.INTERNAL, details: 'Failed to parse server response', metadata: new Metadata()}; readError = {code: Status.INTERNAL, details: 'Failed to parse server response', metadata: new Metadata()};
this.call.cancelWithStatus(readError.code, readError.details); this.call.cancelWithStatus(readError.code, readError.details);
@ -326,9 +335,9 @@ class BaseInterceptingCall implements InterceptingCallInterface {
}, },
onReceiveStatus: (status) => { onReceiveStatus: (status) => {
if (readError) { if (readError) {
listener.onReceiveStatus(readError); interceptingListener?.onReceiveStatus?.(readError);
} else { } else {
listener.onReceiveStatus(status); interceptingListener?.onReceiveStatus?.(status);
} }
} }
}); });
@ -345,8 +354,22 @@ class BaseUnaryInterceptingCall extends BaseInterceptingCall implements Intercep
constructor(call: Call, methodDefinition: ClientMethodDefinition<any, any>) { constructor(call: Call, methodDefinition: ClientMethodDefinition<any, any>) {
super(call, methodDefinition); super(call, methodDefinition);
} }
start(metadata: Metadata, listener: InterceptingListener): void { start(metadata: Metadata, listener?: Partial<InterceptingListener>): void {
super.start(metadata, listener); let receivedMessage = false;
const wrapperListener: InterceptingListener = {
onReceiveMetadata: listener?.onReceiveMetadata?.bind(listener) ?? (metadata => {}),
onReceiveMessage: (message: any) => {
receivedMessage = true;
listener?.onReceiveMessage?.(message);
},
onReceiveStatus: (status: StatusObject) => {
if (!receivedMessage) {
listener?.onReceiveMessage?.(null);
}
listener?.onReceiveStatus?.(status);
}
}
super.start(metadata, wrapperListener);
this.call.startRead(); this.call.startRead();
} }
} }
@ -416,8 +439,8 @@ export function getInterceptingCall(interceptorArgs: InterceptorArguments, metho
* initialValue, which is effectively at the end of the list, is a nextCall * initialValue, which is effectively at the end of the list, is a nextCall
* function that invokes getBottomInterceptingCall, which handles * function that invokes getBottomInterceptingCall, which handles
* (de)serialization and also gets the underlying call from the channel */ * (de)serialization and also gets the underlying call from the channel */
const getCall: NextCall = interceptors.reduceRight<NextCall>((previousValue: NextCall, currentValue: Interceptor) => { const getCall: NextCall = interceptors.reduceRight<NextCall>((nextCall: NextCall, nextInterceptor: Interceptor) => {
return currentOptions => currentValue(currentOptions, previousValue); return currentOptions => nextInterceptor(currentOptions, nextCall);
}, (finalOptions: InterceptorOptions) => getBottomInterceptingCall(channel, methodDefinition.path, finalOptions, methodDefinition)); }, (finalOptions: InterceptorOptions) => getBottomInterceptingCall(channel, methodDefinition.path, finalOptions, methodDefinition));
return getCall(interceptorOptions); return getCall(interceptorOptions);
} }

View File

@ -29,14 +29,14 @@ import {
SurfaceCall, SurfaceCall,
} from './call'; } from './call';
import { CallCredentials } from './call-credentials'; import { CallCredentials } from './call-credentials';
import { Call, Deadline, StatusObject, WriteObject, InterceptingListener } from './call-stream'; import { Deadline, StatusObject, WriteObject, InterceptingListener } from './call-stream';
import { Channel, ConnectivityState, ChannelImplementation } from './channel'; import { Channel, ConnectivityState, ChannelImplementation } from './channel';
import { ChannelCredentials } from './channel-credentials'; import { ChannelCredentials } from './channel-credentials';
import { ChannelOptions } from './channel-options'; import { ChannelOptions } from './channel-options';
import { Status } from './constants'; import { Status } from './constants';
import { Metadata } from './metadata'; import { Metadata } from './metadata';
import { ClientMethodDefinition } from './make-client'; import { ClientMethodDefinition } from './make-client';
import { getInterceptingCall, Interceptor, InterceptorProvider, InterceptorArguments } from './client-interceptors'; import { getInterceptingCall, Interceptor, InterceptorProvider, InterceptorArguments, InterceptingCallInterface } from './client-interceptors';
const CHANNEL_SYMBOL = Symbol(); const CHANNEL_SYMBOL = Symbol();
const INTERCEPTOR_SYMBOL = Symbol(); const INTERCEPTOR_SYMBOL = Symbol();
@ -231,13 +231,13 @@ export class Client {
callInterceptors: options.interceptors || [], callInterceptors: options.interceptors || [],
callInterceptorProviders: options.interceptor_providers || [] callInterceptorProviders: options.interceptor_providers || []
}; };
const call: Call = getInterceptingCall(interceptorArgs, methodDefinition, options, this[CHANNEL_SYMBOL]); const call: InterceptingCallInterface = getInterceptingCall(interceptorArgs, methodDefinition, options, this[CHANNEL_SYMBOL]);
if (options.credentials) { if (options.credentials) {
call.setCredentials(options.credentials); call.setCredentials(options.credentials);
} }
const writeObj: WriteObject = { message: argument };
const emitter = new ClientUnaryCallImpl(call); const emitter = new ClientUnaryCallImpl(call);
let responseMessage: ResponseType | null = null; let responseMessage: ResponseType | null = null;
let receivedStatus = false;
call.start(metadata, { call.start(metadata, {
onReceiveMetadata: (metadata) => { onReceiveMetadata: (metadata) => {
emitter.emit('metadata', metadata); emitter.emit('metadata', metadata);
@ -247,9 +247,12 @@ export class Client {
call.cancelWithStatus(Status.INTERNAL, 'Too many responses received'); call.cancelWithStatus(Status.INTERNAL, 'Too many responses received');
} }
responseMessage = message; responseMessage = message;
call.startRead();
}, },
onReceiveStatus(status: StatusObject) { onReceiveStatus(status: StatusObject) {
if (receivedStatus) {
return;
}
receivedStatus = true;
if (status.code === Status.OK) { if (status.code === Status.OK) {
callback!(null, responseMessage!); callback!(null, responseMessage!);
} else { } else {
@ -258,8 +261,8 @@ export class Client {
emitter.emit('status', status); emitter.emit('status', status);
} }
}); });
call.write(writeObj, () => {call.halfClose();}); call.sendMessage(argument);
call.startRead(); call.halfClose();
return emitter; return emitter;
} }
@ -315,12 +318,13 @@ export class Client {
callInterceptors: options.interceptors || [], callInterceptors: options.interceptors || [],
callInterceptorProviders: options.interceptor_providers || [] callInterceptorProviders: options.interceptor_providers || []
}; };
const call: Call = getInterceptingCall(interceptorArgs, methodDefinition, options, this[CHANNEL_SYMBOL]); const call: InterceptingCallInterface = getInterceptingCall(interceptorArgs, methodDefinition, options, this[CHANNEL_SYMBOL]);
if (options.credentials) { if (options.credentials) {
call.setCredentials(options.credentials); call.setCredentials(options.credentials);
} }
const emitter = new ClientWritableStreamImpl<RequestType>(call, serialize); const emitter = new ClientWritableStreamImpl<RequestType>(call, serialize);
let responseMessage: ResponseType | null = null; let responseMessage: ResponseType | null = null;
let receivedStatus = false;
call.start(metadata, { call.start(metadata, {
onReceiveMetadata: (metadata) => { onReceiveMetadata: (metadata) => {
emitter.emit('metadata', metadata); emitter.emit('metadata', metadata);
@ -330,9 +334,12 @@ export class Client {
call.cancelWithStatus(Status.INTERNAL, 'Too many responses received'); call.cancelWithStatus(Status.INTERNAL, 'Too many responses received');
} }
responseMessage = message; responseMessage = message;
call.startRead();
}, },
onReceiveStatus(status: StatusObject) { onReceiveStatus(status: StatusObject) {
if (receivedStatus) {
return;
}
receivedStatus = true;
if (status.code === Status.OK) { if (status.code === Status.OK) {
callback!(null, responseMessage!); callback!(null, responseMessage!);
} else { } else {
@ -341,7 +348,6 @@ export class Client {
emitter.emit('status', status); emitter.emit('status', status);
} }
}); });
call.startRead();
return emitter; return emitter;
} }
@ -406,12 +412,12 @@ export class Client {
callInterceptors: options.interceptors || [], callInterceptors: options.interceptors || [],
callInterceptorProviders: options.interceptor_providers || [] callInterceptorProviders: options.interceptor_providers || []
}; };
const call: Call = getInterceptingCall(interceptorArgs, methodDefinition, options, this[CHANNEL_SYMBOL]); const call: InterceptingCallInterface = getInterceptingCall(interceptorArgs, methodDefinition, options, this[CHANNEL_SYMBOL]);
if (options.credentials) { if (options.credentials) {
call.setCredentials(options.credentials); call.setCredentials(options.credentials);
} }
const writeObj: WriteObject = { message: argument };
const stream = new ClientReadableStreamImpl<ResponseType>(call, deserialize); const stream = new ClientReadableStreamImpl<ResponseType>(call, deserialize);
let receivedStatus = false;
call.start(metadata, { call.start(metadata, {
onReceiveMetadata(metadata: Metadata) { onReceiveMetadata(metadata: Metadata) {
stream.emit('metadata', metadata); stream.emit('metadata', metadata);
@ -422,6 +428,10 @@ export class Client {
} }
}, },
onReceiveStatus(status: StatusObject) { onReceiveStatus(status: StatusObject) {
if (receivedStatus) {
return;
}
receivedStatus = true;
stream.push(null); stream.push(null);
if (status.code !== Status.OK) { if (status.code !== Status.OK) {
stream.emit('error', callErrorFromStatus(status)); stream.emit('error', callErrorFromStatus(status));
@ -429,7 +439,8 @@ export class Client {
stream.emit('status', status); stream.emit('status', status);
} }
}); });
call.write(writeObj, () => {call.halfClose();}); call.sendMessage(argument);
call.halfClose();
return stream; return stream;
} }
@ -467,7 +478,7 @@ export class Client {
callInterceptors: options.interceptors || [], callInterceptors: options.interceptors || [],
callInterceptorProviders: options.interceptor_providers || [] callInterceptorProviders: options.interceptor_providers || []
}; };
const call: Call = getInterceptingCall(interceptorArgs, methodDefinition, options, this[CHANNEL_SYMBOL]); const call: InterceptingCallInterface = getInterceptingCall(interceptorArgs, methodDefinition, options, this[CHANNEL_SYMBOL]);
if (options.credentials) { if (options.credentials) {
call.setCredentials(options.credentials); call.setCredentials(options.credentials);
} }
@ -476,6 +487,7 @@ export class Client {
serialize, serialize,
deserialize deserialize
); );
let receivedStatus = false;
call.start(metadata, { call.start(metadata, {
onReceiveMetadata(metadata: Metadata) { onReceiveMetadata(metadata: Metadata) {
stream.emit('metadata', metadata); stream.emit('metadata', metadata);
@ -486,6 +498,10 @@ export class Client {
} }
}, },
onReceiveStatus(status: StatusObject) { onReceiveStatus(status: StatusObject) {
if (receivedStatus) {
return;
}
receivedStatus = true;
stream.push(null); stream.push(null);
if (status.code !== Status.OK) { if (status.code !== Status.OK) {
stream.emit('error', callErrorFromStatus(status)); stream.emit('error', callErrorFromStatus(status));

View File

@ -351,10 +351,8 @@ export class Http2ServerCallStream<
}); });
this.stream.once('close', () => { this.stream.once('close', () => {
if (this.stream.rstCode === http2.constants.NGHTTP2_CANCEL) { this.cancelled = true;
this.cancelled = true; this.emit('cancelled', 'cancelled');
this.emit('cancelled', 'cancelled');
}
}); });
this.stream.on('drain', () => { this.stream.on('drain', () => {
@ -362,7 +360,20 @@ export class Http2ServerCallStream<
}); });
} }
private checkCancelled(): boolean {
/* In some cases the stream can become destroyed before the close event
* fires. That creates a race condition that this check works around */
if (this.stream.destroyed) {
this.cancelled = true;
}
return this.cancelled;
}
sendMetadata(customMetadata?: Metadata) { sendMetadata(customMetadata?: Metadata) {
if (this.checkCancelled()) {
return;
}
if (this.metadataSent) { if (this.metadataSent) {
return; return;
} }
@ -397,6 +408,13 @@ export class Http2ServerCallStream<
metadata.remove(GRPC_TIMEOUT_HEADER); metadata.remove(GRPC_TIMEOUT_HEADER);
} }
// Remove several headers that should not be propagated to the application
metadata.remove(http2.constants.HTTP2_HEADER_ACCEPT_ENCODING);
metadata.remove(http2.constants.HTTP2_HEADER_TE);
metadata.remove(http2.constants.HTTP2_HEADER_CONTENT_TYPE);
metadata.remove('grpc-encoding');
metadata.remove('grpc-accept-encoding');
return metadata; return metadata;
} }
@ -450,6 +468,9 @@ export class Http2ServerCallStream<
metadata?: Metadata, metadata?: Metadata,
flags?: number flags?: number
) { ) {
if (this.checkCancelled()) {
return;
}
if (!metadata) { if (!metadata) {
metadata = new Metadata(); metadata = new Metadata();
} }
@ -472,7 +493,7 @@ export class Http2ServerCallStream<
} }
sendStatus(statusObj: StatusObject) { sendStatus(statusObj: StatusObject) {
if (this.cancelled) { if (this.checkCancelled()) {
return; return;
} }
@ -497,6 +518,9 @@ export class Http2ServerCallStream<
} }
sendError(error: ServerErrorResponse | ServerStatusResponse) { sendError(error: ServerErrorResponse | ServerStatusResponse) {
if (this.checkCancelled()) {
return;
}
const status: StatusObject = { const status: StatusObject = {
code: Status.UNKNOWN, code: Status.UNKNOWN,
details: 'message' in error ? error.message : 'Unknown Error', details: 'message' in error ? error.message : 'Unknown Error',
@ -522,7 +546,7 @@ export class Http2ServerCallStream<
} }
write(chunk: Buffer) { write(chunk: Buffer) {
if (this.cancelled) { if (this.checkCancelled()) {
return; return;
} }

View File

@ -346,7 +346,6 @@ export class Server {
const call = new Http2ServerCallStream(stream, handler); const call = new Http2ServerCallStream(stream, handler);
const metadata: Metadata = call.receiveMetadata(headers) as Metadata; const metadata: Metadata = call.receiveMetadata(headers) as Metadata;
switch (handler.type) { switch (handler.type) {
case 'unary': case 'unary':
handleUnary(call, handler as UntypedUnaryHandler, metadata); handleUnary(call, handler as UntypedUnaryHandler, metadata);

File diff suppressed because it is too large Load Diff