grpc-js: Add getAuthContext call method

This commit is contained in:
Michael Lumish 2025-03-07 11:10:24 -08:00
parent 6c7abfe4a8
commit 78f194be6e
12 changed files with 147 additions and 0 deletions

View File

@ -0,0 +1,23 @@
/*
* Copyright 2025 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
import { PeerCertificate } from "tls";
export interface AuthContext {
transportSecurityType?: string;
sslPeerCertificate?: PeerCertificate;
}

View File

@ -15,6 +15,7 @@
* *
*/ */
import { AuthContext } from './auth-context';
import { CallCredentials } from './call-credentials'; import { CallCredentials } from './call-credentials';
import { Status } from './constants'; import { Status } from './constants';
import { Deadline } from './deadline'; import { Deadline } from './deadline';
@ -170,6 +171,7 @@ export interface Call {
halfClose(): void; halfClose(): void;
getCallNumber(): number; getCallNumber(): number;
setCredentials(credentials: CallCredentials): void; setCredentials(credentials: CallCredentials): void;
getAuthContext(): AuthContext | null;
} }
export interface DeadlineInfoProvider { export interface DeadlineInfoProvider {

View File

@ -24,6 +24,7 @@ import { EmitterAugmentation1 } from './events';
import { Metadata } from './metadata'; import { Metadata } from './metadata';
import { ObjectReadable, ObjectWritable, WriteCallback } from './object-stream'; import { ObjectReadable, ObjectWritable, WriteCallback } from './object-stream';
import { InterceptingCallInterface } from './client-interceptors'; import { InterceptingCallInterface } from './client-interceptors';
import { AuthContext } from './auth-context';
/** /**
* A type extending the built-in Error object with additional fields. * A type extending the built-in Error object with additional fields.
@ -37,6 +38,7 @@ export type SurfaceCall = {
call?: InterceptingCallInterface; call?: InterceptingCallInterface;
cancel(): void; cancel(): void;
getPeer(): string; getPeer(): string;
getAuthContext(): AuthContext | null;
} & EmitterAugmentation1<'metadata', Metadata> & } & EmitterAugmentation1<'metadata', Metadata> &
EmitterAugmentation1<'status', StatusObject> & EmitterAugmentation1<'status', StatusObject> &
EventEmitter; EventEmitter;
@ -100,6 +102,10 @@ export class ClientUnaryCallImpl
getPeer(): string { getPeer(): string {
return this.call?.getPeer() ?? 'unknown'; return this.call?.getPeer() ?? 'unknown';
} }
getAuthContext(): AuthContext | null {
return this.call?.getAuthContext() ?? null;
}
} }
export class ClientReadableStreamImpl<ResponseType> export class ClientReadableStreamImpl<ResponseType>
@ -119,6 +125,10 @@ export class ClientReadableStreamImpl<ResponseType>
return this.call?.getPeer() ?? 'unknown'; return this.call?.getPeer() ?? 'unknown';
} }
getAuthContext(): AuthContext | null {
return this.call?.getAuthContext() ?? null;
}
_read(_size: number): void { _read(_size: number): void {
this.call?.startRead(); this.call?.startRead();
} }
@ -141,6 +151,10 @@ export class ClientWritableStreamImpl<RequestType>
return this.call?.getPeer() ?? 'unknown'; return this.call?.getPeer() ?? 'unknown';
} }
getAuthContext(): AuthContext | null {
return this.call?.getAuthContext() ?? null;
}
_write(chunk: RequestType, encoding: string, cb: WriteCallback) { _write(chunk: RequestType, encoding: string, cb: WriteCallback) {
const context: MessageContext = { const context: MessageContext = {
callback: cb, callback: cb,
@ -178,6 +192,10 @@ export class ClientDuplexStreamImpl<RequestType, ResponseType>
return this.call?.getPeer() ?? 'unknown'; return this.call?.getPeer() ?? 'unknown';
} }
getAuthContext(): AuthContext | null {
return this.call?.getAuthContext() ?? null;
}
_read(_size: number): void { _read(_size: number): void {
this.call?.startRead(); this.call?.startRead();
} }

View File

@ -34,6 +34,7 @@ import { Channel } from './channel';
import { CallOptions } from './client'; import { CallOptions } from './client';
import { ClientMethodDefinition } from './make-client'; import { ClientMethodDefinition } from './make-client';
import { getErrorMessage } from './error'; import { getErrorMessage } from './error';
import { AuthContext } from './auth-context';
/** /**
* Error class associated with passing both interceptors and interceptor * Error class associated with passing both interceptors and interceptor
@ -198,6 +199,7 @@ export interface InterceptingCallInterface {
sendMessage(message: any): void; sendMessage(message: any): void;
startRead(): void; startRead(): void;
halfClose(): void; halfClose(): void;
getAuthContext(): AuthContext | null;
} }
export class InterceptingCall implements InterceptingCallInterface { export class InterceptingCall implements InterceptingCallInterface {
@ -338,6 +340,9 @@ export class InterceptingCall implements InterceptingCallInterface {
} }
}); });
} }
getAuthContext(): AuthContext | null {
return this.nextCall.getAuthContext();
}
} }
function getCall(channel: Channel, path: string, options: CallOptions): Call { function getCall(channel: Channel, path: string, options: CallOptions): Call {
@ -427,6 +432,9 @@ class BaseInterceptingCall implements InterceptingCallInterface {
halfClose(): void { halfClose(): void {
this.call.halfClose(); this.call.halfClose();
} }
getAuthContext(): AuthContext | null {
return this.call.getAuthContext();
}
} }
/** /**

View File

@ -35,6 +35,7 @@ import { splitHostPort } from './uri-parser';
import * as logging from './logging'; import * as logging from './logging';
import { restrictControlPlaneStatusCode } from './control-plane-status'; import { restrictControlPlaneStatusCode } from './control-plane-status';
import * as http2 from 'http2'; import * as http2 from 'http2';
import { AuthContext } from './auth-context';
const TRACER_NAME = 'load_balancing_call'; const TRACER_NAME = 'load_balancing_call';
@ -375,4 +376,12 @@ export class LoadBalancingCall implements Call, DeadlineInfoProvider {
getCallNumber(): number { getCallNumber(): number {
return this.callNumber; return this.callNumber;
} }
getAuthContext(): AuthContext | null {
if (this.child) {
return this.child.getAuthContext();
} else {
return null;
}
}
} }

View File

@ -37,6 +37,7 @@ import { InternalChannel } from './internal-channel';
import { Metadata } from './metadata'; import { Metadata } from './metadata';
import * as logging from './logging'; import * as logging from './logging';
import { restrictControlPlaneStatusCode } from './control-plane-status'; import { restrictControlPlaneStatusCode } from './control-plane-status';
import { AuthContext } from './auth-context';
const TRACER_NAME = 'resolving_call'; const TRACER_NAME = 'resolving_call';
@ -367,4 +368,12 @@ export class ResolvingCall implements Call {
getCallNumber(): number { getCallNumber(): number {
return this.callNumber; return this.callNumber;
} }
getAuthContext(): AuthContext | null {
if (this.child) {
return this.child.getAuthContext();
} else {
return null;
}
}
} }

View File

@ -35,6 +35,7 @@ import {
StatusObjectWithProgress, StatusObjectWithProgress,
} from './load-balancing-call'; } from './load-balancing-call';
import { InternalChannel } from './internal-channel'; import { InternalChannel } from './internal-channel';
import { AuthContext } from './auth-context';
const TRACER_NAME = 'retrying_call'; const TRACER_NAME = 'retrying_call';
@ -859,4 +860,11 @@ export class RetryingCall implements Call, DeadlineInfoProvider {
getHost(): string { getHost(): string {
return this.host; return this.host;
} }
getAuthContext(): AuthContext | null {
if (this.committedCallIndex !== null) {
return this.underlyingCalls[this.committedCallIndex].call.getAuthContext();
} else {
return null;
}
}
} }

View File

@ -25,6 +25,7 @@ import type { ObjectReadable, ObjectWritable } from './object-stream';
import type { StatusObject, PartialStatusObject } from './call-interface'; import type { StatusObject, PartialStatusObject } from './call-interface';
import type { Deadline } from './deadline'; import type { Deadline } from './deadline';
import type { ServerInterceptingCallInterface } from './server-interceptors'; import type { ServerInterceptingCallInterface } from './server-interceptors';
import { AuthContext } from './auth-context';
export type ServerStatusResponse = Partial<StatusObject>; export type ServerStatusResponse = Partial<StatusObject>;
@ -38,6 +39,7 @@ export type ServerSurfaceCall = {
getDeadline(): Deadline; getDeadline(): Deadline;
getPath(): string; getPath(): string;
getHost(): string; getHost(): string;
getAuthContext(): AuthContext;
} & EventEmitter; } & EventEmitter;
export type ServerUnaryCall<RequestType, ResponseType> = ServerSurfaceCall & { export type ServerUnaryCall<RequestType, ResponseType> = ServerSurfaceCall & {
@ -114,6 +116,10 @@ export class ServerUnaryCallImpl<RequestType, ResponseType>
getHost(): string { getHost(): string {
return this.call.getHost(); return this.call.getHost();
} }
getAuthContext(): AuthContext {
return this.call.getAuthContext();
}
} }
export class ServerReadableStreamImpl<RequestType, ResponseType> export class ServerReadableStreamImpl<RequestType, ResponseType>
@ -154,6 +160,10 @@ export class ServerReadableStreamImpl<RequestType, ResponseType>
getHost(): string { getHost(): string {
return this.call.getHost(); return this.call.getHost();
} }
getAuthContext(): AuthContext {
return this.call.getAuthContext();
}
} }
export class ServerWritableStreamImpl<RequestType, ResponseType> export class ServerWritableStreamImpl<RequestType, ResponseType>
@ -203,6 +213,10 @@ export class ServerWritableStreamImpl<RequestType, ResponseType>
return this.call.getHost(); return this.call.getHost();
} }
getAuthContext(): AuthContext {
return this.call.getAuthContext();
}
_write( _write(
chunk: ResponseType, chunk: ResponseType,
encoding: string, encoding: string,
@ -276,6 +290,10 @@ export class ServerDuplexStreamImpl<RequestType, ResponseType>
return this.call.getHost(); return this.call.getHost();
} }
getAuthContext(): AuthContext {
return this.call.getAuthContext();
}
_read(size: number) { _read(size: number) {
this.call.startRead(); this.call.startRead();
} }

View File

@ -33,6 +33,8 @@ import * as zlib from 'zlib';
import { StreamDecoder } from './stream-decoder'; import { StreamDecoder } from './stream-decoder';
import { CallEventTracker } from './transport'; import { CallEventTracker } from './transport';
import * as logging from './logging'; import * as logging from './logging';
import { AuthContext } from './auth-context';
import { TLSSocket } from 'tls';
const TRACER_NAME = 'server_call'; const TRACER_NAME = 'server_call';
@ -332,6 +334,10 @@ export interface ServerInterceptingCallInterface {
* Return the host requested by the client in the ":authority" header. * Return the host requested by the client in the ":authority" header.
*/ */
getHost(): string; getHost(): string;
/**
* Return the auth context of the connection the call is associated with.
*/
getAuthContext(): AuthContext;
} }
export class ServerInterceptingCall implements ServerInterceptingCallInterface { export class ServerInterceptingCall implements ServerInterceptingCallInterface {
@ -440,6 +446,9 @@ export class ServerInterceptingCall implements ServerInterceptingCallInterface {
getHost(): string { getHost(): string {
return this.nextCall.getHost(); return this.nextCall.getHost();
} }
getAuthContext(): AuthContext {
return this.nextCall.getAuthContext();
}
} }
export interface ServerInterceptor { export interface ServerInterceptor {
@ -971,6 +980,16 @@ export class BaseServerInterceptingCall
getHost(): string { getHost(): string {
return this.host; return this.host;
} }
getAuthContext(): AuthContext {
if (this.stream.session?.socket instanceof TLSSocket) {
return {
transportSecurityType: 'ssl',
sslPeerCertificate: this.stream.session.socket.getPeerCertificate()
}
} else {
return {};
}
}
} }
export function getServerInterceptingCall( export function getServerInterceptingCall(

View File

@ -30,6 +30,7 @@ import {
WriteCallback, WriteCallback,
} from './call-interface'; } from './call-interface';
import { CallEventTracker, Transport } from './transport'; import { CallEventTracker, Transport } from './transport';
import { AuthContext } from './auth-context';
const TRACER_NAME = 'subchannel_call'; const TRACER_NAME = 'subchannel_call';
@ -71,6 +72,7 @@ export interface SubchannelCall {
halfClose(): void; halfClose(): void;
getCallNumber(): number; getCallNumber(): number;
getDeadlineInfo(): string[]; getDeadlineInfo(): string[];
getAuthContext(): AuthContext;
} }
export interface StatusObjectWithRstCode extends StatusObject { export interface StatusObjectWithRstCode extends StatusObject {
@ -556,6 +558,10 @@ export class Http2SubchannelCall implements SubchannelCall {
return this.callId; return this.callId;
} }
getAuthContext(): AuthContext {
return this.transport.getAuthContext();
}
startRead() { startRead() {
/* If the stream has ended with an error, we should not emit any more /* If the stream has ended with an error, we should not emit any more
* messages and we should communicate that the stream has ended */ * messages and we should communicate that the stream has ended */

View File

@ -51,6 +51,7 @@ import {
import { Metadata } from './metadata'; import { Metadata } from './metadata';
import { getNextCallNumber } from './call-number'; import { getNextCallNumber } from './call-number';
import { Socket } from 'net'; import { Socket } from 'net';
import { AuthContext } from './auth-context';
const TRACER_NAME = 'transport'; const TRACER_NAME = 'transport';
const FLOW_CONTROL_TRACER_NAME = 'transport_flowctrl'; const FLOW_CONTROL_TRACER_NAME = 'transport_flowctrl';
@ -83,6 +84,7 @@ export interface Transport {
getChannelzRef(): SocketRef; getChannelzRef(): SocketRef;
getPeerName(): string; getPeerName(): string;
getOptions(): ChannelOptions; getOptions(): ChannelOptions;
getAuthContext(): AuthContext;
createCall( createCall(
metadata: Metadata, metadata: Metadata,
host: string, host: string,
@ -129,6 +131,8 @@ class Http2Transport implements Transport {
private disconnectHandled = false; private disconnectHandled = false;
private authContext: AuthContext;
// Channelz info // Channelz info
private channelzRef: SocketRef; private channelzRef: SocketRef;
private readonly channelzEnabled: boolean = true; private readonly channelzEnabled: boolean = true;
@ -254,6 +258,15 @@ class Http2Transport implements Transport {
if (this.keepaliveWithoutCalls) { if (this.keepaliveWithoutCalls) {
this.maybeStartKeepalivePingTimer(); this.maybeStartKeepalivePingTimer();
} }
if (session.socket instanceof TLSSocket) {
this.authContext = {
transportSecurityType: 'ssl',
sslPeerCertificate: session.socket.getPeerCertificate()
};
} else {
this.authContext = {};
}
} }
private getChannelzInfo(): SocketInfo { private getChannelzInfo(): SocketInfo {
@ -622,6 +635,10 @@ class Http2Transport implements Transport {
return this.options; return this.options;
} }
getAuthContext(): AuthContext {
return this.authContext;
}
shutdown() { shutdown() {
this.session.close(); this.session.close();
unregisterChannelzRef(this.channelzRef); unregisterChannelzRef(this.channelzRef);

View File

@ -218,6 +218,16 @@ describe('ChannelCredentials usage', () => {
} }
); );
}); });
it('Should provide certificates in getAuthContext', done => {
const call = client.echo({ value: 'test value', value2: 3 }, (error: ServiceError, response: any) => {
assert.ifError(error);
const authContext = call.getAuthContext();
assert(authContext);
assert.strictEqual(authContext.transportSecurityType, 'ssl');
assert(authContext.sslPeerCertificate);
done();
});
})
}); });
describe('Channel credentials mtls', () => { describe('Channel credentials mtls', () => {