refactor(dragonfly-client-util): request and pool logic (#1370)

Signed-off-by: Gaius <gaius.qi@gmail.com>
This commit is contained in:
Gaius 2025-09-23 18:18:16 +08:00 committed by GitHub
parent 826aeabf08
commit 02eddfbd10
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 460 additions and 320 deletions

2
Cargo.lock generated
View File

@ -1227,6 +1227,8 @@ dependencies = [
"protobuf 3.7.2", "protobuf 3.7.2",
"rcgen", "rcgen",
"reqwest", "reqwest",
"reqwest-middleware",
"reqwest-tracing",
"rustix 1.1.2", "rustix 1.1.2",
"rustls 0.22.4", "rustls 0.22.4",
"rustls-pemfile 2.2.0", "rustls-pemfile 2.2.0",

View File

@ -116,6 +116,7 @@ dashmap = "6.1.0"
hostname = "^0.4" hostname = "^0.4"
tonic-health = "0.12.3" tonic-health = "0.12.3"
hashring = "0.3.6" hashring = "0.3.6"
reqwest-tracing = "0.5"
[profile.release] [profile.release]
opt-level = 3 opt-level = 3

View File

@ -25,16 +25,16 @@ tracing.workspace = true
opendal.workspace = true opendal.workspace = true
percent-encoding.workspace = true percent-encoding.workspace = true
futures.workspace = true futures.workspace = true
reqwest-tracing.workspace = true
reqwest-retry = "0.7" reqwest-retry = "0.7"
reqwest-tracing = "0.5"
libloading = "0.8.8" libloading = "0.8.8"
[dev-dependencies] [dev-dependencies]
tempfile.workspace = true tempfile.workspace = true
wiremock = "0.6.4"
rustls-pki-types.workspace = true rustls-pki-types.workspace = true
rustls-pemfile.workspace = true rustls-pemfile.workspace = true
hyper.workspace = true hyper.workspace = true
hyper-util.workspace = true hyper-util.workspace = true
tokio-rustls.workspace = true tokio-rustls.workspace = true
rcgen.workspace = true rcgen.workspace = true
wiremock = "0.6.4"

View File

@ -41,6 +41,8 @@ tonic-health.workspace = true
local-ip-address.workspace = true local-ip-address.workspace = true
hostname.workspace = true hostname.workspace = true
dashmap.workspace = true dashmap.workspace = true
reqwest-tracing.workspace = true
reqwest-middleware.workspace = true
rustix = { version = "1.1.2", features = ["fs"] } rustix = { version = "1.1.2", features = ["fs"] }
base64 = "0.22.1" base64 = "0.22.1"
pnet = "0.35.0" pnet = "0.35.0"

View File

@ -14,42 +14,60 @@
* limitations under the License. * limitations under the License.
*/ */
use hashring::HashRing;
use std::fmt; use std::fmt;
use std::hash::Hash; use std::hash::Hash;
use std::net::SocketAddr; use std::net::SocketAddr;
use hashring::HashRing; /// A virtual node (vnode) on the consistent hash ring.
/// Each physical node (SocketAddr) is represented by multiple vnodes to better
/// balance key distribution across the ring.
#[derive(Debug, Copy, Clone, Hash, PartialEq)] #[derive(Debug, Copy, Clone, Hash, PartialEq)]
pub struct VNode { pub struct VNode {
/// The replica index of this vnode for its physical node (0..replica_count-1).
id: usize, id: usize,
/// The physical node address this vnode represents.
addr: SocketAddr, addr: SocketAddr,
} }
/// VNode implements virtual node for consistent hashing.
impl VNode { impl VNode {
/// Creates a new virtual node with the given replica id and physical address.
fn new(id: usize, addr: SocketAddr) -> Self { fn new(id: usize, addr: SocketAddr) -> Self {
VNode { id, addr } VNode { id, addr }
} }
} }
/// VNode implements Display trait to format.
impl fmt::Display for VNode { impl fmt::Display for VNode {
/// Formats the virtual node as "address|id" as the key for the hash ring.
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}|{}", self.addr, self.id) write!(f, "{}|{}", self.addr, self.id)
} }
} }
/// VNode implements methods for hash ring operations.
impl VNode { impl VNode {
/// Returns a reference to the physical node address associated with this vnode.
pub fn addr(&self) -> &SocketAddr { pub fn addr(&self) -> &SocketAddr {
&self.addr &self.addr
} }
} }
/// A consistent hash ring that uses virtual nodes (vnodes) to improve key distribution.
/// When a physical node is added, replica_count vnodes are inserted into the ring.
pub struct VNodeHashRing { pub struct VNodeHashRing {
/// Number of vnodes to create per physical node.
replica_count: usize, replica_count: usize,
/// The underlying hash ring that stores vnodes.
ring: HashRing<VNode>, ring: HashRing<VNode>,
} }
/// VNodeHashRing implements methods for managing the hash ring.
impl VNodeHashRing { impl VNodeHashRing {
/// Creates a new vnode-based hash ring.
pub fn new(replica_count: usize) -> Self { pub fn new(replica_count: usize) -> Self {
VNodeHashRing { VNodeHashRing {
replica_count, replica_count,

View File

@ -22,6 +22,5 @@ pub mod id_generator;
pub mod net; pub mod net;
pub mod pool; pub mod pool;
pub mod request; pub mod request;
pub mod selector;
pub mod shutdown; pub mod shutdown;
pub mod tls; pub mod tls;

View File

@ -34,14 +34,18 @@ pub struct RequestGuard {
active_requests: Arc<AtomicUsize>, active_requests: Arc<AtomicUsize>,
} }
/// RequestGuard implements the request guard pattern.
impl RequestGuard { impl RequestGuard {
/// Create a new request guard.
fn new(active_requests: Arc<AtomicUsize>) -> Self { fn new(active_requests: Arc<AtomicUsize>) -> Self {
active_requests.fetch_add(1, Ordering::SeqCst); active_requests.fetch_add(1, Ordering::SeqCst);
Self { active_requests } Self { active_requests }
} }
} }
/// RequestGuard decrements the active request count when dropped.
impl Drop for RequestGuard { impl Drop for RequestGuard {
/// Decrement the active request count.
fn drop(&mut self) { fn drop(&mut self) {
self.active_requests.fetch_sub(1, Ordering::SeqCst); self.active_requests.fetch_sub(1, Ordering::SeqCst);
} }
@ -60,7 +64,9 @@ pub struct Entry<T> {
actived_at: Arc<std::sync::Mutex<Instant>>, actived_at: Arc<std::sync::Mutex<Instant>>,
} }
/// Entry methods for managing client state.
impl<T> Entry<T> { impl<T> Entry<T> {
/// Create a new entry with the given client.
fn new(client: T) -> Self { fn new(client: T) -> Self {
Self { Self {
client, client,
@ -75,8 +81,8 @@ impl<T> Entry<T> {
} }
/// Update the last active time. /// Update the last active time.
fn update_actived_at(&self) { fn set_actived_at(&self, actived_at: Instant) {
*self.actived_at.lock().unwrap() = Instant::now(); *self.actived_at.lock().unwrap() = actived_at;
} }
/// Check if the client has active requests. /// Check if the client has active requests.
@ -96,6 +102,7 @@ impl<T> Entry<T> {
pub trait Factory<K, T> { pub trait Factory<K, T> {
type Error; type Error;
/// Create a new client for the given key.
async fn make_client(&self, key: &K) -> Result<T, Self::Error>; async fn make_client(&self, key: &K) -> Result<T, Self::Error>;
} }
@ -119,6 +126,7 @@ pub struct Pool<K, T, F> {
cleanup_at: Arc<Mutex<Instant>>, cleanup_at: Arc<Mutex<Instant>>,
} }
/// Builder for creating a client pool.
pub struct Builder<K, T, F> { pub struct Builder<K, T, F> {
factory: F, factory: F,
capacity: usize, capacity: usize,
@ -126,6 +134,7 @@ pub struct Builder<K, T, F> {
_phantom: PhantomData<(K, T)>, _phantom: PhantomData<(K, T)>,
} }
/// Builder methods for configuring and building the pool.
impl<K, T, F> Builder<K, T, F> impl<K, T, F> Builder<K, T, F>
where where
K: Clone + Eq + Hash + std::fmt::Display, K: Clone + Eq + Hash + std::fmt::Display,
@ -166,6 +175,14 @@ where
} }
} }
/// Generic client pool for managing reusable client instances with automatic cleanup.
///
/// This client pool provides connection reuse, automatic cleanup, and capacity management
/// capabilities, primarily used for:
/// - Connection Reuse: Reuse existing client instances to avoid repeated creation overhead.
/// - Automatic Cleanup: Periodically remove idle clients that exceed timeout thresholds.
/// - Capacity Control: Limit maximum client count to prevent resource exhaustion.
/// - Thread Safety: Use async locks and atomic operations for high-concurrency access.
impl<K, T, F> Pool<K, T, F> impl<K, T, F> Pool<K, T, F>
where where
K: Clone + Eq + Hash + std::fmt::Display, K: Clone + Eq + Hash + std::fmt::Display,
@ -182,7 +199,7 @@ where
let clients = self.clients.lock().await; let clients = self.clients.lock().await;
if let Some(entry) = clients.get(key) { if let Some(entry) = clients.get(key) {
debug!("reusing client: {}", key); debug!("reusing client: {}", key);
entry.update_actived_at(); entry.set_actived_at(Instant::now());
return Ok(entry.clone()); return Ok(entry.clone());
} }
} }
@ -190,11 +207,10 @@ where
// Create new client. // Create new client.
debug!("creating client: {}", key); debug!("creating client: {}", key);
let client = self.factory.make_client(key).await?; let client = self.factory.make_client(key).await?;
let mut clients = self.clients.lock().await; let mut clients = self.clients.lock().await;
let entry = clients.entry(key.clone()).or_insert(Entry::new(client)); let entry = clients.entry(key.clone()).or_insert(Entry::new(client));
entry.set_actived_at(Instant::now());
entry.update_actived_at();
Ok(entry.clone()) Ok(entry.clone())
} }
@ -231,7 +247,6 @@ where
let is_recent = idle_duration <= self.idle_timeout; let is_recent = idle_duration <= self.idle_timeout;
let should_retain = has_active_requests || (!exceeds_capacity && is_recent); let should_retain = has_active_requests || (!exceeds_capacity && is_recent);
if !should_retain { if !should_retain {
info!( info!(
"removing idle client: {}, exceeds_capacity: {}, idle_duration: {}s", "removing idle client: {}, exceeds_capacity: {}, idle_duration: {}s",

View File

@ -14,25 +14,19 @@
* limitations under the License. * limitations under the License.
*/ */
use dragonfly_client_core::Error as DFError;
use reqwest; use reqwest;
use std::collections::HashMap; use std::collections::HashMap;
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
pub enum Error { pub enum Error {
#[error(transparent)] #[error{"request timeout: {0}"}]
Base(#[from] DFError), RequestTimeout(String),
#[allow(clippy::enum_variant_names)] #[error{"invalid argument: {0}"}]
#[error(transparent)] InvalidArgument(String),
ReqwestError(#[from] reqwest::Error),
#[allow(clippy::enum_variant_names)] #[error{"request internal error: {0}"}]
#[error(transparent)] Internal(String),
TonicTransportError(#[from] tonic::transport::Error),
#[error(transparent)]
InvalidHeaderValue(#[from] reqwest::header::InvalidHeaderValue),
#[allow(clippy::enum_variant_names)] #[allow(clippy::enum_variant_names)]
#[error(transparent)] #[error(transparent)]
@ -47,38 +41,42 @@ pub enum Error {
DfdaemonError(#[from] DfdaemonError), DfdaemonError(#[from] DfdaemonError),
} }
// BackendError is error detail for Backend. /// BackendError is error detail for Backend.
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
#[error("error occurred in the backend server, message: {message:?}, header: {header:?}, status_code: {status_code:?}")] #[error(
"backend server error, message: {message:?}, header: {header:?}, status_code: {status_code:?}"
)]
pub struct BackendError { pub struct BackendError {
// Backend error message. /// Backend error message.
pub message: Option<String>, pub message: Option<String>,
// Backend HTTP response header. /// Backend HTTP response header.
pub header: HashMap<String, String>, pub header: HashMap<String, String>,
// Backend HTTP status code. /// Backend HTTP status code.
pub status_code: Option<reqwest::StatusCode>, pub status_code: Option<reqwest::StatusCode>,
} }
// ProxyError is error detail for Proxy. /// ProxyError is error detail for Proxy.
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
#[error("error occurred in the proxy server, message: {message:?}, header: {header:?}, status_code: {status_code:?}")] #[error(
"proxy server error, message: {message:?}, header: {header:?}, status_code: {status_code:?}"
)]
pub struct ProxyError { pub struct ProxyError {
// Proxy error message. /// Proxy error message.
pub message: Option<String>, pub message: Option<String>,
// Proxy HTTP response header. /// Proxy HTTP response header.
pub header: HashMap<String, String>, pub header: HashMap<String, String>,
// Proxy HTTP status code. /// Proxy HTTP status code.
pub status_code: Option<reqwest::StatusCode>, pub status_code: Option<reqwest::StatusCode>,
} }
// DfdaemonError is error detail for Dfdaemon. /// DfdaemonError is error detail for Dfdaemon.
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
#[error("error occurred in the dfdaemon, message: {message:?}")] #[error("dfdaemon error, message: {message:?}")]
pub struct DfdaemonError { pub struct DfdaemonError {
// Dfdaemon error message. /// Dfdaemon error message.
pub message: Option<String>, pub message: Option<String>,
} }

View File

@ -15,23 +15,24 @@
*/ */
mod errors; mod errors;
mod selector;
use crate::http::headermap_to_hashmap; use crate::http::headermap_to_hashmap;
use crate::id_generator::{IDGenerator, TaskIDParameter}; use crate::id_generator::{IDGenerator, TaskIDParameter};
use crate::pool::{Builder as PoolBuilder, Entry, Factory, Pool}; use crate::pool::{Builder as PoolBuilder, Entry, Factory, Pool};
use crate::selector::{SeedPeerSelector, Selector};
use bytes::BytesMut; use bytes::BytesMut;
use dragonfly_api::scheduler::v2::scheduler_client::SchedulerClient; use dragonfly_api::scheduler::v2::scheduler_client::SchedulerClient;
use dragonfly_client_core::error::DFError;
use errors::{BackendError, DfdaemonError, Error, ProxyError}; use errors::{BackendError, DfdaemonError, Error, ProxyError};
use futures::TryStreamExt; use futures::TryStreamExt;
use hostname; use hostname;
use local_ip_address::local_ip; use local_ip_address::local_ip;
use reqwest::{header::HeaderMap, header::HeaderValue, redirect::Policy, Client}; use reqwest::{header::HeaderMap, header::HeaderValue, Client};
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
use reqwest_tracing::TracingMiddleware;
use rustix::path::Arg; use rustix::path::Arg;
use rustls_pki_types::CertificateDer; use rustls_pki_types::CertificateDer;
use selector::{SeedPeerSelector, Selector};
use std::io::{Error as IOError, ErrorKind}; use std::io::{Error as IOError, ErrorKind};
use std::net::{IpAddr, Ipv4Addr};
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use tokio::io::AsyncRead; use tokio::io::AsyncRead;
@ -51,54 +52,45 @@ const DEFAULT_CLIENT_POOL_IDLE_TIMEOUT: Duration = Duration::from_secs(30 * 60);
/// DEFAULT_CLIENT_POOL_CAPACITY is the default capacity of the client pool. /// DEFAULT_CLIENT_POOL_CAPACITY is the default capacity of the client pool.
const DEFAULT_CLIENT_POOL_CAPACITY: usize = 128; const DEFAULT_CLIENT_POOL_CAPACITY: usize = 128;
/// DEFAULT_SCHEDULER_REQUEST_TIMEOUT is the default timeout(5 seconds) for requests to the
/// scheduler service.
const DEFAULT_SCHEDULER_REQUEST_TIMEOUT: Duration = Duration::from_secs(5);
/// Result is a specialized Result type for the proxy module.
pub type Result<T> = std::result::Result<T, Error>; pub type Result<T> = std::result::Result<T, Error>;
/// Body is the type alias for the response body reader.
pub type Body = Box<dyn AsyncRead + Send + Unpin>; pub type Body = Box<dyn AsyncRead + Send + Unpin>;
/// Defines the interface for sending requests via the Dragonfly.
///
/// This trait enables interaction with remote servers through the Dragonfly, providing methods
/// for performing GET requests with flexible response handling. It is designed for clients that
/// need to communicate with Dragonfly seed client efficiently, supporting both streaming and buffered
/// response processing. The trait shields the complex request logic between the client and the
/// Dragonfly seed client's proxy, abstracting the underlying communication details to simplify
/// client implementation and usage.
#[tonic::async_trait] #[tonic::async_trait]
pub trait Request { pub trait Request {
/// get sends the get request to the remote server and returns the response reader. /// Sends an GET request to a remote server via the Dragonfly and returns a response
/// with a streaming body.
///
/// This method is designed for scenarios where the response body is expected to be processed as a
/// stream, allowing efficient handling of large or continuous data. The response includes metadata
/// such as status codes and headers, along with a streaming `Body` for accessing the response content.
async fn get(&self, request: GetRequest) -> Result<GetResponse<Body>>; async fn get(&self, request: GetRequest) -> Result<GetResponse<Body>>;
/// get_into sends the get request to the remote server and writes the response to the input buffer. /// Sends an GET request to a remote server via the Dragonfly and writes the response
/// body directly into the provided buffer.
///
/// This method is optimized for scenarios where the response body needs to be stored directly in
/// memory, avoiding the overhead of streaming for smaller or fixed-size responses. The provided
/// `BytesMut` buffer is used to store the response content, and the response metadata (e.g., status
/// and headers) is returned separately.
async fn get_into(&self, request: GetRequest, buf: &mut BytesMut) -> Result<GetResponse>; async fn get_into(&self, request: GetRequest, buf: &mut BytesMut) -> Result<GetResponse>;
} }
pub struct Proxy { /// GetRequest represents a GET request to be sent via the Dragonfly.
/// seed_peer_selector is the selector service for selecting seed peers.
seed_peer_selector: Arc<SeedPeerSelector>,
/// max_retries is the number of times to retry a request.
max_retries: u8,
/// client_pool is the pool of clients.
client_pool: Pool<String, Client, HTTPClientFactory>,
/// id_generator is the task id generator.
id_generator: Arc<IDGenerator>,
}
pub struct Builder {
/// scheduler_endpoint is the endpoint of the scheduler service.
scheduler_endpoint: String,
/// health_check_interval is the interval of health check for selector(seed peers).
health_check_interval: Duration,
/// max_retries is the number of times to retry a request.
max_retries: u8,
}
impl Default for Builder {
fn default() -> Self {
Self {
scheduler_endpoint: "".to_string(),
health_check_interval: Duration::from_secs(60),
max_retries: 1,
}
}
}
pub struct GetRequest { pub struct GetRequest {
/// url is the url of the request. /// url is the url of the request.
pub url: String, pub url: String,
@ -137,6 +129,7 @@ pub struct GetRequest {
pub client_cert: Option<Vec<CertificateDer<'static>>>, pub client_cert: Option<Vec<CertificateDer<'static>>>,
} }
/// GetResponse represents a GET response received via the Dragonfly.
pub struct GetResponse<R = Body> pub struct GetResponse<R = Body>
where where
R: AsyncRead + Unpin, R: AsyncRead + Unpin,
@ -155,52 +148,64 @@ where
} }
/// Factory for creating HTTPClient instances. /// Factory for creating HTTPClient instances.
#[derive(Debug, Clone, Default)]
struct HTTPClientFactory {} struct HTTPClientFactory {}
/// HTTPClientFactory implements Factory for creating reqwest::Client instances with proxy support.
#[tonic::async_trait] #[tonic::async_trait]
impl Factory<String, Client> for HTTPClientFactory { impl Factory<String, ClientWithMiddleware> for HTTPClientFactory {
type Error = Error; type Error = Error;
async fn make_client(&self, proxy_addr: &String) -> Result<Client> {
// Disable automatic compression to prevent double-decompression issues. /// Creates a new reqwest::Client configured to use the specified proxy address.
// async fn make_client(&self, proxy_addr: &String) -> Result<ClientWithMiddleware> {
// Problem scenario: // TODO(chlins): Support client certificates and set `danger_accept_invalid_certs`
// 1. Origin server supports gzip and returns "content-encoding: gzip" header. // based on the certificates.
// 2. Backend decompresses the response and stores uncompressed content to disk. let client = Client::builder()
// 3. When user's client downloads via dfdaemon proxy, the original "content-encoding: gzip".
// header is forwarded to it.
// 4. User's client attempts to decompress the already-decompressed content, causing errors.
//
// Solution: Disable all compression formats (gzip, brotli, zstd, deflate) to ensure
// we receive and store uncompressed content, eliminating the double-decompression issue.
let mut client_builder = Client::builder()
.no_gzip()
.no_brotli()
.no_zstd()
.no_deflate()
.redirect(Policy::none())
.hickory_dns(true) .hickory_dns(true)
.danger_accept_invalid_certs(true) .danger_accept_invalid_certs(true)
.pool_max_idle_per_host(POOL_MAX_IDLE_PER_HOST) .pool_max_idle_per_host(POOL_MAX_IDLE_PER_HOST)
.tcp_keepalive(KEEP_ALIVE_INTERVAL); .tcp_keepalive(KEEP_ALIVE_INTERVAL)
.proxy(reqwest::Proxy::all(proxy_addr).map_err(|err| {
Error::Internal(format!("failed to set proxy {}: {}", proxy_addr, err))
})?)
.build()
.map_err(|err| Error::Internal(format!("failed to build reqwest client: {}", err)))?;
// Add the proxy configuration if provided. Ok(ClientBuilder::new(client)
if !proxy_addr.is_empty() { .with(TracingMiddleware::default())
let proxy = reqwest::Proxy::all(proxy_addr)?; .build())
client_builder = client_builder.proxy(proxy);
}
let client = client_builder.build()?;
Ok(client)
} }
} }
impl Proxy { /// Builder is the builder for Proxy.
pub fn builder() -> Builder { pub struct Builder {
Builder::default() /// scheduler_endpoint is the endpoint of the scheduler service.
scheduler_endpoint: String,
/// scheduler_request_timeout is the timeout of the request to the scheduler service.
scheduler_request_timeout: Duration,
/// health_check_interval is the interval of health check for selector(seed peers).
health_check_interval: Duration,
/// max_retries is the number of times to retry a request.
max_retries: u8,
}
/// Builder implements Default trait.
impl Default for Builder {
/// default returns a default Builder.
fn default() -> Self {
Self {
scheduler_endpoint: "".to_string(),
scheduler_request_timeout: DEFAULT_SCHEDULER_REQUEST_TIMEOUT,
health_check_interval: Duration::from_secs(60),
max_retries: 1,
}
} }
} }
/// Builder implements the builder pattern for Proxy.
impl Builder { impl Builder {
/// Sets the scheduler endpoint. /// Sets the scheduler endpoint.
pub fn scheduler_endpoint(mut self, endpoint: String) -> Self { pub fn scheduler_endpoint(mut self, endpoint: String) -> Self {
@ -208,6 +213,12 @@ impl Builder {
self self
} }
/// Sets the scheduler request timeout.
pub fn scheduler_request_timeout(mut self, timeout: Duration) -> Self {
self.scheduler_request_timeout = timeout;
self
}
/// Sets the health check interval. /// Sets the health check interval.
pub fn health_check_interval(mut self, interval: Duration) -> Self { pub fn health_check_interval(mut self, interval: Duration) -> Self {
self.health_check_interval = interval; self.health_check_interval = interval;
@ -220,40 +231,51 @@ impl Builder {
self self
} }
/// Builds and returns a Proxy instance.
pub async fn build(self) -> Result<Proxy> { pub async fn build(self) -> Result<Proxy> {
// Validate input parameters. // Validate input parameters.
Self::validate( self.validate()?;
&self.scheduler_endpoint,
self.health_check_interval,
self.max_retries,
)?;
// Create the scheduler channel. // Create the scheduler channel.
let scheduler_channel = Endpoint::from_shared(self.scheduler_endpoint.to_string())? let scheduler_channel = Endpoint::from_shared(self.scheduler_endpoint.to_string())
.map_err(|err| Error::InvalidArgument(err.to_string()))?
.connect_timeout(self.scheduler_request_timeout)
.timeout(self.scheduler_request_timeout)
.connect() .connect()
.await?; .await
.map_err(|err| {
Error::Internal(format!(
"failed to connect to scheduler {}: {}",
self.scheduler_endpoint, err
))
})?;
// Create scheduler client. // Create scheduler client.
let scheduler_client = SchedulerClient::new(scheduler_channel); let scheduler_client = SchedulerClient::new(scheduler_channel);
// Create seed peer selector. // Create seed peer selector.
let seed_peer_selector = let seed_peer_selector = Arc::new(
Arc::new(SeedPeerSelector::new(scheduler_client, self.health_check_interval).await?); SeedPeerSelector::new(scheduler_client, self.health_check_interval)
let seed_peer_selector_clone = Arc::clone(&seed_peer_selector); .await
.map_err(|err| {
Error::Internal(format!("failed to create seed peer selector: {}", err))
})?,
);
let seed_peer_selector_clone = seed_peer_selector.clone();
tokio::spawn(async move { tokio::spawn(async move {
// Run the selector service in the background to refresh the seed peers periodically. // Run the selector service in the background to refresh the seed peers periodically.
seed_peer_selector_clone.run().await; seed_peer_selector_clone.run().await;
}); });
// Get local IP address and hostname. // Get local IP address and hostname.
let local_ip = local_ip().unwrap_or(IpAddr::V4(Ipv4Addr::LOCALHOST)); let local_ip = local_ip().unwrap().to_string();
let hostname = hostname::get().unwrap().to_string_lossy().to_string(); let hostname = hostname::get().unwrap().to_string_lossy().to_string();
let id_generator = IDGenerator::new(local_ip.to_string(), hostname, true); let id_generator = IDGenerator::new(local_ip, hostname, true);
let proxy = Proxy { let proxy = Proxy {
seed_peer_selector, seed_peer_selector,
max_retries: self.max_retries, max_retries: self.max_retries,
client_pool: PoolBuilder::new(HTTPClientFactory {}) client_pool: PoolBuilder::new(HTTPClientFactory::default())
.capacity(DEFAULT_CLIENT_POOL_CAPACITY) .capacity(DEFAULT_CLIENT_POOL_CAPACITY)
.idle_timeout(DEFAULT_CLIENT_POOL_IDLE_TIMEOUT) .idle_timeout(DEFAULT_CLIENT_POOL_IDLE_TIMEOUT)
.build(), .build(),
@ -263,38 +285,73 @@ impl Builder {
Ok(proxy) Ok(proxy)
} }
fn validate( /// validate validates the input parameters.
scheduler_endpoint: &str, fn validate(&self) -> Result<()> {
health_check_duration: Duration, if let Err(err) = url::Url::parse(&self.scheduler_endpoint) {
max_retries: u8, return Err(Error::InvalidArgument(err.to_string()));
) -> Result<()> {
if let Err(err) = url::Url::parse(scheduler_endpoint) {
return Err(
DFError::ValidationError(format!("invalid scheduler endpoint: {}", err)).into(),
);
}; };
if health_check_duration.as_secs() < 1 || health_check_duration.as_secs() > 600 { if self.scheduler_request_timeout.as_millis() < 100 {
return Err(DFError::ValidationError( return Err(Error::InvalidArgument(
"health check duration must be between 1 and 600 seconds".to_string(), "scheduler request timeout must be at least 100 milliseconds".to_string(),
) ));
.into());
} }
if max_retries > 10 { if self.health_check_interval.as_secs() < 1 || self.health_check_interval.as_secs() > 600 {
return Err(DFError::ValidationError( return Err(Error::InvalidArgument(
"health check interval must be between 1 and 600 seconds".to_string(),
));
}
if self.max_retries > 10 {
return Err(Error::InvalidArgument(
"max retries must be between 0 and 10".to_string(), "max retries must be between 0 and 10".to_string(),
) ));
.into());
} }
Ok(()) Ok(())
} }
} }
/// Proxy is the HTTP proxy client that sends requests via Dragonfly.
pub struct Proxy {
/// seed_peer_selector is the selector service for selecting seed peers.
seed_peer_selector: Arc<SeedPeerSelector>,
/// max_retries is the number of times to retry a request.
max_retries: u8,
/// client_pool is the pool of clients.
client_pool: Pool<String, ClientWithMiddleware, HTTPClientFactory>,
/// id_generator is the task id generator.
id_generator: Arc<IDGenerator>,
}
/// Proxy implements the proxy client that sends requests via Dragonfly.
impl Proxy {
/// builder returns a new Builder for Proxy.
pub fn builder() -> Builder {
Builder::default()
}
}
/// Implements the interface for sending requests via the Dragonfly.
///
/// This struct enables interaction with remote servers through the Dragonfly, providing methods
/// for performing GET requests with flexible response handling. It is designed for clients that
/// need to communicate with Dragonfly seed client efficiently, supporting both streaming and buffered
/// response processing. The trait shields the complex request logic between the client and the
/// Dragonfly seed client's proxy, abstracting the underlying communication details to simplify
/// client implementation and usage.
#[tonic::async_trait] #[tonic::async_trait]
impl Request for Proxy { impl Request for Proxy {
/// Performs a GET request, handling custom errors and returning a streaming response. /// Sends an GET request to a remote server via the Dragonfly and returns a response
/// with a streaming body.
///
/// This method is designed for scenarios where the response body is expected to be processed as a
/// stream, allowing efficient handling of large or continuous data. The response includes metadata
/// such as status codes and headers, along with a streaming `Body` for accessing the response content.
#[instrument(skip_all)] #[instrument(skip_all)]
async fn get(&self, request: GetRequest) -> Result<GetResponse> { async fn get(&self, request: GetRequest) -> Result<GetResponse> {
let response = self.try_send(&request).await?; let response = self.try_send(&request).await?;
@ -314,16 +371,25 @@ impl Request for Proxy {
}) })
} }
/// Performs a GET request, handling custom errors and collecting the response into a buffer. /// Sends an GET request to a remote server via the Dragonfly and writes the response
/// body directly into the provided buffer.
///
/// This method is optimized for scenarios where the response body needs to be stored directly in
/// memory, avoiding the overhead of streaming for smaller or fixed-size responses. The provided
/// `BytesMut` buffer is used to store the response content, and the response metadata (e.g., status
/// and headers) is returned separately.
#[instrument(skip_all)] #[instrument(skip_all)]
async fn get_into(&self, request: GetRequest, buf: &mut BytesMut) -> Result<GetResponse> { async fn get_into(&self, request: GetRequest, buf: &mut BytesMut) -> Result<GetResponse> {
let operation = async { let get_into = async {
let response = self.try_send(&request).await?; let response = self.try_send(&request).await?;
let status = response.status(); let status = response.status();
let headers = response.headers().clone(); let headers = response.headers().clone();
if status.is_success() { if status.is_success() {
let bytes = response.bytes().await?; let bytes = response.bytes().await.map_err(|err| {
Error::Internal(format!("failed to read response body: {}", err))
})?;
buf.extend_from_slice(&bytes); buf.extend_from_slice(&bytes);
} }
@ -336,20 +402,24 @@ impl Request for Proxy {
}; };
// Apply timeout which will properly cancel the operation when timeout is reached. // Apply timeout which will properly cancel the operation when timeout is reached.
tokio::time::timeout(request.timeout, operation) tokio::time::timeout(request.timeout, get_into)
.await .await
.map_err(|_| { .map_err(|err| Error::RequestTimeout(err.to_string()))?
DFError::Unknown(format!("request timed out after {:?}", request.timeout))
})?
} }
} }
/// Proxy implements proxy request logic.
impl Proxy { impl Proxy {
/// Creates reqwest clients with proxy configuration for the given request. /// Creates reqwest clients with proxy configuration for the given request.
#[instrument(skip_all)] #[instrument(skip_all)]
async fn client_entries(&self, request: &GetRequest) -> Result<Vec<Entry<Client>>> { async fn client_entries(
// The request should be processed by the proxy, so generate the task ID to select a seed peer as proxy server. &self,
let parameter = match request.content_for_calculating_task_id.as_ref() { request: &GetRequest,
) -> Result<Vec<Entry<ClientWithMiddleware>>> {
// Generate task id for selecting seed peer.
let task_id = self
.id_generator
.task_id(match request.content_for_calculating_task_id.as_ref() {
Some(content) => TaskIDParameter::Content(content.clone()), Some(content) => TaskIDParameter::Content(content.clone()),
None => TaskIDParameter::URLBased { None => TaskIDParameter::URLBased {
url: request.url.clone(), url: request.url.clone(),
@ -358,22 +428,29 @@ impl Proxy {
application: request.application.clone(), application: request.application.clone(),
filtered_query_params: request.filtered_query_params.clone(), filtered_query_params: request.filtered_query_params.clone(),
}, },
}; })
let task_id = self.id_generator.task_id(parameter)?; .map_err(|err| Error::Internal(format!("failed to generate task id: {}", err)))?;
// Select seed peers. // Select seed peers for downloading.
let seed_peers = self let seed_peers = self
.seed_peer_selector .seed_peer_selector
.select(task_id, self.max_retries as u32) .select(task_id, self.max_retries as u32)
.await .await
.map_err(|err| DFError::Unknown(format!("failed to select seed peers: {:?}", err)))?; .map_err(|err| {
Error::Internal(format!(
"failed to select seed peers from scheduler: {}",
err
))
})?;
debug!("selected seed peers: {:?}", seed_peers); debug!("selected seed peers: {:?}", seed_peers);
let mut client_entries = Vec::new(); let mut client_entries = Vec::with_capacity(seed_peers.len());
for peer in seed_peers.iter() { for peer in seed_peers.iter() {
let proxy_addr = format!("http://{}:{}", peer.ip, peer.proxy_port); // TODO(chlins): Support client https scheme.
let client_entry = self.client_pool.entry(&proxy_addr).await?; let client_entry = self
.client_pool
.entry(&format!("http://{}:{}", peer.ip, peer.proxy_port))
.await?;
client_entries.push(client_entry); client_entries.push(client_entry);
} }
@ -386,10 +463,9 @@ impl Proxy {
// Create client and send the request. // Create client and send the request.
let entries = self.client_entries(request).await?; let entries = self.client_entries(request).await?;
if entries.is_empty() { if entries.is_empty() {
return Err(DFError::Unknown( return Err(Error::Internal(
"no available client entries to send request".to_string(), "no available client entries to send request".to_string(),
) ));
.into());
} }
for (index, entry) in entries.iter().enumerate() { for (index, entry) in entries.iter().enumerate() {
@ -400,6 +476,8 @@ impl Proxy {
"failed to send request to client entry {:?}: {:?}", "failed to send request to client entry {:?}: {:?}",
entry.client, err entry.client, err
); );
// If this is the last entry, return the error.
if index == entries.len() - 1 { if index == entries.len() - 1 {
return Err(err); return Err(err);
} }
@ -407,108 +485,143 @@ impl Proxy {
} }
} }
Err(DFError::Unknown("all retries failed".to_string()).into()) Err(Error::Internal(
"failed to send request to any client entry".to_string(),
))
} }
/// Send a request to the specified URL via client entry with the given headers. /// Send a request to the specified URL via client entry with the given headers.
#[instrument(skip_all)] #[instrument(skip_all)]
async fn send(&self, entry: &Entry<Client>, request: &GetRequest) -> Result<reqwest::Response> { async fn send(
let headers = make_request_headers(request)?; &self,
entry: &Entry<ClientWithMiddleware>,
request: &GetRequest,
) -> Result<reqwest::Response> {
let headers = self.make_request_headers(request)?;
let response = entry let response = entry
.client .client
.get(&request.url) .get(&request.url)
.headers(headers.clone()) .headers(headers.clone())
.timeout(request.timeout) .timeout(request.timeout)
.send() .send()
.await?; .await
.map_err(|err| Error::Internal(err.to_string()))?;
// Check for custom Dragonfly error headers.
if let Some(error_type) = response
.headers()
.get("X-Dragonfly-Error-Type")
.and_then(|value| value.to_str().ok())
{
// If a known error type is found, consume the response body
// to get the message and return a specific error.
let status = response.status(); let status = response.status();
let headers = response.headers().clone(); if status.is_success() {
return Ok(response);
}
return match error_type { let response_headers = response.headers().clone();
"backend" => Err(Error::BackendError(BackendError { let header_map = headermap_to_hashmap(&response_headers);
message: response.text().await.ok(), let message = response.text().await.ok();
header: headermap_to_hashmap(&headers), let error_type = response_headers
.get("X-Dragonfly-Error-Type")
.and_then(|v| v.to_str().ok());
match error_type {
Some("backend") => Err(Error::BackendError(BackendError {
message,
header: header_map,
status_code: Some(status), status_code: Some(status),
})), })),
"proxy" => Err(Error::ProxyError(ProxyError { Some("proxy") => Err(Error::ProxyError(ProxyError {
message: response.text().await.ok(), message,
header: headermap_to_hashmap(&headers), header: header_map,
status_code: Some(status), status_code: Some(status),
})), })),
"dfdaemon" => Err(Error::DfdaemonError(DfdaemonError { Some("dfdaemon") => Err(Error::DfdaemonError(DfdaemonError { message })),
message: response.text().await.ok(), Some(other) => Err(Error::ProxyError(ProxyError {
message: Some(format!("unknown error type from proxy: {}", other)),
header: header_map,
status_code: Some(status),
})), })),
// Other error case we handle it as unknown error. None => Err(Error::ProxyError(ProxyError {
_ => Err(DFError::Unknown(format!("unknown error type: {}", error_type)).into()), message: Some(format!("unexpected status code from proxy: {}", status)),
}; header: header_map,
status_code: Some(status),
})),
}
} }
if !response.status().is_success() { /// make_request_headers applies p2p related headers to the request headers.
return Err( #[instrument(skip_all)]
DFError::Unknown(format!("unexpected status code: {}", response.status())).into(), fn make_request_headers(&self, request: &GetRequest) -> Result<HeaderMap> {
);
}
Ok(response)
}
}
/// make_request_headers applies p2p related headers to the request headers.
fn make_request_headers(request: &GetRequest) -> Result<HeaderMap> {
let mut headers = request.header.clone().unwrap_or_default(); let mut headers = request.header.clone().unwrap_or_default();
// Apply the p2p related headers to the request headers.
if let Some(piece_length) = request.piece_length { if let Some(piece_length) = request.piece_length {
headers.insert( headers.insert(
"X-Dragonfly-Piece-Length", "X-Dragonfly-Piece-Length",
piece_length.to_string().parse()?, piece_length.to_string().parse().map_err(|err| {
Error::InvalidArgument(format!("invalid piece length: {}", err))
})?,
); );
} }
if let Some(tag) = request.tag.clone() { if let Some(tag) = request.tag.clone() {
headers.insert("X-Dragonfly-Tag", tag.to_string().parse()?); headers.insert(
"X-Dragonfly-Tag",
tag.to_string()
.parse()
.map_err(|err| Error::InvalidArgument(format!("invalid tag: {}", err)))?,
);
} }
if let Some(application) = request.application.clone() { if let Some(application) = request.application.clone() {
headers.insert("X-Dragonfly-Application", application.to_string().parse()?); headers.insert(
"X-Dragonfly-Application",
application.to_string().parse().map_err(|err| {
Error::InvalidArgument(format!("invalid application: {}", err))
})?,
);
} }
if let Some(content_for_calculating_task_id) = request.content_for_calculating_task_id.clone() { if let Some(content_for_calculating_task_id) =
request.content_for_calculating_task_id.clone()
{
headers.insert( headers.insert(
"X-Dragonfly-Content-For-Calculating-Task-ID", "X-Dragonfly-Content-For-Calculating-Task-ID",
content_for_calculating_task_id.to_string().parse()?, content_for_calculating_task_id
.to_string()
.parse()
.map_err(|err| {
Error::InvalidArgument(format!(
"invalid content for calculating task id: {}",
err
))
})?,
); );
} }
if let Some(priority) = request.priority { if let Some(priority) = request.priority {
headers.insert("X-Dragonfly-Priority", priority.to_string().parse()?); headers.insert(
"X-Dragonfly-Priority",
priority
.to_string()
.parse()
.map_err(|err| Error::InvalidArgument(format!("invalid priority: {}", err)))?,
);
} }
if !request.filtered_query_params.is_empty() { if !request.filtered_query_params.is_empty() {
let value = request.filtered_query_params.join(","); let value = request.filtered_query_params.join(",");
headers.insert("X-Dragonfly-Filtered-Query-Params", value.parse()?); headers.insert(
"X-Dragonfly-Filtered-Query-Params",
value.parse().map_err(|err| {
Error::InvalidArgument(format!("invalid filtered query params: {}", err))
})?,
);
} }
// Always use p2p for sdk scenarios.
headers.insert("X-Dragonfly-Use-P2P", HeaderValue::from_static("true")); headers.insert("X-Dragonfly-Use-P2P", HeaderValue::from_static("true"));
Ok(headers) Ok(headers)
}
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use dragonfly_api::scheduler::v2::ListHostsResponse; use dragonfly_api::scheduler::v2::ListHostsResponse;
use dragonfly_client_core::error::DFError;
use mocktail::prelude::*; use mocktail::prelude::*;
use std::time::Duration; use std::time::Duration;
@ -521,20 +634,16 @@ mod tests {
}); });
let server = MockServer::new_grpc("scheduler.v2.Scheduler").with_mocks(mocks); let server = MockServer::new_grpc("scheduler.v2.Scheduler").with_mocks(mocks);
match server.start().await { server.start().await.map_err(|err| {
Ok(_) => Ok(server), Error::Internal(format!("failed to start mock scheduler server: {}", err))
Err(err) => Err(DFError::Unknown(format!( })?;
"failed to start mock scheduler server: {}",
err, Ok(server)
))
.into()),
}
} }
#[tokio::test] #[tokio::test]
async fn test_proxy_new_success() -> Result<()> { async fn test_proxy_new_success() {
let mock_server = setup_mock_scheduler().await?; let mock_server = setup_mock_scheduler().await.unwrap();
let scheduler_endpoint = format!("http://0.0.0.0:{}", mock_server.port().unwrap()); let scheduler_endpoint = format!("http://0.0.0.0:{}", mock_server.port().unwrap());
let result = Proxy::builder() let result = Proxy::builder()
.scheduler_endpoint(scheduler_endpoint) .scheduler_endpoint(scheduler_endpoint)
@ -542,10 +651,7 @@ mod tests {
.await; .await;
assert!(result.is_ok()); assert!(result.is_ok());
let proxy = result.unwrap(); assert_eq!(result.unwrap().max_retries, 1);
assert_eq!(proxy.max_retries, 1);
Ok(())
} }
#[tokio::test] #[tokio::test]
@ -556,15 +662,12 @@ mod tests {
.await; .await;
assert!(result.is_err()); assert!(result.is_err());
assert!(matches!( assert!(matches!(result, Err(Error::InvalidArgument(_))));
result,
Err(Error::Base(DFError::ValidationError(_)))
));
} }
#[tokio::test] #[tokio::test]
async fn test_proxy_new_invalid_retry_times() -> Result<()> { async fn test_proxy_new_invalid_retry_times() {
let mock_server = setup_mock_scheduler().await?; let mock_server = setup_mock_scheduler().await.unwrap();
let scheduler_endpoint = format!("http://0.0.0.0:{}", mock_server.port().unwrap()); let scheduler_endpoint = format!("http://0.0.0.0:{}", mock_server.port().unwrap());
let result = Proxy::builder() let result = Proxy::builder()
@ -574,17 +677,12 @@ mod tests {
.await; .await;
assert!(result.is_err()); assert!(result.is_err());
assert!(matches!( assert!(matches!(result, Err(Error::InvalidArgument(_))));
result,
Err(Error::Base(DFError::ValidationError(_)))
));
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_proxy_new_invalid_health_check_interval() -> Result<()> { async fn test_proxy_new_invalid_health_check_interval() {
let mock_server = setup_mock_scheduler().await?; let mock_server = setup_mock_scheduler().await.unwrap();
let scheduler_endpoint = format!("http://0.0.0.0:{}", mock_server.port().unwrap()); let scheduler_endpoint = format!("http://0.0.0.0:{}", mock_server.port().unwrap());
let result = Proxy::builder() let result = Proxy::builder()
@ -594,12 +692,7 @@ mod tests {
.await; .await;
assert!(result.is_err()); assert!(result.is_err());
assert!(matches!( assert!(matches!(result, Err(Error::InvalidArgument(_))));
result,
Err(Error::Base(DFError::ValidationError(_)))
));
Ok(())
} }
#[tokio::test] #[tokio::test]
@ -641,6 +734,7 @@ mod tests {
// Create another client. // Create another client.
let _ = pool.entry(&"http://proxy2.com".to_string()).await.unwrap(); let _ = pool.entry(&"http://proxy2.com".to_string()).await.unwrap();
// Still should be 1 because the proxy1 client should have been cleaned up. // Still should be 1 because the proxy1 client should have been cleaned up.
assert_eq!(pool.size().await, 1); assert_eq!(pool.size().await, 1);
} }

View File

@ -18,7 +18,10 @@ use crate::hashring::VNodeHashRing;
use crate::shutdown; use crate::shutdown;
use dragonfly_api::common::v2::Host; use dragonfly_api::common::v2::Host;
use dragonfly_api::scheduler::v2::scheduler_client::SchedulerClient; use dragonfly_api::scheduler::v2::scheduler_client::SchedulerClient;
use dragonfly_client_core::{Error, Result}; use dragonfly_client_core::{
error::{ErrorType, OrErr},
Error, Result,
};
use std::collections::HashMap; use std::collections::HashMap;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::time::Duration; use std::time::Duration;
@ -33,6 +36,7 @@ use tracing::{debug, error, info, instrument};
/// Selector is the interface for selecting item from a list of items by a specific criteria. /// Selector is the interface for selecting item from a list of items by a specific criteria.
#[tonic::async_trait] #[tonic::async_trait]
pub trait Selector: Send + Sync { pub trait Selector: Send + Sync {
/// select selects items based on the given task_id and number of replicas.
async fn select(&self, task_id: String, replicas: u32) -> Result<Vec<Host>>; async fn select(&self, task_id: String, replicas: u32) -> Result<Vec<Host>>;
} }
@ -42,6 +46,7 @@ const SEED_PEERS_HEALTH_CHECK_TIMEOUT: Duration = Duration::from_secs(5);
/// DEFAULT_VNODES_PER_HOST is the default number of virtual nodes per host. /// DEFAULT_VNODES_PER_HOST is the default number of virtual nodes per host.
const DEFAULT_VNODES_PER_HOST: usize = 3; const DEFAULT_VNODES_PER_HOST: usize = 3;
/// SeedPeers holds the data of seed peers.
struct SeedPeers { struct SeedPeers {
/// hosts is the list of seed peers. /// hosts is the list of seed peers.
hosts: HashMap<String, Host>, hosts: HashMap<String, Host>,
@ -50,6 +55,7 @@ struct SeedPeers {
hashring: VNodeHashRing, hashring: VNodeHashRing,
} }
/// SeedPeerSelector is the implementation of Selector for seed peers.
pub struct SeedPeerSelector { pub struct SeedPeerSelector {
/// health_check_interval is the interval of health check for seed peers. /// health_check_interval is the interval of health check for seed peers.
health_check_interval: Duration, health_check_interval: Duration,
@ -60,10 +66,11 @@ pub struct SeedPeerSelector {
/// seed_peers is the data of the seed peer selector. /// seed_peers is the data of the seed peer selector.
seed_peers: RwLock<SeedPeers>, seed_peers: RwLock<SeedPeers>,
/// mutex is used to protect refresh. /// mutex is used to protect hot refresh.
mutex: Mutex<()>, mutex: Mutex<()>,
} }
/// SeedPeerSelector implements a selector that selects seed peers from the scheduler service.
impl SeedPeerSelector { impl SeedPeerSelector {
/// new creates a new seed peer selector. /// new creates a new seed peer selector.
pub async fn new( pub async fn new(
@ -86,10 +93,6 @@ impl SeedPeerSelector {
/// run starts the seed peer selector service. /// run starts the seed peer selector service.
pub async fn run(&self) { pub async fn run(&self) {
// Receive shutdown signal.
let mut shutdown = shutdown::Shutdown::default();
// Start the refresh loop.
let mut interval = tokio::time::interval(self.health_check_interval); let mut interval = tokio::time::interval(self.health_check_interval);
loop { loop {
tokio::select! { tokio::select! {
@ -99,7 +102,7 @@ impl SeedPeerSelector {
Ok(_) => debug!("succeed to refresh seed peers"), Ok(_) => debug!("succeed to refresh seed peers"),
} }
} }
_ = shutdown.recv() => { _ = shutdown::shutdown_signal() => {
info!("seed peer selector service is shutting down"); info!("seed peer selector service is shutting down");
return; return;
} }
@ -135,11 +138,13 @@ impl SeedPeerSelector {
let mut hosts = HashMap::with_capacity(seed_peers_length); let mut hosts = HashMap::with_capacity(seed_peers_length);
let mut hashring = VNodeHashRing::new(DEFAULT_VNODES_PER_HOST); let mut hashring = VNodeHashRing::new(DEFAULT_VNODES_PER_HOST);
while let Some(result) = join_set.join_next().await { while let Some(result) = join_set.join_next().await {
match result { match result {
Ok(Ok(peer)) => { Ok(Ok(peer)) => {
let addr: SocketAddr = format!("{}:{}", peer.ip, peer.port).parse().unwrap(); let addr: SocketAddr = format!("{}:{}", peer.ip, peer.port)
.parse()
.or_err(ErrorType::ParseError)?;
hashring.add(addr); hashring.add(addr);
hosts.insert(addr.to_string(), peer); hosts.insert(addr.to_string(), peer);
} }
@ -179,16 +184,16 @@ impl SeedPeerSelector {
/// check_health checks the health of each seed peer. /// check_health checks the health of each seed peer.
#[instrument(skip_all)] #[instrument(skip_all)]
async fn check_health(addr: &str) -> Result<HealthCheckResponse> { async fn check_health(addr: &str) -> Result<HealthCheckResponse> {
let endpoint = let channel = Endpoint::from_shared(addr.to_string())?
Endpoint::from_shared(addr.to_string())?.timeout(SEED_PEERS_HEALTH_CHECK_TIMEOUT); .connect_timeout(SEED_PEERS_HEALTH_CHECK_TIMEOUT)
.connect()
.await?;
let channel = endpoint.connect().await?;
let mut client = HealthGRPCClient::new(channel); let mut client = HealthGRPCClient::new(channel);
let mut request = tonic::Request::new(HealthCheckRequest {
let request = tonic::Request::new(HealthCheckRequest {
service: "".to_string(), service: "".to_string(),
}); });
request.set_timeout(SEED_PEERS_HEALTH_CHECK_TIMEOUT);
let response = client.check(request).await?; let response = client.check(request).await?;
Ok(response.into_inner()) Ok(response.into_inner())
} }
@ -200,21 +205,21 @@ impl Selector for SeedPeerSelector {
async fn select(&self, task_id: String, replicas: u32) -> Result<Vec<Host>> { async fn select(&self, task_id: String, replicas: u32) -> Result<Vec<Host>> {
// Acquire a read lock and perform all logic within it. // Acquire a read lock and perform all logic within it.
let seed_peers = self.seed_peers.read().await; let seed_peers = self.seed_peers.read().await;
if seed_peers.hosts.is_empty() { if seed_peers.hosts.is_empty() {
return Err(Error::HostNotFound("seed peers".to_string())); return Err(Error::HostNotFound("seed peers".to_string()));
} }
// The number of replicas cannot exceed the total number of seed peers. // The number of replicas cannot exceed the total number of seed peers.
let desired_replicas = std::cmp::min(replicas as usize, seed_peers.hashring.len()); let expected_replicas = std::cmp::min(replicas as usize, seed_peers.hashring.len());
debug!("expected replicas: {}", expected_replicas);
// Get replica nodes from the hash ring. // Get replica nodes from the hash ring.
let vnodes = seed_peers let vnodes = seed_peers
.hashring .hashring
.get_with_replicas(&task_id, desired_replicas) .get_with_replicas(&task_id, expected_replicas)
.unwrap_or_default(); .unwrap_or_default();
let seed_peers = vnodes let seed_peers: Vec<Host> = vnodes
.into_iter() .into_iter()
.filter_map(|vnode| { .filter_map(|vnode| {
seed_peers seed_peers
@ -224,6 +229,10 @@ impl Selector for SeedPeerSelector {
}) })
.collect(); .collect();
if seed_peers.is_empty() {
return Err(Error::HostNotFound("selected seed peers".to_string()));
}
Ok(seed_peers) Ok(seed_peers)
} }
} }

View File

@ -98,9 +98,12 @@ struct DfdaemonUploadClientFactory {
config: Arc<Config>, config: Arc<Config>,
} }
/// DfdaemonUploadClientFactory implements the Factory trait for creating DfdaemonUploadClient
/// instances.
#[tonic::async_trait] #[tonic::async_trait]
impl Factory<String, DfdaemonUploadClient> for DfdaemonUploadClientFactory { impl Factory<String, DfdaemonUploadClient> for DfdaemonUploadClientFactory {
type Error = Error; type Error = Error;
/// Creates a new DfdaemonUploadClient for the given address.
async fn make_client(&self, addr: &String) -> Result<DfdaemonUploadClient> { async fn make_client(&self, addr: &String) -> Result<DfdaemonUploadClient> {
DfdaemonUploadClient::new(self.config.clone(), format!("http://{}", addr), true).await DfdaemonUploadClient::new(self.config.clone(), format!("http://{}", addr), true).await
} }
@ -121,13 +124,9 @@ pub struct GRPCDownloader {
impl GRPCDownloader { impl GRPCDownloader {
/// new returns a new GRPCDownloader. /// new returns a new GRPCDownloader.
pub fn new(config: Arc<Config>, capacity: usize, idle_timeout: Duration) -> Self { pub fn new(config: Arc<Config>, capacity: usize, idle_timeout: Duration) -> Self {
let factory = DfdaemonUploadClientFactory {
config: config.clone(),
};
Self { Self {
config, config: config.clone(),
client_pool: PoolBuilder::new(factory) client_pool: PoolBuilder::new(DfdaemonUploadClientFactory { config })
.capacity(capacity) .capacity(capacity)
.idle_timeout(idle_timeout) .idle_timeout(idle_timeout)
.build(), .build(),
@ -297,9 +296,12 @@ struct TCPClientFactory {
config: Arc<Config>, config: Arc<Config>,
} }
/// TCPClientFactory implements the Factory trait for creating TCPClient instances.
#[tonic::async_trait] #[tonic::async_trait]
impl Factory<String, TCPClient> for TCPClientFactory { impl Factory<String, TCPClient> for TCPClientFactory {
type Error = Error; type Error = Error;
/// Creates a new TCPClient for the given address.
async fn make_client(&self, addr: &String) -> Result<TCPClient> { async fn make_client(&self, addr: &String) -> Result<TCPClient> {
Ok(TCPClient::new(self.config.clone(), addr.clone())) Ok(TCPClient::new(self.config.clone(), addr.clone()))
} }
@ -309,12 +311,10 @@ impl Factory<String, TCPClient> for TCPClientFactory {
impl TCPDownloader { impl TCPDownloader {
/// new returns a new TCPDownloader. /// new returns a new TCPDownloader.
pub fn new(config: Arc<Config>, capacity: usize, idle_timeout: Duration) -> Self { pub fn new(config: Arc<Config>, capacity: usize, idle_timeout: Duration) -> Self {
let factory = TCPClientFactory {
config: config.clone(),
};
Self { Self {
client_pool: PoolBuilder::new(factory) client_pool: PoolBuilder::new(TCPClientFactory {
config: config.clone(),
})
.capacity(capacity) .capacity(capacity)
.idle_timeout(idle_timeout) .idle_timeout(idle_timeout)
.build(), .build(),
@ -346,6 +346,7 @@ impl Downloader for TCPDownloader {
) -> Result<(Box<dyn AsyncRead + Send + Unpin>, u64, String)> { ) -> Result<(Box<dyn AsyncRead + Send + Unpin>, u64, String)> {
let entry = self.get_client_entry(addr).await?; let entry = self.get_client_entry(addr).await?;
let request_guard = entry.request_guard(); let request_guard = entry.request_guard();
match entry.client.download_piece(number, task_id).await { match entry.client.download_piece(number, task_id).await {
Ok((reader, offset, digest)) => Ok((Box::new(reader), offset, digest)), Ok((reader, offset, digest)) => Ok((Box::new(reader), offset, digest)),
Err(err) => { Err(err) => {
@ -370,6 +371,7 @@ impl Downloader for TCPDownloader {
) -> Result<(Box<dyn AsyncRead + Send + Unpin>, u64, String)> { ) -> Result<(Box<dyn AsyncRead + Send + Unpin>, u64, String)> {
let entry = self.get_client_entry(addr).await?; let entry = self.get_client_entry(addr).await?;
let request_guard = entry.request_guard(); let request_guard = entry.request_guard();
match entry match entry
.client .client
.download_persistent_cache_piece(number, task_id) .download_persistent_cache_piece(number, task_id)