diff --git a/CHANGELOG.md b/CHANGELOG.md index 6676c49..8a2f5e6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +## 4.1.0 +* Add a `serverInterceptors` argument to `ConnectionServer`. These interceptors are acting + as middleware, wrapping a `ServiceMethod` invocation. + ## 4.0.4 * Allow the latest `package:googleapis_auth`. diff --git a/lib/grpc.dart b/lib/grpc.dart index d2301a0..f84465f 100644 --- a/lib/grpc.dart +++ b/lib/grpc.dart @@ -42,7 +42,8 @@ export 'src/client/proxy.dart' show Proxy; export 'src/client/transport/http2_credentials.dart' show BadCertificateHandler, allowBadCertificates, ChannelCredentials; export 'src/server/call.dart' show ServiceCall; -export 'src/server/interceptor.dart' show Interceptor; +export 'src/server/interceptor.dart' + show Interceptor, ServerInterceptor, ServerStreamingInvoker; export 'src/server/server.dart' show ServerCredentials, diff --git a/lib/src/server/handler.dart b/lib/src/server/handler.dart index e89557e..fb285c6 100644 --- a/lib/src/server/handler.dart +++ b/lib/src/server/handler.dart @@ -37,6 +37,7 @@ class ServerHandler extends ServiceCall { final ServerTransportStream _stream; final ServiceLookup _serviceLookup; final List _interceptors; + final List _serverInterceptors; final CodecRegistry? _codecRegistry; final GrpcErrorHandler? _errorHandler; @@ -83,6 +84,7 @@ class ServerHandler extends ServiceCall { required ServerTransportStream stream, required ServiceLookup serviceLookup, required List interceptors, + required List serverInterceptors, required CodecRegistry? codecRegistry, X509Certificate? clientCertificate, InternetAddress? remoteAddress, @@ -94,7 +96,8 @@ class ServerHandler extends ServiceCall { _codecRegistry = codecRegistry, _clientCertificate = clientCertificate, _remoteAddress = remoteAddress, - _errorHandler = errorHandler; + _errorHandler = errorHandler, + _serverInterceptors = serverInterceptors; @override DateTime? get deadline => _deadline; @@ -239,7 +242,7 @@ class ServerHandler extends ServiceCall { return; } - _responses = _descriptor.handle(this, requests.stream); + _responses = _descriptor.handle(this, requests.stream, _serverInterceptors); _responseSubscription = _responses.listen(_onResponse, onError: _onResponseError, diff --git a/lib/src/server/interceptor.dart b/lib/src/server/interceptor.dart index 81060f9..2a8e74a 100644 --- a/lib/src/server/interceptor.dart +++ b/lib/src/server/interceptor.dart @@ -27,3 +27,18 @@ import 'service.dart'; /// If the interceptor returns null, the corresponding [ServiceMethod] of [Service] will be called. typedef Interceptor = FutureOr Function( ServiceCall call, ServiceMethod method); + +typedef ServerStreamingInvoker = Stream Function( + ServiceCall call, ServiceMethod method, Stream requests); + +/// A gRPC Interceptor. +/// +/// An interceptor is called around the corresponding [ServiceMethod] invocation. +/// If the interceptor throws [GrpcError], the error will be returned as a response. [ServiceMethod] wouldn't be called if the error is thrown before calling the invoker. +/// If the interceptor modifies the provided stream, the invocation will continue with the provided stream. +abstract class ServerInterceptor { + Stream intercept(ServiceCall call, ServiceMethod method, + Stream requests, ServerStreamingInvoker invoker) { + return invoker(call, method, requests); + } +} diff --git a/lib/src/server/server.dart b/lib/src/server/server.dart index bc40edb..a58a3bd 100644 --- a/lib/src/server/server.dart +++ b/lib/src/server/server.dart @@ -87,6 +87,7 @@ class ServerTlsCredentials extends ServerCredentials { class ConnectionServer { final Map _services = {}; final List _interceptors; + final List _serverInterceptors; final CodecRegistry? _codecRegistry; final GrpcErrorHandler? _errorHandler; final ServerKeepAliveOptions _keepAliveOptions; @@ -100,11 +101,13 @@ class ConnectionServer { ConnectionServer( List services, [ List interceptors = const [], + List serverInterceptors = const [], CodecRegistry? codecRegistry, GrpcErrorHandler? errorHandler, this._keepAliveOptions = const ServerKeepAliveOptions(), ]) : _codecRegistry = codecRegistry, _interceptors = interceptors, + _serverInterceptors = serverInterceptors, _errorHandler = errorHandler { for (final service in services) { _services[service.$name] = service; @@ -168,6 +171,7 @@ class ConnectionServer { stream: stream, serviceLookup: lookupService, interceptors: _interceptors, + serverInterceptors: _serverInterceptors, codecRegistry: _codecRegistry, // ignore: unnecessary_cast clientCertificate: clientCertificate as io_bits.X509Certificate?, @@ -201,11 +205,13 @@ class Server extends ConnectionServer { required List services, ServerKeepAliveOptions keepAliveOptions = const ServerKeepAliveOptions(), List interceptors = const [], + List serverInterceptors = const [], CodecRegistry? codecRegistry, GrpcErrorHandler? errorHandler, }) : super( services, interceptors, + serverInterceptors, codecRegistry, errorHandler, keepAliveOptions, @@ -308,6 +314,7 @@ class Server extends ConnectionServer { stream: stream, serviceLookup: lookupService, interceptors: _interceptors, + serverInterceptors: _serverInterceptors, codecRegistry: _codecRegistry, // ignore: unnecessary_cast clientCertificate: clientCertificate as io_bits.X509Certificate?, diff --git a/lib/src/server/service.dart b/lib/src/server/service.dart index 7a457c3..9f1b2fa 100644 --- a/lib/src/server/service.dart +++ b/lib/src/server/service.dart @@ -17,6 +17,7 @@ import 'dart:async'; import '../shared/status.dart'; import 'call.dart'; +import 'interceptor.dart'; /// Definition of a gRPC service method. class ServiceMethod { @@ -48,19 +49,42 @@ class ServiceMethod { List serialize(dynamic response) => responseSerializer(response as R); - Stream handle(ServiceCall call, Stream requests) { - if (streamingResponse) { - if (streamingRequest) { - return handler(call, requests); - } else { - return handler(call, _toSingleFuture(requests)); - } - } else { - final response = streamingRequest - ? handler(call, requests) - : handler(call, _toSingleFuture(requests)); - return response.asStream(); + ServerStreamingInvoker _createCall() => (( + ServiceCall call, + ServiceMethod method, + Stream requests, + ) { + if (streamingResponse) { + if (streamingRequest) { + return handler(call, requests); + } else { + return handler(call, _toSingleFuture(requests)); + } + } else { + final response = streamingRequest + ? handler(call, requests) + : handler(call, _toSingleFuture(requests)); + return response.asStream(); + } + }); + + Stream handle( + ServiceCall call, + Stream requests, + List interceptors, + ) { + var invoker = _createCall(); + + for (final interceptor in interceptors.reversed) { + final delegate = invoker; + // invoker is actually reassigned in the same scope as the above function, + // reassigning invoker in delegate is required to avoid an infinite + // recursion + invoker = (call, method, requests) => + interceptor.intercept(call, method, requests, delegate); } + + return invoker(call, this, requests); } Future _toSingleFuture(Stream stream) { diff --git a/pubspec.yaml b/pubspec.yaml index 99401ff..9f9c8f6 100644 --- a/pubspec.yaml +++ b/pubspec.yaml @@ -1,5 +1,5 @@ name: grpc -version: 4.0.4 +version: 4.1.0 description: Dart implementation of gRPC, a high performance, open-source universal RPC framework. repository: https://github.com/grpc/grpc-dart diff --git a/test/server_test.dart b/test/server_test.dart index b323f77..3f1a388 100644 --- a/test/server_test.dart +++ b/test/server_test.dart @@ -384,4 +384,221 @@ void main() { await harness.fromServer.done; }); }); + + group('Server with server interceptor', () { + group('processes calls if interceptor allows request', () { + const expectedRequest = 5; + const expectedResponse = 7; + Future methodHandler(ServiceCall call, Future request) async { + expect(await request, expectedRequest); + return expectedResponse; + } + + Null interceptor(call, method, requests) { + if (method.name == 'Unary') { + return null; + } + throw GrpcError.unauthenticated('Request is unauthenticated'); + } + + Future doTest(TestServerInterceptorOnStart? handler) async { + harness + ..serverInterceptor.onStart = handler + ..service.unaryHandler = methodHandler + ..runTest('/Test/Unary', [expectedRequest], [expectedResponse]); + + await harness.fromServer.done; + } + + test('with sync interceptor', () => doTest(interceptor)); + test( + 'with async interceptor', + () => doTest((call, method, requests) async => + interceptor(call, method, requests))); + }); + + group('returns error if interceptor blocks request', () { + Null interceptor(call, method, requests) { + if (method.name == 'Unary') { + throw GrpcError.unauthenticated('Request is unauthenticated'); + } + return null; + } + + Future doTest(TestServerInterceptorOnStart handler) async { + harness + ..serverInterceptor.onStart = handler + ..expectErrorResponse( + StatusCode.unauthenticated, 'Request is unauthenticated') + ..sendRequestHeader('/Test/Unary'); + + await harness.fromServer.done; + } + + test('with sync interceptor', () => doTest(interceptor)); + test( + 'with async interceptor', + () => doTest((call, method, request) async => + interceptor(call, method, request))); + }); + + test("don't fail if interceptor await 2 times", () async { + Future interceptor(call, method, requests) async { + await Future.value(); + await Future.value(); + throw GrpcError.internal('Reason is unknown'); + } + + harness + ..serverInterceptor.onStart = interceptor + ..expectErrorResponse(StatusCode.internal, 'Reason is unknown') + ..sendRequestHeader('/Test/Unary') + ..sendData(1); + + await harness.fromServer.done; + }); + + group('serviceInterceptors are invoked', () { + const expectedRequest = 5; + const expectedResponse = 7; + Future methodHandler(ServiceCall call, Future request) async { + expect(await request, expectedRequest); + return expectedResponse; + } + + Future doTest(List interceptors) async { + harness + // ↓ mutation: Server is already built + ..serverInterceptors.addAll(interceptors) + ..service.unaryHandler = methodHandler + ..runTest('/Test/Unary', [expectedRequest], [expectedResponse]); + + await harness.fromServer.done; + } + + test('single serviceInterceptor is invoked', () async { + final invocationsOrderRecords = []; + + await doTest([ + TestServerInterceptor( + onStart: (call, method, requests) { + invocationsOrderRecords.add('Start'); + }, + onData: (call, method, requests, data) { + invocationsOrderRecords.add('Data [$data]'); + }, + onFinish: (call, method, requests) { + invocationsOrderRecords.add('Done'); + }, + ) + ]); + + expect(invocationsOrderRecords, equals(['Start', 'Data [7]', 'Done'])); + }); + + test('multiple serviceInterceptors are invoked', () async { + final invocationsOrderRecords = []; + + await doTest([ + TestServerInterceptor( + onStart: (call, method, requests) { + invocationsOrderRecords.add('Start 1'); + }, + onData: (call, method, requests, data) { + invocationsOrderRecords.add('Data 1 [$data]'); + }, + onFinish: (call, method, requests) { + invocationsOrderRecords.add('Done 1'); + }, + ), + TestServerInterceptor( + onStart: (call, method, requests) { + invocationsOrderRecords.add('Start 2'); + }, + onData: (call, method, requests, data) { + invocationsOrderRecords.add('Data 2 [$data]'); + }, + onFinish: (call, method, requests) { + invocationsOrderRecords.add('Done 2'); + }, + ) + ]); + + expect( + invocationsOrderRecords, + equals([ + 'Start 1', + 'Start 2', + 'Data 2 [7]', + 'Data 1 [7]', + 'Done 2', + 'Done 1', + ])); + }); + }); + + test('can modify response', () async { + const expectedRequest = 5; + const baseResponse = 7; + const expectedResponse = 14; + + final invocationsOrderRecords = []; + + final interceptors = [ + TestServerInterceptor( + onStart: (call, method, requests) { + invocationsOrderRecords.add('Start 1'); + }, + onData: (call, method, requests, data) { + invocationsOrderRecords.add('Data 1 [$data]'); + }, + onFinish: (call, method, requests) { + invocationsOrderRecords.add('Done 1'); + }, + ), + TestServerInterruptingInterceptor(transform: (value) { + if (value is int) { + return value * 2 as R; + } + + return value; + }), + TestServerInterceptor( + onStart: (call, method, requests) { + invocationsOrderRecords.add('Start 2'); + }, + onData: (call, method, requests, data) { + invocationsOrderRecords.add('Data 2 [$data]'); + }, + onFinish: (call, method, requests) { + invocationsOrderRecords.add('Done 2'); + }, + ) + ]; + + Future methodHandler(ServiceCall call, Future request) async { + expect(await request, expectedRequest); + return baseResponse; + } + + harness + // ↓ mutation: Server is already built + ..serverInterceptors.addAll(interceptors) + ..service.unaryHandler = methodHandler + ..runTest('/Test/Unary', [expectedRequest], [expectedResponse]); + + await harness.fromServer.done; + + expect( + invocationsOrderRecords, + equals([ + 'Start 1', + 'Start 2', + 'Data 2 [7]', + 'Data 1 [14]', + 'Done 2', + 'Done 1', + ])); + }); + }); } diff --git a/test/src/server_utils.dart b/test/src/server_utils.dart index 8f4a1d5..fa43749 100644 --- a/test/src/server_utils.dart +++ b/test/src/server_utils.dart @@ -90,6 +90,47 @@ class TestInterceptor { } } +typedef TestServerInterceptorOnStart = Function( + ServiceCall call, ServiceMethod method, Stream requests); +typedef TestServerInterceptorOnData = Function( + ServiceCall call, ServiceMethod method, Stream requests, dynamic data); +typedef TestServerInterceptorOnFinish = Function( + ServiceCall call, ServiceMethod method, Stream requests); + +class TestServerInterceptor extends ServerInterceptor { + TestServerInterceptorOnStart? onStart; + TestServerInterceptorOnData? onData; + TestServerInterceptorOnFinish? onFinish; + + TestServerInterceptor({this.onStart, this.onData, this.onFinish}); + + @override + Stream intercept(ServiceCall call, ServiceMethod method, + Stream requests, ServerStreamingInvoker invoker) async* { + await onStart?.call(call, method, requests); + + await for (final chunk + in super.intercept(call, method, requests, invoker)) { + await onData?.call(call, method, requests, chunk); + yield chunk; + } + + await onFinish?.call(call, method, requests); + } +} + +class TestServerInterruptingInterceptor extends ServerInterceptor { + final R Function(R) transform; + + TestServerInterruptingInterceptor({required this.transform}); + + @override + Stream intercept(ServiceCall call, ServiceMethod method, + Stream requests, ServerStreamingInvoker invoker) async* { + yield* super.intercept(call, method, requests, invoker).map(transform); + } +} + class TestServerStream extends ServerTransportStream { @override final Stream incomingMessages; @@ -123,6 +164,7 @@ class ServerHarness extends _Harness { ConnectionServer createServer() => Server.create( services: [service], interceptors: [interceptor.call], + serverInterceptors: serverInterceptors..insert(0, serverInterceptor), ); static ServiceMethod createMethod(String name, @@ -161,6 +203,10 @@ abstract class _Harness { final fromServer = StreamController(); final service = TestService(); final interceptor = TestInterceptor(); + final serverInterceptor = TestServerInterceptor(); + + final serverInterceptors = []; + ConnectionServer? _server; ConnectionServer createServer();