grpc-js: Implement weighted_round_robin LB policy

This commit is contained in:
Michael Lumish 2025-08-14 13:44:22 -07:00
parent 83ece61c88
commit 13065ad5a1
11 changed files with 862 additions and 59 deletions

View File

@ -64,7 +64,7 @@
"pretest": "npm run generate-types && npm run generate-test-types && npm run compile",
"posttest": "npm run check && madge -c ./build/src",
"generate-types": "proto-loader-gen-types --keepCase --longs String --enums String --defaults --oneofs --includeComments --includeDirs proto/ --include-dirs proto/ proto/xds/ proto/protoc-gen-validate/ -O src/generated/ --grpcLib ../index channelz.proto xds/service/orca/v3/orca.proto",
"generate-test-types": "proto-loader-gen-types --keepCase --longs String --enums String --defaults --oneofs --includeComments --include-dirs test/fixtures/ -O test/generated/ --grpcLib ../../src/index test_service.proto",
"generate-test-types": "proto-loader-gen-types --keepCase --longs String --enums String --defaults --oneofs --includeComments --include-dirs test/fixtures/ -O test/generated/ --grpcLib ../../src/index test_service.proto echo_service.proto",
"copy-protos": "node ./copy-protos"
},
"dependencies": {

View File

@ -588,6 +588,11 @@ export class PickFirstLoadBalancer implements LoadBalancer {
destroy() {
this.resetSubchannelList();
this.removeCurrentPick();
this.metricsCall?.cancel();
this.metricsCall = null;
this.orcaClient?.close();
this.orcaClient = null;
this.metricsBackoffTimer.stop();
}
getTypeName(): string {

View File

@ -71,22 +71,24 @@ function validateFieldType(
function parseDurationField(obj: any, fieldName: string): number | null {
if (fieldName in obj && obj[fieldName] !== undefined) {
let durationObject: Duration | null;
let durationObject: Duration;
if (isDuration(obj[fieldName])) {
durationObject = obj[fieldName];
} else if (typeof obj[fieldName] === 'string') {
durationObject = parseDuration(obj[fieldName]);
const parsedDuration = parseDuration(obj[fieldName]);
if (!parsedDuration) {
throw new Error(`weighted round robin config ${fieldName}: failed to parse duration string ${obj[fieldName]}`);
}
durationObject = parsedDuration;
} else {
durationObject = null;
}
if (durationObject) {
return durationToMs(durationObject);
throw new Error(`weighted round robin config ${fieldName}: expected duration, got ${typeof obj[fieldName]}`);
}
return durationToMs(durationObject);
}
return null;
}
class WeightedRoundRobinLoadBalancingConfig implements TypedLoadBalancingConfig {
export class WeightedRoundRobinLoadBalancingConfig implements TypedLoadBalancingConfig {
private readonly enableOobLoadReport: boolean;
private readonly oobLoadReportingPeriodMs: number;
private readonly blackoutPeriodMs: number;
@ -106,7 +108,7 @@ class WeightedRoundRobinLoadBalancingConfig implements TypedLoadBalancingConfig
this.oobLoadReportingPeriodMs = oobLoadReportingPeriodMs ?? DEFAULT_OOB_REPORTING_PERIOD_MS;
this.blackoutPeriodMs = blackoutPeriodMs ?? DEFAULT_BLACKOUT_PERIOD_MS;
this.weightExpirationPeriodMs = weightExpirationPeriodMs ?? DEFAULT_WEIGHT_EXPIRATION_PERIOD_MS;
this.weightUpdatePeriodMs = Math.min(weightUpdatePeriodMs ?? DEFAULT_WEIGHT_UPDATE_PERIOD_MS, 100);
this.weightUpdatePeriodMs = Math.max(weightUpdatePeriodMs ?? DEFAULT_WEIGHT_UPDATE_PERIOD_MS, 100);
this.errorUtilizationPenalty = errorUtilizationPenalty ?? DEFAULT_ERROR_UTILIZATION_PENALTY;
}
@ -177,8 +179,19 @@ type MetricsHandler = (loadReport: OrcaLoadReport__Output, endpointName: string)
class WeightedRoundRobinPicker implements Picker {
private queue: PriorityQueue<QueueEntry> = new PriorityQueue((a, b) => a.deadline < b.deadline);
constructor(children: WeightedPicker[], private readonly metricsHandler: MetricsHandler | null) {
const positiveWeight = children.filter(picker => picker.weight > 0);
let averageWeight: number;
if (positiveWeight.length < 2) {
averageWeight = 1;
} else {
let weightSum: number = 0;
for (const { weight } of positiveWeight) {
weightSum += weight;
}
averageWeight = weightSum / positiveWeight.length;
}
for (const child of children) {
const period = 1 / child.weight;
const period = child.weight > 0 ? 1 / child.weight : averageWeight;
this.queue.push({
endpointName: child.endpointName,
picker: child.picker,
@ -188,7 +201,7 @@ class WeightedRoundRobinPicker implements Picker {
}
}
pick(pickArgs: PickArgs): PickResult {
const entry = this.queue.pop();
const entry = this.queue.pop()!;
this.queue.push({
...entry,
deadline: entry.deadline + entry.period
@ -224,6 +237,8 @@ class WeightedRoundRobinLoadBalancer implements LoadBalancer {
private lastError: string | null = null;
private weightUpdateTimer: NodeJS.Timeout | null = null;
constructor(private readonly channelControlHelper: ChannelControlHelper) {}
private countChildrenWithState(state: ConnectivityState) {
@ -280,16 +295,15 @@ class WeightedRoundRobinLoadBalancer implements LoadBalancer {
if (entry.child.getConnectivityState() !== ConnectivityState.READY) {
continue;
}
if (entry.weight > 0) {
weightedPickers.push({
endpointName: endpoint,
picker: entry.child.getPicker(),
weight: this.getWeight(entry)
});
}
weightedPickers.push({
endpointName: endpoint,
picker: entry.child.getPicker(),
weight: this.getWeight(entry)
});
}
trace('Created picker with weights: ' + weightedPickers.map(entry => entry.endpointName + ':' + entry.weight).join(','));
let metricsHandler: MetricsHandler | null;
if (this.latestConfig.getEnableOobLoadReport()) {
if (!this.latestConfig.getEnableOobLoadReport()) {
metricsHandler = (loadReport, endpointName) => {
const childEntry = this.children.get(endpointName);
if (childEntry) {
@ -358,6 +372,16 @@ class WeightedRoundRobinLoadBalancer implements LoadBalancer {
}
return true;
}
if (maybeEndpointList.value.length === 0) {
const errorMessage = `No addresses resolved. Resolution note: ${resolutionNote}`;
this.updateState(
ConnectivityState.TRANSIENT_FAILURE,
new UnavailablePicker({details: errorMessage}),
errorMessage
);
return false;
}
trace('Connect to endpoint list ' + maybeEndpointList.value.map(endpointToString));
const now = new Date();
const seenEndpointNames = new Set<string>();
this.updatesPaused = true;
@ -412,16 +436,28 @@ class WeightedRoundRobinLoadBalancer implements LoadBalancer {
};
entry.child.addMetricsSubscription(entry.oobMetricsListener, lbConfig.getOobLoadReportingPeriodMs());
}
this.children.set(name, entry);
}
}
for (const [endpointName, entry] of this.children) {
if (!seenEndpointNames.has(endpointName)) {
if (seenEndpointNames.has(endpointName)) {
entry.child.startConnecting();
} else {
entry.child.destroy();
this.children.delete(endpointName);
}
}
this.latestConfig = lbConfig;
this.updatesPaused = false;
this.calculateAndUpdateState();
if (this.weightUpdateTimer) {
clearInterval(this.weightUpdateTimer);
}
this.weightUpdateTimer = setInterval(() => {
if (this.currentState === ConnectivityState.READY) {
this.calculateAndUpdateState();
}
}, lbConfig.getWeightUpdatePeriodMs()).unref?.();
return true;
}
exitIdle(): void {
@ -437,6 +473,9 @@ class WeightedRoundRobinLoadBalancer implements LoadBalancer {
entry.child.destroy();
}
this.children.clear();
if (this.weightUpdateTimer) {
clearInterval(this.weightUpdateTimer);
}
}
getTypeName(): string {
return TYPE_NAME;

View File

@ -15,34 +15,61 @@
*
*/
// Implementation adapted from https://stackoverflow.com/a/42919752/159388
const top = 0;
const parent = (i: number) => Math.floor(i / 2);
const left = (i: number) => i * 2 + 1;
const right = (i: number) => i * 2 + 2;
export class PriorityQueue<T> {
/**
* A generic priority queue implemented as an array-based binary heap.
* Adapted from https://stackoverflow.com/a/42919752/159388
*/
export class PriorityQueue<T=number> {
private readonly heap: T[] = [];
/**
*
* @param comparator Returns true if the first argument should precede the
* second in the queue. Defaults to `(a, b) => a > b`
*/
constructor(private readonly comparator = (a: T, b: T) => a > b) {}
/**
* @returns The number of items currently in the queue
*/
size(): number {
return this.heap.length;
}
/**
* @returns True if there are no items in the queue, false otherwise
*/
isEmpty(): boolean {
return this.size() == 0;
}
peek(): T {
/**
* Look at the front item that would be popped, without modifying the contents
* of the queue
* @returns The front item in the queue, or undefined if the queue is empty
*/
peek(): T | undefined {
return this.heap[top];
}
push(...values: T[]) {
/**
* Add the items to the queue
* @param values The items to add
* @returns The new size of the queue after adding the items
*/
push(...values: T[]): number {
values.forEach(value => {
this.heap.push(value);
this.siftUp();
});
return this.size();
}
pop(): T {
/**
* Remove the front item in the queue and return it
* @returns The front item in the queue, or undefined if the queue is empty
*/
pop(): T | undefined {
const poppedValue = this.peek();
const bottom = this.size() - 1;
if (bottom > top) {
@ -52,7 +79,13 @@ export class PriorityQueue<T> {
this.siftDown();
return poppedValue;
}
replace(value: T): T {
/**
* Simultaneously remove the front item in the queue and add the provided
* item.
* @param value The item to add
* @returns The front item in the queue, or undefined if the queue is empty
*/
replace(value: T): T | undefined {
const replacedValue = this.peek();
this.heap[top] = value;
this.siftDown();

View File

@ -20,7 +20,7 @@ import { ChannelOptions } from './channel-options';
import { LogVerbosity, Status } from './constants';
import { Metadata } from './metadata';
import { registerResolver, Resolver, ResolverListener } from './resolver';
import { Endpoint, SubchannelAddress } from './subchannel-address';
import { Endpoint, SubchannelAddress, subchannelAddressToString } from './subchannel-address';
import { GrpcUri, splitHostPort, uriToString } from './uri-parser';
import * as logging from './logging';
@ -85,7 +85,7 @@ class IpResolver implements Resolver {
});
}
this.endpoints = addresses.map(address => ({ addresses: [address] }));
trace('Parsed ' + target.scheme + ' address list ' + addresses);
trace('Parsed ' + target.scheme + ' address list ' + addresses.map(subchannelAddressToString));
}
updateResolution(): void {
if (!this.hasReturnedResult) {

View File

@ -22,10 +22,12 @@ import { getNextCallNumber } from "./call-number";
import { Channel } from "./channel";
import { ChannelOptions } from "./channel-options";
import { ChannelRef, ChannelzCallTracker, ChannelzChildrenTracker, ChannelzTrace, registerChannelzChannel, unregisterChannelzRef } from "./channelz";
import { CompressionFilterFactory } from "./compression-filter";
import { ConnectivityState } from "./connectivity-state";
import { Propagate, Status } from "./constants";
import { restrictControlPlaneStatusCode } from "./control-plane-status";
import { Deadline, getRelativeTimeout } from "./deadline";
import { FilterStack, FilterStackFactory } from "./filter-stack";
import { Metadata } from "./metadata";
import { getDefaultAuthority } from "./resolver";
import { Subchannel } from "./subchannel";
@ -40,7 +42,10 @@ class SubchannelCallWrapper implements Call {
private halfClosePending = false;
private pendingStatus: StatusObject | null = null;
private serviceUrl: string;
constructor(private subchannel: Subchannel, private method: string, private options: CallStreamOptions, private callNumber: number) {
private filterStack: FilterStack;
private readFilterPending = false;
private writeFilterPending = false;
constructor(private subchannel: Subchannel, private method: string, filterStackFactory: FilterStackFactory, private options: CallStreamOptions, private callNumber: number) {
const splitPath: string[] = this.method.split('/');
let serviceName = '';
/* The standard path format is "/{serviceName}/{methodName}", so if we split
@ -63,6 +68,7 @@ class SubchannelCallWrapper implements Call {
}, timeout);
}
}
this.filterStack = filterStackFactory.createFilter();
}
cancelWithStatus(status: Status, details: string): void {
@ -80,7 +86,7 @@ class SubchannelCallWrapper implements Call {
getPeer(): string {
return this.childCall?.getPeer() ?? this.subchannel.getAddress();
}
start(metadata: Metadata, listener: InterceptingListener): void {
async start(metadata: Metadata, listener: InterceptingListener): Promise<void> {
if (this.pendingStatus) {
listener.onReceiveStatus(this.pendingStatus);
return;
@ -93,38 +99,71 @@ class SubchannelCallWrapper implements Call {
});
return;
}
this.subchannel.getCallCredentials()
.generateMetadata({method_name: this.method, service_url: this.serviceUrl})
.then(credsMetadata => {
this.childCall = this.subchannel.createCall(credsMetadata, this.options.host, this.method, listener);
if (this.readPending) {
this.childCall.startRead();
const filteredMetadata = await this.filterStack.sendMetadata(Promise.resolve(metadata));
let credsMetadata: Metadata;
try {
credsMetadata = await this.subchannel.getCallCredentials()
.generateMetadata({method_name: this.method, service_url: this.serviceUrl});
} catch (e) {
const error = e as (Error & { code: number });
const { code, details } = restrictControlPlaneStatusCode(
typeof error.code === 'number' ? error.code : Status.UNKNOWN,
`Getting metadata from plugin failed with error: ${error.message}`
);
listener.onReceiveStatus(
{
code: code,
details: details,
metadata: new Metadata(),
}
if (this.pendingMessage) {
this.childCall.sendMessageWithContext(this.pendingMessage.context, this.pendingMessage.message);
);
return;
}
credsMetadata.merge(filteredMetadata);
const childListener: InterceptingListener = {
onReceiveMetadata: async metadata => {
listener.onReceiveMetadata(await this.filterStack.receiveMetadata(metadata));
},
onReceiveMessage: async message => {
this.readFilterPending = true;
const filteredMessage = await this.filterStack.receiveMessage(message);
this.readFilterPending = false;
listener.onReceiveMessage(filteredMessage);
if (this.pendingStatus) {
listener.onReceiveStatus(this.pendingStatus);
}
if (this.halfClosePending) {
this.childCall.halfClose();
},
onReceiveStatus: async status => {
const filteredStatus = await this.filterStack.receiveTrailers(status);
if (this.readFilterPending) {
this.pendingStatus = filteredStatus;
} else {
listener.onReceiveStatus(filteredStatus);
}
}, (error: Error & { code: number }) => {
const { code, details } = restrictControlPlaneStatusCode(
typeof error.code === 'number' ? error.code : Status.UNKNOWN,
`Getting metadata from plugin failed with error: ${error.message}`
);
listener.onReceiveStatus(
{
code: code,
details: details,
metadata: new Metadata(),
}
);
});
}
}
this.childCall = this.subchannel.createCall(credsMetadata, this.options.host, this.method, childListener);
if (this.readPending) {
this.childCall.startRead();
}
if (this.pendingMessage) {
this.childCall.sendMessageWithContext(this.pendingMessage.context, this.pendingMessage.message);
}
if (this.halfClosePending && !this.writeFilterPending) {
this.childCall.halfClose();
}
}
sendMessageWithContext(context: MessageContext, message: Buffer): void {
async sendMessageWithContext(context: MessageContext, message: Buffer): Promise<void> {
this.writeFilterPending = true;
const filteredMessage = await this.filterStack.sendMessage(Promise.resolve({message: message, flags: context.flags}));
this.writeFilterPending = false;
if (this.childCall) {
this.childCall.sendMessageWithContext(context, message);
this.childCall.sendMessageWithContext(context, filteredMessage.message);
if (this.halfClosePending) {
this.childCall.halfClose();
}
} else {
this.pendingMessage = { context, message };
this.pendingMessage = { context, message: filteredMessage.message };
}
}
startRead(): void {
@ -135,7 +174,7 @@ class SubchannelCallWrapper implements Call {
}
}
halfClose(): void {
if (this.childCall) {
if (this.childCall && !this.writeFilterPending) {
this.childCall.halfClose();
} else {
this.halfClosePending = true;
@ -162,6 +201,7 @@ export class SingleSubchannelChannel implements Channel {
private channelzTrace = new ChannelzTrace();
private callTracker = new ChannelzCallTracker();
private childrenTracker = new ChannelzChildrenTracker();
private filterStackFactory: FilterStackFactory;
constructor(private subchannel: Subchannel, private target: GrpcUri, options: ChannelOptions) {
this.channelzEnabled = options['grpc.enable_channelz'] !== 0;
this.channelzRef = registerChannelzChannel(uriToString(target), () => ({
@ -174,6 +214,7 @@ export class SingleSubchannelChannel implements Channel {
if (this.channelzEnabled) {
this.childrenTracker.refChild(subchannel.getChannelzRef());
}
this.filterStackFactory = new FilterStackFactory([new CompressionFilterFactory(this, options)]);
}
close(): void {
@ -202,6 +243,6 @@ export class SingleSubchannelChannel implements Channel {
flags: Propagate.DEFAULTS,
parentCall: null
};
return new SubchannelCallWrapper(this.subchannel, method, callOptions, getNextCallNumber());
return new SubchannelCallWrapper(this.subchannel, method, this.filterStackFactory, callOptions, getNextCallNumber());
}
}

View File

@ -0,0 +1,12 @@
// Original file: test/fixtures/echo_service.proto
export interface EchoMessage {
'value'?: (string);
'value2'?: (number);
}
export interface EchoMessage__Output {
'value': (string);
'value2': (number);
}

View File

@ -0,0 +1,54 @@
// Original file: test/fixtures/echo_service.proto
import type * as grpc from './../../src/index'
import type { MethodDefinition } from '@grpc/proto-loader'
import type { EchoMessage as _EchoMessage, EchoMessage__Output as _EchoMessage__Output } from './EchoMessage';
export interface EchoServiceClient extends grpc.Client {
Echo(argument: _EchoMessage, metadata: grpc.Metadata, options: grpc.CallOptions, callback: grpc.requestCallback<_EchoMessage__Output>): grpc.ClientUnaryCall;
Echo(argument: _EchoMessage, metadata: grpc.Metadata, callback: grpc.requestCallback<_EchoMessage__Output>): grpc.ClientUnaryCall;
Echo(argument: _EchoMessage, options: grpc.CallOptions, callback: grpc.requestCallback<_EchoMessage__Output>): grpc.ClientUnaryCall;
Echo(argument: _EchoMessage, callback: grpc.requestCallback<_EchoMessage__Output>): grpc.ClientUnaryCall;
echo(argument: _EchoMessage, metadata: grpc.Metadata, options: grpc.CallOptions, callback: grpc.requestCallback<_EchoMessage__Output>): grpc.ClientUnaryCall;
echo(argument: _EchoMessage, metadata: grpc.Metadata, callback: grpc.requestCallback<_EchoMessage__Output>): grpc.ClientUnaryCall;
echo(argument: _EchoMessage, options: grpc.CallOptions, callback: grpc.requestCallback<_EchoMessage__Output>): grpc.ClientUnaryCall;
echo(argument: _EchoMessage, callback: grpc.requestCallback<_EchoMessage__Output>): grpc.ClientUnaryCall;
EchoBidiStream(metadata: grpc.Metadata, options?: grpc.CallOptions): grpc.ClientDuplexStream<_EchoMessage, _EchoMessage__Output>;
EchoBidiStream(options?: grpc.CallOptions): grpc.ClientDuplexStream<_EchoMessage, _EchoMessage__Output>;
echoBidiStream(metadata: grpc.Metadata, options?: grpc.CallOptions): grpc.ClientDuplexStream<_EchoMessage, _EchoMessage__Output>;
echoBidiStream(options?: grpc.CallOptions): grpc.ClientDuplexStream<_EchoMessage, _EchoMessage__Output>;
EchoClientStream(metadata: grpc.Metadata, options: grpc.CallOptions, callback: grpc.requestCallback<_EchoMessage__Output>): grpc.ClientWritableStream<_EchoMessage>;
EchoClientStream(metadata: grpc.Metadata, callback: grpc.requestCallback<_EchoMessage__Output>): grpc.ClientWritableStream<_EchoMessage>;
EchoClientStream(options: grpc.CallOptions, callback: grpc.requestCallback<_EchoMessage__Output>): grpc.ClientWritableStream<_EchoMessage>;
EchoClientStream(callback: grpc.requestCallback<_EchoMessage__Output>): grpc.ClientWritableStream<_EchoMessage>;
echoClientStream(metadata: grpc.Metadata, options: grpc.CallOptions, callback: grpc.requestCallback<_EchoMessage__Output>): grpc.ClientWritableStream<_EchoMessage>;
echoClientStream(metadata: grpc.Metadata, callback: grpc.requestCallback<_EchoMessage__Output>): grpc.ClientWritableStream<_EchoMessage>;
echoClientStream(options: grpc.CallOptions, callback: grpc.requestCallback<_EchoMessage__Output>): grpc.ClientWritableStream<_EchoMessage>;
echoClientStream(callback: grpc.requestCallback<_EchoMessage__Output>): grpc.ClientWritableStream<_EchoMessage>;
EchoServerStream(argument: _EchoMessage, metadata: grpc.Metadata, options?: grpc.CallOptions): grpc.ClientReadableStream<_EchoMessage__Output>;
EchoServerStream(argument: _EchoMessage, options?: grpc.CallOptions): grpc.ClientReadableStream<_EchoMessage__Output>;
echoServerStream(argument: _EchoMessage, metadata: grpc.Metadata, options?: grpc.CallOptions): grpc.ClientReadableStream<_EchoMessage__Output>;
echoServerStream(argument: _EchoMessage, options?: grpc.CallOptions): grpc.ClientReadableStream<_EchoMessage__Output>;
}
export interface EchoServiceHandlers extends grpc.UntypedServiceImplementation {
Echo: grpc.handleUnaryCall<_EchoMessage__Output, _EchoMessage>;
EchoBidiStream: grpc.handleBidiStreamingCall<_EchoMessage__Output, _EchoMessage>;
EchoClientStream: grpc.handleClientStreamingCall<_EchoMessage__Output, _EchoMessage>;
EchoServerStream: grpc.handleServerStreamingCall<_EchoMessage__Output, _EchoMessage>;
}
export interface EchoServiceDefinition extends grpc.ServiceDefinition {
Echo: MethodDefinition<_EchoMessage, _EchoMessage, _EchoMessage__Output, _EchoMessage__Output>
EchoBidiStream: MethodDefinition<_EchoMessage, _EchoMessage, _EchoMessage__Output, _EchoMessage__Output>
EchoClientStream: MethodDefinition<_EchoMessage, _EchoMessage, _EchoMessage__Output, _EchoMessage__Output>
EchoServerStream: MethodDefinition<_EchoMessage, _EchoMessage, _EchoMessage__Output, _EchoMessage__Output>
}

View File

@ -0,0 +1,15 @@
import type * as grpc from '../../src/index';
import type { MessageTypeDefinition } from '@grpc/proto-loader';
import type { EchoMessage as _EchoMessage, EchoMessage__Output as _EchoMessage__Output } from './EchoMessage';
import type { EchoServiceClient as _EchoServiceClient, EchoServiceDefinition as _EchoServiceDefinition } from './EchoService';
type SubtypeConstructor<Constructor extends new (...args: any) => any, Subtype> = {
new(...args: ConstructorParameters<Constructor>): Subtype;
};
export interface ProtoGrpcType {
EchoMessage: MessageTypeDefinition<_EchoMessage, _EchoMessage__Output>
EchoService: SubtypeConstructor<typeof grpc.Client, _EchoServiceClient> & { service: _EchoServiceDefinition }
}

View File

@ -0,0 +1,168 @@
/*
* Copyright 2025 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
import * as assert from 'assert';
import { PriorityQueue } from '../src/priority-queue';
describe('PriorityQueue', () => {
describe('size', () => {
it('Should be 0 initially', () => {
const queue = new PriorityQueue();
assert.strictEqual(queue.size(), 0);
});
it('Should be 1 after pushing one item', () => {
const queue = new PriorityQueue();
queue.push(1);
assert.strictEqual(queue.size(), 1);
});
it('Should be 0 after pushing and popping one item', () => {
const queue = new PriorityQueue();
queue.push(1);
queue.pop();
assert.strictEqual(queue.size(), 0);
});
});
describe('isEmpty', () => {
it('Should be true initially', () => {
const queue = new PriorityQueue();
assert.strictEqual(queue.isEmpty(), true);
});
it('Should be false after pushing one item', () => {
const queue = new PriorityQueue();
queue.push(1);
assert.strictEqual(queue.isEmpty(), false);
});
it('Should be 0 after pushing and popping one item', () => {
const queue = new PriorityQueue();
queue.push(1);
queue.pop();
assert.strictEqual(queue.isEmpty(), true);
});
});
describe('peek', () => {
it('Should return undefined initially', () => {
const queue = new PriorityQueue();
assert.strictEqual(queue.peek(), undefined);
});
it('Should return the same value multiple times', () => {
const queue = new PriorityQueue();
queue.push(1);
assert.strictEqual(queue.peek(), 1);
assert.strictEqual(queue.peek(), 1);
});
it('Should return the maximum of multiple values', () => {
const queue = new PriorityQueue();
queue.push(1, 3, 8, 5, 6);
assert.strictEqual(queue.peek(), 8);
});
it('Should return undefined after popping the last item', () => {
const queue = new PriorityQueue();
queue.push(1);
queue.pop();
assert.strictEqual(queue.peek(), undefined);
});
});
describe('pop', () => {
it('Should return undefined initially', () => {
const queue = new PriorityQueue();
assert.strictEqual(queue.pop(), undefined);
});
it('Should return a pushed item', () => {
const queue = new PriorityQueue();
queue.push(1);
assert.strictEqual(queue.pop(), 1);
});
it('Should return pushed items in decreasing order', () => {
const queue = new PriorityQueue();
queue.push(1, 3, 8, 5, 6);
assert.strictEqual(queue.pop(), 8);
assert.strictEqual(queue.pop(), 6);
assert.strictEqual(queue.pop(), 5);
assert.strictEqual(queue.pop(), 3);
assert.strictEqual(queue.pop(), 1);
});
it('Should return undefined after popping the last item', () => {
const queue = new PriorityQueue();
queue.push(1);
queue.pop();
assert.strictEqual(queue.pop(), undefined);
});
});
describe('replace', () => {
it('should return undefined initially', () => {
const queue = new PriorityQueue();
assert.strictEqual(queue.replace(1), undefined);
});
it('Should return a pushed item', () => {
const queue = new PriorityQueue();
queue.push(1);
assert.strictEqual(queue.replace(2), 1);
});
it('Should replace the max value if providing the new max', () => {
const queue = new PriorityQueue();
queue.push(1, 3, 8, 5, 6);
assert.strictEqual(queue.replace(10), 8);
assert.strictEqual(queue.peek(), 10);
});
it('Should not replace the max value if providing a lower value', () => {
const queue = new PriorityQueue();
queue.push(1, 3, 8, 5, 6);
assert.strictEqual(queue.replace(4), 8);
assert.strictEqual(queue.peek(), 6);
});
});
describe('push', () => {
it('Should would the same with one call or multiple', () => {
const queue1 = new PriorityQueue();
queue1.push(1, 3, 8, 5, 6);
assert.strictEqual(queue1.pop(), 8);
assert.strictEqual(queue1.pop(), 6);
assert.strictEqual(queue1.pop(), 5);
assert.strictEqual(queue1.pop(), 3);
assert.strictEqual(queue1.pop(), 1);
const queue2 = new PriorityQueue();
queue2.push(1);
queue2.push(3);
queue2.push(8);
queue2.push(5);
queue2.push(6);
assert.strictEqual(queue2.pop(), 8);
assert.strictEqual(queue2.pop(), 6);
assert.strictEqual(queue2.pop(), 5);
assert.strictEqual(queue2.pop(), 3);
assert.strictEqual(queue2.pop(), 1);
});
});
describe('custom comparator', () => {
it('Should produce items in the reverse order with a reversed comparator', () => {
const queue = new PriorityQueue((a, b) => a < b);
queue.push(1, 3, 8, 5, 6);
assert.strictEqual(queue.pop(), 1);
assert.strictEqual(queue.pop(), 3);
assert.strictEqual(queue.pop(), 5);
assert.strictEqual(queue.pop(), 6);
assert.strictEqual(queue.pop(), 8);
});
it('Should support other types', () => {
const queue = new PriorityQueue<string>((a, b) => a.localeCompare(b) > 0);
queue.push('a', 'c', 'b');
assert.strictEqual(queue.pop(), 'c');
assert.strictEqual(queue.pop(), 'b');
assert.strictEqual(queue.pop(), 'a');
});
});
});

View File

@ -0,0 +1,436 @@
/*
* Copyright 2025 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
import * as assert from 'assert';
import * as path from 'path';
import * as grpc from '../src';
import { loadProtoFile } from './common';
import { EchoServiceClient } from './generated/EchoService';
import { ProtoGrpcType } from './generated/echo_service'
import { WeightedRoundRobinLoadBalancingConfig } from '../src/load-balancer-weighted-round-robin';
const protoFile = path.join(__dirname, 'fixtures', 'echo_service.proto');
const EchoService = (loadProtoFile(protoFile) as unknown as ProtoGrpcType).EchoService;
function makeNCalls(client: EchoServiceClient, count: number): Promise<{[serverId: string]: number}> {
return new Promise((resolve, reject) => {
const result: {[serverId: string]: number} = {};
function makeOneCall(callsLeft: number) {
if (callsLeft <= 0) {
resolve(result);
} else {
const deadline = new Date();
deadline.setMilliseconds(deadline.getMilliseconds() + 100);
const call= client.echo({}, {deadline}, (error, value) => {
if (error) {
reject(error);
return;
}
makeOneCall(callsLeft - 1);
});
call.on('metadata', metadata => {
const serverEntry = metadata.get('server');
if (serverEntry.length > 0) {
const serverId = serverEntry[0] as string;
if (!(serverId in result)) {
result[serverId] = 0;
}
result[serverId] += 1;
}
});
}
}
makeOneCall(count);
});
}
function createServiceConfig(wrrConfig: object): grpc.ServiceConfig {
return {
methodConfig: [],
loadBalancingConfig: [
{'weighted_round_robin': wrrConfig}
]
};
}
function createClient(ports: number[], serviceConfig: grpc.ServiceConfig) {
return new EchoService(`ipv4:${ports.map(port => `127.0.0.1:${port}`).join(',')}`, grpc.credentials.createInsecure(), {'grpc.service_config': JSON.stringify(serviceConfig)});
}
function asyncTimeout(delay: number): Promise<void> {
return new Promise(resolve => {
setTimeout(resolve, delay);
});
}
describe('Weighted round robin LB policy', () => {
describe('Config parsing', () => {
it('Should have default values with an empty object', () => {
const config = WeightedRoundRobinLoadBalancingConfig.createFromJson({});
assert.strictEqual(config.getEnableOobLoadReport(), false);
assert.strictEqual(config.getBlackoutPeriodMs(), 10_000);
assert.strictEqual(config.getErrorUtilizationPenalty(), 1);
assert.strictEqual(config.getOobLoadReportingPeriodMs(), 10_000);
assert.strictEqual(config.getWeightExpirationPeriodMs(), 180_000);
assert.strictEqual(config.getWeightUpdatePeriodMs(), 1_000);
});
it('Should handle enable_oob_load_report', () => {
const config = WeightedRoundRobinLoadBalancingConfig.createFromJson({
enable_oob_load_report: true
});
assert.strictEqual(config.getEnableOobLoadReport(), true);
});
it('Should handle error_utilization_penalty', () => {
const config = WeightedRoundRobinLoadBalancingConfig.createFromJson({
error_utilization_penalty: 0.5
});
assert.strictEqual(config.getErrorUtilizationPenalty(), 0.5);
});
it('Should reject negative error_utilization_penalty', () => {
const loadBalancingConfig = {
error_utilization_penalty: -1
};
assert.throws(() => {
WeightedRoundRobinLoadBalancingConfig.createFromJson(loadBalancingConfig);
}, /error_utilization_penalty < 0/);
});
it('Should handle blackout_period as a string', () => {
const config = WeightedRoundRobinLoadBalancingConfig.createFromJson({
blackout_period: '1s'
});
assert.strictEqual(config.getBlackoutPeriodMs(), 1_000);
});
it('Should handle blackout_period as an object', () => {
const config = WeightedRoundRobinLoadBalancingConfig.createFromJson({
blackout_period: {
seconds: 1,
nanos: 0
}
});
assert.strictEqual(config.getBlackoutPeriodMs(), 1_000);
});
it('Should handle oob_load_reporting_period as a string', () => {
const config = WeightedRoundRobinLoadBalancingConfig.createFromJson({
oob_load_reporting_period: '1s'
});
assert.strictEqual(config.getOobLoadReportingPeriodMs(), 1_000);
});
it('Should handle oob_load_reporting_period as an object', () => {
const config = WeightedRoundRobinLoadBalancingConfig.createFromJson({
oob_load_reporting_period: {
seconds: 1,
nanos: 0
}
});
assert.strictEqual(config.getOobLoadReportingPeriodMs(), 1_000);
});
it('Should handle weight_expiration_period as a string', () => {
const config = WeightedRoundRobinLoadBalancingConfig.createFromJson({
weight_expiration_period: '1s'
});
assert.strictEqual(config.getWeightExpirationPeriodMs(), 1_000);
});
it('Should handle weight_expiration_period as an object', () => {
const config = WeightedRoundRobinLoadBalancingConfig.createFromJson({
weight_expiration_period: {
seconds: 1,
nanos: 0
}
});
assert.strictEqual(config.getWeightExpirationPeriodMs(), 1_000);
});
it('Should handle weight_update_period as a string', () => {
const config = WeightedRoundRobinLoadBalancingConfig.createFromJson({
weight_update_period: '2s'
});
assert.strictEqual(config.getWeightUpdatePeriodMs(), 2_000);
});
it('Should handle weight_update_period as an object', () => {
const config = WeightedRoundRobinLoadBalancingConfig.createFromJson({
weight_update_period: {
seconds: 2,
nanos: 0
}
});
assert.strictEqual(config.getWeightUpdatePeriodMs(), 2_000);
});
it('Should cap weight_update_period to a minimum of 0.1s', () => {
const config = WeightedRoundRobinLoadBalancingConfig.createFromJson({
weight_update_period: '0.01s'
});
assert.strictEqual(config.getWeightUpdatePeriodMs(), 100);
});
});
describe('Per-call metrics', () => {
const server1Metrics = {
qps: 0,
utilization: 0,
eps: 0
};
const server2Metrics = {
qps: 0,
utilization: 0,
eps: 0
};
const server1 = new grpc.Server({'grpc.server_call_metric_recording': 1});
const server2 = new grpc.Server({'grpc.server_call_metric_recording': 1});
const server1Impl = {
echo: (
call: grpc.ServerUnaryCall<any, any>,
callback: grpc.sendUnaryData<any>
) => {
const metricsRecorder = call.getMetricsRecorder();
metricsRecorder.recordQpsMetric(server1Metrics.qps);
metricsRecorder.recordApplicationUtilizationMetric(server1Metrics.utilization);
metricsRecorder.recordEpsMetric(server1Metrics.eps);
const metadata = new grpc.Metadata();
metadata.set('server', '1');
call.sendMetadata(metadata);
callback(null, call.request);
},
};
const server2Impl = {
echo: (
call: grpc.ServerUnaryCall<any, any>,
callback: grpc.sendUnaryData<any>
) => {
const metricsRecorder = call.getMetricsRecorder();
metricsRecorder.recordQpsMetric(server2Metrics.qps);
metricsRecorder.recordApplicationUtilizationMetric(server2Metrics.utilization);
metricsRecorder.recordEpsMetric(server2Metrics.eps);
const metadata = new grpc.Metadata();
metadata.set('server', '2');
call.sendMetadata(metadata);
callback(null, call.request);
},
};
let port1: number;
let port2: number;
let client: EchoServiceClient | null = null;
before(done => {
const creds = grpc.ServerCredentials.createInsecure();
server1.addService(EchoService.service, server1Impl);
server2.addService(EchoService.service, server2Impl);
server1.bindAsync('localhost:0', creds, (error1, server1Port) => {
if (error1) {
done(error1);
return;
}
port1 = server1Port;
server2.bindAsync('localhost:0', creds, (error2, server2Port) => {
if (error2) {
done(error2);
return;
}
port2 = server2Port;
done();
});
});
});
beforeEach(() => {
server1Metrics.qps = 0;
server1Metrics.utilization = 0;
server1Metrics.eps = 0;
server2Metrics.qps = 0;
server2Metrics.utilization = 0;
server2Metrics.eps = 0;
});
afterEach(() => {
client?.close();
client = null;
});
after(() => {
server1.forceShutdown();
server2.forceShutdown();
});
it('Should evenly balance among endpoints with no weight', async () => {
const serviceConfig = createServiceConfig({});
client = createClient([port1, port2], serviceConfig);
await makeNCalls(client, 10);
const result = await makeNCalls(client, 30);
assert(Math.abs(result['1'] - result['2']) < 3, `server1: ${result['1']}, server2: ${result[2]}`);
});
it('Should send more requests to endpoints with higher QPS', async () => {
const serviceConfig = createServiceConfig({
blackout_period: '0.01s',
weight_update_period: '0.1s'
});
client = createClient([port1, port2], serviceConfig);
server1Metrics.qps = 3;
server1Metrics.utilization = 1;
server2Metrics.qps = 1;
server2Metrics.utilization = 1;
await makeNCalls(client, 10);
await asyncTimeout(100);
const result = await makeNCalls(client, 40);
assert(Math.abs(result['1'] - 30) < 2, `server1: ${result['1']}, server2: ${result['2']}`);
});
// Calls aren't fast enough for this to work consistently
it.skip('Should wait for the blackout period to apply weights', async () => {
const serviceConfig = createServiceConfig({
blackout_period: '0.5s'
});
client = createClient([port1, port2], serviceConfig);
server1Metrics.qps = 3;
server1Metrics.utilization = 1;
server2Metrics.qps = 1;
server2Metrics.utilization = 1;
await makeNCalls(client, 10);
await asyncTimeout(100);
const result1 = await makeNCalls(client, 20);
assert(Math.abs(result1['1'] - result1['2']) < 3, `result1: server1: ${result1['1']}, server2: ${result1[2]}`);
await asyncTimeout(400);
const result2 = await makeNCalls(client, 40);
assert(Math.abs(result2['1'] - 30) < 2, `result2: server1: ${result2['1']}, server2: ${result2['2']}`);
}).timeout(3000);
// Calls aren't fast enough for this to work consistently
it.skip('Should wait for the weight update period to apply weights', async () => {
const serviceConfig = createServiceConfig({
blackout_period: '0.01s',
weight_update_period: '1s'
});
client = createClient([port1, port2], serviceConfig);
server1Metrics.qps = 3;
server1Metrics.utilization = 1;
server2Metrics.qps = 1;
server2Metrics.utilization = 1;
await makeNCalls(client, 10);
await asyncTimeout(100);
const result1 = await makeNCalls(client, 20);
assert(Math.abs(result1['1'] - result1['2']) < 3, `result1: server1: ${result1['1']}, server2: ${result1[2]}`);
await asyncTimeout(400);
const result2 = await makeNCalls(client, 40);
assert(Math.abs(result2['1'] - 30) < 2, `result2: server1: ${result2['1']}, server2: ${result2['2']}`);
}).timeout(3000);
it('Should send more requests to endpoints with lower EPS', async () => {
const serviceConfig = createServiceConfig({
blackout_period: '0.01s',
weight_update_period: '0.1s',
error_utilization_penalty: 1
});
client = createClient([port1, port2], serviceConfig);
server1Metrics.qps = 2;
server1Metrics.utilization = 1;
server1Metrics.eps = 0;
server2Metrics.qps = 2;
server2Metrics.utilization = 1;
server2Metrics.eps = 2;
await makeNCalls(client, 10);
await asyncTimeout(100);
const result = await makeNCalls(client, 30);
assert(Math.abs(result['1'] - 20) < 3, `server1: ${result['1']}, server2: ${result['2']}`);
});
});
describe('Out of band metrics', () => {
const server1MetricRecorder = new grpc.ServerMetricRecorder();
const server2MetricRecorder = new grpc.ServerMetricRecorder();
const server1 = new grpc.Server();
const server2 = new grpc.Server();
const server1Impl = {
echo: (
call: grpc.ServerUnaryCall<any, any>,
callback: grpc.sendUnaryData<any>
) => {
const metadata = new grpc.Metadata();
metadata.set('server', '1');
call.sendMetadata(metadata);
callback(null, call.request);
},
};
const server2Impl = {
echo: (
call: grpc.ServerUnaryCall<any, any>,
callback: grpc.sendUnaryData<any>
) => {
const metadata = new grpc.Metadata();
metadata.set('server', '2');
call.sendMetadata(metadata);
callback(null, call.request);
},
};
let port1: number;
let port2: number;
let client: EchoServiceClient | null = null;
before(done => {
const creds = grpc.ServerCredentials.createInsecure();
server1.addService(EchoService.service, server1Impl);
server2.addService(EchoService.service, server2Impl);
server1MetricRecorder.addToServer(server1);
server2MetricRecorder.addToServer(server2);
server1.bindAsync('localhost:0', creds, (error1, server1Port) => {
if (error1) {
done(error1);
return;
}
port1 = server1Port;
server2.bindAsync('localhost:0', creds, (error2, server2Port) => {
if (error2) {
done(error2);
return;
}
port2 = server2Port;
done();
});
});
});
beforeEach(() => {
server1MetricRecorder.deleteQpsMetric();
server1MetricRecorder.deleteEpsMetric();
server1MetricRecorder.deleteApplicationUtilizationMetric();
server2MetricRecorder.deleteQpsMetric();
server2MetricRecorder.deleteEpsMetric();
server2MetricRecorder.deleteApplicationUtilizationMetric();
});
afterEach(() => {
client?.close();
client = null;
});
after(() => {
server1.forceShutdown();
server2.forceShutdown();
});
it('Should evenly balance among endpoints with no weight', async () => {
const serviceConfig = createServiceConfig({
enable_oob_load_report: true,
oob_load_reporting_period: '0.01s',
blackout_period: '0.01s'
});
client = createClient([port1, port2], serviceConfig);
await makeNCalls(client, 10);
const result = await makeNCalls(client, 30);
assert(Math.abs(result['1'] - result['2']) < 3, `server1: ${result['1']}, server2: ${result[2]}`);
});
it('Should send more requests to endpoints with higher QPS', async () => {
const serviceConfig = createServiceConfig({
enable_oob_load_report: true,
oob_load_reporting_period: '0.01s',
blackout_period: '0.01s',
weight_update_period: '0.1s'
});
client = createClient([port1, port2], serviceConfig);
server1MetricRecorder.setQpsMetric(3);
server1MetricRecorder.setApplicationUtilizationMetric(1);
server2MetricRecorder.setQpsMetric(1);
server2MetricRecorder.setApplicationUtilizationMetric(1);
await makeNCalls(client, 10);
await asyncTimeout(100);
const result = await makeNCalls(client, 40);
assert(Math.abs(result['1'] - 30) < 2, `server1: ${result['1']}, server2: ${result['2']}`);
});
});
});