Add proxy functionality (#657)

* Add proxy to options

* Add proxy connect

* Works now

* Uncomment proxy line

* Revert change

* Doesn't work

* Works

* Fix bug

* Add secure test

* Refine test

* Add changelog

* Changes as per review
This commit is contained in:
Moritz 2023-08-23 10:07:15 +02:00 committed by GitHub
parent a6322db468
commit 4ccd8a0e3d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 331 additions and 21 deletions

View File

@ -1,6 +1,7 @@
## 3.2.4-wip
* Forward internal `GrpcError` on when throwing while sending a request.
* Add support for proxies, see [#33](https://github.com/grpc/grpc-dart/issues/33).
## 3.2.3

View File

@ -16,7 +16,9 @@
import 'dart:async';
import 'dart:convert';
import 'dart:io';
import 'dart:typed_data';
import 'package:grpc/src/client/proxy.dart';
import 'package:http2/transport.dart';
import '../shared/codec.dart';
@ -61,7 +63,7 @@ class Http2ClientConnection implements connection.ClientConnection {
ClientKeepAlive? keepAliveManager;
Http2ClientConnection(Object host, int port, this.options)
: _transportConnector = _SocketTransportConnector(host, port, options);
: _transportConnector = SocketTransportConnector(host, port, options);
Http2ClientConnection.fromClientTransportConnector(
this._transportConnector, this.options);
@ -351,39 +353,68 @@ class Http2ClientConnection implements connection.ClientConnection {
}
}
class _SocketTransportConnector implements ClientTransportConnector {
class SocketTransportConnector implements ClientTransportConnector {
/// Either [InternetAddress] or [String].
final Object _host;
final int _port;
final ChannelOptions _options;
late Socket _socket; // ignore: close_sinks
late Socket socket;
_SocketTransportConnector(this._host, this._port, this._options)
Proxy? get proxy => _options.proxy;
Object get host => proxy == null ? _host : proxy!.host;
int get port => proxy == null ? _port : proxy!.port;
SocketTransportConnector(this._host, this._port, this._options)
: assert(_host is InternetAddress || _host is String);
@override
Future<ClientTransportConnection> connect() async {
final securityContext = _options.credentials.securityContext;
_socket =
await Socket.connect(_host, _port, timeout: _options.connectTimeout);
var incoming = await connectImpl(proxy);
// Don't wait for io buffers to fill up before sending requests.
if (_socket.address.type != InternetAddressType.unix) {
_socket.setOption(SocketOption.tcpNoDelay, true);
if (socket.address.type != InternetAddressType.unix) {
socket.setOption(SocketOption.tcpNoDelay, true);
}
if (securityContext != null) {
// Todo(sigurdm): We want to pass supportedProtocols: ['h2'].
// http://dartbug.com/37950
_socket = await SecureSocket.secure(_socket,
// This is not really the host, but the authority to verify the TLC
// connection against.
//
// We don't use `this.authority` here, as that includes the port.
host: _options.credentials.authority ?? _host,
context: securityContext,
onBadCertificate: _validateBadCertificate);
socket = await SecureSocket.secure(
socket,
// This is not really the host, but the authority to verify the TLC
// connection against.
//
// We don't use `this.authority` here, as that includes the port.
host: _options.credentials.authority ?? host,
context: securityContext,
onBadCertificate: _validateBadCertificate,
);
incoming = socket;
}
return ClientTransportConnection.viaStreams(incoming, socket);
}
return ClientTransportConnection.viaSocket(_socket);
Future<Stream<List<int>>> connectImpl(Proxy? proxy) async {
socket = await initSocket(host, port);
if (proxy == null) {
return socket;
}
return await connectToProxy(proxy);
}
Future<Socket> initSocket(Object host, int port) async {
return await Socket.connect(host, port, timeout: _options.connectTimeout);
}
void _sendConnect(Map<String, String> headers) {
const linebreak = '\r\n';
socket.write('CONNECT $_host:$_port HTTP/1.1');
socket.write(linebreak);
headers.forEach((key, value) {
socket.write('$key: $value');
socket.write(linebreak);
});
socket.write(linebreak);
}
@override
@ -409,14 +440,14 @@ class _SocketTransportConnector implements ClientTransportConnector {
@override
Future get done {
ArgumentError.checkNotNull(_socket);
return _socket.done;
ArgumentError.checkNotNull(socket);
return socket.done;
}
@override
void shutdown() {
ArgumentError.checkNotNull(_socket);
_socket.destroy();
ArgumentError.checkNotNull(socket);
socket.destroy();
}
bool _validateBadCertificate(X509Certificate certificate) {
@ -426,6 +457,52 @@ class _SocketTransportConnector implements ClientTransportConnector {
if (validator == null) return false;
return validator(certificate, authority);
}
Future<Stream<List<int>>> connectToProxy(Proxy proxy) async {
final headers = {'Host': '$_host:$_port'};
if (proxy.isAuthenticated) {
// If the proxy configuration contains user information use that
// for proxy basic authorization.
final authStr = '${proxy.username}:${proxy.password}';
final auth = base64Encode(utf8.encode(authStr));
headers[HttpHeaders.proxyAuthorizationHeader] = 'Basic $auth';
}
final completer = Completer<void>();
/// Routes the events through after connection to the proxy has been
/// established.
final intermediate = StreamController<List<int>>();
/// Route events after the successfull connect to the `intermediate`.
socket.listen(
(event) {
if (completer.isCompleted) {
intermediate.sink.add(event);
} else {
_waitForResponse(event, completer);
}
},
onDone: intermediate.close,
onError: intermediate.addError,
);
_sendConnect(headers);
await completer.future;
return intermediate.stream;
}
/// Wait for the response to the `CONNECT` request, which should be an
/// acknowledgement with a 200 status code.
void _waitForResponse(Uint8List chunk, Completer<void> completer) {
final response = ascii.decode(chunk);
print(response);
if (response.startsWith('HTTP/1.1 200')) {
completer.complete();
} else {
throw TransportException(
'Error establishing proxy connection: $response');
}
}
}
class _ShutdownException implements Exception {}

View File

@ -17,6 +17,7 @@ import 'dart:math';
import '../shared/codec_registry.dart';
import 'client_keepalive.dart';
import 'proxy.dart';
import 'transport/http2_credentials.dart';
const defaultIdleTimeout = Duration(minutes: 5);
@ -59,6 +60,7 @@ class ChannelOptions {
final BackoffStrategy backoffStrategy;
final String userAgent;
final ClientKeepAliveOptions keepAlive;
final Proxy? proxy;
const ChannelOptions({
this.credentials = const ChannelCredentials.secure(),
@ -69,5 +71,6 @@ class ChannelOptions {
this.connectionTimeout = defaultConnectionTimeOut,
this.codecRegistry,
this.keepAlive = const ClientKeepAliveOptions(),
this.proxy,
});
}

15
lib/src/client/proxy.dart Normal file
View File

@ -0,0 +1,15 @@
class Proxy {
final String host;
final int port;
final String? username;
final String? password;
const Proxy({
required this.host,
required this.port,
this.username,
this.password,
});
bool get isAuthenticated => username != null;
}

View File

@ -0,0 +1,75 @@
@TestOn('vm')
import 'dart:async';
import 'dart:io';
import 'package:grpc/grpc.dart';
import 'package:grpc/src/client/proxy.dart';
import 'package:test/test.dart';
import 'src/generated/echo.pbgrpc.dart';
void main() {
late Server server;
late EchoServiceClient fakeClient;
late ClientChannel fakeChannel;
setUp(() async {
server = Server.create(services: [FakeEchoService()]);
await server.serve(
address: 'localhost',
port: 8888,
security: ServerTlsCredentials(
certificate: File('test/data/localhost.crt').readAsBytesSync(),
privateKey: File('test/data/localhost.key').readAsBytesSync(),
),
);
final proxy = Proxy(host: 'localhost', port: 8080);
final proxyCAName = '/CN=mitmproxy/O=mitmproxy';
fakeChannel = ClientChannel(
'localhost',
port: server.port!,
options: ChannelOptions(
credentials: ChannelCredentials.secure(
certificates: File('test/data/localhost.crt').readAsBytesSync(),
authority: 'localhost',
onBadCertificate: (certificate, host) {
/// Workaround to avoid having to add the proxy to our list of
/// trusted CAs.
return certificate.issuer == proxyCAName;
},
),
proxy: proxy,
),
);
fakeClient = EchoServiceClient(fakeChannel);
});
tearDown(() async {
await fakeChannel.shutdown();
await server.shutdown();
});
test(
'Sending and receiving over secure proxy works',
() async {
final echoRequest = EchoRequest(message: 'blablablubb');
final echoResponse = await fakeClient.echo(echoRequest);
expect(echoResponse.message, 'blibliblabb');
},
skip: 'Run this test iff you have a proxy running.',
);
}
class FakeEchoService extends EchoServiceBase {
@override
Future<EchoResponse> echo(ServiceCall call, EchoRequest request) async {
expect(request.message, 'blablablubb');
return EchoResponse(message: 'blibliblabb');
}
@override
Stream<ServerStreamingEchoResponse> serverStreamingEcho(
ServiceCall call, ServerStreamingEchoRequest request) =>
throw UnimplementedError();
}

57
test/proxy_test.dart Normal file
View File

@ -0,0 +1,57 @@
@TestOn('vm')
import 'dart:async';
import 'package:grpc/grpc.dart';
import 'package:grpc/src/client/proxy.dart';
import 'package:test/test.dart';
import 'src/generated/echo.pbgrpc.dart';
void main() {
late Server server;
late EchoServiceClient fakeClient;
late ClientChannel fakeChannel;
setUp(() async {
server = Server.create(services: [FakeEchoService()]);
await server.serve(address: 'localhost', port: 8888);
final proxy = Proxy(host: 'localhost', port: 8080);
fakeChannel = ClientChannel(
'localhost',
port: server.port!,
options: ChannelOptions(
credentials: ChannelCredentials.insecure(),
proxy: proxy,
),
);
fakeClient = EchoServiceClient(fakeChannel);
});
tearDown(() async {
await fakeChannel.shutdown();
await server.shutdown();
});
test(
'Sending and receiving over proxy works',
() async {
final echoRequest = EchoRequest(message: 'blablablubb');
final echoResponse = await fakeClient.echo(echoRequest);
expect(echoResponse.message, 'blibliblabb');
},
skip: 'Run this test iff you have a proxy running.',
);
}
class FakeEchoService extends EchoServiceBase {
@override
Future<EchoResponse> echo(ServiceCall call, EchoRequest request) async =>
EchoResponse(message: 'blibliblabb');
@override
Stream<ServerStreamingEchoResponse> serverStreamingEcho(
ServiceCall call, ServerStreamingEchoRequest request) =>
throw UnimplementedError();
}

View File

@ -19,6 +19,7 @@ import 'dart:convert';
import 'package:grpc/grpc.dart';
import 'package:grpc/src/client/channel.dart' as base;
import 'package:grpc/src/client/http2_connection.dart';
import 'package:grpc/src/client/proxy.dart';
import 'package:grpc/src/shared/message.dart';
import 'package:http2/transport.dart';
import 'package:mockito/annotations.dart';
@ -79,6 +80,9 @@ class FakeChannelOptions implements ChannelOptions {
@override
ClientKeepAliveOptions get keepAlive => const ClientKeepAliveOptions();
@override
Proxy? get proxy => null;
}
class FakeChannel extends ClientChannel {

View File

@ -0,0 +1,52 @@
import 'dart:convert';
import 'package:grpc/grpc.dart';
import 'package:grpc/src/client/http2_connection.dart';
import 'package:grpc/src/client/proxy.dart';
import 'package:http2/http2.dart';
Future<void> main(List<String> args) async {
final serverPort = 5678;
final proxyPort = int.tryParse(args.first);
final proxy =
proxyPort != null ? Proxy(host: 'localhost', port: proxyPort) : null;
final port = proxyPort ?? serverPort;
final connector = SocketTransportConnector(
'localhost',
serverPort,
ChannelOptions(proxy: proxy),
);
await connector.initSocket('localhost', port);
final incoming =
proxy == null ? connector.socket : await connector.connectToProxy(proxy);
final uri = Uri.parse('http://localhost:8080');
final transport =
ClientTransportConnection.viaStreams(incoming, connector.socket);
final request = transport.makeRequest(
[
Header.ascii(':method', 'GET'),
Header.ascii(':path', uri.path),
Header.ascii(':scheme', uri.scheme),
Header.ascii(':authority', uri.host),
],
endStream: true,
);
await for (var message in request.incomingMessages) {
if (message is HeadersStreamMessage) {
for (var header in message.headers) {
final name = utf8.decode(header.name);
final value = utf8.decode(header.value);
print('Header: $name: $value');
}
} else if (message is DataStreamMessage) {
print(message.bytes);
}
}
}

View File

@ -0,0 +1,26 @@
import 'dart:convert';
import 'dart:io';
import 'package:http2/transport.dart';
void main() async {
final server = await ServerSocket.bind(InternetAddress.anyIPv4, 5678);
server.listen((client) => handleConnection(client));
}
void handleConnection(Socket client) {
final connection = ServerTransportConnection.viaSocket(client);
connection.incomingStreams.listen((stream) {
stream.incomingMessages.listen((event) {
if (event is HeadersStreamMessage) {
print(event.headers);
final headersStreamMessage = HeadersStreamMessage([
Header(utf8.encode('SomeName'), utf8.encode('SomeValue')),
]);
print('send $headersStreamMessage');
stream.outgoingMessages.add(headersStreamMessage);
}
});
});
}