refactor(dragonfly-client-util): request and pool logic (#1370)

Signed-off-by: Gaius <gaius.qi@gmail.com>
This commit is contained in:
Gaius 2025-09-23 18:18:16 +08:00 committed by GitHub
parent 826aeabf08
commit 02eddfbd10
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 460 additions and 320 deletions

2
Cargo.lock generated
View File

@ -1227,6 +1227,8 @@ dependencies = [
"protobuf 3.7.2",
"rcgen",
"reqwest",
"reqwest-middleware",
"reqwest-tracing",
"rustix 1.1.2",
"rustls 0.22.4",
"rustls-pemfile 2.2.0",

View File

@ -116,6 +116,7 @@ dashmap = "6.1.0"
hostname = "^0.4"
tonic-health = "0.12.3"
hashring = "0.3.6"
reqwest-tracing = "0.5"
[profile.release]
opt-level = 3

View File

@ -25,16 +25,16 @@ tracing.workspace = true
opendal.workspace = true
percent-encoding.workspace = true
futures.workspace = true
reqwest-tracing.workspace = true
reqwest-retry = "0.7"
reqwest-tracing = "0.5"
libloading = "0.8.8"
[dev-dependencies]
tempfile.workspace = true
wiremock = "0.6.4"
rustls-pki-types.workspace = true
rustls-pemfile.workspace = true
hyper.workspace = true
hyper-util.workspace = true
tokio-rustls.workspace = true
rcgen.workspace = true
wiremock = "0.6.4"

View File

@ -41,6 +41,8 @@ tonic-health.workspace = true
local-ip-address.workspace = true
hostname.workspace = true
dashmap.workspace = true
reqwest-tracing.workspace = true
reqwest-middleware.workspace = true
rustix = { version = "1.1.2", features = ["fs"] }
base64 = "0.22.1"
pnet = "0.35.0"

View File

@ -14,42 +14,60 @@
* limitations under the License.
*/
use hashring::HashRing;
use std::fmt;
use std::hash::Hash;
use std::net::SocketAddr;
use hashring::HashRing;
/// A virtual node (vnode) on the consistent hash ring.
/// Each physical node (SocketAddr) is represented by multiple vnodes to better
/// balance key distribution across the ring.
#[derive(Debug, Copy, Clone, Hash, PartialEq)]
pub struct VNode {
/// The replica index of this vnode for its physical node (0..replica_count-1).
id: usize,
/// The physical node address this vnode represents.
addr: SocketAddr,
}
/// VNode implements virtual node for consistent hashing.
impl VNode {
/// Creates a new virtual node with the given replica id and physical address.
fn new(id: usize, addr: SocketAddr) -> Self {
VNode { id, addr }
}
}
/// VNode implements Display trait to format.
impl fmt::Display for VNode {
/// Formats the virtual node as "address|id" as the key for the hash ring.
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}|{}", self.addr, self.id)
}
}
/// VNode implements methods for hash ring operations.
impl VNode {
/// Returns a reference to the physical node address associated with this vnode.
pub fn addr(&self) -> &SocketAddr {
&self.addr
}
}
/// A consistent hash ring that uses virtual nodes (vnodes) to improve key distribution.
/// When a physical node is added, replica_count vnodes are inserted into the ring.
pub struct VNodeHashRing {
/// Number of vnodes to create per physical node.
replica_count: usize,
/// The underlying hash ring that stores vnodes.
ring: HashRing<VNode>,
}
/// VNodeHashRing implements methods for managing the hash ring.
impl VNodeHashRing {
/// Creates a new vnode-based hash ring.
pub fn new(replica_count: usize) -> Self {
VNodeHashRing {
replica_count,

View File

@ -22,6 +22,5 @@ pub mod id_generator;
pub mod net;
pub mod pool;
pub mod request;
pub mod selector;
pub mod shutdown;
pub mod tls;

View File

@ -34,14 +34,18 @@ pub struct RequestGuard {
active_requests: Arc<AtomicUsize>,
}
/// RequestGuard implements the request guard pattern.
impl RequestGuard {
/// Create a new request guard.
fn new(active_requests: Arc<AtomicUsize>) -> Self {
active_requests.fetch_add(1, Ordering::SeqCst);
Self { active_requests }
}
}
/// RequestGuard decrements the active request count when dropped.
impl Drop for RequestGuard {
/// Decrement the active request count.
fn drop(&mut self) {
self.active_requests.fetch_sub(1, Ordering::SeqCst);
}
@ -60,7 +64,9 @@ pub struct Entry<T> {
actived_at: Arc<std::sync::Mutex<Instant>>,
}
/// Entry methods for managing client state.
impl<T> Entry<T> {
/// Create a new entry with the given client.
fn new(client: T) -> Self {
Self {
client,
@ -75,8 +81,8 @@ impl<T> Entry<T> {
}
/// Update the last active time.
fn update_actived_at(&self) {
*self.actived_at.lock().unwrap() = Instant::now();
fn set_actived_at(&self, actived_at: Instant) {
*self.actived_at.lock().unwrap() = actived_at;
}
/// Check if the client has active requests.
@ -96,6 +102,7 @@ impl<T> Entry<T> {
pub trait Factory<K, T> {
type Error;
/// Create a new client for the given key.
async fn make_client(&self, key: &K) -> Result<T, Self::Error>;
}
@ -119,6 +126,7 @@ pub struct Pool<K, T, F> {
cleanup_at: Arc<Mutex<Instant>>,
}
/// Builder for creating a client pool.
pub struct Builder<K, T, F> {
factory: F,
capacity: usize,
@ -126,6 +134,7 @@ pub struct Builder<K, T, F> {
_phantom: PhantomData<(K, T)>,
}
/// Builder methods for configuring and building the pool.
impl<K, T, F> Builder<K, T, F>
where
K: Clone + Eq + Hash + std::fmt::Display,
@ -166,6 +175,14 @@ where
}
}
/// Generic client pool for managing reusable client instances with automatic cleanup.
///
/// This client pool provides connection reuse, automatic cleanup, and capacity management
/// capabilities, primarily used for:
/// - Connection Reuse: Reuse existing client instances to avoid repeated creation overhead.
/// - Automatic Cleanup: Periodically remove idle clients that exceed timeout thresholds.
/// - Capacity Control: Limit maximum client count to prevent resource exhaustion.
/// - Thread Safety: Use async locks and atomic operations for high-concurrency access.
impl<K, T, F> Pool<K, T, F>
where
K: Clone + Eq + Hash + std::fmt::Display,
@ -182,7 +199,7 @@ where
let clients = self.clients.lock().await;
if let Some(entry) = clients.get(key) {
debug!("reusing client: {}", key);
entry.update_actived_at();
entry.set_actived_at(Instant::now());
return Ok(entry.clone());
}
}
@ -190,11 +207,10 @@ where
// Create new client.
debug!("creating client: {}", key);
let client = self.factory.make_client(key).await?;
let mut clients = self.clients.lock().await;
let entry = clients.entry(key.clone()).or_insert(Entry::new(client));
entry.set_actived_at(Instant::now());
entry.update_actived_at();
Ok(entry.clone())
}
@ -231,7 +247,6 @@ where
let is_recent = idle_duration <= self.idle_timeout;
let should_retain = has_active_requests || (!exceeds_capacity && is_recent);
if !should_retain {
info!(
"removing idle client: {}, exceeds_capacity: {}, idle_duration: {}s",

View File

@ -14,25 +14,19 @@
* limitations under the License.
*/
use dragonfly_client_core::Error as DFError;
use reqwest;
use std::collections::HashMap;
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error(transparent)]
Base(#[from] DFError),
#[error{"request timeout: {0}"}]
RequestTimeout(String),
#[allow(clippy::enum_variant_names)]
#[error(transparent)]
ReqwestError(#[from] reqwest::Error),
#[error{"invalid argument: {0}"}]
InvalidArgument(String),
#[allow(clippy::enum_variant_names)]
#[error(transparent)]
TonicTransportError(#[from] tonic::transport::Error),
#[error(transparent)]
InvalidHeaderValue(#[from] reqwest::header::InvalidHeaderValue),
#[error{"request internal error: {0}"}]
Internal(String),
#[allow(clippy::enum_variant_names)]
#[error(transparent)]
@ -47,38 +41,42 @@ pub enum Error {
DfdaemonError(#[from] DfdaemonError),
}
// BackendError is error detail for Backend.
/// BackendError is error detail for Backend.
#[derive(Debug, thiserror::Error)]
#[error("error occurred in the backend server, message: {message:?}, header: {header:?}, status_code: {status_code:?}")]
#[error(
"backend server error, message: {message:?}, header: {header:?}, status_code: {status_code:?}"
)]
pub struct BackendError {
// Backend error message.
/// Backend error message.
pub message: Option<String>,
// Backend HTTP response header.
/// Backend HTTP response header.
pub header: HashMap<String, String>,
// Backend HTTP status code.
/// Backend HTTP status code.
pub status_code: Option<reqwest::StatusCode>,
}
// ProxyError is error detail for Proxy.
/// ProxyError is error detail for Proxy.
#[derive(Debug, thiserror::Error)]
#[error("error occurred in the proxy server, message: {message:?}, header: {header:?}, status_code: {status_code:?}")]
#[error(
"proxy server error, message: {message:?}, header: {header:?}, status_code: {status_code:?}"
)]
pub struct ProxyError {
// Proxy error message.
/// Proxy error message.
pub message: Option<String>,
// Proxy HTTP response header.
/// Proxy HTTP response header.
pub header: HashMap<String, String>,
// Proxy HTTP status code.
/// Proxy HTTP status code.
pub status_code: Option<reqwest::StatusCode>,
}
// DfdaemonError is error detail for Dfdaemon.
/// DfdaemonError is error detail for Dfdaemon.
#[derive(Debug, thiserror::Error)]
#[error("error occurred in the dfdaemon, message: {message:?}")]
#[error("dfdaemon error, message: {message:?}")]
pub struct DfdaemonError {
// Dfdaemon error message.
/// Dfdaemon error message.
pub message: Option<String>,
}

View File

@ -15,23 +15,24 @@
*/
mod errors;
mod selector;
use crate::http::headermap_to_hashmap;
use crate::id_generator::{IDGenerator, TaskIDParameter};
use crate::pool::{Builder as PoolBuilder, Entry, Factory, Pool};
use crate::selector::{SeedPeerSelector, Selector};
use bytes::BytesMut;
use dragonfly_api::scheduler::v2::scheduler_client::SchedulerClient;
use dragonfly_client_core::error::DFError;
use errors::{BackendError, DfdaemonError, Error, ProxyError};
use futures::TryStreamExt;
use hostname;
use local_ip_address::local_ip;
use reqwest::{header::HeaderMap, header::HeaderValue, redirect::Policy, Client};
use reqwest::{header::HeaderMap, header::HeaderValue, Client};
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
use reqwest_tracing::TracingMiddleware;
use rustix::path::Arg;
use rustls_pki_types::CertificateDer;
use selector::{SeedPeerSelector, Selector};
use std::io::{Error as IOError, ErrorKind};
use std::net::{IpAddr, Ipv4Addr};
use std::sync::Arc;
use std::time::Duration;
use tokio::io::AsyncRead;
@ -51,54 +52,45 @@ const DEFAULT_CLIENT_POOL_IDLE_TIMEOUT: Duration = Duration::from_secs(30 * 60);
/// DEFAULT_CLIENT_POOL_CAPACITY is the default capacity of the client pool.
const DEFAULT_CLIENT_POOL_CAPACITY: usize = 128;
/// DEFAULT_SCHEDULER_REQUEST_TIMEOUT is the default timeout(5 seconds) for requests to the
/// scheduler service.
const DEFAULT_SCHEDULER_REQUEST_TIMEOUT: Duration = Duration::from_secs(5);
/// Result is a specialized Result type for the proxy module.
pub type Result<T> = std::result::Result<T, Error>;
/// Body is the type alias for the response body reader.
pub type Body = Box<dyn AsyncRead + Send + Unpin>;
/// Defines the interface for sending requests via the Dragonfly.
///
/// This trait enables interaction with remote servers through the Dragonfly, providing methods
/// for performing GET requests with flexible response handling. It is designed for clients that
/// need to communicate with Dragonfly seed client efficiently, supporting both streaming and buffered
/// response processing. The trait shields the complex request logic between the client and the
/// Dragonfly seed client's proxy, abstracting the underlying communication details to simplify
/// client implementation and usage.
#[tonic::async_trait]
pub trait Request {
/// get sends the get request to the remote server and returns the response reader.
/// Sends an GET request to a remote server via the Dragonfly and returns a response
/// with a streaming body.
///
/// This method is designed for scenarios where the response body is expected to be processed as a
/// stream, allowing efficient handling of large or continuous data. The response includes metadata
/// such as status codes and headers, along with a streaming `Body` for accessing the response content.
async fn get(&self, request: GetRequest) -> Result<GetResponse<Body>>;
/// get_into sends the get request to the remote server and writes the response to the input buffer.
/// Sends an GET request to a remote server via the Dragonfly and writes the response
/// body directly into the provided buffer.
///
/// This method is optimized for scenarios where the response body needs to be stored directly in
/// memory, avoiding the overhead of streaming for smaller or fixed-size responses. The provided
/// `BytesMut` buffer is used to store the response content, and the response metadata (e.g., status
/// and headers) is returned separately.
async fn get_into(&self, request: GetRequest, buf: &mut BytesMut) -> Result<GetResponse>;
}
pub struct Proxy {
/// seed_peer_selector is the selector service for selecting seed peers.
seed_peer_selector: Arc<SeedPeerSelector>,
/// max_retries is the number of times to retry a request.
max_retries: u8,
/// client_pool is the pool of clients.
client_pool: Pool<String, Client, HTTPClientFactory>,
/// id_generator is the task id generator.
id_generator: Arc<IDGenerator>,
}
pub struct Builder {
/// scheduler_endpoint is the endpoint of the scheduler service.
scheduler_endpoint: String,
/// health_check_interval is the interval of health check for selector(seed peers).
health_check_interval: Duration,
/// max_retries is the number of times to retry a request.
max_retries: u8,
}
impl Default for Builder {
fn default() -> Self {
Self {
scheduler_endpoint: "".to_string(),
health_check_interval: Duration::from_secs(60),
max_retries: 1,
}
}
}
/// GetRequest represents a GET request to be sent via the Dragonfly.
pub struct GetRequest {
/// url is the url of the request.
pub url: String,
@ -137,6 +129,7 @@ pub struct GetRequest {
pub client_cert: Option<Vec<CertificateDer<'static>>>,
}
/// GetResponse represents a GET response received via the Dragonfly.
pub struct GetResponse<R = Body>
where
R: AsyncRead + Unpin,
@ -155,52 +148,64 @@ where
}
/// Factory for creating HTTPClient instances.
#[derive(Debug, Clone, Default)]
struct HTTPClientFactory {}
/// HTTPClientFactory implements Factory for creating reqwest::Client instances with proxy support.
#[tonic::async_trait]
impl Factory<String, Client> for HTTPClientFactory {
impl Factory<String, ClientWithMiddleware> for HTTPClientFactory {
type Error = Error;
async fn make_client(&self, proxy_addr: &String) -> Result<Client> {
// Disable automatic compression to prevent double-decompression issues.
//
// Problem scenario:
// 1. Origin server supports gzip and returns "content-encoding: gzip" header.
// 2. Backend decompresses the response and stores uncompressed content to disk.
// 3. When user's client downloads via dfdaemon proxy, the original "content-encoding: gzip".
// header is forwarded to it.
// 4. User's client attempts to decompress the already-decompressed content, causing errors.
//
// Solution: Disable all compression formats (gzip, brotli, zstd, deflate) to ensure
// we receive and store uncompressed content, eliminating the double-decompression issue.
let mut client_builder = Client::builder()
.no_gzip()
.no_brotli()
.no_zstd()
.no_deflate()
.redirect(Policy::none())
/// Creates a new reqwest::Client configured to use the specified proxy address.
async fn make_client(&self, proxy_addr: &String) -> Result<ClientWithMiddleware> {
// TODO(chlins): Support client certificates and set `danger_accept_invalid_certs`
// based on the certificates.
let client = Client::builder()
.hickory_dns(true)
.danger_accept_invalid_certs(true)
.pool_max_idle_per_host(POOL_MAX_IDLE_PER_HOST)
.tcp_keepalive(KEEP_ALIVE_INTERVAL);
.tcp_keepalive(KEEP_ALIVE_INTERVAL)
.proxy(reqwest::Proxy::all(proxy_addr).map_err(|err| {
Error::Internal(format!("failed to set proxy {}: {}", proxy_addr, err))
})?)
.build()
.map_err(|err| Error::Internal(format!("failed to build reqwest client: {}", err)))?;
// Add the proxy configuration if provided.
if !proxy_addr.is_empty() {
let proxy = reqwest::Proxy::all(proxy_addr)?;
client_builder = client_builder.proxy(proxy);
}
let client = client_builder.build()?;
Ok(client)
Ok(ClientBuilder::new(client)
.with(TracingMiddleware::default())
.build())
}
}
impl Proxy {
pub fn builder() -> Builder {
Builder::default()
/// Builder is the builder for Proxy.
pub struct Builder {
/// scheduler_endpoint is the endpoint of the scheduler service.
scheduler_endpoint: String,
/// scheduler_request_timeout is the timeout of the request to the scheduler service.
scheduler_request_timeout: Duration,
/// health_check_interval is the interval of health check for selector(seed peers).
health_check_interval: Duration,
/// max_retries is the number of times to retry a request.
max_retries: u8,
}
/// Builder implements Default trait.
impl Default for Builder {
/// default returns a default Builder.
fn default() -> Self {
Self {
scheduler_endpoint: "".to_string(),
scheduler_request_timeout: DEFAULT_SCHEDULER_REQUEST_TIMEOUT,
health_check_interval: Duration::from_secs(60),
max_retries: 1,
}
}
}
/// Builder implements the builder pattern for Proxy.
impl Builder {
/// Sets the scheduler endpoint.
pub fn scheduler_endpoint(mut self, endpoint: String) -> Self {
@ -208,6 +213,12 @@ impl Builder {
self
}
/// Sets the scheduler request timeout.
pub fn scheduler_request_timeout(mut self, timeout: Duration) -> Self {
self.scheduler_request_timeout = timeout;
self
}
/// Sets the health check interval.
pub fn health_check_interval(mut self, interval: Duration) -> Self {
self.health_check_interval = interval;
@ -220,40 +231,51 @@ impl Builder {
self
}
/// Builds and returns a Proxy instance.
pub async fn build(self) -> Result<Proxy> {
// Validate input parameters.
Self::validate(
&self.scheduler_endpoint,
self.health_check_interval,
self.max_retries,
)?;
self.validate()?;
// Create the scheduler channel.
let scheduler_channel = Endpoint::from_shared(self.scheduler_endpoint.to_string())?
let scheduler_channel = Endpoint::from_shared(self.scheduler_endpoint.to_string())
.map_err(|err| Error::InvalidArgument(err.to_string()))?
.connect_timeout(self.scheduler_request_timeout)
.timeout(self.scheduler_request_timeout)
.connect()
.await?;
.await
.map_err(|err| {
Error::Internal(format!(
"failed to connect to scheduler {}: {}",
self.scheduler_endpoint, err
))
})?;
// Create scheduler client.
let scheduler_client = SchedulerClient::new(scheduler_channel);
// Create seed peer selector.
let seed_peer_selector =
Arc::new(SeedPeerSelector::new(scheduler_client, self.health_check_interval).await?);
let seed_peer_selector_clone = Arc::clone(&seed_peer_selector);
let seed_peer_selector = Arc::new(
SeedPeerSelector::new(scheduler_client, self.health_check_interval)
.await
.map_err(|err| {
Error::Internal(format!("failed to create seed peer selector: {}", err))
})?,
);
let seed_peer_selector_clone = seed_peer_selector.clone();
tokio::spawn(async move {
// Run the selector service in the background to refresh the seed peers periodically.
seed_peer_selector_clone.run().await;
});
// Get local IP address and hostname.
let local_ip = local_ip().unwrap_or(IpAddr::V4(Ipv4Addr::LOCALHOST));
let local_ip = local_ip().unwrap().to_string();
let hostname = hostname::get().unwrap().to_string_lossy().to_string();
let id_generator = IDGenerator::new(local_ip.to_string(), hostname, true);
let id_generator = IDGenerator::new(local_ip, hostname, true);
let proxy = Proxy {
seed_peer_selector,
max_retries: self.max_retries,
client_pool: PoolBuilder::new(HTTPClientFactory {})
client_pool: PoolBuilder::new(HTTPClientFactory::default())
.capacity(DEFAULT_CLIENT_POOL_CAPACITY)
.idle_timeout(DEFAULT_CLIENT_POOL_IDLE_TIMEOUT)
.build(),
@ -263,38 +285,73 @@ impl Builder {
Ok(proxy)
}
fn validate(
scheduler_endpoint: &str,
health_check_duration: Duration,
max_retries: u8,
) -> Result<()> {
if let Err(err) = url::Url::parse(scheduler_endpoint) {
return Err(
DFError::ValidationError(format!("invalid scheduler endpoint: {}", err)).into(),
);
/// validate validates the input parameters.
fn validate(&self) -> Result<()> {
if let Err(err) = url::Url::parse(&self.scheduler_endpoint) {
return Err(Error::InvalidArgument(err.to_string()));
};
if health_check_duration.as_secs() < 1 || health_check_duration.as_secs() > 600 {
return Err(DFError::ValidationError(
"health check duration must be between 1 and 600 seconds".to_string(),
)
.into());
if self.scheduler_request_timeout.as_millis() < 100 {
return Err(Error::InvalidArgument(
"scheduler request timeout must be at least 100 milliseconds".to_string(),
));
}
if max_retries > 10 {
return Err(DFError::ValidationError(
if self.health_check_interval.as_secs() < 1 || self.health_check_interval.as_secs() > 600 {
return Err(Error::InvalidArgument(
"health check interval must be between 1 and 600 seconds".to_string(),
));
}
if self.max_retries > 10 {
return Err(Error::InvalidArgument(
"max retries must be between 0 and 10".to_string(),
)
.into());
));
}
Ok(())
}
}
/// Proxy is the HTTP proxy client that sends requests via Dragonfly.
pub struct Proxy {
/// seed_peer_selector is the selector service for selecting seed peers.
seed_peer_selector: Arc<SeedPeerSelector>,
/// max_retries is the number of times to retry a request.
max_retries: u8,
/// client_pool is the pool of clients.
client_pool: Pool<String, ClientWithMiddleware, HTTPClientFactory>,
/// id_generator is the task id generator.
id_generator: Arc<IDGenerator>,
}
/// Proxy implements the proxy client that sends requests via Dragonfly.
impl Proxy {
/// builder returns a new Builder for Proxy.
pub fn builder() -> Builder {
Builder::default()
}
}
/// Implements the interface for sending requests via the Dragonfly.
///
/// This struct enables interaction with remote servers through the Dragonfly, providing methods
/// for performing GET requests with flexible response handling. It is designed for clients that
/// need to communicate with Dragonfly seed client efficiently, supporting both streaming and buffered
/// response processing. The trait shields the complex request logic between the client and the
/// Dragonfly seed client's proxy, abstracting the underlying communication details to simplify
/// client implementation and usage.
#[tonic::async_trait]
impl Request for Proxy {
/// Performs a GET request, handling custom errors and returning a streaming response.
/// Sends an GET request to a remote server via the Dragonfly and returns a response
/// with a streaming body.
///
/// This method is designed for scenarios where the response body is expected to be processed as a
/// stream, allowing efficient handling of large or continuous data. The response includes metadata
/// such as status codes and headers, along with a streaming `Body` for accessing the response content.
#[instrument(skip_all)]
async fn get(&self, request: GetRequest) -> Result<GetResponse> {
let response = self.try_send(&request).await?;
@ -314,16 +371,25 @@ impl Request for Proxy {
})
}
/// Performs a GET request, handling custom errors and collecting the response into a buffer.
/// Sends an GET request to a remote server via the Dragonfly and writes the response
/// body directly into the provided buffer.
///
/// This method is optimized for scenarios where the response body needs to be stored directly in
/// memory, avoiding the overhead of streaming for smaller or fixed-size responses. The provided
/// `BytesMut` buffer is used to store the response content, and the response metadata (e.g., status
/// and headers) is returned separately.
#[instrument(skip_all)]
async fn get_into(&self, request: GetRequest, buf: &mut BytesMut) -> Result<GetResponse> {
let operation = async {
let get_into = async {
let response = self.try_send(&request).await?;
let status = response.status();
let headers = response.headers().clone();
if status.is_success() {
let bytes = response.bytes().await?;
let bytes = response.bytes().await.map_err(|err| {
Error::Internal(format!("failed to read response body: {}", err))
})?;
buf.extend_from_slice(&bytes);
}
@ -336,20 +402,24 @@ impl Request for Proxy {
};
// Apply timeout which will properly cancel the operation when timeout is reached.
tokio::time::timeout(request.timeout, operation)
tokio::time::timeout(request.timeout, get_into)
.await
.map_err(|_| {
DFError::Unknown(format!("request timed out after {:?}", request.timeout))
})?
.map_err(|err| Error::RequestTimeout(err.to_string()))?
}
}
/// Proxy implements proxy request logic.
impl Proxy {
/// Creates reqwest clients with proxy configuration for the given request.
#[instrument(skip_all)]
async fn client_entries(&self, request: &GetRequest) -> Result<Vec<Entry<Client>>> {
// The request should be processed by the proxy, so generate the task ID to select a seed peer as proxy server.
let parameter = match request.content_for_calculating_task_id.as_ref() {
async fn client_entries(
&self,
request: &GetRequest,
) -> Result<Vec<Entry<ClientWithMiddleware>>> {
// Generate task id for selecting seed peer.
let task_id = self
.id_generator
.task_id(match request.content_for_calculating_task_id.as_ref() {
Some(content) => TaskIDParameter::Content(content.clone()),
None => TaskIDParameter::URLBased {
url: request.url.clone(),
@ -358,22 +428,29 @@ impl Proxy {
application: request.application.clone(),
filtered_query_params: request.filtered_query_params.clone(),
},
};
let task_id = self.id_generator.task_id(parameter)?;
})
.map_err(|err| Error::Internal(format!("failed to generate task id: {}", err)))?;
// Select seed peers.
// Select seed peers for downloading.
let seed_peers = self
.seed_peer_selector
.select(task_id, self.max_retries as u32)
.await
.map_err(|err| DFError::Unknown(format!("failed to select seed peers: {:?}", err)))?;
.map_err(|err| {
Error::Internal(format!(
"failed to select seed peers from scheduler: {}",
err
))
})?;
debug!("selected seed peers: {:?}", seed_peers);
let mut client_entries = Vec::new();
let mut client_entries = Vec::with_capacity(seed_peers.len());
for peer in seed_peers.iter() {
let proxy_addr = format!("http://{}:{}", peer.ip, peer.proxy_port);
let client_entry = self.client_pool.entry(&proxy_addr).await?;
// TODO(chlins): Support client https scheme.
let client_entry = self
.client_pool
.entry(&format!("http://{}:{}", peer.ip, peer.proxy_port))
.await?;
client_entries.push(client_entry);
}
@ -386,10 +463,9 @@ impl Proxy {
// Create client and send the request.
let entries = self.client_entries(request).await?;
if entries.is_empty() {
return Err(DFError::Unknown(
return Err(Error::Internal(
"no available client entries to send request".to_string(),
)
.into());
));
}
for (index, entry) in entries.iter().enumerate() {
@ -400,6 +476,8 @@ impl Proxy {
"failed to send request to client entry {:?}: {:?}",
entry.client, err
);
// If this is the last entry, return the error.
if index == entries.len() - 1 {
return Err(err);
}
@ -407,108 +485,143 @@ impl Proxy {
}
}
Err(DFError::Unknown("all retries failed".to_string()).into())
Err(Error::Internal(
"failed to send request to any client entry".to_string(),
))
}
/// Send a request to the specified URL via client entry with the given headers.
#[instrument(skip_all)]
async fn send(&self, entry: &Entry<Client>, request: &GetRequest) -> Result<reqwest::Response> {
let headers = make_request_headers(request)?;
async fn send(
&self,
entry: &Entry<ClientWithMiddleware>,
request: &GetRequest,
) -> Result<reqwest::Response> {
let headers = self.make_request_headers(request)?;
let response = entry
.client
.get(&request.url)
.headers(headers.clone())
.timeout(request.timeout)
.send()
.await?;
.await
.map_err(|err| Error::Internal(err.to_string()))?;
// Check for custom Dragonfly error headers.
if let Some(error_type) = response
.headers()
.get("X-Dragonfly-Error-Type")
.and_then(|value| value.to_str().ok())
{
// If a known error type is found, consume the response body
// to get the message and return a specific error.
let status = response.status();
let headers = response.headers().clone();
if status.is_success() {
return Ok(response);
}
return match error_type {
"backend" => Err(Error::BackendError(BackendError {
message: response.text().await.ok(),
header: headermap_to_hashmap(&headers),
let response_headers = response.headers().clone();
let header_map = headermap_to_hashmap(&response_headers);
let message = response.text().await.ok();
let error_type = response_headers
.get("X-Dragonfly-Error-Type")
.and_then(|v| v.to_str().ok());
match error_type {
Some("backend") => Err(Error::BackendError(BackendError {
message,
header: header_map,
status_code: Some(status),
})),
"proxy" => Err(Error::ProxyError(ProxyError {
message: response.text().await.ok(),
header: headermap_to_hashmap(&headers),
Some("proxy") => Err(Error::ProxyError(ProxyError {
message,
header: header_map,
status_code: Some(status),
})),
"dfdaemon" => Err(Error::DfdaemonError(DfdaemonError {
message: response.text().await.ok(),
Some("dfdaemon") => Err(Error::DfdaemonError(DfdaemonError { message })),
Some(other) => Err(Error::ProxyError(ProxyError {
message: Some(format!("unknown error type from proxy: {}", other)),
header: header_map,
status_code: Some(status),
})),
// Other error case we handle it as unknown error.
_ => Err(DFError::Unknown(format!("unknown error type: {}", error_type)).into()),
};
None => Err(Error::ProxyError(ProxyError {
message: Some(format!("unexpected status code from proxy: {}", status)),
header: header_map,
status_code: Some(status),
})),
}
}
if !response.status().is_success() {
return Err(
DFError::Unknown(format!("unexpected status code: {}", response.status())).into(),
);
}
Ok(response)
}
}
/// make_request_headers applies p2p related headers to the request headers.
fn make_request_headers(request: &GetRequest) -> Result<HeaderMap> {
/// make_request_headers applies p2p related headers to the request headers.
#[instrument(skip_all)]
fn make_request_headers(&self, request: &GetRequest) -> Result<HeaderMap> {
let mut headers = request.header.clone().unwrap_or_default();
// Apply the p2p related headers to the request headers.
if let Some(piece_length) = request.piece_length {
headers.insert(
"X-Dragonfly-Piece-Length",
piece_length.to_string().parse()?,
piece_length.to_string().parse().map_err(|err| {
Error::InvalidArgument(format!("invalid piece length: {}", err))
})?,
);
}
if let Some(tag) = request.tag.clone() {
headers.insert("X-Dragonfly-Tag", tag.to_string().parse()?);
headers.insert(
"X-Dragonfly-Tag",
tag.to_string()
.parse()
.map_err(|err| Error::InvalidArgument(format!("invalid tag: {}", err)))?,
);
}
if let Some(application) = request.application.clone() {
headers.insert("X-Dragonfly-Application", application.to_string().parse()?);
headers.insert(
"X-Dragonfly-Application",
application.to_string().parse().map_err(|err| {
Error::InvalidArgument(format!("invalid application: {}", err))
})?,
);
}
if let Some(content_for_calculating_task_id) = request.content_for_calculating_task_id.clone() {
if let Some(content_for_calculating_task_id) =
request.content_for_calculating_task_id.clone()
{
headers.insert(
"X-Dragonfly-Content-For-Calculating-Task-ID",
content_for_calculating_task_id.to_string().parse()?,
content_for_calculating_task_id
.to_string()
.parse()
.map_err(|err| {
Error::InvalidArgument(format!(
"invalid content for calculating task id: {}",
err
))
})?,
);
}
if let Some(priority) = request.priority {
headers.insert("X-Dragonfly-Priority", priority.to_string().parse()?);
headers.insert(
"X-Dragonfly-Priority",
priority
.to_string()
.parse()
.map_err(|err| Error::InvalidArgument(format!("invalid priority: {}", err)))?,
);
}
if !request.filtered_query_params.is_empty() {
let value = request.filtered_query_params.join(",");
headers.insert("X-Dragonfly-Filtered-Query-Params", value.parse()?);
headers.insert(
"X-Dragonfly-Filtered-Query-Params",
value.parse().map_err(|err| {
Error::InvalidArgument(format!("invalid filtered query params: {}", err))
})?,
);
}
// Always use p2p for sdk scenarios.
headers.insert("X-Dragonfly-Use-P2P", HeaderValue::from_static("true"));
Ok(headers)
}
}
#[cfg(test)]
mod tests {
use super::*;
use dragonfly_api::scheduler::v2::ListHostsResponse;
use dragonfly_client_core::error::DFError;
use mocktail::prelude::*;
use std::time::Duration;
@ -521,20 +634,16 @@ mod tests {
});
let server = MockServer::new_grpc("scheduler.v2.Scheduler").with_mocks(mocks);
match server.start().await {
Ok(_) => Ok(server),
Err(err) => Err(DFError::Unknown(format!(
"failed to start mock scheduler server: {}",
err,
))
.into()),
}
server.start().await.map_err(|err| {
Error::Internal(format!("failed to start mock scheduler server: {}", err))
})?;
Ok(server)
}
#[tokio::test]
async fn test_proxy_new_success() -> Result<()> {
let mock_server = setup_mock_scheduler().await?;
async fn test_proxy_new_success() {
let mock_server = setup_mock_scheduler().await.unwrap();
let scheduler_endpoint = format!("http://0.0.0.0:{}", mock_server.port().unwrap());
let result = Proxy::builder()
.scheduler_endpoint(scheduler_endpoint)
@ -542,10 +651,7 @@ mod tests {
.await;
assert!(result.is_ok());
let proxy = result.unwrap();
assert_eq!(proxy.max_retries, 1);
Ok(())
assert_eq!(result.unwrap().max_retries, 1);
}
#[tokio::test]
@ -556,15 +662,12 @@ mod tests {
.await;
assert!(result.is_err());
assert!(matches!(
result,
Err(Error::Base(DFError::ValidationError(_)))
));
assert!(matches!(result, Err(Error::InvalidArgument(_))));
}
#[tokio::test]
async fn test_proxy_new_invalid_retry_times() -> Result<()> {
let mock_server = setup_mock_scheduler().await?;
async fn test_proxy_new_invalid_retry_times() {
let mock_server = setup_mock_scheduler().await.unwrap();
let scheduler_endpoint = format!("http://0.0.0.0:{}", mock_server.port().unwrap());
let result = Proxy::builder()
@ -574,17 +677,12 @@ mod tests {
.await;
assert!(result.is_err());
assert!(matches!(
result,
Err(Error::Base(DFError::ValidationError(_)))
));
Ok(())
assert!(matches!(result, Err(Error::InvalidArgument(_))));
}
#[tokio::test]
async fn test_proxy_new_invalid_health_check_interval() -> Result<()> {
let mock_server = setup_mock_scheduler().await?;
async fn test_proxy_new_invalid_health_check_interval() {
let mock_server = setup_mock_scheduler().await.unwrap();
let scheduler_endpoint = format!("http://0.0.0.0:{}", mock_server.port().unwrap());
let result = Proxy::builder()
@ -594,12 +692,7 @@ mod tests {
.await;
assert!(result.is_err());
assert!(matches!(
result,
Err(Error::Base(DFError::ValidationError(_)))
));
Ok(())
assert!(matches!(result, Err(Error::InvalidArgument(_))));
}
#[tokio::test]
@ -641,6 +734,7 @@ mod tests {
// Create another client.
let _ = pool.entry(&"http://proxy2.com".to_string()).await.unwrap();
// Still should be 1 because the proxy1 client should have been cleaned up.
assert_eq!(pool.size().await, 1);
}

View File

@ -18,7 +18,10 @@ use crate::hashring::VNodeHashRing;
use crate::shutdown;
use dragonfly_api::common::v2::Host;
use dragonfly_api::scheduler::v2::scheduler_client::SchedulerClient;
use dragonfly_client_core::{Error, Result};
use dragonfly_client_core::{
error::{ErrorType, OrErr},
Error, Result,
};
use std::collections::HashMap;
use std::net::SocketAddr;
use std::time::Duration;
@ -33,6 +36,7 @@ use tracing::{debug, error, info, instrument};
/// Selector is the interface for selecting item from a list of items by a specific criteria.
#[tonic::async_trait]
pub trait Selector: Send + Sync {
/// select selects items based on the given task_id and number of replicas.
async fn select(&self, task_id: String, replicas: u32) -> Result<Vec<Host>>;
}
@ -42,6 +46,7 @@ const SEED_PEERS_HEALTH_CHECK_TIMEOUT: Duration = Duration::from_secs(5);
/// DEFAULT_VNODES_PER_HOST is the default number of virtual nodes per host.
const DEFAULT_VNODES_PER_HOST: usize = 3;
/// SeedPeers holds the data of seed peers.
struct SeedPeers {
/// hosts is the list of seed peers.
hosts: HashMap<String, Host>,
@ -50,6 +55,7 @@ struct SeedPeers {
hashring: VNodeHashRing,
}
/// SeedPeerSelector is the implementation of Selector for seed peers.
pub struct SeedPeerSelector {
/// health_check_interval is the interval of health check for seed peers.
health_check_interval: Duration,
@ -60,10 +66,11 @@ pub struct SeedPeerSelector {
/// seed_peers is the data of the seed peer selector.
seed_peers: RwLock<SeedPeers>,
/// mutex is used to protect refresh.
/// mutex is used to protect hot refresh.
mutex: Mutex<()>,
}
/// SeedPeerSelector implements a selector that selects seed peers from the scheduler service.
impl SeedPeerSelector {
/// new creates a new seed peer selector.
pub async fn new(
@ -86,10 +93,6 @@ impl SeedPeerSelector {
/// run starts the seed peer selector service.
pub async fn run(&self) {
// Receive shutdown signal.
let mut shutdown = shutdown::Shutdown::default();
// Start the refresh loop.
let mut interval = tokio::time::interval(self.health_check_interval);
loop {
tokio::select! {
@ -99,7 +102,7 @@ impl SeedPeerSelector {
Ok(_) => debug!("succeed to refresh seed peers"),
}
}
_ = shutdown.recv() => {
_ = shutdown::shutdown_signal() => {
info!("seed peer selector service is shutting down");
return;
}
@ -135,11 +138,13 @@ impl SeedPeerSelector {
let mut hosts = HashMap::with_capacity(seed_peers_length);
let mut hashring = VNodeHashRing::new(DEFAULT_VNODES_PER_HOST);
while let Some(result) = join_set.join_next().await {
match result {
Ok(Ok(peer)) => {
let addr: SocketAddr = format!("{}:{}", peer.ip, peer.port).parse().unwrap();
let addr: SocketAddr = format!("{}:{}", peer.ip, peer.port)
.parse()
.or_err(ErrorType::ParseError)?;
hashring.add(addr);
hosts.insert(addr.to_string(), peer);
}
@ -179,16 +184,16 @@ impl SeedPeerSelector {
/// check_health checks the health of each seed peer.
#[instrument(skip_all)]
async fn check_health(addr: &str) -> Result<HealthCheckResponse> {
let endpoint =
Endpoint::from_shared(addr.to_string())?.timeout(SEED_PEERS_HEALTH_CHECK_TIMEOUT);
let channel = Endpoint::from_shared(addr.to_string())?
.connect_timeout(SEED_PEERS_HEALTH_CHECK_TIMEOUT)
.connect()
.await?;
let channel = endpoint.connect().await?;
let mut client = HealthGRPCClient::new(channel);
let request = tonic::Request::new(HealthCheckRequest {
let mut request = tonic::Request::new(HealthCheckRequest {
service: "".to_string(),
});
request.set_timeout(SEED_PEERS_HEALTH_CHECK_TIMEOUT);
let response = client.check(request).await?;
Ok(response.into_inner())
}
@ -200,21 +205,21 @@ impl Selector for SeedPeerSelector {
async fn select(&self, task_id: String, replicas: u32) -> Result<Vec<Host>> {
// Acquire a read lock and perform all logic within it.
let seed_peers = self.seed_peers.read().await;
if seed_peers.hosts.is_empty() {
return Err(Error::HostNotFound("seed peers".to_string()));
}
// The number of replicas cannot exceed the total number of seed peers.
let desired_replicas = std::cmp::min(replicas as usize, seed_peers.hashring.len());
let expected_replicas = std::cmp::min(replicas as usize, seed_peers.hashring.len());
debug!("expected replicas: {}", expected_replicas);
// Get replica nodes from the hash ring.
let vnodes = seed_peers
.hashring
.get_with_replicas(&task_id, desired_replicas)
.get_with_replicas(&task_id, expected_replicas)
.unwrap_or_default();
let seed_peers = vnodes
let seed_peers: Vec<Host> = vnodes
.into_iter()
.filter_map(|vnode| {
seed_peers
@ -224,6 +229,10 @@ impl Selector for SeedPeerSelector {
})
.collect();
if seed_peers.is_empty() {
return Err(Error::HostNotFound("selected seed peers".to_string()));
}
Ok(seed_peers)
}
}

View File

@ -98,9 +98,12 @@ struct DfdaemonUploadClientFactory {
config: Arc<Config>,
}
/// DfdaemonUploadClientFactory implements the Factory trait for creating DfdaemonUploadClient
/// instances.
#[tonic::async_trait]
impl Factory<String, DfdaemonUploadClient> for DfdaemonUploadClientFactory {
type Error = Error;
/// Creates a new DfdaemonUploadClient for the given address.
async fn make_client(&self, addr: &String) -> Result<DfdaemonUploadClient> {
DfdaemonUploadClient::new(self.config.clone(), format!("http://{}", addr), true).await
}
@ -121,13 +124,9 @@ pub struct GRPCDownloader {
impl GRPCDownloader {
/// new returns a new GRPCDownloader.
pub fn new(config: Arc<Config>, capacity: usize, idle_timeout: Duration) -> Self {
let factory = DfdaemonUploadClientFactory {
config: config.clone(),
};
Self {
config,
client_pool: PoolBuilder::new(factory)
config: config.clone(),
client_pool: PoolBuilder::new(DfdaemonUploadClientFactory { config })
.capacity(capacity)
.idle_timeout(idle_timeout)
.build(),
@ -297,9 +296,12 @@ struct TCPClientFactory {
config: Arc<Config>,
}
/// TCPClientFactory implements the Factory trait for creating TCPClient instances.
#[tonic::async_trait]
impl Factory<String, TCPClient> for TCPClientFactory {
type Error = Error;
/// Creates a new TCPClient for the given address.
async fn make_client(&self, addr: &String) -> Result<TCPClient> {
Ok(TCPClient::new(self.config.clone(), addr.clone()))
}
@ -309,12 +311,10 @@ impl Factory<String, TCPClient> for TCPClientFactory {
impl TCPDownloader {
/// new returns a new TCPDownloader.
pub fn new(config: Arc<Config>, capacity: usize, idle_timeout: Duration) -> Self {
let factory = TCPClientFactory {
config: config.clone(),
};
Self {
client_pool: PoolBuilder::new(factory)
client_pool: PoolBuilder::new(TCPClientFactory {
config: config.clone(),
})
.capacity(capacity)
.idle_timeout(idle_timeout)
.build(),
@ -346,6 +346,7 @@ impl Downloader for TCPDownloader {
) -> Result<(Box<dyn AsyncRead + Send + Unpin>, u64, String)> {
let entry = self.get_client_entry(addr).await?;
let request_guard = entry.request_guard();
match entry.client.download_piece(number, task_id).await {
Ok((reader, offset, digest)) => Ok((Box::new(reader), offset, digest)),
Err(err) => {
@ -370,6 +371,7 @@ impl Downloader for TCPDownloader {
) -> Result<(Box<dyn AsyncRead + Send + Unpin>, u64, String)> {
let entry = self.get_client_entry(addr).await?;
let request_guard = entry.request_guard();
match entry
.client
.download_persistent_cache_piece(number, task_id)