From b272632450498f1c688fe4b0988fa595a0004124 Mon Sep 17 00:00:00 2001 From: EPNW Date: Tue, 11 May 2021 13:35:36 +0200 Subject: [PATCH] Make clientCertificate available in ServiceCall (#472) Co-authored-by: Vyacheslav Egorov --- lib/src/server/call.dart | 5 ++ lib/src/server/handler.dart | 26 ++++--- lib/src/server/server.dart | 52 +++++++++---- test/client_certificate_test.dart | 124 ++++++++++++++++++++++++++++++ 4 files changed, 179 insertions(+), 28 deletions(-) create mode 100644 test/client_certificate_test.dart diff --git a/lib/src/server/call.dart b/lib/src/server/call.dart index 4605bdc..f6f18f8 100644 --- a/lib/src/server/call.dart +++ b/lib/src/server/call.dart @@ -13,6 +13,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +import 'dart:io'; + /// Server-side context for a gRPC call. /// /// Gives the method handler access to custom metadata from the client, and @@ -39,6 +41,9 @@ abstract class ServiceCall { /// Returns [true] if the client has canceled this call. bool get isCanceled; + /// Returns the client certificate if it is requested and available + X509Certificate? get clientCertificate; + /// Send response headers. This is done automatically before sending the first /// response message, but can be done manually before the first response is /// ready, if necessary. diff --git a/lib/src/server/handler.dart b/lib/src/server/handler.dart index 4239e88..c5d036e 100644 --- a/lib/src/server/handler.dart +++ b/lib/src/server/handler.dart @@ -15,6 +15,7 @@ import 'dart:async'; import 'dart:convert'; +import 'dart:io'; import 'package:http2/transport.dart'; @@ -58,13 +59,11 @@ class ServerHandler_ extends ServiceCall { bool _isCanceled = false; bool _isTimedOut = false; Timer? _timeoutTimer; + final X509Certificate? _clientCertificate; - ServerHandler_( - this._serviceLookup, - this._stream, - this._interceptors, - this._codecRegistry, - ); + ServerHandler_(this._serviceLookup, this._stream, this._interceptors, + this._codecRegistry, + [this._clientCertificate]); @override DateTime? get deadline => _deadline; @@ -84,6 +83,9 @@ class ServerHandler_ extends ServiceCall { @override Map? get trailers => _customTrailers; + @override + X509Certificate? get clientCertificate => _clientCertificate; + void handle() { _stream.onTerminated = (_) => cancel(); @@ -410,10 +412,10 @@ class ServerHandler_ extends ServiceCall { } class ServerHandler extends ServerHandler_ { - ServerHandler( - Service Function(String service) serviceLookup, - stream, [ - List interceptors = const [], - CodecRegistry? codecRegistry, - ]) : super(serviceLookup, stream, interceptors, codecRegistry); + ServerHandler(Service Function(String service) serviceLookup, stream, + [List interceptors = const [], + CodecRegistry? codecRegistry, + X509Certificate? clientCertificate]) + : super(serviceLookup, stream, interceptors, codecRegistry, + clientCertificate); } diff --git a/lib/src/server/server.dart b/lib/src/server/server.dart index 71ee1d5..7793628 100644 --- a/lib/src/server/server.dart +++ b/lib/src/server/server.dart @@ -103,13 +103,14 @@ class ConnectionServer { Service? lookupService(String service) => _services[service]; - Future serveConnection(ServerTransportConnection connection) async { + Future serveConnection(ServerTransportConnection connection, + [X509Certificate? clientCertificate]) async { _connections.add(connection); ServerHandler_? handler; // TODO(jakobr): Set active state handlers, close connection after idle // timeout. connection.incomingStreams.listen((stream) { - handler = serveStream_(stream); + handler = serveStream_(stream, clientCertificate); }, onError: (error, stackTrace) { if (error is Error) { Zone.current.handleUncaughtError(error, stackTrace); @@ -125,8 +126,10 @@ class ConnectionServer { } @visibleForTesting - ServerHandler_ serveStream_(ServerTransportStream stream) { - return ServerHandler_(lookupService, stream, _interceptors, _codecRegistry) + ServerHandler_ serveStream_(ServerTransportStream stream, + [X509Certificate? clientCertificate]) { + return ServerHandler_( + lookupService, stream, _interceptors, _codecRegistry, clientCertificate) ..handle(); } } @@ -159,21 +162,32 @@ class Server extends ConnectionServer { /// Starts the [Server] with the given options. /// [address] can be either a [String] or an [InternetAddress], in the latter /// case it can be a Unix Domain Socket address. - Future serve( - {dynamic address, - int? port, - ServerCredentials? security, - ServerSettings? http2ServerSettings, - int backlog = 0, - bool v6Only = false, - bool shared = false}) async { + /// + /// If [port] is [null] then it defaults to `80` for non-secure and `443` for + /// secure variants. Pass `0` for [port] to let OS select a port for the + /// server. + Future serve({ + dynamic address, + int? port, + ServerCredentials? security, + ServerSettings? http2ServerSettings, + int backlog = 0, + bool v6Only = false, + bool shared = false, + bool requestClientCertificate = false, + bool requireClientCertificate = false, + }) async { // TODO(dart-lang/grpc-dart#9): Handle HTTP/1.1 upgrade to h2c, if allowed. Stream? server; final securityContext = security?.securityContext; if (securityContext != null) { _secureServer = await SecureServerSocket.bind( address ?? InternetAddress.anyIPv4, port ?? 443, securityContext, - backlog: backlog, shared: shared, v6Only: v6Only); + backlog: backlog, + shared: shared, + v6Only: v6Only, + requestClientCertificate: requestClientCertificate, + requireClientCertificate: requireClientCertificate); server = _secureServer; } else { _insecureServer = await ServerSocket.bind( @@ -190,9 +204,13 @@ class Server extends ConnectionServer { if (socket.address.type != InternetAddressType.unix) { socket.setOption(SocketOption.tcpNoDelay, true); } + X509Certificate? clientCertificate; + if (socket is SecureSocket) { + clientCertificate = socket.peerCertificate; + } final connection = ServerTransportConnection.viaSocket(socket, settings: http2ServerSettings); - serveConnection(connection); + serveConnection(connection, clientCertificate); }, onError: (error, stackTrace) { if (error is Error) { Zone.current.handleUncaughtError(error, stackTrace); @@ -202,8 +220,10 @@ class Server extends ConnectionServer { @override @visibleForTesting - ServerHandler_ serveStream_(ServerTransportStream stream) { - return ServerHandler_(lookupService, stream, _interceptors, _codecRegistry) + ServerHandler_ serveStream_(ServerTransportStream stream, + [X509Certificate? clientCertificate]) { + return ServerHandler_( + lookupService, stream, _interceptors, _codecRegistry, clientCertificate) ..handle(); } diff --git a/test/client_certificate_test.dart b/test/client_certificate_test.dart new file mode 100644 index 0000000..0a78ab8 --- /dev/null +++ b/test/client_certificate_test.dart @@ -0,0 +1,124 @@ +// TODO(dartbug.com/26057) currently Mac OS X seems to have some issues with +// client certificates so we disable the test. +@TestOn('vm && !mac-os') +import 'dart:async'; +import 'dart:io'; + +import 'package:grpc/grpc.dart'; +import 'package:test/test.dart'; + +import 'src/generated/echo.pbgrpc.dart'; + +class EchoService extends EchoServiceBase { + @override + Future echo(ServiceCall call, EchoRequest request) async { + final subject = call.clientCertificate?.subject; + return (EchoResponse()..message = subject ?? 'NO CERT'); + } + + @override + Stream serverStreamingEcho( + ServiceCall call, ServerStreamingEchoRequest request) { + // TODO: implement serverStreamingEcho + throw UnimplementedError(); + } +} + +const String address = 'localhost'; +Future main() async { + test('Client certificate required', () async { + // Server + final server = await _setUpServer(true); + + // Client + final channelContext = + SecurityContextChannelCredentials.baseSecurityContext(); + channelContext.useCertificateChain('test/data/localhost.crt'); + channelContext.usePrivateKey('test/data/localhost.key'); + final channelCredentials = SecurityContextChannelCredentials(channelContext, + onBadCertificate: (cert, s) { + return true; + }); + final channel = ClientChannel(address, + port: server.port ?? 443, + options: ChannelOptions(credentials: channelCredentials)); + final client = EchoServiceClient(channel); + + // Test + expect((await client.echo(EchoRequest())).message, '/CN=localhost'); + + // Clean up + await channel.shutdown(); + await server.shutdown(); + }); + + test('Client certificate not required', () async { + // Server + final server = await _setUpServer(); + + // Client + final channelContext = + SecurityContextChannelCredentials.baseSecurityContext(); + channelContext.useCertificateChain('test/data/localhost.crt'); + channelContext.usePrivateKey('test/data/localhost.key'); + final channelCredentials = SecurityContextChannelCredentials(channelContext, + onBadCertificate: (cert, s) { + return true; + }); + final channel = ClientChannel(address, + port: server.port ?? 443, + options: ChannelOptions(credentials: channelCredentials)); + final client = EchoServiceClient(channel); + + // Test + expect((await client.echo(EchoRequest())).message, 'NO CERT'); + + // Clean up + await channel.shutdown(); + await server.shutdown(); + }); +} + +Future _setUpServer([bool requireClientCertificate = false]) async { + final server = Server([EchoService()]); + final serverContext = SecurityContextChannelCredentials.baseSecurityContext(); + serverContext.useCertificateChain('test/data/localhost.crt'); + serverContext.usePrivateKey('test/data/localhost.key'); + serverContext.setTrustedCertificates('test/data/localhost.crt'); + final ServerCredentials serverCredentials = + SecurityContextServerCredentials(serverContext); + await server.serve( + address: address, + port: 0, + security: serverCredentials, + requireClientCertificate: requireClientCertificate); + return server; +} + +class SecurityContextChannelCredentials extends ChannelCredentials { + final SecurityContext _securityContext; + + SecurityContextChannelCredentials(SecurityContext securityContext, + {String? authority, BadCertificateHandler? onBadCertificate}) + : _securityContext = securityContext, + super.secure(authority: authority, onBadCertificate: onBadCertificate); + @override + SecurityContext get securityContext => _securityContext; + + static SecurityContext baseSecurityContext() { + return createSecurityContext(false); + } +} + +class SecurityContextServerCredentials extends ServerTlsCredentials { + final SecurityContext _securityContext; + + SecurityContextServerCredentials(SecurityContext securityContext) + : _securityContext = securityContext, + super(); + @override + SecurityContext get securityContext => _securityContext; + static SecurityContext baseSecurityContext() { + return createSecurityContext(true); + } +}