feat: add proxy config to dfdaemon (#216)

Signed-off-by: Gaius <gaius.qi@gmail.com>
This commit is contained in:
Gaius 2024-01-17 15:19:50 +08:00 committed by GitHub
parent 0e4035707d
commit 14633c2653
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 146 additions and 45 deletions

36
Cargo.lock generated
View File

@ -61,9 +61,9 @@ dependencies = [
[[package]]
name = "anstream"
version = "0.6.7"
version = "0.6.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4cd2405b3ac1faab2990b74d728624cd9fd115651fcecc7c2d8daf01376275ba"
checksum = "628a8f9bd1e24b4e0db2b4bc2d000b001e7dd032d54afa60a68836aeec5aa54a"
dependencies = [
"anstyle",
"anstyle-parse",
@ -253,9 +253,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
[[package]]
name = "bitflags"
version = "2.4.1"
version = "2.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "327762f6e5a765692301e5bb513e0d9fef63be86bbc14528052b1cd3e6f03e07"
checksum = "ed570934406eb16438a4e976b1b4500774099c13b8cb96eec99f620f05090ddf"
[[package]]
name = "block-buffer"
@ -348,9 +348,9 @@ dependencies = [
[[package]]
name = "clap"
version = "4.4.17"
version = "4.4.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "80932e03c33999b9235edb8655bc9df3204adc9887c2f95b50cb1deb9fd54253"
checksum = "1e578d6ec4194633722ccf9544794b71b1385c3c027efe0c55db226fc880865c"
dependencies = [
"clap_builder",
"clap_derive",
@ -358,9 +358,9 @@ dependencies = [
[[package]]
name = "clap_builder"
version = "4.4.17"
version = "4.4.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d6c0db58c659eef1c73e444d298c27322a1b52f6927d2ad470c0c0f96fa7b8fa"
checksum = "4df4df40ec50c46000231c914968278b1eb05098cf8f1b3a518a95030e71d1c7"
dependencies = [
"anstream",
"anstyle",
@ -574,10 +574,12 @@ dependencies = [
"prometheus",
"prost-wkt-types",
"rand",
"regex",
"reqwest",
"rocksdb",
"serde",
"serde_json",
"serde_regex",
"serde_yaml",
"sha2",
"sysinfo",
@ -1373,9 +1375,9 @@ dependencies = [
[[package]]
name = "linux-raw-sys"
version = "0.4.12"
version = "0.4.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c4cd1a83af159aa67994778be9070f0ae1bd732942279cabb14f86f986a21456"
checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c"
[[package]]
name = "local-ip-address"
@ -1626,7 +1628,7 @@ version = "0.10.62"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8cde4d2d9200ad5909f8dac647e29482e07c3a35de8a13fce7c9c7747ad9f671"
dependencies = [
"bitflags 2.4.1",
"bitflags 2.4.2",
"cfg-if",
"foreign-types",
"libc",
@ -2264,7 +2266,7 @@ version = "0.38.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "322394588aaf33c24007e8bb3238ee3e4c5c09c084ab32bc73890b99ff326bca"
dependencies = [
"bitflags 2.4.1",
"bitflags 2.4.2",
"errno",
"libc",
"linux-raw-sys",
@ -2421,6 +2423,16 @@ dependencies = [
"serde",
]
[[package]]
name = "serde_regex"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8136f1a4ea815d7eac4101cfd0b16dc0cb5e1fe1b8609dfd728058656b7badf"
dependencies = [
"regex",
"serde",
]
[[package]]
name = "serde_urlencoded"
version = "0.7.1"

View File

@ -32,6 +32,7 @@ tracing-appender = "0.2.3"
tracing-opentelemetry = "0.18.0"
humantime = "2.1.0"
serde = { version = "1.0", features = ["derive"] }
serde_regex = "1.1.0"
serde_yaml = "0.9"
serde_json = "1.0"
validator = { version = "0.16", features = ["derive"] }
@ -82,3 +83,4 @@ hyper-util = { version = "0.1.2", features = ["tokio"] }
tokio-rustls = "0.25"
hyper-rustls = "0.25"
http-body-util = "0.1.0"
regex = "1.10.2"

View File

@ -25,6 +25,7 @@ use dragonfly_client::grpc::{
manager::ManagerClient, scheduler::SchedulerClient,
};
use dragonfly_client::metrics::Metrics;
use dragonfly_client::proxy::Proxy;
use dragonfly_client::shutdown;
use dragonfly_client::storage::Storage;
use dragonfly_client::task::Task;
@ -183,7 +184,16 @@ async fn main() -> Result<(), anyhow::Error> {
// Initialize metrics server.
let metrics = Metrics::new(
SocketAddr::new(config.metrics.ip.unwrap(), config.metrics.port),
SocketAddr::new(
config.metrics.server.ip.unwrap(),
config.metrics.server.port,
),
shutdown.clone(),
shutdown_complete_tx.clone(),
);
let proxy = Proxy::new(
SocketAddr::new(config.proxy.server.ip.unwrap(), config.proxy.server.port),
shutdown.clone(),
shutdown_complete_tx.clone(),
);
@ -248,6 +258,10 @@ async fn main() -> Result<(), anyhow::Error> {
info!("metrics server exited");
},
_ = tokio::spawn(async move { proxy.run().await }) => {
info!("proxy server exited");
},
_ = tokio::spawn(async move { manager_announcer.run().await }) => {
info!("announcer manager exited");
},

View File

@ -15,6 +15,7 @@
*/
use crate::Result;
use local_ip_address::{local_ip, local_ipv6};
use regex::Regex;
use serde::Deserialize;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use std::path::PathBuf;
@ -82,7 +83,7 @@ fn default_upload_rate_limit() -> u64 {
// default_metrics_server_port is the default port of the metrics server.
#[inline]
fn default_metrics_server_port() -> u16 {
4001
4002
}
// default_download_rate_limit is the default rate limit of the download speed in bps(bytes per second).
@ -170,6 +171,12 @@ fn default_gc_policy_dist_low_threshold_percent() -> u8 {
60
}
// default_proxy_server_port is the default port of the proxy server.
#[inline]
fn default_proxy_server_port() -> u16 {
4001
}
// Host is the host configuration for dfdaemon.
#[derive(Debug, Clone, Validate, Deserialize)]
#[serde(default, rename_all = "camelCase")]
@ -536,12 +543,64 @@ impl Default for GC {
}
}
// ProxyServer is the proxy server configuration for dfdaemon.
#[derive(Debug, Clone, Validate, Deserialize)]
#[serde(default, rename_all = "camelCase")]
pub struct ProxyServer {
// ip is the listen ip of the proxy server.
pub ip: Option<IpAddr>,
// port is the port to the proxy server.
#[serde(default = "default_proxy_server_port")]
pub port: u16,
}
// ProxyServer implements Default.
impl Default for ProxyServer {
fn default() -> Self {
Self {
ip: None,
port: default_proxy_server_port(),
}
}
}
// Rule is the proxy rule.
#[derive(Debug, Clone, Validate, Deserialize)]
#[serde(default, rename_all = "camelCase")]
pub struct Rule {
// regex is the regex of the request url.
#[serde(with = "serde_regex")]
pub regex: Regex,
// use_tls indicates whether use tls for the proxy backend.
#[serde(rename = "useTLS")]
pub use_tls: bool,
// redirect is the redirect url.
pub redirect: Option<String>,
}
// Rule implements Default.
impl Default for Rule {
fn default() -> Self {
Self {
regex: Regex::new(r".*").unwrap(),
use_tls: false,
redirect: None,
}
}
}
// Proxy is the proxy configuration for dfdaemon.
#[derive(Debug, Clone, Default, Validate, Deserialize)]
#[serde(default, rename_all = "camelCase")]
pub struct Proxy {
// enable indicates whether enable proxy.
pub enable: bool,
// server is the proxy server configuration for dfdaemon.
pub server: ProxyServer,
// rules is the proxy rules.
pub rules: Option<Vec<Rule>>,
}
// Security is the security configuration for dfdaemon.
@ -560,10 +619,10 @@ pub struct Network {
pub enable_ipv6: bool,
}
// Metrics is the metrics configuration for dfdaemon.
// MetricsServer is the metrics server configuration for dfdaemon.
#[derive(Debug, Clone, Validate, Deserialize)]
#[serde(default, rename_all = "camelCase")]
pub struct Metrics {
pub struct MetricsServer {
// ip is the listen ip of the metrics server.
pub ip: Option<IpAddr>,
@ -572,16 +631,24 @@ pub struct Metrics {
pub port: u16,
}
// Metrics implements Default.
impl Default for Metrics {
// MetricsServer implements Default.
impl Default for MetricsServer {
fn default() -> Self {
Metrics {
Self {
ip: None,
port: default_metrics_server_port(),
}
}
}
// Metrics is the metrics configuration for dfdaemon.
#[derive(Debug, Clone, Default, Validate, Deserialize)]
#[serde(default, rename_all = "camelCase")]
pub struct Metrics {
// server is the metrics server configuration for dfdaemon.
pub server: MetricsServer,
}
// Tracing is the tracing configuration for dfdaemon.
#[derive(Debug, Clone, Default, Validate, Deserialize)]
#[serde(default, rename_all = "camelCase")]
@ -693,8 +760,17 @@ impl Config {
}
// Convert metrics server listen ip.
if self.metrics.ip.is_none() {
self.metrics.ip = if self.network.enable_ipv6 {
if self.metrics.server.ip.is_none() {
self.metrics.server.ip = if self.network.enable_ipv6 {
Some(Ipv6Addr::UNSPECIFIED.into())
} else {
Some(Ipv4Addr::UNSPECIFIED.into())
}
}
// Convert proxy server listen ip.
if self.proxy.server.ip.is_none() {
self.proxy.server.ip = if self.network.enable_ipv6 {
Some(Ipv6Addr::UNSPECIFIED.into())
} else {
Some(Ipv4Addr::UNSPECIFIED.into())

View File

@ -28,7 +28,7 @@ use tokio::net::TcpListener;
use tokio::net::TcpStream;
use tokio::pin;
use tokio::sync::mpsc;
use tracing::{info, instrument, Span};
use tracing::{error, info, instrument, Span};
// Proxy is the proxy server.
#[derive(Debug)]
@ -70,35 +70,29 @@ impl Proxy {
// Clone the shutdown channel.
let mut shutdown = self.shutdown.clone();
// Wait for a client connection.
let (tcp, remote_address) = listener.accept().await?;
// Spawn a task to handle the connection.
let io = TokioIo::new(tcp);
info!("accepted connection from {}", remote_address);
tokio::task::spawn(async move {
// Pin the connection object so we can use tokio::select! below.
let conn = http1::Builder::new()
.preserve_header_case(true)
.title_case_headers(true)
.serve_connection(io, service_fn(handler));
pin!(conn);
// Iterate the timeouts. Use tokio::select! to wait on the
// result of polling the connection itself,
// and also on tokio::time::sleep for the current timeout duration.
tokio::select! {
res = conn.as_mut() => {
// Polling the connection returned a result.
// In this case print either the successful or error result for the connection
// and break out of the loop.
match res {
Ok(()) => println!("after polling conn, no error"),
Err(e) => println!("error serving connection: {:?}", e),
};
}
_ = shutdown.recv() => {
// tokio::time::sleep returned a result.
// Call graceful_shutdown on the connection and continue the loop.
info!("shutdown signal received");
conn.as_mut().graceful_shutdown();
}
}
@ -120,15 +114,15 @@ pub async fn handler(
// Handle CONNECT request.
if Method::CONNECT == request.method() {
return handle_https(request).await;
return https_handler(request).await;
}
return handle_http(request).await;
return http_handler(request).await;
}
// handle_http handles the http request.
// http_handler handles the http request.
#[instrument(skip_all)]
pub async fn handle_http(
pub async fn http_handler(
request: Request<hyper::body::Incoming>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
info!("handle http request: {:?}", request);
@ -138,9 +132,9 @@ pub async fn handle_http(
Ok(resp)
}
// handle_https handles the https request.
// https_handler handles the https request.
#[instrument(skip_all)]
pub async fn handle_https(
pub async fn https_handler(
request: Request<hyper::body::Incoming>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
if let Some(addr) = host_addr(request.uri()) {
@ -148,24 +142,25 @@ pub async fn handle_https(
match hyper::upgrade::on(request).await {
Ok(upgraded) => {
if let Err(e) = tunnel(upgraded, addr).await {
eprintln!("server io error: {}", e);
error!("server io error: {}", e);
};
}
Err(e) => eprintln!("upgrade error: {}", e),
Err(e) => error!("upgrade error: {}", e),
}
});
Ok(Response::new(empty()))
} else {
eprintln!("CONNECT host is not socket addr: {:?}", request.uri());
let mut resp = Response::new(full("CONNECT must be to a socket address"));
*resp.status_mut() = http::StatusCode::BAD_REQUEST;
error!("CONNECT host is not socket addr: {:?}", request.uri());
let mut response = Response::new(full("CONNECT must be to a socket address"));
*response.status_mut() = http::StatusCode::BAD_REQUEST;
Ok(resp)
Ok(response)
}
}
// empty returns an empty body.
#[instrument(skip_all)]
fn empty() -> BoxBody<Bytes, hyper::Error> {
Empty::<Bytes>::new()
.map_err(|never| match never {})
@ -173,6 +168,7 @@ fn empty() -> BoxBody<Bytes, hyper::Error> {
}
// full returns a body with the given chunk.
#[instrument(skip_all)]
fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
Full::new(chunk.into())
.map_err(|never| match never {})
@ -186,6 +182,7 @@ fn host_addr(uri: &hyper::Uri) -> Option<String> {
}
// tunnel proxies the data between the client and the remote server.
#[instrument(skip_all)]
async fn tunnel(upgraded: Upgraded, addr: String) -> std::io::Result<()> {
// Connect to remote server
let mut server = TcpStream::connect(addr).await?;