237 lines
8.0 KiB
C#
237 lines
8.0 KiB
C#
// Licensed to the Apache Software Foundation(ASF) under one
|
|
// or more contributor license agreements.See the NOTICE file
|
|
// distributed with this work for additional information
|
|
// regarding copyright ownership.The ASF licenses this file
|
|
// to you under the Apache License, Version 2.0 (the
|
|
// "License"); you may not use this file except in compliance
|
|
// with the License. You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing,
|
|
// software distributed under the License is distributed on an
|
|
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
// KIND, either express or implied. See the License for the
|
|
// specific language governing permissions and limitations
|
|
// under the License.
|
|
|
|
using System;
|
|
using System.Net;
|
|
using System.Net.Security;
|
|
using System.Net.Sockets;
|
|
using System.Security.Authentication;
|
|
using System.Security.Cryptography.X509Certificates;
|
|
using System.Threading;
|
|
using System.Threading.Tasks;
|
|
|
|
namespace Thrift.Transports.Client
|
|
{
|
|
//TODO: check for correct work
|
|
|
|
// ReSharper disable once InconsistentNaming
|
|
public class TTlsSocketClientTransport : TStreamClientTransport
|
|
{
|
|
private readonly X509Certificate2 _certificate;
|
|
private readonly RemoteCertificateValidationCallback _certValidator;
|
|
private readonly IPAddress _host;
|
|
private readonly bool _isServer;
|
|
private readonly LocalCertificateSelectionCallback _localCertificateSelectionCallback;
|
|
private readonly int _port;
|
|
private readonly SslProtocols _sslProtocols;
|
|
private TcpClient _client;
|
|
private SslStream _secureStream;
|
|
private int _timeout;
|
|
|
|
public TTlsSocketClientTransport(TcpClient client, X509Certificate2 certificate, bool isServer = false,
|
|
RemoteCertificateValidationCallback certValidator = null,
|
|
LocalCertificateSelectionCallback localCertificateSelectionCallback = null,
|
|
SslProtocols sslProtocols = SslProtocols.Tls12)
|
|
{
|
|
_client = client;
|
|
_certificate = certificate;
|
|
_certValidator = certValidator;
|
|
_localCertificateSelectionCallback = localCertificateSelectionCallback;
|
|
_sslProtocols = sslProtocols;
|
|
_isServer = isServer;
|
|
|
|
if (isServer && certificate == null)
|
|
{
|
|
throw new ArgumentException("TTlsSocketClientTransport needs certificate to be used for server",
|
|
nameof(certificate));
|
|
}
|
|
|
|
if (IsOpen)
|
|
{
|
|
InputStream = client.GetStream();
|
|
OutputStream = client.GetStream();
|
|
}
|
|
}
|
|
|
|
public TTlsSocketClientTransport(IPAddress host, int port, string certificatePath,
|
|
RemoteCertificateValidationCallback certValidator = null,
|
|
LocalCertificateSelectionCallback localCertificateSelectionCallback = null,
|
|
SslProtocols sslProtocols = SslProtocols.Tls12)
|
|
: this(host, port, 0,
|
|
new X509Certificate2(certificatePath),
|
|
certValidator,
|
|
localCertificateSelectionCallback,
|
|
sslProtocols)
|
|
{
|
|
}
|
|
|
|
public TTlsSocketClientTransport(IPAddress host, int port,
|
|
X509Certificate2 certificate = null,
|
|
RemoteCertificateValidationCallback certValidator = null,
|
|
LocalCertificateSelectionCallback localCertificateSelectionCallback = null,
|
|
SslProtocols sslProtocols = SslProtocols.Tls12)
|
|
: this(host, port, 0,
|
|
certificate,
|
|
certValidator,
|
|
localCertificateSelectionCallback,
|
|
sslProtocols)
|
|
{
|
|
}
|
|
|
|
public TTlsSocketClientTransport(IPAddress host, int port, int timeout,
|
|
X509Certificate2 certificate,
|
|
RemoteCertificateValidationCallback certValidator = null,
|
|
LocalCertificateSelectionCallback localCertificateSelectionCallback = null,
|
|
SslProtocols sslProtocols = SslProtocols.Tls12)
|
|
{
|
|
_host = host;
|
|
_port = port;
|
|
_timeout = timeout;
|
|
_certificate = certificate;
|
|
_certValidator = certValidator;
|
|
_localCertificateSelectionCallback = localCertificateSelectionCallback;
|
|
_sslProtocols = sslProtocols;
|
|
|
|
InitSocket();
|
|
}
|
|
|
|
public int Timeout
|
|
{
|
|
set { _client.ReceiveTimeout = _client.SendTimeout = _timeout = value; }
|
|
}
|
|
|
|
public TcpClient TcpClient => _client;
|
|
|
|
public IPAddress Host => _host;
|
|
|
|
public int Port => _port;
|
|
|
|
public override bool IsOpen
|
|
{
|
|
get
|
|
{
|
|
if (_client == null)
|
|
{
|
|
return false;
|
|
}
|
|
|
|
return _client.Connected;
|
|
}
|
|
}
|
|
|
|
private void InitSocket()
|
|
{
|
|
_client = new TcpClient();
|
|
_client.ReceiveTimeout = _client.SendTimeout = _timeout;
|
|
_client.Client.NoDelay = true;
|
|
}
|
|
|
|
private bool DefaultCertificateValidator(object sender, X509Certificate certificate, X509Chain chain,
|
|
SslPolicyErrors sslValidationErrors)
|
|
{
|
|
return sslValidationErrors == SslPolicyErrors.None;
|
|
}
|
|
|
|
public override async Task OpenAsync(CancellationToken cancellationToken)
|
|
{
|
|
if (IsOpen)
|
|
{
|
|
throw new TTransportException(TTransportException.ExceptionType.AlreadyOpen, "Socket already connected");
|
|
}
|
|
|
|
if (_host == null)
|
|
{
|
|
throw new TTransportException(TTransportException.ExceptionType.NotOpen, "Cannot open null host");
|
|
}
|
|
|
|
if (_port <= 0)
|
|
{
|
|
throw new TTransportException(TTransportException.ExceptionType.NotOpen, "Cannot open without port");
|
|
}
|
|
|
|
if (_client == null)
|
|
{
|
|
InitSocket();
|
|
}
|
|
|
|
if (_client != null)
|
|
{
|
|
await _client.ConnectAsync(_host, _port);
|
|
await SetupTlsAsync();
|
|
}
|
|
}
|
|
|
|
public async Task SetupTlsAsync()
|
|
{
|
|
var validator = _certValidator ?? DefaultCertificateValidator;
|
|
|
|
if (_localCertificateSelectionCallback != null)
|
|
{
|
|
_secureStream = new SslStream(_client.GetStream(), false, validator, _localCertificateSelectionCallback);
|
|
}
|
|
else
|
|
{
|
|
_secureStream = new SslStream(_client.GetStream(), false, validator);
|
|
}
|
|
|
|
try
|
|
{
|
|
if (_isServer)
|
|
{
|
|
// Server authentication
|
|
await
|
|
_secureStream.AuthenticateAsServerAsync(_certificate, _certValidator != null, _sslProtocols,
|
|
true);
|
|
}
|
|
else
|
|
{
|
|
// Client authentication
|
|
var certs = _certificate != null
|
|
? new X509CertificateCollection {_certificate}
|
|
: new X509CertificateCollection();
|
|
|
|
var targetHost = _host.ToString();
|
|
await _secureStream.AuthenticateAsClientAsync(targetHost, certs, _sslProtocols, true);
|
|
}
|
|
}
|
|
catch (Exception)
|
|
{
|
|
Close();
|
|
throw;
|
|
}
|
|
|
|
InputStream = _secureStream;
|
|
OutputStream = _secureStream;
|
|
}
|
|
|
|
public override void Close()
|
|
{
|
|
base.Close();
|
|
if (_client != null)
|
|
{
|
|
_client.Dispose();
|
|
_client = null;
|
|
}
|
|
|
|
if (_secureStream != null)
|
|
{
|
|
_secureStream.Dispose();
|
|
_secureStream = null;
|
|
}
|
|
}
|
|
}
|
|
} |