From 40ffab8da5feab110fe9c46003f7f9ef267dddc6 Mon Sep 17 00:00:00 2001 From: Jakob Andersen Date: Tue, 27 Feb 2018 10:10:44 +0100 Subject: [PATCH] Split out TLS credentials to a separate class. (#60) Add a 'bad certificate handler' to the new ChannelCredentials, which can be used to override certificate validation (for example, to allow auto-generated self-signed certificates during development). Also fixed a bug in Server.shutdown(). --- example/googleapis/bin/logging.dart | 3 +- example/helloworld/bin/client.dart | 4 +- example/metadata/lib/src/client.dart | 4 +- example/route_guide/lib/src/client.dart | 4 +- example/route_guide/lib/src/server.dart | 14 ++++-- interop/lib/src/client.dart | 9 ++-- lib/src/client/call.dart | 2 +- lib/src/client/channel.dart | 2 +- lib/src/client/connection.dart | 18 +++++-- lib/src/client/options.dart | 65 +++++++++++++++---------- lib/src/server/server.dart | 4 +- test/options_test.dart | 12 ++--- test/src/client_utils.dart | 5 +- 13 files changed, 87 insertions(+), 59 deletions(-) diff --git a/example/googleapis/bin/logging.dart b/example/googleapis/bin/logging.dart index 4841b67..0273e7b 100644 --- a/example/googleapis/bin/logging.dart +++ b/example/googleapis/bin/logging.dart @@ -40,8 +40,7 @@ Future main() async { serviceAccountFile.readAsStringSync(), scopes); final projectId = authenticator.projectId; - final channel = new ClientChannel('logging.googleapis.com', - options: const ChannelOptions.secure()); + final channel = new ClientChannel('logging.googleapis.com'); final logging = new LoggingServiceV2Client(channel, options: authenticator.toCallOptions); diff --git a/example/helloworld/bin/client.dart b/example/helloworld/bin/client.dart index 5e09f03..d37e0e9 100644 --- a/example/helloworld/bin/client.dart +++ b/example/helloworld/bin/client.dart @@ -23,7 +23,9 @@ import 'package:helloworld/src/generated/helloworld.pbgrpc.dart'; Future main(List args) async { final channel = new ClientChannel('localhost', - port: 50051, options: const ChannelOptions.insecure()); + port: 50051, + options: const ChannelOptions( + credentials: const ChannelCredentials.insecure())); final stub = new GreeterClient(channel); final name = args.isNotEmpty ? args[0] : 'world'; diff --git a/example/metadata/lib/src/client.dart b/example/metadata/lib/src/client.dart index b8b0092..fbb47a1 100644 --- a/example/metadata/lib/src/client.dart +++ b/example/metadata/lib/src/client.dart @@ -25,7 +25,9 @@ class Client { Future main(List args) async { channel = new ClientChannel('127.0.0.1', - port: 8080, options: const ChannelOptions.insecure()); + port: 8080, + options: const ChannelOptions( + credentials: const ChannelCredentials.insecure())); stub = new MetadataClient(channel); // Run all of the demos in order. await runEcho(); diff --git a/example/route_guide/lib/src/client.dart b/example/route_guide/lib/src/client.dart index e31d118..397c84d 100644 --- a/example/route_guide/lib/src/client.dart +++ b/example/route_guide/lib/src/client.dart @@ -28,7 +28,9 @@ class Client { Future main(List args) async { channel = new ClientChannel('127.0.0.1', - port: 8080, options: const ChannelOptions.insecure()); + port: 8080, + options: const ChannelOptions( + credentials: const ChannelCredentials.insecure())); stub = new RouteGuideClient(channel, options: new CallOptions(timeout: new Duration(seconds: 30))); // Run all of the demos in order. diff --git a/example/route_guide/lib/src/server.dart b/example/route_guide/lib/src/server.dart index 8feba6f..9e156c3 100644 --- a/example/route_guide/lib/src/server.dart +++ b/example/route_guide/lib/src/server.dart @@ -25,8 +25,9 @@ import 'generated/route_guide.pbgrpc.dart'; class RouteGuideService extends RouteGuideServiceBase { final routeNotes = >{}; - // getFeature handler. Returns a feature for the given location. - // The [context] object provides access to client metadata, cancellation, etc. + /// GetFeature handler. Returns a feature for the given location. + /// The [context] object provides access to client metadata, cancellation, etc. + @override Future getFeature(grpc.ServiceCall call, Point request) async { return featuresDb.firstWhere((f) => f.location == request, orElse: () => new Feature()..location = request); @@ -53,8 +54,9 @@ class RouteGuideService extends RouteGuideServiceBase { p.latitude <= r.hi.latitude; } - /// listFeatures handler. Returns a stream of features within the given + /// ListFeatures handler. Returns a stream of features within the given /// rectangle. + @override Stream listFeatures( grpc.ServiceCall call, Rectangle request) async* { final normalizedRectangle = _normalize(request); @@ -68,9 +70,10 @@ class RouteGuideService extends RouteGuideServiceBase { } } - /// recordRoute handler. Gets a stream of points, and responds with statistics + /// RecordRoute handler. Gets a stream of points, and responds with statistics /// about the "trip": number of points, number of known features visited, /// total distance traveled, and total time spent. + @override Future recordRoute( grpc.ServiceCall call, Stream request) async { int pointCount = 0; @@ -100,9 +103,10 @@ class RouteGuideService extends RouteGuideServiceBase { ..elapsedTime = timer.elapsed.inSeconds; } - /// routeChat handler. Receives a stream of message/location pairs, and + /// RouteChat handler. Receives a stream of message/location pairs, and /// responds with a stream of all previous messages at each of those /// locations. + @override Stream routeChat( grpc.ServiceCall call, Stream request) async* { await for (var note in request) { diff --git a/interop/lib/src/client.dart b/interop/lib/src/client.dart index c7a6225..47602d4 100644 --- a/interop/lib/src/client.dart +++ b/interop/lib/src/client.dart @@ -89,18 +89,19 @@ class Tester { } Future runTest() async { - ChannelOptions options; + ChannelCredentials credentials; if (_useTls) { List trustedRoot; if (_useTestCA) { trustedRoot = new File('ca.pem').readAsBytesSync(); } - options = new ChannelOptions.secure( - certificate: trustedRoot, authority: serverHostOverride); + credentials = new ChannelCredentials.secure( + certificates: trustedRoot, authority: serverHostOverride); } else { - options = new ChannelOptions.insecure(); + credentials = const ChannelCredentials.insecure(); } + final options = new ChannelOptions(credentials: credentials); channel = new ClientChannel(serverHost, port: _serverPort, options: options); client = new TestServiceClient(channel); diff --git a/lib/src/client/call.dart b/lib/src/client/call.dart index cb5a4d2..8832734 100644 --- a/lib/src/client/call.dart +++ b/lib/src/client/call.dart @@ -93,7 +93,7 @@ class ClientCall implements Response { } else { final metadata = new Map.from(options.metadata); String audience; - if (connection.options.isSecure) { + if (connection.options.credentials.isSecure) { final port = connection.port != 443 ? ':${connection.port}' : ''; final lastSlashPos = path.lastIndexOf('/'); final audiencePath = diff --git a/lib/src/client/channel.dart b/lib/src/client/channel.dart index 4685133..95a7318 100644 --- a/lib/src/client/channel.dart +++ b/lib/src/client/channel.dart @@ -38,7 +38,7 @@ class ClientChannel { bool _isShutdown = false; ClientChannel(this.host, - {this.port = 443, this.options = const ChannelOptions.secure()}); + {this.port = 443, this.options = const ChannelOptions()}); /// Shuts down this channel. /// diff --git a/lib/src/client/connection.dart b/lib/src/client/connection.dart index 4d5262f..6386dcd 100644 --- a/lib/src/client/connection.dart +++ b/lib/src/client/connection.dart @@ -96,11 +96,11 @@ class ClientConnection { return headers; } - String get authority => options.authority ?? host; + String get authority => options.credentials.authority ?? host; @visibleForTesting Future connectTransport() async { - final securityContext = options.securityContext; + final securityContext = options.credentials.securityContext; var socket = await Socket.connect(host, port); if (_state == ConnectionState.shutdown) { @@ -109,7 +109,9 @@ class ClientConnection { } if (securityContext != null) { socket = await SecureSocket.secure(socket, - host: authority, context: securityContext); + host: authority, + context: securityContext, + onBadCertificate: _validateBadCertificate); if (_state == ConnectionState.shutdown) { socket.destroy(); throw 'Shutting down'; @@ -119,6 +121,12 @@ class ClientConnection { return new ClientTransportConnection.viaSocket(socket); } + bool _validateBadCertificate(X509Certificate certificate) { + final validator = options.credentials.onBadCertificate; + if (validator == null) return false; + return validator(certificate, authority); + } + void _connect() { if (_state != ConnectionState.idle && _state != ConnectionState.transientFailure) { @@ -153,8 +161,8 @@ class ClientConnection { ClientTransportStream makeRequest( String path, Duration timeout, Map metadata) { - final headers = - createCallHeaders(options.isSecure, authority, path, timeout, metadata); + final headers = createCallHeaders( + options.credentials.isSecure, authority, path, timeout, metadata); return _transport.makeRequest(headers); } diff --git a/lib/src/client/options.dart b/lib/src/client/options.dart index a4600b6..a286e80 100644 --- a/lib/src/client/options.dart +++ b/lib/src/client/options.dart @@ -39,43 +39,40 @@ Duration defaultBackoffStrategy(Duration lastBackoff) { return nextBackoff < _maxBackoff ? nextBackoff : _maxBackoff; } -/// Options controlling how connections are made on a [ClientChannel]. -class ChannelOptions { +/// Handler for checking certificates that fail validation. If this handler +/// returns `true`, the bad certificate is allowed, and the TLS handshake can +/// continue. If the handler returns `false`, the TLS handshake fails, and the +/// connection is aborted. +typedef bool BadCertificateHandler(X509Certificate certificate, String host); + +/// Bad certificate handler that disables all certificate checks. +/// DO NOT USE IN PRODUCTION! +/// Can be used during development and testing to accept self-signed +/// certificates, etc. +bool allowBadCertificates(X509Certificate certificate, String host) => true; + +/// Options controlling TLS security settings on a [ClientChannel]. +class ChannelCredentials { final bool isSecure; final List _certificateBytes; final String _certificatePassword; final String authority; - final Duration idleTimeout; - final BackoffStrategy backoffStrategy; + final BadCertificateHandler onBadCertificate; - const ChannelOptions._( - this.isSecure, - this._certificateBytes, - this._certificatePassword, - this.authority, - Duration idleTimeout, - BackoffStrategy backoffStrategy) - : this.idleTimeout = idleTimeout ?? defaultIdleTimeout, - this.backoffStrategy = backoffStrategy ?? defaultBackoffStrategy; + const ChannelCredentials._(this.isSecure, this._certificateBytes, + this._certificatePassword, this.authority, this.onBadCertificate); /// Disable TLS. RPCs are sent in clear text. - const ChannelOptions.insecure( - {Duration idleTimeout, - BackoffStrategy backoffStrategy = - defaultBackoffStrategy}) // Remove when dart-lang/sdk#31066 is fixed. - : this._(false, null, null, null, idleTimeout, backoffStrategy); + const ChannelCredentials.insecure() : this._(false, null, null, null, null); - /// Enable TLS and optionally specify the [certificate]s to trust. If + /// Enable TLS and optionally specify the [certificates] to trust. If /// [certificates] is not provided, the default trust store is used. - const ChannelOptions.secure( - {List certificate, + const ChannelCredentials.secure( + {List certificates, String password, String authority, - Duration idleTimeout, - BackoffStrategy backoffStrategy = - defaultBackoffStrategy}) // Remove when dart-lang/sdk#31066 is fixed. - : this._(true, certificate, password, authority, idleTimeout, - backoffStrategy); + BadCertificateHandler onBadCertificate}) + : this._(true, certificates, password, authority, onBadCertificate); SecurityContext get securityContext { if (!isSecure) return null; @@ -92,6 +89,22 @@ class ChannelOptions { } } +/// Options controlling how connections are made on a [ClientChannel]. +class ChannelOptions { + final ChannelCredentials credentials; + final Duration idleTimeout; + final BackoffStrategy backoffStrategy; + + const ChannelOptions( + {ChannelCredentials credentials, + Duration idleTimeout, + BackoffStrategy backoffStrategy = + defaultBackoffStrategy}) // Remove when dart-lang/sdk#31066 is fixed. + : this.credentials = credentials ?? const ChannelCredentials.secure(), + this.idleTimeout = idleTimeout ?? defaultIdleTimeout, + this.backoffStrategy = backoffStrategy ?? defaultBackoffStrategy; +} + /// Provides per-RPC metadata. /// /// Metadata providers will be invoked for every RPC, and can add their own diff --git a/lib/src/server/server.dart b/lib/src/server/server.dart index 10cfba1..8870097 100644 --- a/lib/src/server/server.dart +++ b/lib/src/server/server.dart @@ -100,7 +100,7 @@ class Server { new ServerHandler(lookupService, stream).handle(); } - Future shutdown() { + Future shutdown() async { final done = _connections.map((connection) => connection.finish()).toList(); if (_insecureServer != null) { done.add(_insecureServer.close()); @@ -108,6 +108,6 @@ class Server { if (_secureServer != null) { done.add(_secureServer.close()); } - return Future.wait(done); + await Future.wait(done); } } diff --git a/test/options_test.dart b/test/options_test.dart index 22867ef..484e407 100644 --- a/test/options_test.dart +++ b/test/options_test.dart @@ -23,19 +23,19 @@ const isTlsException = const isInstanceOf(); void main() { group('Certificates', () { test('report password errors correctly', () async { - final certificate = + final certificates = await new File('test/data/certstore.p12').readAsBytes(); final missingPassword = - new ChannelOptions.secure(certificate: certificate); + new ChannelCredentials.secure(certificates: certificates); expect(() => missingPassword.securityContext, throwsA(isTlsException)); - final wrongPassword = new ChannelOptions.secure( - certificate: certificate, password: 'wrong'); + final wrongPassword = new ChannelCredentials.secure( + certificates: certificates, password: 'wrong'); expect(() => wrongPassword.securityContext, throwsA(isTlsException)); - final correctPassword = new ChannelOptions.secure( - certificate: certificate, password: 'correct'); + final correctPassword = new ChannelCredentials.secure( + certificates: certificates, password: 'correct'); expect(correctPassword.securityContext, isNotNull); }); }); diff --git a/test/src/client_utils.dart b/test/src/client_utils.dart index f7468a9..a36a9cd 100644 --- a/test/src/client_utils.dart +++ b/test/src/client_utils.dart @@ -15,7 +15,6 @@ import 'dart:async'; -import 'dart:io'; import 'package:grpc/src/shared/streams.dart'; import 'package:http2/transport.dart'; import 'package:test/test.dart'; @@ -47,11 +46,9 @@ class FakeConnection extends ClientConnection { Duration testBackoff(Duration lastBackoff) => const Duration(milliseconds: 1); class FakeChannelOptions implements ChannelOptions { - String authority; + ChannelCredentials credentials = const ChannelCredentials.secure(); Duration idleTimeout = const Duration(seconds: 1); BackoffStrategy backoffStrategy = testBackoff; - SecurityContext securityContext = new SecurityContext(); - bool isSecure = true; } class FakeChannel extends ClientChannel {