diff --git a/Cargo.lock b/Cargo.lock index f03684c2..93b5d328 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -61,9 +61,9 @@ dependencies = [ [[package]] name = "anstream" -version = "0.6.7" +version = "0.6.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4cd2405b3ac1faab2990b74d728624cd9fd115651fcecc7c2d8daf01376275ba" +checksum = "628a8f9bd1e24b4e0db2b4bc2d000b001e7dd032d54afa60a68836aeec5aa54a" dependencies = [ "anstyle", "anstyle-parse", @@ -253,9 +253,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.4.1" +version = "2.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "327762f6e5a765692301e5bb513e0d9fef63be86bbc14528052b1cd3e6f03e07" +checksum = "ed570934406eb16438a4e976b1b4500774099c13b8cb96eec99f620f05090ddf" [[package]] name = "block-buffer" @@ -348,9 +348,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.4.17" +version = "4.4.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "80932e03c33999b9235edb8655bc9df3204adc9887c2f95b50cb1deb9fd54253" +checksum = "1e578d6ec4194633722ccf9544794b71b1385c3c027efe0c55db226fc880865c" dependencies = [ "clap_builder", "clap_derive", @@ -358,9 +358,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.4.17" +version = "4.4.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d6c0db58c659eef1c73e444d298c27322a1b52f6927d2ad470c0c0f96fa7b8fa" +checksum = "4df4df40ec50c46000231c914968278b1eb05098cf8f1b3a518a95030e71d1c7" dependencies = [ "anstream", "anstyle", @@ -574,10 +574,12 @@ dependencies = [ "prometheus", "prost-wkt-types", "rand", + "regex", "reqwest", "rocksdb", "serde", "serde_json", + "serde_regex", "serde_yaml", "sha2", "sysinfo", @@ -1373,9 +1375,9 @@ dependencies = [ [[package]] name = "linux-raw-sys" -version = "0.4.12" +version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4cd1a83af159aa67994778be9070f0ae1bd732942279cabb14f86f986a21456" +checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c" [[package]] name = "local-ip-address" @@ -1626,7 +1628,7 @@ version = "0.10.62" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8cde4d2d9200ad5909f8dac647e29482e07c3a35de8a13fce7c9c7747ad9f671" dependencies = [ - "bitflags 2.4.1", + "bitflags 2.4.2", "cfg-if", "foreign-types", "libc", @@ -2264,7 +2266,7 @@ version = "0.38.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "322394588aaf33c24007e8bb3238ee3e4c5c09c084ab32bc73890b99ff326bca" dependencies = [ - "bitflags 2.4.1", + "bitflags 2.4.2", "errno", "libc", "linux-raw-sys", @@ -2421,6 +2423,16 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_regex" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8136f1a4ea815d7eac4101cfd0b16dc0cb5e1fe1b8609dfd728058656b7badf" +dependencies = [ + "regex", + "serde", +] + [[package]] name = "serde_urlencoded" version = "0.7.1" diff --git a/Cargo.toml b/Cargo.toml index ab927114..3df44974 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,6 +32,7 @@ tracing-appender = "0.2.3" tracing-opentelemetry = "0.18.0" humantime = "2.1.0" serde = { version = "1.0", features = ["derive"] } +serde_regex = "1.1.0" serde_yaml = "0.9" serde_json = "1.0" validator = { version = "0.16", features = ["derive"] } @@ -82,3 +83,4 @@ hyper-util = { version = "0.1.2", features = ["tokio"] } tokio-rustls = "0.25" hyper-rustls = "0.25" http-body-util = "0.1.0" +regex = "1.10.2" diff --git a/src/bin/dfdaemon/main.rs b/src/bin/dfdaemon/main.rs index 038da4a1..da49e3bb 100644 --- a/src/bin/dfdaemon/main.rs +++ b/src/bin/dfdaemon/main.rs @@ -25,6 +25,7 @@ use dragonfly_client::grpc::{ manager::ManagerClient, scheduler::SchedulerClient, }; use dragonfly_client::metrics::Metrics; +use dragonfly_client::proxy::Proxy; use dragonfly_client::shutdown; use dragonfly_client::storage::Storage; use dragonfly_client::task::Task; @@ -183,7 +184,16 @@ async fn main() -> Result<(), anyhow::Error> { // Initialize metrics server. let metrics = Metrics::new( - SocketAddr::new(config.metrics.ip.unwrap(), config.metrics.port), + SocketAddr::new( + config.metrics.server.ip.unwrap(), + config.metrics.server.port, + ), + shutdown.clone(), + shutdown_complete_tx.clone(), + ); + + let proxy = Proxy::new( + SocketAddr::new(config.proxy.server.ip.unwrap(), config.proxy.server.port), shutdown.clone(), shutdown_complete_tx.clone(), ); @@ -248,6 +258,10 @@ async fn main() -> Result<(), anyhow::Error> { info!("metrics server exited"); }, + _ = tokio::spawn(async move { proxy.run().await }) => { + info!("proxy server exited"); + }, + _ = tokio::spawn(async move { manager_announcer.run().await }) => { info!("announcer manager exited"); }, diff --git a/src/config/dfdaemon.rs b/src/config/dfdaemon.rs index ea4acb9e..e18b9b2e 100644 --- a/src/config/dfdaemon.rs +++ b/src/config/dfdaemon.rs @@ -15,6 +15,7 @@ */ use crate::Result; use local_ip_address::{local_ip, local_ipv6}; +use regex::Regex; use serde::Deserialize; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use std::path::PathBuf; @@ -82,7 +83,7 @@ fn default_upload_rate_limit() -> u64 { // default_metrics_server_port is the default port of the metrics server. #[inline] fn default_metrics_server_port() -> u16 { - 4001 + 4002 } // default_download_rate_limit is the default rate limit of the download speed in bps(bytes per second). @@ -170,6 +171,12 @@ fn default_gc_policy_dist_low_threshold_percent() -> u8 { 60 } +// default_proxy_server_port is the default port of the proxy server. +#[inline] +fn default_proxy_server_port() -> u16 { + 4001 +} + // Host is the host configuration for dfdaemon. #[derive(Debug, Clone, Validate, Deserialize)] #[serde(default, rename_all = "camelCase")] @@ -536,12 +543,64 @@ impl Default for GC { } } +// ProxyServer is the proxy server configuration for dfdaemon. +#[derive(Debug, Clone, Validate, Deserialize)] +#[serde(default, rename_all = "camelCase")] +pub struct ProxyServer { + // ip is the listen ip of the proxy server. + pub ip: Option, + + // port is the port to the proxy server. + #[serde(default = "default_proxy_server_port")] + pub port: u16, +} + +// ProxyServer implements Default. +impl Default for ProxyServer { + fn default() -> Self { + Self { + ip: None, + port: default_proxy_server_port(), + } + } +} + +// Rule is the proxy rule. +#[derive(Debug, Clone, Validate, Deserialize)] +#[serde(default, rename_all = "camelCase")] +pub struct Rule { + // regex is the regex of the request url. + #[serde(with = "serde_regex")] + pub regex: Regex, + + // use_tls indicates whether use tls for the proxy backend. + #[serde(rename = "useTLS")] + pub use_tls: bool, + + // redirect is the redirect url. + pub redirect: Option, +} + +// Rule implements Default. +impl Default for Rule { + fn default() -> Self { + Self { + regex: Regex::new(r".*").unwrap(), + use_tls: false, + redirect: None, + } + } +} + // Proxy is the proxy configuration for dfdaemon. #[derive(Debug, Clone, Default, Validate, Deserialize)] #[serde(default, rename_all = "camelCase")] pub struct Proxy { - // enable indicates whether enable proxy. - pub enable: bool, + // server is the proxy server configuration for dfdaemon. + pub server: ProxyServer, + + // rules is the proxy rules. + pub rules: Option>, } // Security is the security configuration for dfdaemon. @@ -560,10 +619,10 @@ pub struct Network { pub enable_ipv6: bool, } -// Metrics is the metrics configuration for dfdaemon. +// MetricsServer is the metrics server configuration for dfdaemon. #[derive(Debug, Clone, Validate, Deserialize)] #[serde(default, rename_all = "camelCase")] -pub struct Metrics { +pub struct MetricsServer { // ip is the listen ip of the metrics server. pub ip: Option, @@ -572,16 +631,24 @@ pub struct Metrics { pub port: u16, } -// Metrics implements Default. -impl Default for Metrics { +// MetricsServer implements Default. +impl Default for MetricsServer { fn default() -> Self { - Metrics { + Self { ip: None, port: default_metrics_server_port(), } } } +// Metrics is the metrics configuration for dfdaemon. +#[derive(Debug, Clone, Default, Validate, Deserialize)] +#[serde(default, rename_all = "camelCase")] +pub struct Metrics { + // server is the metrics server configuration for dfdaemon. + pub server: MetricsServer, +} + // Tracing is the tracing configuration for dfdaemon. #[derive(Debug, Clone, Default, Validate, Deserialize)] #[serde(default, rename_all = "camelCase")] @@ -693,8 +760,17 @@ impl Config { } // Convert metrics server listen ip. - if self.metrics.ip.is_none() { - self.metrics.ip = if self.network.enable_ipv6 { + if self.metrics.server.ip.is_none() { + self.metrics.server.ip = if self.network.enable_ipv6 { + Some(Ipv6Addr::UNSPECIFIED.into()) + } else { + Some(Ipv4Addr::UNSPECIFIED.into()) + } + } + + // Convert proxy server listen ip. + if self.proxy.server.ip.is_none() { + self.proxy.server.ip = if self.network.enable_ipv6 { Some(Ipv6Addr::UNSPECIFIED.into()) } else { Some(Ipv4Addr::UNSPECIFIED.into()) diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs index 65edc025..48fcbe4c 100644 --- a/src/proxy/mod.rs +++ b/src/proxy/mod.rs @@ -28,7 +28,7 @@ use tokio::net::TcpListener; use tokio::net::TcpStream; use tokio::pin; use tokio::sync::mpsc; -use tracing::{info, instrument, Span}; +use tracing::{error, info, instrument, Span}; // Proxy is the proxy server. #[derive(Debug)] @@ -70,35 +70,29 @@ impl Proxy { // Clone the shutdown channel. let mut shutdown = self.shutdown.clone(); + // Wait for a client connection. let (tcp, remote_address) = listener.accept().await?; + // Spawn a task to handle the connection. let io = TokioIo::new(tcp); info!("accepted connection from {}", remote_address); tokio::task::spawn(async move { - // Pin the connection object so we can use tokio::select! below. let conn = http1::Builder::new() .preserve_header_case(true) .title_case_headers(true) .serve_connection(io, service_fn(handler)); pin!(conn); - // Iterate the timeouts. Use tokio::select! to wait on the - // result of polling the connection itself, - // and also on tokio::time::sleep for the current timeout duration. tokio::select! { res = conn.as_mut() => { - // Polling the connection returned a result. - // In this case print either the successful or error result for the connection - // and break out of the loop. match res { Ok(()) => println!("after polling conn, no error"), Err(e) => println!("error serving connection: {:?}", e), }; } _ = shutdown.recv() => { - // tokio::time::sleep returned a result. - // Call graceful_shutdown on the connection and continue the loop. + info!("shutdown signal received"); conn.as_mut().graceful_shutdown(); } } @@ -120,15 +114,15 @@ pub async fn handler( // Handle CONNECT request. if Method::CONNECT == request.method() { - return handle_https(request).await; + return https_handler(request).await; } - return handle_http(request).await; + return http_handler(request).await; } -// handle_http handles the http request. +// http_handler handles the http request. #[instrument(skip_all)] -pub async fn handle_http( +pub async fn http_handler( request: Request, ) -> Result>, hyper::Error> { info!("handle http request: {:?}", request); @@ -138,9 +132,9 @@ pub async fn handle_http( Ok(resp) } -// handle_https handles the https request. +// https_handler handles the https request. #[instrument(skip_all)] -pub async fn handle_https( +pub async fn https_handler( request: Request, ) -> Result>, hyper::Error> { if let Some(addr) = host_addr(request.uri()) { @@ -148,24 +142,25 @@ pub async fn handle_https( match hyper::upgrade::on(request).await { Ok(upgraded) => { if let Err(e) = tunnel(upgraded, addr).await { - eprintln!("server io error: {}", e); + error!("server io error: {}", e); }; } - Err(e) => eprintln!("upgrade error: {}", e), + Err(e) => error!("upgrade error: {}", e), } }); Ok(Response::new(empty())) } else { - eprintln!("CONNECT host is not socket addr: {:?}", request.uri()); - let mut resp = Response::new(full("CONNECT must be to a socket address")); - *resp.status_mut() = http::StatusCode::BAD_REQUEST; + error!("CONNECT host is not socket addr: {:?}", request.uri()); + let mut response = Response::new(full("CONNECT must be to a socket address")); + *response.status_mut() = http::StatusCode::BAD_REQUEST; - Ok(resp) + Ok(response) } } // empty returns an empty body. +#[instrument(skip_all)] fn empty() -> BoxBody { Empty::::new() .map_err(|never| match never {}) @@ -173,6 +168,7 @@ fn empty() -> BoxBody { } // full returns a body with the given chunk. +#[instrument(skip_all)] fn full>(chunk: T) -> BoxBody { Full::new(chunk.into()) .map_err(|never| match never {}) @@ -186,6 +182,7 @@ fn host_addr(uri: &hyper::Uri) -> Option { } // tunnel proxies the data between the client and the remote server. +#[instrument(skip_all)] async fn tunnel(upgraded: Upgraded, addr: String) -> std::io::Result<()> { // Connect to remote server let mut server = TcpStream::connect(addr).await?;