From 13065ad5a18e0519cfec2c2ba66ef0c8fbc33971 Mon Sep 17 00:00:00 2001 From: Michael Lumish Date: Thu, 14 Aug 2025 13:44:22 -0700 Subject: [PATCH] grpc-js: Implement weighted_round_robin LB policy --- packages/grpc-js/package.json | 2 +- .../grpc-js/src/load-balancer-pick-first.ts | 5 + .../src/load-balancer-weighted-round-robin.ts | 77 +++- packages/grpc-js/src/priority-queue.ts | 47 +- packages/grpc-js/src/resolver-ip.ts | 4 +- .../grpc-js/src/single-subchannel-channel.ts | 101 ++-- .../grpc-js/test/generated/EchoMessage.ts | 12 + .../grpc-js/test/generated/EchoService.ts | 54 +++ .../grpc-js/test/generated/echo_service.ts | 15 + packages/grpc-js/test/test-priority-queue.ts | 168 +++++++ .../grpc-js/test/test-weighted-round-robin.ts | 436 ++++++++++++++++++ 11 files changed, 862 insertions(+), 59 deletions(-) create mode 100644 packages/grpc-js/test/generated/EchoMessage.ts create mode 100644 packages/grpc-js/test/generated/EchoService.ts create mode 100644 packages/grpc-js/test/generated/echo_service.ts create mode 100644 packages/grpc-js/test/test-priority-queue.ts create mode 100644 packages/grpc-js/test/test-weighted-round-robin.ts diff --git a/packages/grpc-js/package.json b/packages/grpc-js/package.json index ebd0b3d0..8f2b9d43 100644 --- a/packages/grpc-js/package.json +++ b/packages/grpc-js/package.json @@ -64,7 +64,7 @@ "pretest": "npm run generate-types && npm run generate-test-types && npm run compile", "posttest": "npm run check && madge -c ./build/src", "generate-types": "proto-loader-gen-types --keepCase --longs String --enums String --defaults --oneofs --includeComments --includeDirs proto/ --include-dirs proto/ proto/xds/ proto/protoc-gen-validate/ -O src/generated/ --grpcLib ../index channelz.proto xds/service/orca/v3/orca.proto", - "generate-test-types": "proto-loader-gen-types --keepCase --longs String --enums String --defaults --oneofs --includeComments --include-dirs test/fixtures/ -O test/generated/ --grpcLib ../../src/index test_service.proto", + "generate-test-types": "proto-loader-gen-types --keepCase --longs String --enums String --defaults --oneofs --includeComments --include-dirs test/fixtures/ -O test/generated/ --grpcLib ../../src/index test_service.proto echo_service.proto", "copy-protos": "node ./copy-protos" }, "dependencies": { diff --git a/packages/grpc-js/src/load-balancer-pick-first.ts b/packages/grpc-js/src/load-balancer-pick-first.ts index 5eaf31c2..7a9fd48c 100644 --- a/packages/grpc-js/src/load-balancer-pick-first.ts +++ b/packages/grpc-js/src/load-balancer-pick-first.ts @@ -588,6 +588,11 @@ export class PickFirstLoadBalancer implements LoadBalancer { destroy() { this.resetSubchannelList(); this.removeCurrentPick(); + this.metricsCall?.cancel(); + this.metricsCall = null; + this.orcaClient?.close(); + this.orcaClient = null; + this.metricsBackoffTimer.stop(); } getTypeName(): string { diff --git a/packages/grpc-js/src/load-balancer-weighted-round-robin.ts b/packages/grpc-js/src/load-balancer-weighted-round-robin.ts index fd6d161c..05c934ae 100644 --- a/packages/grpc-js/src/load-balancer-weighted-round-robin.ts +++ b/packages/grpc-js/src/load-balancer-weighted-round-robin.ts @@ -71,22 +71,24 @@ function validateFieldType( function parseDurationField(obj: any, fieldName: string): number | null { if (fieldName in obj && obj[fieldName] !== undefined) { - let durationObject: Duration | null; + let durationObject: Duration; if (isDuration(obj[fieldName])) { durationObject = obj[fieldName]; } else if (typeof obj[fieldName] === 'string') { - durationObject = parseDuration(obj[fieldName]); + const parsedDuration = parseDuration(obj[fieldName]); + if (!parsedDuration) { + throw new Error(`weighted round robin config ${fieldName}: failed to parse duration string ${obj[fieldName]}`); + } + durationObject = parsedDuration; } else { - durationObject = null; - } - if (durationObject) { - return durationToMs(durationObject); + throw new Error(`weighted round robin config ${fieldName}: expected duration, got ${typeof obj[fieldName]}`); } + return durationToMs(durationObject); } return null; } -class WeightedRoundRobinLoadBalancingConfig implements TypedLoadBalancingConfig { +export class WeightedRoundRobinLoadBalancingConfig implements TypedLoadBalancingConfig { private readonly enableOobLoadReport: boolean; private readonly oobLoadReportingPeriodMs: number; private readonly blackoutPeriodMs: number; @@ -106,7 +108,7 @@ class WeightedRoundRobinLoadBalancingConfig implements TypedLoadBalancingConfig this.oobLoadReportingPeriodMs = oobLoadReportingPeriodMs ?? DEFAULT_OOB_REPORTING_PERIOD_MS; this.blackoutPeriodMs = blackoutPeriodMs ?? DEFAULT_BLACKOUT_PERIOD_MS; this.weightExpirationPeriodMs = weightExpirationPeriodMs ?? DEFAULT_WEIGHT_EXPIRATION_PERIOD_MS; - this.weightUpdatePeriodMs = Math.min(weightUpdatePeriodMs ?? DEFAULT_WEIGHT_UPDATE_PERIOD_MS, 100); + this.weightUpdatePeriodMs = Math.max(weightUpdatePeriodMs ?? DEFAULT_WEIGHT_UPDATE_PERIOD_MS, 100); this.errorUtilizationPenalty = errorUtilizationPenalty ?? DEFAULT_ERROR_UTILIZATION_PENALTY; } @@ -177,8 +179,19 @@ type MetricsHandler = (loadReport: OrcaLoadReport__Output, endpointName: string) class WeightedRoundRobinPicker implements Picker { private queue: PriorityQueue = new PriorityQueue((a, b) => a.deadline < b.deadline); constructor(children: WeightedPicker[], private readonly metricsHandler: MetricsHandler | null) { + const positiveWeight = children.filter(picker => picker.weight > 0); + let averageWeight: number; + if (positiveWeight.length < 2) { + averageWeight = 1; + } else { + let weightSum: number = 0; + for (const { weight } of positiveWeight) { + weightSum += weight; + } + averageWeight = weightSum / positiveWeight.length; + } for (const child of children) { - const period = 1 / child.weight; + const period = child.weight > 0 ? 1 / child.weight : averageWeight; this.queue.push({ endpointName: child.endpointName, picker: child.picker, @@ -188,7 +201,7 @@ class WeightedRoundRobinPicker implements Picker { } } pick(pickArgs: PickArgs): PickResult { - const entry = this.queue.pop(); + const entry = this.queue.pop()!; this.queue.push({ ...entry, deadline: entry.deadline + entry.period @@ -224,6 +237,8 @@ class WeightedRoundRobinLoadBalancer implements LoadBalancer { private lastError: string | null = null; + private weightUpdateTimer: NodeJS.Timeout | null = null; + constructor(private readonly channelControlHelper: ChannelControlHelper) {} private countChildrenWithState(state: ConnectivityState) { @@ -280,16 +295,15 @@ class WeightedRoundRobinLoadBalancer implements LoadBalancer { if (entry.child.getConnectivityState() !== ConnectivityState.READY) { continue; } - if (entry.weight > 0) { - weightedPickers.push({ - endpointName: endpoint, - picker: entry.child.getPicker(), - weight: this.getWeight(entry) - }); - } + weightedPickers.push({ + endpointName: endpoint, + picker: entry.child.getPicker(), + weight: this.getWeight(entry) + }); } + trace('Created picker with weights: ' + weightedPickers.map(entry => entry.endpointName + ':' + entry.weight).join(',')); let metricsHandler: MetricsHandler | null; - if (this.latestConfig.getEnableOobLoadReport()) { + if (!this.latestConfig.getEnableOobLoadReport()) { metricsHandler = (loadReport, endpointName) => { const childEntry = this.children.get(endpointName); if (childEntry) { @@ -358,6 +372,16 @@ class WeightedRoundRobinLoadBalancer implements LoadBalancer { } return true; } + if (maybeEndpointList.value.length === 0) { + const errorMessage = `No addresses resolved. Resolution note: ${resolutionNote}`; + this.updateState( + ConnectivityState.TRANSIENT_FAILURE, + new UnavailablePicker({details: errorMessage}), + errorMessage + ); + return false; + } + trace('Connect to endpoint list ' + maybeEndpointList.value.map(endpointToString)); const now = new Date(); const seenEndpointNames = new Set(); this.updatesPaused = true; @@ -412,16 +436,28 @@ class WeightedRoundRobinLoadBalancer implements LoadBalancer { }; entry.child.addMetricsSubscription(entry.oobMetricsListener, lbConfig.getOobLoadReportingPeriodMs()); } + this.children.set(name, entry); } } for (const [endpointName, entry] of this.children) { - if (!seenEndpointNames.has(endpointName)) { + if (seenEndpointNames.has(endpointName)) { + entry.child.startConnecting(); + } else { entry.child.destroy(); this.children.delete(endpointName); } } + this.latestConfig = lbConfig; this.updatesPaused = false; this.calculateAndUpdateState(); + if (this.weightUpdateTimer) { + clearInterval(this.weightUpdateTimer); + } + this.weightUpdateTimer = setInterval(() => { + if (this.currentState === ConnectivityState.READY) { + this.calculateAndUpdateState(); + } + }, lbConfig.getWeightUpdatePeriodMs()).unref?.(); return true; } exitIdle(): void { @@ -437,6 +473,9 @@ class WeightedRoundRobinLoadBalancer implements LoadBalancer { entry.child.destroy(); } this.children.clear(); + if (this.weightUpdateTimer) { + clearInterval(this.weightUpdateTimer); + } } getTypeName(): string { return TYPE_NAME; diff --git a/packages/grpc-js/src/priority-queue.ts b/packages/grpc-js/src/priority-queue.ts index 6973cf38..3ddf8f82 100644 --- a/packages/grpc-js/src/priority-queue.ts +++ b/packages/grpc-js/src/priority-queue.ts @@ -15,34 +15,61 @@ * */ -// Implementation adapted from https://stackoverflow.com/a/42919752/159388 - const top = 0; const parent = (i: number) => Math.floor(i / 2); const left = (i: number) => i * 2 + 1; const right = (i: number) => i * 2 + 2; -export class PriorityQueue { +/** + * A generic priority queue implemented as an array-based binary heap. + * Adapted from https://stackoverflow.com/a/42919752/159388 + */ +export class PriorityQueue { private readonly heap: T[] = []; + /** + * + * @param comparator Returns true if the first argument should precede the + * second in the queue. Defaults to `(a, b) => a > b` + */ constructor(private readonly comparator = (a: T, b: T) => a > b) {} + /** + * @returns The number of items currently in the queue + */ size(): number { return this.heap.length; } + /** + * @returns True if there are no items in the queue, false otherwise + */ isEmpty(): boolean { return this.size() == 0; } - peek(): T { + /** + * Look at the front item that would be popped, without modifying the contents + * of the queue + * @returns The front item in the queue, or undefined if the queue is empty + */ + peek(): T | undefined { return this.heap[top]; } - push(...values: T[]) { + /** + * Add the items to the queue + * @param values The items to add + * @returns The new size of the queue after adding the items + */ + push(...values: T[]): number { values.forEach(value => { this.heap.push(value); this.siftUp(); }); return this.size(); } - pop(): T { + /** + * Remove the front item in the queue and return it + * @returns The front item in the queue, or undefined if the queue is empty + */ + pop(): T | undefined { const poppedValue = this.peek(); const bottom = this.size() - 1; if (bottom > top) { @@ -52,7 +79,13 @@ export class PriorityQueue { this.siftDown(); return poppedValue; } - replace(value: T): T { + /** + * Simultaneously remove the front item in the queue and add the provided + * item. + * @param value The item to add + * @returns The front item in the queue, or undefined if the queue is empty + */ + replace(value: T): T | undefined { const replacedValue = this.peek(); this.heap[top] = value; this.siftDown(); diff --git a/packages/grpc-js/src/resolver-ip.ts b/packages/grpc-js/src/resolver-ip.ts index 80837f86..76e13e3d 100644 --- a/packages/grpc-js/src/resolver-ip.ts +++ b/packages/grpc-js/src/resolver-ip.ts @@ -20,7 +20,7 @@ import { ChannelOptions } from './channel-options'; import { LogVerbosity, Status } from './constants'; import { Metadata } from './metadata'; import { registerResolver, Resolver, ResolverListener } from './resolver'; -import { Endpoint, SubchannelAddress } from './subchannel-address'; +import { Endpoint, SubchannelAddress, subchannelAddressToString } from './subchannel-address'; import { GrpcUri, splitHostPort, uriToString } from './uri-parser'; import * as logging from './logging'; @@ -85,7 +85,7 @@ class IpResolver implements Resolver { }); } this.endpoints = addresses.map(address => ({ addresses: [address] })); - trace('Parsed ' + target.scheme + ' address list ' + addresses); + trace('Parsed ' + target.scheme + ' address list ' + addresses.map(subchannelAddressToString)); } updateResolution(): void { if (!this.hasReturnedResult) { diff --git a/packages/grpc-js/src/single-subchannel-channel.ts b/packages/grpc-js/src/single-subchannel-channel.ts index 3b1e8629..c1a1fd1b 100644 --- a/packages/grpc-js/src/single-subchannel-channel.ts +++ b/packages/grpc-js/src/single-subchannel-channel.ts @@ -22,10 +22,12 @@ import { getNextCallNumber } from "./call-number"; import { Channel } from "./channel"; import { ChannelOptions } from "./channel-options"; import { ChannelRef, ChannelzCallTracker, ChannelzChildrenTracker, ChannelzTrace, registerChannelzChannel, unregisterChannelzRef } from "./channelz"; +import { CompressionFilterFactory } from "./compression-filter"; import { ConnectivityState } from "./connectivity-state"; import { Propagate, Status } from "./constants"; import { restrictControlPlaneStatusCode } from "./control-plane-status"; import { Deadline, getRelativeTimeout } from "./deadline"; +import { FilterStack, FilterStackFactory } from "./filter-stack"; import { Metadata } from "./metadata"; import { getDefaultAuthority } from "./resolver"; import { Subchannel } from "./subchannel"; @@ -40,7 +42,10 @@ class SubchannelCallWrapper implements Call { private halfClosePending = false; private pendingStatus: StatusObject | null = null; private serviceUrl: string; - constructor(private subchannel: Subchannel, private method: string, private options: CallStreamOptions, private callNumber: number) { + private filterStack: FilterStack; + private readFilterPending = false; + private writeFilterPending = false; + constructor(private subchannel: Subchannel, private method: string, filterStackFactory: FilterStackFactory, private options: CallStreamOptions, private callNumber: number) { const splitPath: string[] = this.method.split('/'); let serviceName = ''; /* The standard path format is "/{serviceName}/{methodName}", so if we split @@ -63,6 +68,7 @@ class SubchannelCallWrapper implements Call { }, timeout); } } + this.filterStack = filterStackFactory.createFilter(); } cancelWithStatus(status: Status, details: string): void { @@ -80,7 +86,7 @@ class SubchannelCallWrapper implements Call { getPeer(): string { return this.childCall?.getPeer() ?? this.subchannel.getAddress(); } - start(metadata: Metadata, listener: InterceptingListener): void { + async start(metadata: Metadata, listener: InterceptingListener): Promise { if (this.pendingStatus) { listener.onReceiveStatus(this.pendingStatus); return; @@ -93,38 +99,71 @@ class SubchannelCallWrapper implements Call { }); return; } - this.subchannel.getCallCredentials() - .generateMetadata({method_name: this.method, service_url: this.serviceUrl}) - .then(credsMetadata => { - this.childCall = this.subchannel.createCall(credsMetadata, this.options.host, this.method, listener); - if (this.readPending) { - this.childCall.startRead(); + const filteredMetadata = await this.filterStack.sendMetadata(Promise.resolve(metadata)); + let credsMetadata: Metadata; + try { + credsMetadata = await this.subchannel.getCallCredentials() + .generateMetadata({method_name: this.method, service_url: this.serviceUrl}); + } catch (e) { + const error = e as (Error & { code: number }); + const { code, details } = restrictControlPlaneStatusCode( + typeof error.code === 'number' ? error.code : Status.UNKNOWN, + `Getting metadata from plugin failed with error: ${error.message}` + ); + listener.onReceiveStatus( + { + code: code, + details: details, + metadata: new Metadata(), } - if (this.pendingMessage) { - this.childCall.sendMessageWithContext(this.pendingMessage.context, this.pendingMessage.message); + ); + return; + } + credsMetadata.merge(filteredMetadata); + const childListener: InterceptingListener = { + onReceiveMetadata: async metadata => { + listener.onReceiveMetadata(await this.filterStack.receiveMetadata(metadata)); + }, + onReceiveMessage: async message => { + this.readFilterPending = true; + const filteredMessage = await this.filterStack.receiveMessage(message); + this.readFilterPending = false; + listener.onReceiveMessage(filteredMessage); + if (this.pendingStatus) { + listener.onReceiveStatus(this.pendingStatus); } - if (this.halfClosePending) { - this.childCall.halfClose(); + }, + onReceiveStatus: async status => { + const filteredStatus = await this.filterStack.receiveTrailers(status); + if (this.readFilterPending) { + this.pendingStatus = filteredStatus; + } else { + listener.onReceiveStatus(filteredStatus); } - }, (error: Error & { code: number }) => { - const { code, details } = restrictControlPlaneStatusCode( - typeof error.code === 'number' ? error.code : Status.UNKNOWN, - `Getting metadata from plugin failed with error: ${error.message}` - ); - listener.onReceiveStatus( - { - code: code, - details: details, - metadata: new Metadata(), - } - ); - }); + } + } + this.childCall = this.subchannel.createCall(credsMetadata, this.options.host, this.method, childListener); + if (this.readPending) { + this.childCall.startRead(); + } + if (this.pendingMessage) { + this.childCall.sendMessageWithContext(this.pendingMessage.context, this.pendingMessage.message); + } + if (this.halfClosePending && !this.writeFilterPending) { + this.childCall.halfClose(); + } } - sendMessageWithContext(context: MessageContext, message: Buffer): void { + async sendMessageWithContext(context: MessageContext, message: Buffer): Promise { + this.writeFilterPending = true; + const filteredMessage = await this.filterStack.sendMessage(Promise.resolve({message: message, flags: context.flags})); + this.writeFilterPending = false; if (this.childCall) { - this.childCall.sendMessageWithContext(context, message); + this.childCall.sendMessageWithContext(context, filteredMessage.message); + if (this.halfClosePending) { + this.childCall.halfClose(); + } } else { - this.pendingMessage = { context, message }; + this.pendingMessage = { context, message: filteredMessage.message }; } } startRead(): void { @@ -135,7 +174,7 @@ class SubchannelCallWrapper implements Call { } } halfClose(): void { - if (this.childCall) { + if (this.childCall && !this.writeFilterPending) { this.childCall.halfClose(); } else { this.halfClosePending = true; @@ -162,6 +201,7 @@ export class SingleSubchannelChannel implements Channel { private channelzTrace = new ChannelzTrace(); private callTracker = new ChannelzCallTracker(); private childrenTracker = new ChannelzChildrenTracker(); + private filterStackFactory: FilterStackFactory; constructor(private subchannel: Subchannel, private target: GrpcUri, options: ChannelOptions) { this.channelzEnabled = options['grpc.enable_channelz'] !== 0; this.channelzRef = registerChannelzChannel(uriToString(target), () => ({ @@ -174,6 +214,7 @@ export class SingleSubchannelChannel implements Channel { if (this.channelzEnabled) { this.childrenTracker.refChild(subchannel.getChannelzRef()); } + this.filterStackFactory = new FilterStackFactory([new CompressionFilterFactory(this, options)]); } close(): void { @@ -202,6 +243,6 @@ export class SingleSubchannelChannel implements Channel { flags: Propagate.DEFAULTS, parentCall: null }; - return new SubchannelCallWrapper(this.subchannel, method, callOptions, getNextCallNumber()); + return new SubchannelCallWrapper(this.subchannel, method, this.filterStackFactory, callOptions, getNextCallNumber()); } } diff --git a/packages/grpc-js/test/generated/EchoMessage.ts b/packages/grpc-js/test/generated/EchoMessage.ts new file mode 100644 index 00000000..f273c17a --- /dev/null +++ b/packages/grpc-js/test/generated/EchoMessage.ts @@ -0,0 +1,12 @@ +// Original file: test/fixtures/echo_service.proto + + +export interface EchoMessage { + 'value'?: (string); + 'value2'?: (number); +} + +export interface EchoMessage__Output { + 'value': (string); + 'value2': (number); +} diff --git a/packages/grpc-js/test/generated/EchoService.ts b/packages/grpc-js/test/generated/EchoService.ts new file mode 100644 index 00000000..1999687f --- /dev/null +++ b/packages/grpc-js/test/generated/EchoService.ts @@ -0,0 +1,54 @@ +// Original file: test/fixtures/echo_service.proto + +import type * as grpc from './../../src/index' +import type { MethodDefinition } from '@grpc/proto-loader' +import type { EchoMessage as _EchoMessage, EchoMessage__Output as _EchoMessage__Output } from './EchoMessage'; + +export interface EchoServiceClient extends grpc.Client { + Echo(argument: _EchoMessage, metadata: grpc.Metadata, options: grpc.CallOptions, callback: grpc.requestCallback<_EchoMessage__Output>): grpc.ClientUnaryCall; + Echo(argument: _EchoMessage, metadata: grpc.Metadata, callback: grpc.requestCallback<_EchoMessage__Output>): grpc.ClientUnaryCall; + Echo(argument: _EchoMessage, options: grpc.CallOptions, callback: grpc.requestCallback<_EchoMessage__Output>): grpc.ClientUnaryCall; + Echo(argument: _EchoMessage, callback: grpc.requestCallback<_EchoMessage__Output>): grpc.ClientUnaryCall; + echo(argument: _EchoMessage, metadata: grpc.Metadata, options: grpc.CallOptions, callback: grpc.requestCallback<_EchoMessage__Output>): grpc.ClientUnaryCall; + echo(argument: _EchoMessage, metadata: grpc.Metadata, callback: grpc.requestCallback<_EchoMessage__Output>): grpc.ClientUnaryCall; + echo(argument: _EchoMessage, options: grpc.CallOptions, callback: grpc.requestCallback<_EchoMessage__Output>): grpc.ClientUnaryCall; + echo(argument: _EchoMessage, callback: grpc.requestCallback<_EchoMessage__Output>): grpc.ClientUnaryCall; + + EchoBidiStream(metadata: grpc.Metadata, options?: grpc.CallOptions): grpc.ClientDuplexStream<_EchoMessage, _EchoMessage__Output>; + EchoBidiStream(options?: grpc.CallOptions): grpc.ClientDuplexStream<_EchoMessage, _EchoMessage__Output>; + echoBidiStream(metadata: grpc.Metadata, options?: grpc.CallOptions): grpc.ClientDuplexStream<_EchoMessage, _EchoMessage__Output>; + echoBidiStream(options?: grpc.CallOptions): grpc.ClientDuplexStream<_EchoMessage, _EchoMessage__Output>; + + EchoClientStream(metadata: grpc.Metadata, options: grpc.CallOptions, callback: grpc.requestCallback<_EchoMessage__Output>): grpc.ClientWritableStream<_EchoMessage>; + EchoClientStream(metadata: grpc.Metadata, callback: grpc.requestCallback<_EchoMessage__Output>): grpc.ClientWritableStream<_EchoMessage>; + EchoClientStream(options: grpc.CallOptions, callback: grpc.requestCallback<_EchoMessage__Output>): grpc.ClientWritableStream<_EchoMessage>; + EchoClientStream(callback: grpc.requestCallback<_EchoMessage__Output>): grpc.ClientWritableStream<_EchoMessage>; + echoClientStream(metadata: grpc.Metadata, options: grpc.CallOptions, callback: grpc.requestCallback<_EchoMessage__Output>): grpc.ClientWritableStream<_EchoMessage>; + echoClientStream(metadata: grpc.Metadata, callback: grpc.requestCallback<_EchoMessage__Output>): grpc.ClientWritableStream<_EchoMessage>; + echoClientStream(options: grpc.CallOptions, callback: grpc.requestCallback<_EchoMessage__Output>): grpc.ClientWritableStream<_EchoMessage>; + echoClientStream(callback: grpc.requestCallback<_EchoMessage__Output>): grpc.ClientWritableStream<_EchoMessage>; + + EchoServerStream(argument: _EchoMessage, metadata: grpc.Metadata, options?: grpc.CallOptions): grpc.ClientReadableStream<_EchoMessage__Output>; + EchoServerStream(argument: _EchoMessage, options?: grpc.CallOptions): grpc.ClientReadableStream<_EchoMessage__Output>; + echoServerStream(argument: _EchoMessage, metadata: grpc.Metadata, options?: grpc.CallOptions): grpc.ClientReadableStream<_EchoMessage__Output>; + echoServerStream(argument: _EchoMessage, options?: grpc.CallOptions): grpc.ClientReadableStream<_EchoMessage__Output>; + +} + +export interface EchoServiceHandlers extends grpc.UntypedServiceImplementation { + Echo: grpc.handleUnaryCall<_EchoMessage__Output, _EchoMessage>; + + EchoBidiStream: grpc.handleBidiStreamingCall<_EchoMessage__Output, _EchoMessage>; + + EchoClientStream: grpc.handleClientStreamingCall<_EchoMessage__Output, _EchoMessage>; + + EchoServerStream: grpc.handleServerStreamingCall<_EchoMessage__Output, _EchoMessage>; + +} + +export interface EchoServiceDefinition extends grpc.ServiceDefinition { + Echo: MethodDefinition<_EchoMessage, _EchoMessage, _EchoMessage__Output, _EchoMessage__Output> + EchoBidiStream: MethodDefinition<_EchoMessage, _EchoMessage, _EchoMessage__Output, _EchoMessage__Output> + EchoClientStream: MethodDefinition<_EchoMessage, _EchoMessage, _EchoMessage__Output, _EchoMessage__Output> + EchoServerStream: MethodDefinition<_EchoMessage, _EchoMessage, _EchoMessage__Output, _EchoMessage__Output> +} diff --git a/packages/grpc-js/test/generated/echo_service.ts b/packages/grpc-js/test/generated/echo_service.ts new file mode 100644 index 00000000..4c215dc5 --- /dev/null +++ b/packages/grpc-js/test/generated/echo_service.ts @@ -0,0 +1,15 @@ +import type * as grpc from '../../src/index'; +import type { MessageTypeDefinition } from '@grpc/proto-loader'; + +import type { EchoMessage as _EchoMessage, EchoMessage__Output as _EchoMessage__Output } from './EchoMessage'; +import type { EchoServiceClient as _EchoServiceClient, EchoServiceDefinition as _EchoServiceDefinition } from './EchoService'; + +type SubtypeConstructor any, Subtype> = { + new(...args: ConstructorParameters): Subtype; +}; + +export interface ProtoGrpcType { + EchoMessage: MessageTypeDefinition<_EchoMessage, _EchoMessage__Output> + EchoService: SubtypeConstructor & { service: _EchoServiceDefinition } +} + diff --git a/packages/grpc-js/test/test-priority-queue.ts b/packages/grpc-js/test/test-priority-queue.ts new file mode 100644 index 00000000..9c666ca0 --- /dev/null +++ b/packages/grpc-js/test/test-priority-queue.ts @@ -0,0 +1,168 @@ +/* + * Copyright 2025 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +import * as assert from 'assert'; +import { PriorityQueue } from '../src/priority-queue'; + +describe('PriorityQueue', () => { + describe('size', () => { + it('Should be 0 initially', () => { + const queue = new PriorityQueue(); + assert.strictEqual(queue.size(), 0); + }); + it('Should be 1 after pushing one item', () => { + const queue = new PriorityQueue(); + queue.push(1); + assert.strictEqual(queue.size(), 1); + }); + it('Should be 0 after pushing and popping one item', () => { + const queue = new PriorityQueue(); + queue.push(1); + queue.pop(); + assert.strictEqual(queue.size(), 0); + }); + }); + describe('isEmpty', () => { + it('Should be true initially', () => { + const queue = new PriorityQueue(); + assert.strictEqual(queue.isEmpty(), true); + }); + it('Should be false after pushing one item', () => { + const queue = new PriorityQueue(); + queue.push(1); + assert.strictEqual(queue.isEmpty(), false); + }); + it('Should be 0 after pushing and popping one item', () => { + const queue = new PriorityQueue(); + queue.push(1); + queue.pop(); + assert.strictEqual(queue.isEmpty(), true); + }); + }); + describe('peek', () => { + it('Should return undefined initially', () => { + const queue = new PriorityQueue(); + assert.strictEqual(queue.peek(), undefined); + }); + it('Should return the same value multiple times', () => { + const queue = new PriorityQueue(); + queue.push(1); + assert.strictEqual(queue.peek(), 1); + assert.strictEqual(queue.peek(), 1); + }); + it('Should return the maximum of multiple values', () => { + const queue = new PriorityQueue(); + queue.push(1, 3, 8, 5, 6); + assert.strictEqual(queue.peek(), 8); + }); + it('Should return undefined after popping the last item', () => { + const queue = new PriorityQueue(); + queue.push(1); + queue.pop(); + assert.strictEqual(queue.peek(), undefined); + }); + }); + describe('pop', () => { + it('Should return undefined initially', () => { + const queue = new PriorityQueue(); + assert.strictEqual(queue.pop(), undefined); + }); + it('Should return a pushed item', () => { + const queue = new PriorityQueue(); + queue.push(1); + assert.strictEqual(queue.pop(), 1); + }); + it('Should return pushed items in decreasing order', () => { + const queue = new PriorityQueue(); + queue.push(1, 3, 8, 5, 6); + assert.strictEqual(queue.pop(), 8); + assert.strictEqual(queue.pop(), 6); + assert.strictEqual(queue.pop(), 5); + assert.strictEqual(queue.pop(), 3); + assert.strictEqual(queue.pop(), 1); + }); + it('Should return undefined after popping the last item', () => { + const queue = new PriorityQueue(); + queue.push(1); + queue.pop(); + assert.strictEqual(queue.pop(), undefined); + }); + }); + describe('replace', () => { + it('should return undefined initially', () => { + const queue = new PriorityQueue(); + assert.strictEqual(queue.replace(1), undefined); + }); + it('Should return a pushed item', () => { + const queue = new PriorityQueue(); + queue.push(1); + assert.strictEqual(queue.replace(2), 1); + }); + it('Should replace the max value if providing the new max', () => { + const queue = new PriorityQueue(); + queue.push(1, 3, 8, 5, 6); + assert.strictEqual(queue.replace(10), 8); + assert.strictEqual(queue.peek(), 10); + }); + it('Should not replace the max value if providing a lower value', () => { + const queue = new PriorityQueue(); + queue.push(1, 3, 8, 5, 6); + assert.strictEqual(queue.replace(4), 8); + assert.strictEqual(queue.peek(), 6); + }); + }); + describe('push', () => { + it('Should would the same with one call or multiple', () => { + const queue1 = new PriorityQueue(); + queue1.push(1, 3, 8, 5, 6); + assert.strictEqual(queue1.pop(), 8); + assert.strictEqual(queue1.pop(), 6); + assert.strictEqual(queue1.pop(), 5); + assert.strictEqual(queue1.pop(), 3); + assert.strictEqual(queue1.pop(), 1); + const queue2 = new PriorityQueue(); + queue2.push(1); + queue2.push(3); + queue2.push(8); + queue2.push(5); + queue2.push(6); + assert.strictEqual(queue2.pop(), 8); + assert.strictEqual(queue2.pop(), 6); + assert.strictEqual(queue2.pop(), 5); + assert.strictEqual(queue2.pop(), 3); + assert.strictEqual(queue2.pop(), 1); + }); + }); + describe('custom comparator', () => { + it('Should produce items in the reverse order with a reversed comparator', () => { + const queue = new PriorityQueue((a, b) => a < b); + queue.push(1, 3, 8, 5, 6); + assert.strictEqual(queue.pop(), 1); + assert.strictEqual(queue.pop(), 3); + assert.strictEqual(queue.pop(), 5); + assert.strictEqual(queue.pop(), 6); + assert.strictEqual(queue.pop(), 8); + }); + it('Should support other types', () => { + const queue = new PriorityQueue((a, b) => a.localeCompare(b) > 0); + queue.push('a', 'c', 'b'); + assert.strictEqual(queue.pop(), 'c'); + assert.strictEqual(queue.pop(), 'b'); + assert.strictEqual(queue.pop(), 'a'); + }); + }); +}); diff --git a/packages/grpc-js/test/test-weighted-round-robin.ts b/packages/grpc-js/test/test-weighted-round-robin.ts new file mode 100644 index 00000000..6f1fb6de --- /dev/null +++ b/packages/grpc-js/test/test-weighted-round-robin.ts @@ -0,0 +1,436 @@ +/* + * Copyright 2025 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +import * as assert from 'assert'; +import * as path from 'path'; + +import * as grpc from '../src'; +import { loadProtoFile } from './common'; +import { EchoServiceClient } from './generated/EchoService'; +import { ProtoGrpcType } from './generated/echo_service' +import { WeightedRoundRobinLoadBalancingConfig } from '../src/load-balancer-weighted-round-robin'; + +const protoFile = path.join(__dirname, 'fixtures', 'echo_service.proto'); +const EchoService = (loadProtoFile(protoFile) as unknown as ProtoGrpcType).EchoService; + +function makeNCalls(client: EchoServiceClient, count: number): Promise<{[serverId: string]: number}> { + return new Promise((resolve, reject) => { + const result: {[serverId: string]: number} = {}; + function makeOneCall(callsLeft: number) { + if (callsLeft <= 0) { + resolve(result); + } else { + const deadline = new Date(); + deadline.setMilliseconds(deadline.getMilliseconds() + 100); + const call= client.echo({}, {deadline}, (error, value) => { + if (error) { + reject(error); + return; + } + makeOneCall(callsLeft - 1); + }); + call.on('metadata', metadata => { + const serverEntry = metadata.get('server'); + if (serverEntry.length > 0) { + const serverId = serverEntry[0] as string; + if (!(serverId in result)) { + result[serverId] = 0; + } + result[serverId] += 1; + } + }); + } + } + makeOneCall(count); + }); +} + +function createServiceConfig(wrrConfig: object): grpc.ServiceConfig { + return { + methodConfig: [], + loadBalancingConfig: [ + {'weighted_round_robin': wrrConfig} + ] + }; +} + +function createClient(ports: number[], serviceConfig: grpc.ServiceConfig) { + return new EchoService(`ipv4:${ports.map(port => `127.0.0.1:${port}`).join(',')}`, grpc.credentials.createInsecure(), {'grpc.service_config': JSON.stringify(serviceConfig)}); +} + +function asyncTimeout(delay: number): Promise { + return new Promise(resolve => { + setTimeout(resolve, delay); + }); +} + +describe('Weighted round robin LB policy', () => { + describe('Config parsing', () => { + it('Should have default values with an empty object', () => { + const config = WeightedRoundRobinLoadBalancingConfig.createFromJson({}); + assert.strictEqual(config.getEnableOobLoadReport(), false); + assert.strictEqual(config.getBlackoutPeriodMs(), 10_000); + assert.strictEqual(config.getErrorUtilizationPenalty(), 1); + assert.strictEqual(config.getOobLoadReportingPeriodMs(), 10_000); + assert.strictEqual(config.getWeightExpirationPeriodMs(), 180_000); + assert.strictEqual(config.getWeightUpdatePeriodMs(), 1_000); + }); + it('Should handle enable_oob_load_report', () => { + const config = WeightedRoundRobinLoadBalancingConfig.createFromJson({ + enable_oob_load_report: true + }); + assert.strictEqual(config.getEnableOobLoadReport(), true); + }); + it('Should handle error_utilization_penalty', () => { + const config = WeightedRoundRobinLoadBalancingConfig.createFromJson({ + error_utilization_penalty: 0.5 + }); + assert.strictEqual(config.getErrorUtilizationPenalty(), 0.5); + }); + it('Should reject negative error_utilization_penalty', () => { + const loadBalancingConfig = { + error_utilization_penalty: -1 + }; + assert.throws(() => { + WeightedRoundRobinLoadBalancingConfig.createFromJson(loadBalancingConfig); + }, /error_utilization_penalty < 0/); + }); + it('Should handle blackout_period as a string', () => { + const config = WeightedRoundRobinLoadBalancingConfig.createFromJson({ + blackout_period: '1s' + }); + assert.strictEqual(config.getBlackoutPeriodMs(), 1_000); + }); + it('Should handle blackout_period as an object', () => { + const config = WeightedRoundRobinLoadBalancingConfig.createFromJson({ + blackout_period: { + seconds: 1, + nanos: 0 + } + }); + assert.strictEqual(config.getBlackoutPeriodMs(), 1_000); + }); + it('Should handle oob_load_reporting_period as a string', () => { + const config = WeightedRoundRobinLoadBalancingConfig.createFromJson({ + oob_load_reporting_period: '1s' + }); + assert.strictEqual(config.getOobLoadReportingPeriodMs(), 1_000); + }); + it('Should handle oob_load_reporting_period as an object', () => { + const config = WeightedRoundRobinLoadBalancingConfig.createFromJson({ + oob_load_reporting_period: { + seconds: 1, + nanos: 0 + } + }); + assert.strictEqual(config.getOobLoadReportingPeriodMs(), 1_000); + }); + it('Should handle weight_expiration_period as a string', () => { + const config = WeightedRoundRobinLoadBalancingConfig.createFromJson({ + weight_expiration_period: '1s' + }); + assert.strictEqual(config.getWeightExpirationPeriodMs(), 1_000); + }); + it('Should handle weight_expiration_period as an object', () => { + const config = WeightedRoundRobinLoadBalancingConfig.createFromJson({ + weight_expiration_period: { + seconds: 1, + nanos: 0 + } + }); + assert.strictEqual(config.getWeightExpirationPeriodMs(), 1_000); + }); + it('Should handle weight_update_period as a string', () => { + const config = WeightedRoundRobinLoadBalancingConfig.createFromJson({ + weight_update_period: '2s' + }); + assert.strictEqual(config.getWeightUpdatePeriodMs(), 2_000); + }); + it('Should handle weight_update_period as an object', () => { + const config = WeightedRoundRobinLoadBalancingConfig.createFromJson({ + weight_update_period: { + seconds: 2, + nanos: 0 + } + }); + assert.strictEqual(config.getWeightUpdatePeriodMs(), 2_000); + }); + it('Should cap weight_update_period to a minimum of 0.1s', () => { + const config = WeightedRoundRobinLoadBalancingConfig.createFromJson({ + weight_update_period: '0.01s' + }); + assert.strictEqual(config.getWeightUpdatePeriodMs(), 100); + }); + }); + describe('Per-call metrics', () => { + const server1Metrics = { + qps: 0, + utilization: 0, + eps: 0 + }; + const server2Metrics = { + qps: 0, + utilization: 0, + eps: 0 + }; + const server1 = new grpc.Server({'grpc.server_call_metric_recording': 1}); + const server2 = new grpc.Server({'grpc.server_call_metric_recording': 1}); + const server1Impl = { + echo: ( + call: grpc.ServerUnaryCall, + callback: grpc.sendUnaryData + ) => { + const metricsRecorder = call.getMetricsRecorder(); + metricsRecorder.recordQpsMetric(server1Metrics.qps); + metricsRecorder.recordApplicationUtilizationMetric(server1Metrics.utilization); + metricsRecorder.recordEpsMetric(server1Metrics.eps); + const metadata = new grpc.Metadata(); + metadata.set('server', '1'); + call.sendMetadata(metadata); + callback(null, call.request); + }, + }; + const server2Impl = { + echo: ( + call: grpc.ServerUnaryCall, + callback: grpc.sendUnaryData + ) => { + const metricsRecorder = call.getMetricsRecorder(); + metricsRecorder.recordQpsMetric(server2Metrics.qps); + metricsRecorder.recordApplicationUtilizationMetric(server2Metrics.utilization); + metricsRecorder.recordEpsMetric(server2Metrics.eps); + const metadata = new grpc.Metadata(); + metadata.set('server', '2'); + call.sendMetadata(metadata); + callback(null, call.request); + }, + }; + let port1: number; + let port2: number; + let client: EchoServiceClient | null = null; + before(done => { + const creds = grpc.ServerCredentials.createInsecure(); + server1.addService(EchoService.service, server1Impl); + server2.addService(EchoService.service, server2Impl); + server1.bindAsync('localhost:0', creds, (error1, server1Port) => { + if (error1) { + done(error1); + return; + } + port1 = server1Port; + server2.bindAsync('localhost:0', creds, (error2, server2Port) => { + if (error2) { + done(error2); + return; + } + port2 = server2Port; + done(); + }); + }); + }); + beforeEach(() => { + server1Metrics.qps = 0; + server1Metrics.utilization = 0; + server1Metrics.eps = 0; + server2Metrics.qps = 0; + server2Metrics.utilization = 0; + server2Metrics.eps = 0; + }); + afterEach(() => { + client?.close(); + client = null; + }); + after(() => { + server1.forceShutdown(); + server2.forceShutdown(); + }); + it('Should evenly balance among endpoints with no weight', async () => { + const serviceConfig = createServiceConfig({}); + client = createClient([port1, port2], serviceConfig); + await makeNCalls(client, 10); + const result = await makeNCalls(client, 30); + assert(Math.abs(result['1'] - result['2']) < 3, `server1: ${result['1']}, server2: ${result[2]}`); + }); + it('Should send more requests to endpoints with higher QPS', async () => { + const serviceConfig = createServiceConfig({ + blackout_period: '0.01s', + weight_update_period: '0.1s' + }); + client = createClient([port1, port2], serviceConfig); + server1Metrics.qps = 3; + server1Metrics.utilization = 1; + server2Metrics.qps = 1; + server2Metrics.utilization = 1; + await makeNCalls(client, 10); + await asyncTimeout(100); + const result = await makeNCalls(client, 40); + assert(Math.abs(result['1'] - 30) < 2, `server1: ${result['1']}, server2: ${result['2']}`); + }); + // Calls aren't fast enough for this to work consistently + it.skip('Should wait for the blackout period to apply weights', async () => { + const serviceConfig = createServiceConfig({ + blackout_period: '0.5s' + }); + client = createClient([port1, port2], serviceConfig); + server1Metrics.qps = 3; + server1Metrics.utilization = 1; + server2Metrics.qps = 1; + server2Metrics.utilization = 1; + await makeNCalls(client, 10); + await asyncTimeout(100); + const result1 = await makeNCalls(client, 20); + assert(Math.abs(result1['1'] - result1['2']) < 3, `result1: server1: ${result1['1']}, server2: ${result1[2]}`); + await asyncTimeout(400); + const result2 = await makeNCalls(client, 40); + assert(Math.abs(result2['1'] - 30) < 2, `result2: server1: ${result2['1']}, server2: ${result2['2']}`); + }).timeout(3000); + // Calls aren't fast enough for this to work consistently + it.skip('Should wait for the weight update period to apply weights', async () => { + const serviceConfig = createServiceConfig({ + blackout_period: '0.01s', + weight_update_period: '1s' + }); + client = createClient([port1, port2], serviceConfig); + server1Metrics.qps = 3; + server1Metrics.utilization = 1; + server2Metrics.qps = 1; + server2Metrics.utilization = 1; + await makeNCalls(client, 10); + await asyncTimeout(100); + const result1 = await makeNCalls(client, 20); + assert(Math.abs(result1['1'] - result1['2']) < 3, `result1: server1: ${result1['1']}, server2: ${result1[2]}`); + await asyncTimeout(400); + const result2 = await makeNCalls(client, 40); + assert(Math.abs(result2['1'] - 30) < 2, `result2: server1: ${result2['1']}, server2: ${result2['2']}`); + }).timeout(3000); + it('Should send more requests to endpoints with lower EPS', async () => { + const serviceConfig = createServiceConfig({ + blackout_period: '0.01s', + weight_update_period: '0.1s', + error_utilization_penalty: 1 + }); + client = createClient([port1, port2], serviceConfig); + server1Metrics.qps = 2; + server1Metrics.utilization = 1; + server1Metrics.eps = 0; + server2Metrics.qps = 2; + server2Metrics.utilization = 1; + server2Metrics.eps = 2; + await makeNCalls(client, 10); + await asyncTimeout(100); + const result = await makeNCalls(client, 30); + assert(Math.abs(result['1'] - 20) < 3, `server1: ${result['1']}, server2: ${result['2']}`); + }); + }); + describe('Out of band metrics', () => { + const server1MetricRecorder = new grpc.ServerMetricRecorder(); + const server2MetricRecorder = new grpc.ServerMetricRecorder(); + const server1 = new grpc.Server(); + const server2 = new grpc.Server(); + const server1Impl = { + echo: ( + call: grpc.ServerUnaryCall, + callback: grpc.sendUnaryData + ) => { + const metadata = new grpc.Metadata(); + metadata.set('server', '1'); + call.sendMetadata(metadata); + callback(null, call.request); + }, + }; + const server2Impl = { + echo: ( + call: grpc.ServerUnaryCall, + callback: grpc.sendUnaryData + ) => { + const metadata = new grpc.Metadata(); + metadata.set('server', '2'); + call.sendMetadata(metadata); + callback(null, call.request); + }, + }; + let port1: number; + let port2: number; + let client: EchoServiceClient | null = null; + before(done => { + const creds = grpc.ServerCredentials.createInsecure(); + server1.addService(EchoService.service, server1Impl); + server2.addService(EchoService.service, server2Impl); + server1MetricRecorder.addToServer(server1); + server2MetricRecorder.addToServer(server2); + server1.bindAsync('localhost:0', creds, (error1, server1Port) => { + if (error1) { + done(error1); + return; + } + port1 = server1Port; + server2.bindAsync('localhost:0', creds, (error2, server2Port) => { + if (error2) { + done(error2); + return; + } + port2 = server2Port; + done(); + }); + }); + }); + beforeEach(() => { + server1MetricRecorder.deleteQpsMetric(); + server1MetricRecorder.deleteEpsMetric(); + server1MetricRecorder.deleteApplicationUtilizationMetric(); + server2MetricRecorder.deleteQpsMetric(); + server2MetricRecorder.deleteEpsMetric(); + server2MetricRecorder.deleteApplicationUtilizationMetric(); + }); + afterEach(() => { + client?.close(); + client = null; + }); + after(() => { + server1.forceShutdown(); + server2.forceShutdown(); + }); + it('Should evenly balance among endpoints with no weight', async () => { + const serviceConfig = createServiceConfig({ + enable_oob_load_report: true, + oob_load_reporting_period: '0.01s', + blackout_period: '0.01s' + }); + client = createClient([port1, port2], serviceConfig); + await makeNCalls(client, 10); + const result = await makeNCalls(client, 30); + assert(Math.abs(result['1'] - result['2']) < 3, `server1: ${result['1']}, server2: ${result[2]}`); + }); + it('Should send more requests to endpoints with higher QPS', async () => { + const serviceConfig = createServiceConfig({ + enable_oob_load_report: true, + oob_load_reporting_period: '0.01s', + blackout_period: '0.01s', + weight_update_period: '0.1s' + }); + client = createClient([port1, port2], serviceConfig); + server1MetricRecorder.setQpsMetric(3); + server1MetricRecorder.setApplicationUtilizationMetric(1); + server2MetricRecorder.setQpsMetric(1); + server2MetricRecorder.setApplicationUtilizationMetric(1); + await makeNCalls(client, 10); + await asyncTimeout(100); + const result = await makeNCalls(client, 40); + assert(Math.abs(result['1'] - 30) < 2, `server1: ${result['1']}, server2: ${result['2']}`); + }); + }); +});