grpc-js: Add getAuthContext call method

This commit is contained in:
Michael Lumish 2025-03-07 11:10:24 -08:00
parent 6c7abfe4a8
commit 78f194be6e
12 changed files with 147 additions and 0 deletions

View File

@ -0,0 +1,23 @@
/*
* Copyright 2025 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
import { PeerCertificate } from "tls";
export interface AuthContext {
transportSecurityType?: string;
sslPeerCertificate?: PeerCertificate;
}

View File

@ -15,6 +15,7 @@
*
*/
import { AuthContext } from './auth-context';
import { CallCredentials } from './call-credentials';
import { Status } from './constants';
import { Deadline } from './deadline';
@ -170,6 +171,7 @@ export interface Call {
halfClose(): void;
getCallNumber(): number;
setCredentials(credentials: CallCredentials): void;
getAuthContext(): AuthContext | null;
}
export interface DeadlineInfoProvider {

View File

@ -24,6 +24,7 @@ import { EmitterAugmentation1 } from './events';
import { Metadata } from './metadata';
import { ObjectReadable, ObjectWritable, WriteCallback } from './object-stream';
import { InterceptingCallInterface } from './client-interceptors';
import { AuthContext } from './auth-context';
/**
* A type extending the built-in Error object with additional fields.
@ -37,6 +38,7 @@ export type SurfaceCall = {
call?: InterceptingCallInterface;
cancel(): void;
getPeer(): string;
getAuthContext(): AuthContext | null;
} & EmitterAugmentation1<'metadata', Metadata> &
EmitterAugmentation1<'status', StatusObject> &
EventEmitter;
@ -100,6 +102,10 @@ export class ClientUnaryCallImpl
getPeer(): string {
return this.call?.getPeer() ?? 'unknown';
}
getAuthContext(): AuthContext | null {
return this.call?.getAuthContext() ?? null;
}
}
export class ClientReadableStreamImpl<ResponseType>
@ -119,6 +125,10 @@ export class ClientReadableStreamImpl<ResponseType>
return this.call?.getPeer() ?? 'unknown';
}
getAuthContext(): AuthContext | null {
return this.call?.getAuthContext() ?? null;
}
_read(_size: number): void {
this.call?.startRead();
}
@ -141,6 +151,10 @@ export class ClientWritableStreamImpl<RequestType>
return this.call?.getPeer() ?? 'unknown';
}
getAuthContext(): AuthContext | null {
return this.call?.getAuthContext() ?? null;
}
_write(chunk: RequestType, encoding: string, cb: WriteCallback) {
const context: MessageContext = {
callback: cb,
@ -178,6 +192,10 @@ export class ClientDuplexStreamImpl<RequestType, ResponseType>
return this.call?.getPeer() ?? 'unknown';
}
getAuthContext(): AuthContext | null {
return this.call?.getAuthContext() ?? null;
}
_read(_size: number): void {
this.call?.startRead();
}

View File

@ -34,6 +34,7 @@ import { Channel } from './channel';
import { CallOptions } from './client';
import { ClientMethodDefinition } from './make-client';
import { getErrorMessage } from './error';
import { AuthContext } from './auth-context';
/**
* Error class associated with passing both interceptors and interceptor
@ -198,6 +199,7 @@ export interface InterceptingCallInterface {
sendMessage(message: any): void;
startRead(): void;
halfClose(): void;
getAuthContext(): AuthContext | null;
}
export class InterceptingCall implements InterceptingCallInterface {
@ -338,6 +340,9 @@ export class InterceptingCall implements InterceptingCallInterface {
}
});
}
getAuthContext(): AuthContext | null {
return this.nextCall.getAuthContext();
}
}
function getCall(channel: Channel, path: string, options: CallOptions): Call {
@ -427,6 +432,9 @@ class BaseInterceptingCall implements InterceptingCallInterface {
halfClose(): void {
this.call.halfClose();
}
getAuthContext(): AuthContext | null {
return this.call.getAuthContext();
}
}
/**

View File

@ -35,6 +35,7 @@ import { splitHostPort } from './uri-parser';
import * as logging from './logging';
import { restrictControlPlaneStatusCode } from './control-plane-status';
import * as http2 from 'http2';
import { AuthContext } from './auth-context';
const TRACER_NAME = 'load_balancing_call';
@ -375,4 +376,12 @@ export class LoadBalancingCall implements Call, DeadlineInfoProvider {
getCallNumber(): number {
return this.callNumber;
}
getAuthContext(): AuthContext | null {
if (this.child) {
return this.child.getAuthContext();
} else {
return null;
}
}
}

View File

@ -37,6 +37,7 @@ import { InternalChannel } from './internal-channel';
import { Metadata } from './metadata';
import * as logging from './logging';
import { restrictControlPlaneStatusCode } from './control-plane-status';
import { AuthContext } from './auth-context';
const TRACER_NAME = 'resolving_call';
@ -367,4 +368,12 @@ export class ResolvingCall implements Call {
getCallNumber(): number {
return this.callNumber;
}
getAuthContext(): AuthContext | null {
if (this.child) {
return this.child.getAuthContext();
} else {
return null;
}
}
}

View File

@ -35,6 +35,7 @@ import {
StatusObjectWithProgress,
} from './load-balancing-call';
import { InternalChannel } from './internal-channel';
import { AuthContext } from './auth-context';
const TRACER_NAME = 'retrying_call';
@ -859,4 +860,11 @@ export class RetryingCall implements Call, DeadlineInfoProvider {
getHost(): string {
return this.host;
}
getAuthContext(): AuthContext | null {
if (this.committedCallIndex !== null) {
return this.underlyingCalls[this.committedCallIndex].call.getAuthContext();
} else {
return null;
}
}
}

View File

@ -25,6 +25,7 @@ import type { ObjectReadable, ObjectWritable } from './object-stream';
import type { StatusObject, PartialStatusObject } from './call-interface';
import type { Deadline } from './deadline';
import type { ServerInterceptingCallInterface } from './server-interceptors';
import { AuthContext } from './auth-context';
export type ServerStatusResponse = Partial<StatusObject>;
@ -38,6 +39,7 @@ export type ServerSurfaceCall = {
getDeadline(): Deadline;
getPath(): string;
getHost(): string;
getAuthContext(): AuthContext;
} & EventEmitter;
export type ServerUnaryCall<RequestType, ResponseType> = ServerSurfaceCall & {
@ -114,6 +116,10 @@ export class ServerUnaryCallImpl<RequestType, ResponseType>
getHost(): string {
return this.call.getHost();
}
getAuthContext(): AuthContext {
return this.call.getAuthContext();
}
}
export class ServerReadableStreamImpl<RequestType, ResponseType>
@ -154,6 +160,10 @@ export class ServerReadableStreamImpl<RequestType, ResponseType>
getHost(): string {
return this.call.getHost();
}
getAuthContext(): AuthContext {
return this.call.getAuthContext();
}
}
export class ServerWritableStreamImpl<RequestType, ResponseType>
@ -203,6 +213,10 @@ export class ServerWritableStreamImpl<RequestType, ResponseType>
return this.call.getHost();
}
getAuthContext(): AuthContext {
return this.call.getAuthContext();
}
_write(
chunk: ResponseType,
encoding: string,
@ -276,6 +290,10 @@ export class ServerDuplexStreamImpl<RequestType, ResponseType>
return this.call.getHost();
}
getAuthContext(): AuthContext {
return this.call.getAuthContext();
}
_read(size: number) {
this.call.startRead();
}

View File

@ -33,6 +33,8 @@ import * as zlib from 'zlib';
import { StreamDecoder } from './stream-decoder';
import { CallEventTracker } from './transport';
import * as logging from './logging';
import { AuthContext } from './auth-context';
import { TLSSocket } from 'tls';
const TRACER_NAME = 'server_call';
@ -332,6 +334,10 @@ export interface ServerInterceptingCallInterface {
* Return the host requested by the client in the ":authority" header.
*/
getHost(): string;
/**
* Return the auth context of the connection the call is associated with.
*/
getAuthContext(): AuthContext;
}
export class ServerInterceptingCall implements ServerInterceptingCallInterface {
@ -440,6 +446,9 @@ export class ServerInterceptingCall implements ServerInterceptingCallInterface {
getHost(): string {
return this.nextCall.getHost();
}
getAuthContext(): AuthContext {
return this.nextCall.getAuthContext();
}
}
export interface ServerInterceptor {
@ -971,6 +980,16 @@ export class BaseServerInterceptingCall
getHost(): string {
return this.host;
}
getAuthContext(): AuthContext {
if (this.stream.session?.socket instanceof TLSSocket) {
return {
transportSecurityType: 'ssl',
sslPeerCertificate: this.stream.session.socket.getPeerCertificate()
}
} else {
return {};
}
}
}
export function getServerInterceptingCall(

View File

@ -30,6 +30,7 @@ import {
WriteCallback,
} from './call-interface';
import { CallEventTracker, Transport } from './transport';
import { AuthContext } from './auth-context';
const TRACER_NAME = 'subchannel_call';
@ -71,6 +72,7 @@ export interface SubchannelCall {
halfClose(): void;
getCallNumber(): number;
getDeadlineInfo(): string[];
getAuthContext(): AuthContext;
}
export interface StatusObjectWithRstCode extends StatusObject {
@ -556,6 +558,10 @@ export class Http2SubchannelCall implements SubchannelCall {
return this.callId;
}
getAuthContext(): AuthContext {
return this.transport.getAuthContext();
}
startRead() {
/* If the stream has ended with an error, we should not emit any more
* messages and we should communicate that the stream has ended */

View File

@ -51,6 +51,7 @@ import {
import { Metadata } from './metadata';
import { getNextCallNumber } from './call-number';
import { Socket } from 'net';
import { AuthContext } from './auth-context';
const TRACER_NAME = 'transport';
const FLOW_CONTROL_TRACER_NAME = 'transport_flowctrl';
@ -83,6 +84,7 @@ export interface Transport {
getChannelzRef(): SocketRef;
getPeerName(): string;
getOptions(): ChannelOptions;
getAuthContext(): AuthContext;
createCall(
metadata: Metadata,
host: string,
@ -129,6 +131,8 @@ class Http2Transport implements Transport {
private disconnectHandled = false;
private authContext: AuthContext;
// Channelz info
private channelzRef: SocketRef;
private readonly channelzEnabled: boolean = true;
@ -254,6 +258,15 @@ class Http2Transport implements Transport {
if (this.keepaliveWithoutCalls) {
this.maybeStartKeepalivePingTimer();
}
if (session.socket instanceof TLSSocket) {
this.authContext = {
transportSecurityType: 'ssl',
sslPeerCertificate: session.socket.getPeerCertificate()
};
} else {
this.authContext = {};
}
}
private getChannelzInfo(): SocketInfo {
@ -622,6 +635,10 @@ class Http2Transport implements Transport {
return this.options;
}
getAuthContext(): AuthContext {
return this.authContext;
}
shutdown() {
this.session.close();
unregisterChannelzRef(this.channelzRef);

View File

@ -218,6 +218,16 @@ describe('ChannelCredentials usage', () => {
}
);
});
it('Should provide certificates in getAuthContext', done => {
const call = client.echo({ value: 'test value', value2: 3 }, (error: ServiceError, response: any) => {
assert.ifError(error);
const authContext = call.getAuthContext();
assert(authContext);
assert.strictEqual(authContext.transportSecurityType, 'ssl');
assert(authContext.sslPeerCertificate);
done();
});
})
});
describe('Channel credentials mtls', () => {