refactor(dragonfly-client-util): request and pool logic (#1370)
Signed-off-by: Gaius <gaius.qi@gmail.com>
This commit is contained in:
parent
826aeabf08
commit
02eddfbd10
|
|
@ -1227,6 +1227,8 @@ dependencies = [
|
|||
"protobuf 3.7.2",
|
||||
"rcgen",
|
||||
"reqwest",
|
||||
"reqwest-middleware",
|
||||
"reqwest-tracing",
|
||||
"rustix 1.1.2",
|
||||
"rustls 0.22.4",
|
||||
"rustls-pemfile 2.2.0",
|
||||
|
|
|
|||
|
|
@ -116,6 +116,7 @@ dashmap = "6.1.0"
|
|||
hostname = "^0.4"
|
||||
tonic-health = "0.12.3"
|
||||
hashring = "0.3.6"
|
||||
reqwest-tracing = "0.5"
|
||||
|
||||
[profile.release]
|
||||
opt-level = 3
|
||||
|
|
|
|||
|
|
@ -25,16 +25,16 @@ tracing.workspace = true
|
|||
opendal.workspace = true
|
||||
percent-encoding.workspace = true
|
||||
futures.workspace = true
|
||||
reqwest-tracing.workspace = true
|
||||
reqwest-retry = "0.7"
|
||||
reqwest-tracing = "0.5"
|
||||
libloading = "0.8.8"
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile.workspace = true
|
||||
wiremock = "0.6.4"
|
||||
rustls-pki-types.workspace = true
|
||||
rustls-pemfile.workspace = true
|
||||
hyper.workspace = true
|
||||
hyper-util.workspace = true
|
||||
tokio-rustls.workspace = true
|
||||
rcgen.workspace = true
|
||||
wiremock = "0.6.4"
|
||||
|
|
|
|||
|
|
@ -41,6 +41,8 @@ tonic-health.workspace = true
|
|||
local-ip-address.workspace = true
|
||||
hostname.workspace = true
|
||||
dashmap.workspace = true
|
||||
reqwest-tracing.workspace = true
|
||||
reqwest-middleware.workspace = true
|
||||
rustix = { version = "1.1.2", features = ["fs"] }
|
||||
base64 = "0.22.1"
|
||||
pnet = "0.35.0"
|
||||
|
|
|
|||
|
|
@ -14,42 +14,60 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
use hashring::HashRing;
|
||||
use std::fmt;
|
||||
use std::hash::Hash;
|
||||
use std::net::SocketAddr;
|
||||
|
||||
use hashring::HashRing;
|
||||
|
||||
/// A virtual node (vnode) on the consistent hash ring.
|
||||
/// Each physical node (SocketAddr) is represented by multiple vnodes to better
|
||||
/// balance key distribution across the ring.
|
||||
#[derive(Debug, Copy, Clone, Hash, PartialEq)]
|
||||
pub struct VNode {
|
||||
/// The replica index of this vnode for its physical node (0..replica_count-1).
|
||||
id: usize,
|
||||
|
||||
/// The physical node address this vnode represents.
|
||||
addr: SocketAddr,
|
||||
}
|
||||
|
||||
/// VNode implements virtual node for consistent hashing.
|
||||
impl VNode {
|
||||
/// Creates a new virtual node with the given replica id and physical address.
|
||||
fn new(id: usize, addr: SocketAddr) -> Self {
|
||||
VNode { id, addr }
|
||||
}
|
||||
}
|
||||
|
||||
/// VNode implements Display trait to format.
|
||||
impl fmt::Display for VNode {
|
||||
/// Formats the virtual node as "address|id" as the key for the hash ring.
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}|{}", self.addr, self.id)
|
||||
}
|
||||
}
|
||||
|
||||
/// VNode implements methods for hash ring operations.
|
||||
impl VNode {
|
||||
/// Returns a reference to the physical node address associated with this vnode.
|
||||
pub fn addr(&self) -> &SocketAddr {
|
||||
&self.addr
|
||||
}
|
||||
}
|
||||
|
||||
/// A consistent hash ring that uses virtual nodes (vnodes) to improve key distribution.
|
||||
/// When a physical node is added, replica_count vnodes are inserted into the ring.
|
||||
pub struct VNodeHashRing {
|
||||
/// Number of vnodes to create per physical node.
|
||||
replica_count: usize,
|
||||
|
||||
/// The underlying hash ring that stores vnodes.
|
||||
ring: HashRing<VNode>,
|
||||
}
|
||||
|
||||
/// VNodeHashRing implements methods for managing the hash ring.
|
||||
impl VNodeHashRing {
|
||||
/// Creates a new vnode-based hash ring.
|
||||
pub fn new(replica_count: usize) -> Self {
|
||||
VNodeHashRing {
|
||||
replica_count,
|
||||
|
|
|
|||
|
|
@ -22,6 +22,5 @@ pub mod id_generator;
|
|||
pub mod net;
|
||||
pub mod pool;
|
||||
pub mod request;
|
||||
pub mod selector;
|
||||
pub mod shutdown;
|
||||
pub mod tls;
|
||||
|
|
|
|||
|
|
@ -34,14 +34,18 @@ pub struct RequestGuard {
|
|||
active_requests: Arc<AtomicUsize>,
|
||||
}
|
||||
|
||||
/// RequestGuard implements the request guard pattern.
|
||||
impl RequestGuard {
|
||||
/// Create a new request guard.
|
||||
fn new(active_requests: Arc<AtomicUsize>) -> Self {
|
||||
active_requests.fetch_add(1, Ordering::SeqCst);
|
||||
Self { active_requests }
|
||||
}
|
||||
}
|
||||
|
||||
/// RequestGuard decrements the active request count when dropped.
|
||||
impl Drop for RequestGuard {
|
||||
/// Decrement the active request count.
|
||||
fn drop(&mut self) {
|
||||
self.active_requests.fetch_sub(1, Ordering::SeqCst);
|
||||
}
|
||||
|
|
@ -60,7 +64,9 @@ pub struct Entry<T> {
|
|||
actived_at: Arc<std::sync::Mutex<Instant>>,
|
||||
}
|
||||
|
||||
/// Entry methods for managing client state.
|
||||
impl<T> Entry<T> {
|
||||
/// Create a new entry with the given client.
|
||||
fn new(client: T) -> Self {
|
||||
Self {
|
||||
client,
|
||||
|
|
@ -75,8 +81,8 @@ impl<T> Entry<T> {
|
|||
}
|
||||
|
||||
/// Update the last active time.
|
||||
fn update_actived_at(&self) {
|
||||
*self.actived_at.lock().unwrap() = Instant::now();
|
||||
fn set_actived_at(&self, actived_at: Instant) {
|
||||
*self.actived_at.lock().unwrap() = actived_at;
|
||||
}
|
||||
|
||||
/// Check if the client has active requests.
|
||||
|
|
@ -96,6 +102,7 @@ impl<T> Entry<T> {
|
|||
pub trait Factory<K, T> {
|
||||
type Error;
|
||||
|
||||
/// Create a new client for the given key.
|
||||
async fn make_client(&self, key: &K) -> Result<T, Self::Error>;
|
||||
}
|
||||
|
||||
|
|
@ -119,6 +126,7 @@ pub struct Pool<K, T, F> {
|
|||
cleanup_at: Arc<Mutex<Instant>>,
|
||||
}
|
||||
|
||||
/// Builder for creating a client pool.
|
||||
pub struct Builder<K, T, F> {
|
||||
factory: F,
|
||||
capacity: usize,
|
||||
|
|
@ -126,6 +134,7 @@ pub struct Builder<K, T, F> {
|
|||
_phantom: PhantomData<(K, T)>,
|
||||
}
|
||||
|
||||
/// Builder methods for configuring and building the pool.
|
||||
impl<K, T, F> Builder<K, T, F>
|
||||
where
|
||||
K: Clone + Eq + Hash + std::fmt::Display,
|
||||
|
|
@ -166,6 +175,14 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
/// Generic client pool for managing reusable client instances with automatic cleanup.
|
||||
///
|
||||
/// This client pool provides connection reuse, automatic cleanup, and capacity management
|
||||
/// capabilities, primarily used for:
|
||||
/// - Connection Reuse: Reuse existing client instances to avoid repeated creation overhead.
|
||||
/// - Automatic Cleanup: Periodically remove idle clients that exceed timeout thresholds.
|
||||
/// - Capacity Control: Limit maximum client count to prevent resource exhaustion.
|
||||
/// - Thread Safety: Use async locks and atomic operations for high-concurrency access.
|
||||
impl<K, T, F> Pool<K, T, F>
|
||||
where
|
||||
K: Clone + Eq + Hash + std::fmt::Display,
|
||||
|
|
@ -182,7 +199,7 @@ where
|
|||
let clients = self.clients.lock().await;
|
||||
if let Some(entry) = clients.get(key) {
|
||||
debug!("reusing client: {}", key);
|
||||
entry.update_actived_at();
|
||||
entry.set_actived_at(Instant::now());
|
||||
return Ok(entry.clone());
|
||||
}
|
||||
}
|
||||
|
|
@ -190,11 +207,10 @@ where
|
|||
// Create new client.
|
||||
debug!("creating client: {}", key);
|
||||
let client = self.factory.make_client(key).await?;
|
||||
|
||||
let mut clients = self.clients.lock().await;
|
||||
let entry = clients.entry(key.clone()).or_insert(Entry::new(client));
|
||||
entry.set_actived_at(Instant::now());
|
||||
|
||||
entry.update_actived_at();
|
||||
Ok(entry.clone())
|
||||
}
|
||||
|
||||
|
|
@ -231,7 +247,6 @@ where
|
|||
let is_recent = idle_duration <= self.idle_timeout;
|
||||
|
||||
let should_retain = has_active_requests || (!exceeds_capacity && is_recent);
|
||||
|
||||
if !should_retain {
|
||||
info!(
|
||||
"removing idle client: {}, exceeds_capacity: {}, idle_duration: {}s",
|
||||
|
|
|
|||
|
|
@ -14,25 +14,19 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
use dragonfly_client_core::Error as DFError;
|
||||
use reqwest;
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum Error {
|
||||
#[error(transparent)]
|
||||
Base(#[from] DFError),
|
||||
#[error{"request timeout: {0}"}]
|
||||
RequestTimeout(String),
|
||||
|
||||
#[allow(clippy::enum_variant_names)]
|
||||
#[error(transparent)]
|
||||
ReqwestError(#[from] reqwest::Error),
|
||||
#[error{"invalid argument: {0}"}]
|
||||
InvalidArgument(String),
|
||||
|
||||
#[allow(clippy::enum_variant_names)]
|
||||
#[error(transparent)]
|
||||
TonicTransportError(#[from] tonic::transport::Error),
|
||||
|
||||
#[error(transparent)]
|
||||
InvalidHeaderValue(#[from] reqwest::header::InvalidHeaderValue),
|
||||
#[error{"request internal error: {0}"}]
|
||||
Internal(String),
|
||||
|
||||
#[allow(clippy::enum_variant_names)]
|
||||
#[error(transparent)]
|
||||
|
|
@ -47,38 +41,42 @@ pub enum Error {
|
|||
DfdaemonError(#[from] DfdaemonError),
|
||||
}
|
||||
|
||||
// BackendError is error detail for Backend.
|
||||
/// BackendError is error detail for Backend.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("error occurred in the backend server, message: {message:?}, header: {header:?}, status_code: {status_code:?}")]
|
||||
#[error(
|
||||
"backend server error, message: {message:?}, header: {header:?}, status_code: {status_code:?}"
|
||||
)]
|
||||
pub struct BackendError {
|
||||
// Backend error message.
|
||||
/// Backend error message.
|
||||
pub message: Option<String>,
|
||||
|
||||
// Backend HTTP response header.
|
||||
/// Backend HTTP response header.
|
||||
pub header: HashMap<String, String>,
|
||||
|
||||
// Backend HTTP status code.
|
||||
/// Backend HTTP status code.
|
||||
pub status_code: Option<reqwest::StatusCode>,
|
||||
}
|
||||
|
||||
// ProxyError is error detail for Proxy.
|
||||
/// ProxyError is error detail for Proxy.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("error occurred in the proxy server, message: {message:?}, header: {header:?}, status_code: {status_code:?}")]
|
||||
#[error(
|
||||
"proxy server error, message: {message:?}, header: {header:?}, status_code: {status_code:?}"
|
||||
)]
|
||||
pub struct ProxyError {
|
||||
// Proxy error message.
|
||||
/// Proxy error message.
|
||||
pub message: Option<String>,
|
||||
|
||||
// Proxy HTTP response header.
|
||||
/// Proxy HTTP response header.
|
||||
pub header: HashMap<String, String>,
|
||||
|
||||
// Proxy HTTP status code.
|
||||
/// Proxy HTTP status code.
|
||||
pub status_code: Option<reqwest::StatusCode>,
|
||||
}
|
||||
|
||||
// DfdaemonError is error detail for Dfdaemon.
|
||||
/// DfdaemonError is error detail for Dfdaemon.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("error occurred in the dfdaemon, message: {message:?}")]
|
||||
#[error("dfdaemon error, message: {message:?}")]
|
||||
pub struct DfdaemonError {
|
||||
// Dfdaemon error message.
|
||||
/// Dfdaemon error message.
|
||||
pub message: Option<String>,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,23 +15,24 @@
|
|||
*/
|
||||
|
||||
mod errors;
|
||||
mod selector;
|
||||
|
||||
use crate::http::headermap_to_hashmap;
|
||||
use crate::id_generator::{IDGenerator, TaskIDParameter};
|
||||
use crate::pool::{Builder as PoolBuilder, Entry, Factory, Pool};
|
||||
use crate::selector::{SeedPeerSelector, Selector};
|
||||
use bytes::BytesMut;
|
||||
use dragonfly_api::scheduler::v2::scheduler_client::SchedulerClient;
|
||||
use dragonfly_client_core::error::DFError;
|
||||
use errors::{BackendError, DfdaemonError, Error, ProxyError};
|
||||
use futures::TryStreamExt;
|
||||
use hostname;
|
||||
use local_ip_address::local_ip;
|
||||
use reqwest::{header::HeaderMap, header::HeaderValue, redirect::Policy, Client};
|
||||
use reqwest::{header::HeaderMap, header::HeaderValue, Client};
|
||||
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
|
||||
use reqwest_tracing::TracingMiddleware;
|
||||
use rustix::path::Arg;
|
||||
use rustls_pki_types::CertificateDer;
|
||||
use selector::{SeedPeerSelector, Selector};
|
||||
use std::io::{Error as IOError, ErrorKind};
|
||||
use std::net::{IpAddr, Ipv4Addr};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::io::AsyncRead;
|
||||
|
|
@ -51,54 +52,45 @@ const DEFAULT_CLIENT_POOL_IDLE_TIMEOUT: Duration = Duration::from_secs(30 * 60);
|
|||
/// DEFAULT_CLIENT_POOL_CAPACITY is the default capacity of the client pool.
|
||||
const DEFAULT_CLIENT_POOL_CAPACITY: usize = 128;
|
||||
|
||||
/// DEFAULT_SCHEDULER_REQUEST_TIMEOUT is the default timeout(5 seconds) for requests to the
|
||||
/// scheduler service.
|
||||
const DEFAULT_SCHEDULER_REQUEST_TIMEOUT: Duration = Duration::from_secs(5);
|
||||
|
||||
/// Result is a specialized Result type for the proxy module.
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
/// Body is the type alias for the response body reader.
|
||||
pub type Body = Box<dyn AsyncRead + Send + Unpin>;
|
||||
|
||||
/// Defines the interface for sending requests via the Dragonfly.
|
||||
///
|
||||
/// This trait enables interaction with remote servers through the Dragonfly, providing methods
|
||||
/// for performing GET requests with flexible response handling. It is designed for clients that
|
||||
/// need to communicate with Dragonfly seed client efficiently, supporting both streaming and buffered
|
||||
/// response processing. The trait shields the complex request logic between the client and the
|
||||
/// Dragonfly seed client's proxy, abstracting the underlying communication details to simplify
|
||||
/// client implementation and usage.
|
||||
#[tonic::async_trait]
|
||||
pub trait Request {
|
||||
/// get sends the get request to the remote server and returns the response reader.
|
||||
/// Sends an GET request to a remote server via the Dragonfly and returns a response
|
||||
/// with a streaming body.
|
||||
///
|
||||
/// This method is designed for scenarios where the response body is expected to be processed as a
|
||||
/// stream, allowing efficient handling of large or continuous data. The response includes metadata
|
||||
/// such as status codes and headers, along with a streaming `Body` for accessing the response content.
|
||||
async fn get(&self, request: GetRequest) -> Result<GetResponse<Body>>;
|
||||
|
||||
/// get_into sends the get request to the remote server and writes the response to the input buffer.
|
||||
/// Sends an GET request to a remote server via the Dragonfly and writes the response
|
||||
/// body directly into the provided buffer.
|
||||
///
|
||||
/// This method is optimized for scenarios where the response body needs to be stored directly in
|
||||
/// memory, avoiding the overhead of streaming for smaller or fixed-size responses. The provided
|
||||
/// `BytesMut` buffer is used to store the response content, and the response metadata (e.g., status
|
||||
/// and headers) is returned separately.
|
||||
async fn get_into(&self, request: GetRequest, buf: &mut BytesMut) -> Result<GetResponse>;
|
||||
}
|
||||
|
||||
pub struct Proxy {
|
||||
/// seed_peer_selector is the selector service for selecting seed peers.
|
||||
seed_peer_selector: Arc<SeedPeerSelector>,
|
||||
|
||||
/// max_retries is the number of times to retry a request.
|
||||
max_retries: u8,
|
||||
|
||||
/// client_pool is the pool of clients.
|
||||
client_pool: Pool<String, Client, HTTPClientFactory>,
|
||||
|
||||
/// id_generator is the task id generator.
|
||||
id_generator: Arc<IDGenerator>,
|
||||
}
|
||||
|
||||
pub struct Builder {
|
||||
/// scheduler_endpoint is the endpoint of the scheduler service.
|
||||
scheduler_endpoint: String,
|
||||
|
||||
/// health_check_interval is the interval of health check for selector(seed peers).
|
||||
health_check_interval: Duration,
|
||||
|
||||
/// max_retries is the number of times to retry a request.
|
||||
max_retries: u8,
|
||||
}
|
||||
|
||||
impl Default for Builder {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
scheduler_endpoint: "".to_string(),
|
||||
health_check_interval: Duration::from_secs(60),
|
||||
max_retries: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// GetRequest represents a GET request to be sent via the Dragonfly.
|
||||
pub struct GetRequest {
|
||||
/// url is the url of the request.
|
||||
pub url: String,
|
||||
|
|
@ -137,6 +129,7 @@ pub struct GetRequest {
|
|||
pub client_cert: Option<Vec<CertificateDer<'static>>>,
|
||||
}
|
||||
|
||||
/// GetResponse represents a GET response received via the Dragonfly.
|
||||
pub struct GetResponse<R = Body>
|
||||
where
|
||||
R: AsyncRead + Unpin,
|
||||
|
|
@ -155,52 +148,64 @@ where
|
|||
}
|
||||
|
||||
/// Factory for creating HTTPClient instances.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
struct HTTPClientFactory {}
|
||||
|
||||
/// HTTPClientFactory implements Factory for creating reqwest::Client instances with proxy support.
|
||||
#[tonic::async_trait]
|
||||
impl Factory<String, Client> for HTTPClientFactory {
|
||||
impl Factory<String, ClientWithMiddleware> for HTTPClientFactory {
|
||||
type Error = Error;
|
||||
async fn make_client(&self, proxy_addr: &String) -> Result<Client> {
|
||||
// Disable automatic compression to prevent double-decompression issues.
|
||||
//
|
||||
// Problem scenario:
|
||||
// 1. Origin server supports gzip and returns "content-encoding: gzip" header.
|
||||
// 2. Backend decompresses the response and stores uncompressed content to disk.
|
||||
// 3. When user's client downloads via dfdaemon proxy, the original "content-encoding: gzip".
|
||||
// header is forwarded to it.
|
||||
// 4. User's client attempts to decompress the already-decompressed content, causing errors.
|
||||
//
|
||||
// Solution: Disable all compression formats (gzip, brotli, zstd, deflate) to ensure
|
||||
// we receive and store uncompressed content, eliminating the double-decompression issue.
|
||||
let mut client_builder = Client::builder()
|
||||
.no_gzip()
|
||||
.no_brotli()
|
||||
.no_zstd()
|
||||
.no_deflate()
|
||||
.redirect(Policy::none())
|
||||
|
||||
/// Creates a new reqwest::Client configured to use the specified proxy address.
|
||||
async fn make_client(&self, proxy_addr: &String) -> Result<ClientWithMiddleware> {
|
||||
// TODO(chlins): Support client certificates and set `danger_accept_invalid_certs`
|
||||
// based on the certificates.
|
||||
let client = Client::builder()
|
||||
.hickory_dns(true)
|
||||
.danger_accept_invalid_certs(true)
|
||||
.pool_max_idle_per_host(POOL_MAX_IDLE_PER_HOST)
|
||||
.tcp_keepalive(KEEP_ALIVE_INTERVAL);
|
||||
.tcp_keepalive(KEEP_ALIVE_INTERVAL)
|
||||
.proxy(reqwest::Proxy::all(proxy_addr).map_err(|err| {
|
||||
Error::Internal(format!("failed to set proxy {}: {}", proxy_addr, err))
|
||||
})?)
|
||||
.build()
|
||||
.map_err(|err| Error::Internal(format!("failed to build reqwest client: {}", err)))?;
|
||||
|
||||
// Add the proxy configuration if provided.
|
||||
if !proxy_addr.is_empty() {
|
||||
let proxy = reqwest::Proxy::all(proxy_addr)?;
|
||||
client_builder = client_builder.proxy(proxy);
|
||||
}
|
||||
|
||||
let client = client_builder.build()?;
|
||||
|
||||
Ok(client)
|
||||
Ok(ClientBuilder::new(client)
|
||||
.with(TracingMiddleware::default())
|
||||
.build())
|
||||
}
|
||||
}
|
||||
|
||||
impl Proxy {
|
||||
pub fn builder() -> Builder {
|
||||
Builder::default()
|
||||
/// Builder is the builder for Proxy.
|
||||
pub struct Builder {
|
||||
/// scheduler_endpoint is the endpoint of the scheduler service.
|
||||
scheduler_endpoint: String,
|
||||
|
||||
/// scheduler_request_timeout is the timeout of the request to the scheduler service.
|
||||
scheduler_request_timeout: Duration,
|
||||
|
||||
/// health_check_interval is the interval of health check for selector(seed peers).
|
||||
health_check_interval: Duration,
|
||||
|
||||
/// max_retries is the number of times to retry a request.
|
||||
max_retries: u8,
|
||||
}
|
||||
|
||||
/// Builder implements Default trait.
|
||||
impl Default for Builder {
|
||||
/// default returns a default Builder.
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
scheduler_endpoint: "".to_string(),
|
||||
scheduler_request_timeout: DEFAULT_SCHEDULER_REQUEST_TIMEOUT,
|
||||
health_check_interval: Duration::from_secs(60),
|
||||
max_retries: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder implements the builder pattern for Proxy.
|
||||
impl Builder {
|
||||
/// Sets the scheduler endpoint.
|
||||
pub fn scheduler_endpoint(mut self, endpoint: String) -> Self {
|
||||
|
|
@ -208,6 +213,12 @@ impl Builder {
|
|||
self
|
||||
}
|
||||
|
||||
/// Sets the scheduler request timeout.
|
||||
pub fn scheduler_request_timeout(mut self, timeout: Duration) -> Self {
|
||||
self.scheduler_request_timeout = timeout;
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the health check interval.
|
||||
pub fn health_check_interval(mut self, interval: Duration) -> Self {
|
||||
self.health_check_interval = interval;
|
||||
|
|
@ -220,40 +231,51 @@ impl Builder {
|
|||
self
|
||||
}
|
||||
|
||||
/// Builds and returns a Proxy instance.
|
||||
pub async fn build(self) -> Result<Proxy> {
|
||||
// Validate input parameters.
|
||||
Self::validate(
|
||||
&self.scheduler_endpoint,
|
||||
self.health_check_interval,
|
||||
self.max_retries,
|
||||
)?;
|
||||
self.validate()?;
|
||||
|
||||
// Create the scheduler channel.
|
||||
let scheduler_channel = Endpoint::from_shared(self.scheduler_endpoint.to_string())?
|
||||
let scheduler_channel = Endpoint::from_shared(self.scheduler_endpoint.to_string())
|
||||
.map_err(|err| Error::InvalidArgument(err.to_string()))?
|
||||
.connect_timeout(self.scheduler_request_timeout)
|
||||
.timeout(self.scheduler_request_timeout)
|
||||
.connect()
|
||||
.await?;
|
||||
.await
|
||||
.map_err(|err| {
|
||||
Error::Internal(format!(
|
||||
"failed to connect to scheduler {}: {}",
|
||||
self.scheduler_endpoint, err
|
||||
))
|
||||
})?;
|
||||
|
||||
// Create scheduler client.
|
||||
let scheduler_client = SchedulerClient::new(scheduler_channel);
|
||||
|
||||
// Create seed peer selector.
|
||||
let seed_peer_selector =
|
||||
Arc::new(SeedPeerSelector::new(scheduler_client, self.health_check_interval).await?);
|
||||
let seed_peer_selector_clone = Arc::clone(&seed_peer_selector);
|
||||
let seed_peer_selector = Arc::new(
|
||||
SeedPeerSelector::new(scheduler_client, self.health_check_interval)
|
||||
.await
|
||||
.map_err(|err| {
|
||||
Error::Internal(format!("failed to create seed peer selector: {}", err))
|
||||
})?,
|
||||
);
|
||||
|
||||
let seed_peer_selector_clone = seed_peer_selector.clone();
|
||||
tokio::spawn(async move {
|
||||
// Run the selector service in the background to refresh the seed peers periodically.
|
||||
seed_peer_selector_clone.run().await;
|
||||
});
|
||||
|
||||
// Get local IP address and hostname.
|
||||
let local_ip = local_ip().unwrap_or(IpAddr::V4(Ipv4Addr::LOCALHOST));
|
||||
let local_ip = local_ip().unwrap().to_string();
|
||||
let hostname = hostname::get().unwrap().to_string_lossy().to_string();
|
||||
let id_generator = IDGenerator::new(local_ip.to_string(), hostname, true);
|
||||
|
||||
let id_generator = IDGenerator::new(local_ip, hostname, true);
|
||||
let proxy = Proxy {
|
||||
seed_peer_selector,
|
||||
max_retries: self.max_retries,
|
||||
client_pool: PoolBuilder::new(HTTPClientFactory {})
|
||||
client_pool: PoolBuilder::new(HTTPClientFactory::default())
|
||||
.capacity(DEFAULT_CLIENT_POOL_CAPACITY)
|
||||
.idle_timeout(DEFAULT_CLIENT_POOL_IDLE_TIMEOUT)
|
||||
.build(),
|
||||
|
|
@ -263,38 +285,73 @@ impl Builder {
|
|||
Ok(proxy)
|
||||
}
|
||||
|
||||
fn validate(
|
||||
scheduler_endpoint: &str,
|
||||
health_check_duration: Duration,
|
||||
max_retries: u8,
|
||||
) -> Result<()> {
|
||||
if let Err(err) = url::Url::parse(scheduler_endpoint) {
|
||||
return Err(
|
||||
DFError::ValidationError(format!("invalid scheduler endpoint: {}", err)).into(),
|
||||
);
|
||||
/// validate validates the input parameters.
|
||||
fn validate(&self) -> Result<()> {
|
||||
if let Err(err) = url::Url::parse(&self.scheduler_endpoint) {
|
||||
return Err(Error::InvalidArgument(err.to_string()));
|
||||
};
|
||||
|
||||
if health_check_duration.as_secs() < 1 || health_check_duration.as_secs() > 600 {
|
||||
return Err(DFError::ValidationError(
|
||||
"health check duration must be between 1 and 600 seconds".to_string(),
|
||||
)
|
||||
.into());
|
||||
if self.scheduler_request_timeout.as_millis() < 100 {
|
||||
return Err(Error::InvalidArgument(
|
||||
"scheduler request timeout must be at least 100 milliseconds".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if max_retries > 10 {
|
||||
return Err(DFError::ValidationError(
|
||||
if self.health_check_interval.as_secs() < 1 || self.health_check_interval.as_secs() > 600 {
|
||||
return Err(Error::InvalidArgument(
|
||||
"health check interval must be between 1 and 600 seconds".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if self.max_retries > 10 {
|
||||
return Err(Error::InvalidArgument(
|
||||
"max retries must be between 0 and 10".to_string(),
|
||||
)
|
||||
.into());
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Proxy is the HTTP proxy client that sends requests via Dragonfly.
|
||||
pub struct Proxy {
|
||||
/// seed_peer_selector is the selector service for selecting seed peers.
|
||||
seed_peer_selector: Arc<SeedPeerSelector>,
|
||||
|
||||
/// max_retries is the number of times to retry a request.
|
||||
max_retries: u8,
|
||||
|
||||
/// client_pool is the pool of clients.
|
||||
client_pool: Pool<String, ClientWithMiddleware, HTTPClientFactory>,
|
||||
|
||||
/// id_generator is the task id generator.
|
||||
id_generator: Arc<IDGenerator>,
|
||||
}
|
||||
|
||||
/// Proxy implements the proxy client that sends requests via Dragonfly.
|
||||
impl Proxy {
|
||||
/// builder returns a new Builder for Proxy.
|
||||
pub fn builder() -> Builder {
|
||||
Builder::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Implements the interface for sending requests via the Dragonfly.
|
||||
///
|
||||
/// This struct enables interaction with remote servers through the Dragonfly, providing methods
|
||||
/// for performing GET requests with flexible response handling. It is designed for clients that
|
||||
/// need to communicate with Dragonfly seed client efficiently, supporting both streaming and buffered
|
||||
/// response processing. The trait shields the complex request logic between the client and the
|
||||
/// Dragonfly seed client's proxy, abstracting the underlying communication details to simplify
|
||||
/// client implementation and usage.
|
||||
#[tonic::async_trait]
|
||||
impl Request for Proxy {
|
||||
/// Performs a GET request, handling custom errors and returning a streaming response.
|
||||
/// Sends an GET request to a remote server via the Dragonfly and returns a response
|
||||
/// with a streaming body.
|
||||
///
|
||||
/// This method is designed for scenarios where the response body is expected to be processed as a
|
||||
/// stream, allowing efficient handling of large or continuous data. The response includes metadata
|
||||
/// such as status codes and headers, along with a streaming `Body` for accessing the response content.
|
||||
#[instrument(skip_all)]
|
||||
async fn get(&self, request: GetRequest) -> Result<GetResponse> {
|
||||
let response = self.try_send(&request).await?;
|
||||
|
|
@ -314,16 +371,25 @@ impl Request for Proxy {
|
|||
})
|
||||
}
|
||||
|
||||
/// Performs a GET request, handling custom errors and collecting the response into a buffer.
|
||||
/// Sends an GET request to a remote server via the Dragonfly and writes the response
|
||||
/// body directly into the provided buffer.
|
||||
///
|
||||
/// This method is optimized for scenarios where the response body needs to be stored directly in
|
||||
/// memory, avoiding the overhead of streaming for smaller or fixed-size responses. The provided
|
||||
/// `BytesMut` buffer is used to store the response content, and the response metadata (e.g., status
|
||||
/// and headers) is returned separately.
|
||||
#[instrument(skip_all)]
|
||||
async fn get_into(&self, request: GetRequest, buf: &mut BytesMut) -> Result<GetResponse> {
|
||||
let operation = async {
|
||||
let get_into = async {
|
||||
let response = self.try_send(&request).await?;
|
||||
let status = response.status();
|
||||
let headers = response.headers().clone();
|
||||
|
||||
if status.is_success() {
|
||||
let bytes = response.bytes().await?;
|
||||
let bytes = response.bytes().await.map_err(|err| {
|
||||
Error::Internal(format!("failed to read response body: {}", err))
|
||||
})?;
|
||||
|
||||
buf.extend_from_slice(&bytes);
|
||||
}
|
||||
|
||||
|
|
@ -336,20 +402,24 @@ impl Request for Proxy {
|
|||
};
|
||||
|
||||
// Apply timeout which will properly cancel the operation when timeout is reached.
|
||||
tokio::time::timeout(request.timeout, operation)
|
||||
tokio::time::timeout(request.timeout, get_into)
|
||||
.await
|
||||
.map_err(|_| {
|
||||
DFError::Unknown(format!("request timed out after {:?}", request.timeout))
|
||||
})?
|
||||
.map_err(|err| Error::RequestTimeout(err.to_string()))?
|
||||
}
|
||||
}
|
||||
|
||||
/// Proxy implements proxy request logic.
|
||||
impl Proxy {
|
||||
/// Creates reqwest clients with proxy configuration for the given request.
|
||||
#[instrument(skip_all)]
|
||||
async fn client_entries(&self, request: &GetRequest) -> Result<Vec<Entry<Client>>> {
|
||||
// The request should be processed by the proxy, so generate the task ID to select a seed peer as proxy server.
|
||||
let parameter = match request.content_for_calculating_task_id.as_ref() {
|
||||
async fn client_entries(
|
||||
&self,
|
||||
request: &GetRequest,
|
||||
) -> Result<Vec<Entry<ClientWithMiddleware>>> {
|
||||
// Generate task id for selecting seed peer.
|
||||
let task_id = self
|
||||
.id_generator
|
||||
.task_id(match request.content_for_calculating_task_id.as_ref() {
|
||||
Some(content) => TaskIDParameter::Content(content.clone()),
|
||||
None => TaskIDParameter::URLBased {
|
||||
url: request.url.clone(),
|
||||
|
|
@ -358,22 +428,29 @@ impl Proxy {
|
|||
application: request.application.clone(),
|
||||
filtered_query_params: request.filtered_query_params.clone(),
|
||||
},
|
||||
};
|
||||
let task_id = self.id_generator.task_id(parameter)?;
|
||||
})
|
||||
.map_err(|err| Error::Internal(format!("failed to generate task id: {}", err)))?;
|
||||
|
||||
// Select seed peers.
|
||||
// Select seed peers for downloading.
|
||||
let seed_peers = self
|
||||
.seed_peer_selector
|
||||
.select(task_id, self.max_retries as u32)
|
||||
.await
|
||||
.map_err(|err| DFError::Unknown(format!("failed to select seed peers: {:?}", err)))?;
|
||||
|
||||
.map_err(|err| {
|
||||
Error::Internal(format!(
|
||||
"failed to select seed peers from scheduler: {}",
|
||||
err
|
||||
))
|
||||
})?;
|
||||
debug!("selected seed peers: {:?}", seed_peers);
|
||||
|
||||
let mut client_entries = Vec::new();
|
||||
let mut client_entries = Vec::with_capacity(seed_peers.len());
|
||||
for peer in seed_peers.iter() {
|
||||
let proxy_addr = format!("http://{}:{}", peer.ip, peer.proxy_port);
|
||||
let client_entry = self.client_pool.entry(&proxy_addr).await?;
|
||||
// TODO(chlins): Support client https scheme.
|
||||
let client_entry = self
|
||||
.client_pool
|
||||
.entry(&format!("http://{}:{}", peer.ip, peer.proxy_port))
|
||||
.await?;
|
||||
client_entries.push(client_entry);
|
||||
}
|
||||
|
||||
|
|
@ -386,10 +463,9 @@ impl Proxy {
|
|||
// Create client and send the request.
|
||||
let entries = self.client_entries(request).await?;
|
||||
if entries.is_empty() {
|
||||
return Err(DFError::Unknown(
|
||||
return Err(Error::Internal(
|
||||
"no available client entries to send request".to_string(),
|
||||
)
|
||||
.into());
|
||||
));
|
||||
}
|
||||
|
||||
for (index, entry) in entries.iter().enumerate() {
|
||||
|
|
@ -400,6 +476,8 @@ impl Proxy {
|
|||
"failed to send request to client entry {:?}: {:?}",
|
||||
entry.client, err
|
||||
);
|
||||
|
||||
// If this is the last entry, return the error.
|
||||
if index == entries.len() - 1 {
|
||||
return Err(err);
|
||||
}
|
||||
|
|
@ -407,108 +485,143 @@ impl Proxy {
|
|||
}
|
||||
}
|
||||
|
||||
Err(DFError::Unknown("all retries failed".to_string()).into())
|
||||
Err(Error::Internal(
|
||||
"failed to send request to any client entry".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
/// Send a request to the specified URL via client entry with the given headers.
|
||||
#[instrument(skip_all)]
|
||||
async fn send(&self, entry: &Entry<Client>, request: &GetRequest) -> Result<reqwest::Response> {
|
||||
let headers = make_request_headers(request)?;
|
||||
async fn send(
|
||||
&self,
|
||||
entry: &Entry<ClientWithMiddleware>,
|
||||
request: &GetRequest,
|
||||
) -> Result<reqwest::Response> {
|
||||
let headers = self.make_request_headers(request)?;
|
||||
let response = entry
|
||||
.client
|
||||
.get(&request.url)
|
||||
.headers(headers.clone())
|
||||
.timeout(request.timeout)
|
||||
.send()
|
||||
.await?;
|
||||
.await
|
||||
.map_err(|err| Error::Internal(err.to_string()))?;
|
||||
|
||||
// Check for custom Dragonfly error headers.
|
||||
if let Some(error_type) = response
|
||||
.headers()
|
||||
.get("X-Dragonfly-Error-Type")
|
||||
.and_then(|value| value.to_str().ok())
|
||||
{
|
||||
// If a known error type is found, consume the response body
|
||||
// to get the message and return a specific error.
|
||||
let status = response.status();
|
||||
let headers = response.headers().clone();
|
||||
if status.is_success() {
|
||||
return Ok(response);
|
||||
}
|
||||
|
||||
return match error_type {
|
||||
"backend" => Err(Error::BackendError(BackendError {
|
||||
message: response.text().await.ok(),
|
||||
header: headermap_to_hashmap(&headers),
|
||||
let response_headers = response.headers().clone();
|
||||
let header_map = headermap_to_hashmap(&response_headers);
|
||||
let message = response.text().await.ok();
|
||||
let error_type = response_headers
|
||||
.get("X-Dragonfly-Error-Type")
|
||||
.and_then(|v| v.to_str().ok());
|
||||
|
||||
match error_type {
|
||||
Some("backend") => Err(Error::BackendError(BackendError {
|
||||
message,
|
||||
header: header_map,
|
||||
status_code: Some(status),
|
||||
})),
|
||||
"proxy" => Err(Error::ProxyError(ProxyError {
|
||||
message: response.text().await.ok(),
|
||||
header: headermap_to_hashmap(&headers),
|
||||
Some("proxy") => Err(Error::ProxyError(ProxyError {
|
||||
message,
|
||||
header: header_map,
|
||||
status_code: Some(status),
|
||||
})),
|
||||
"dfdaemon" => Err(Error::DfdaemonError(DfdaemonError {
|
||||
message: response.text().await.ok(),
|
||||
Some("dfdaemon") => Err(Error::DfdaemonError(DfdaemonError { message })),
|
||||
Some(other) => Err(Error::ProxyError(ProxyError {
|
||||
message: Some(format!("unknown error type from proxy: {}", other)),
|
||||
header: header_map,
|
||||
status_code: Some(status),
|
||||
})),
|
||||
// Other error case we handle it as unknown error.
|
||||
_ => Err(DFError::Unknown(format!("unknown error type: {}", error_type)).into()),
|
||||
};
|
||||
None => Err(Error::ProxyError(ProxyError {
|
||||
message: Some(format!("unexpected status code from proxy: {}", status)),
|
||||
header: header_map,
|
||||
status_code: Some(status),
|
||||
})),
|
||||
}
|
||||
}
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(
|
||||
DFError::Unknown(format!("unexpected status code: {}", response.status())).into(),
|
||||
);
|
||||
}
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
}
|
||||
|
||||
/// make_request_headers applies p2p related headers to the request headers.
|
||||
fn make_request_headers(request: &GetRequest) -> Result<HeaderMap> {
|
||||
/// make_request_headers applies p2p related headers to the request headers.
|
||||
#[instrument(skip_all)]
|
||||
fn make_request_headers(&self, request: &GetRequest) -> Result<HeaderMap> {
|
||||
let mut headers = request.header.clone().unwrap_or_default();
|
||||
|
||||
// Apply the p2p related headers to the request headers.
|
||||
if let Some(piece_length) = request.piece_length {
|
||||
headers.insert(
|
||||
"X-Dragonfly-Piece-Length",
|
||||
piece_length.to_string().parse()?,
|
||||
piece_length.to_string().parse().map_err(|err| {
|
||||
Error::InvalidArgument(format!("invalid piece length: {}", err))
|
||||
})?,
|
||||
);
|
||||
}
|
||||
|
||||
if let Some(tag) = request.tag.clone() {
|
||||
headers.insert("X-Dragonfly-Tag", tag.to_string().parse()?);
|
||||
headers.insert(
|
||||
"X-Dragonfly-Tag",
|
||||
tag.to_string()
|
||||
.parse()
|
||||
.map_err(|err| Error::InvalidArgument(format!("invalid tag: {}", err)))?,
|
||||
);
|
||||
}
|
||||
|
||||
if let Some(application) = request.application.clone() {
|
||||
headers.insert("X-Dragonfly-Application", application.to_string().parse()?);
|
||||
headers.insert(
|
||||
"X-Dragonfly-Application",
|
||||
application.to_string().parse().map_err(|err| {
|
||||
Error::InvalidArgument(format!("invalid application: {}", err))
|
||||
})?,
|
||||
);
|
||||
}
|
||||
|
||||
if let Some(content_for_calculating_task_id) = request.content_for_calculating_task_id.clone() {
|
||||
if let Some(content_for_calculating_task_id) =
|
||||
request.content_for_calculating_task_id.clone()
|
||||
{
|
||||
headers.insert(
|
||||
"X-Dragonfly-Content-For-Calculating-Task-ID",
|
||||
content_for_calculating_task_id.to_string().parse()?,
|
||||
content_for_calculating_task_id
|
||||
.to_string()
|
||||
.parse()
|
||||
.map_err(|err| {
|
||||
Error::InvalidArgument(format!(
|
||||
"invalid content for calculating task id: {}",
|
||||
err
|
||||
))
|
||||
})?,
|
||||
);
|
||||
}
|
||||
|
||||
if let Some(priority) = request.priority {
|
||||
headers.insert("X-Dragonfly-Priority", priority.to_string().parse()?);
|
||||
headers.insert(
|
||||
"X-Dragonfly-Priority",
|
||||
priority
|
||||
.to_string()
|
||||
.parse()
|
||||
.map_err(|err| Error::InvalidArgument(format!("invalid priority: {}", err)))?,
|
||||
);
|
||||
}
|
||||
|
||||
if !request.filtered_query_params.is_empty() {
|
||||
let value = request.filtered_query_params.join(",");
|
||||
headers.insert("X-Dragonfly-Filtered-Query-Params", value.parse()?);
|
||||
headers.insert(
|
||||
"X-Dragonfly-Filtered-Query-Params",
|
||||
value.parse().map_err(|err| {
|
||||
Error::InvalidArgument(format!("invalid filtered query params: {}", err))
|
||||
})?,
|
||||
);
|
||||
}
|
||||
|
||||
// Always use p2p for sdk scenarios.
|
||||
headers.insert("X-Dragonfly-Use-P2P", HeaderValue::from_static("true"));
|
||||
|
||||
Ok(headers)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use dragonfly_api::scheduler::v2::ListHostsResponse;
|
||||
use dragonfly_client_core::error::DFError;
|
||||
use mocktail::prelude::*;
|
||||
use std::time::Duration;
|
||||
|
||||
|
|
@ -521,20 +634,16 @@ mod tests {
|
|||
});
|
||||
|
||||
let server = MockServer::new_grpc("scheduler.v2.Scheduler").with_mocks(mocks);
|
||||
match server.start().await {
|
||||
Ok(_) => Ok(server),
|
||||
Err(err) => Err(DFError::Unknown(format!(
|
||||
"failed to start mock scheduler server: {}",
|
||||
err,
|
||||
))
|
||||
.into()),
|
||||
}
|
||||
server.start().await.map_err(|err| {
|
||||
Error::Internal(format!("failed to start mock scheduler server: {}", err))
|
||||
})?;
|
||||
|
||||
Ok(server)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_proxy_new_success() -> Result<()> {
|
||||
let mock_server = setup_mock_scheduler().await?;
|
||||
|
||||
async fn test_proxy_new_success() {
|
||||
let mock_server = setup_mock_scheduler().await.unwrap();
|
||||
let scheduler_endpoint = format!("http://0.0.0.0:{}", mock_server.port().unwrap());
|
||||
let result = Proxy::builder()
|
||||
.scheduler_endpoint(scheduler_endpoint)
|
||||
|
|
@ -542,10 +651,7 @@ mod tests {
|
|||
.await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
let proxy = result.unwrap();
|
||||
assert_eq!(proxy.max_retries, 1);
|
||||
|
||||
Ok(())
|
||||
assert_eq!(result.unwrap().max_retries, 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
|
@ -556,15 +662,12 @@ mod tests {
|
|||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert!(matches!(
|
||||
result,
|
||||
Err(Error::Base(DFError::ValidationError(_)))
|
||||
));
|
||||
assert!(matches!(result, Err(Error::InvalidArgument(_))));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_proxy_new_invalid_retry_times() -> Result<()> {
|
||||
let mock_server = setup_mock_scheduler().await?;
|
||||
async fn test_proxy_new_invalid_retry_times() {
|
||||
let mock_server = setup_mock_scheduler().await.unwrap();
|
||||
|
||||
let scheduler_endpoint = format!("http://0.0.0.0:{}", mock_server.port().unwrap());
|
||||
let result = Proxy::builder()
|
||||
|
|
@ -574,17 +677,12 @@ mod tests {
|
|||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert!(matches!(
|
||||
result,
|
||||
Err(Error::Base(DFError::ValidationError(_)))
|
||||
));
|
||||
|
||||
Ok(())
|
||||
assert!(matches!(result, Err(Error::InvalidArgument(_))));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_proxy_new_invalid_health_check_interval() -> Result<()> {
|
||||
let mock_server = setup_mock_scheduler().await?;
|
||||
async fn test_proxy_new_invalid_health_check_interval() {
|
||||
let mock_server = setup_mock_scheduler().await.unwrap();
|
||||
|
||||
let scheduler_endpoint = format!("http://0.0.0.0:{}", mock_server.port().unwrap());
|
||||
let result = Proxy::builder()
|
||||
|
|
@ -594,12 +692,7 @@ mod tests {
|
|||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert!(matches!(
|
||||
result,
|
||||
Err(Error::Base(DFError::ValidationError(_)))
|
||||
));
|
||||
|
||||
Ok(())
|
||||
assert!(matches!(result, Err(Error::InvalidArgument(_))));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
|
@ -641,6 +734,7 @@ mod tests {
|
|||
|
||||
// Create another client.
|
||||
let _ = pool.entry(&"http://proxy2.com".to_string()).await.unwrap();
|
||||
|
||||
// Still should be 1 because the proxy1 client should have been cleaned up.
|
||||
assert_eq!(pool.size().await, 1);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,7 +18,10 @@ use crate::hashring::VNodeHashRing;
|
|||
use crate::shutdown;
|
||||
use dragonfly_api::common::v2::Host;
|
||||
use dragonfly_api::scheduler::v2::scheduler_client::SchedulerClient;
|
||||
use dragonfly_client_core::{Error, Result};
|
||||
use dragonfly_client_core::{
|
||||
error::{ErrorType, OrErr},
|
||||
Error, Result,
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
use std::net::SocketAddr;
|
||||
use std::time::Duration;
|
||||
|
|
@ -33,6 +36,7 @@ use tracing::{debug, error, info, instrument};
|
|||
/// Selector is the interface for selecting item from a list of items by a specific criteria.
|
||||
#[tonic::async_trait]
|
||||
pub trait Selector: Send + Sync {
|
||||
/// select selects items based on the given task_id and number of replicas.
|
||||
async fn select(&self, task_id: String, replicas: u32) -> Result<Vec<Host>>;
|
||||
}
|
||||
|
||||
|
|
@ -42,6 +46,7 @@ const SEED_PEERS_HEALTH_CHECK_TIMEOUT: Duration = Duration::from_secs(5);
|
|||
/// DEFAULT_VNODES_PER_HOST is the default number of virtual nodes per host.
|
||||
const DEFAULT_VNODES_PER_HOST: usize = 3;
|
||||
|
||||
/// SeedPeers holds the data of seed peers.
|
||||
struct SeedPeers {
|
||||
/// hosts is the list of seed peers.
|
||||
hosts: HashMap<String, Host>,
|
||||
|
|
@ -50,6 +55,7 @@ struct SeedPeers {
|
|||
hashring: VNodeHashRing,
|
||||
}
|
||||
|
||||
/// SeedPeerSelector is the implementation of Selector for seed peers.
|
||||
pub struct SeedPeerSelector {
|
||||
/// health_check_interval is the interval of health check for seed peers.
|
||||
health_check_interval: Duration,
|
||||
|
|
@ -60,10 +66,11 @@ pub struct SeedPeerSelector {
|
|||
/// seed_peers is the data of the seed peer selector.
|
||||
seed_peers: RwLock<SeedPeers>,
|
||||
|
||||
/// mutex is used to protect refresh.
|
||||
/// mutex is used to protect hot refresh.
|
||||
mutex: Mutex<()>,
|
||||
}
|
||||
|
||||
/// SeedPeerSelector implements a selector that selects seed peers from the scheduler service.
|
||||
impl SeedPeerSelector {
|
||||
/// new creates a new seed peer selector.
|
||||
pub async fn new(
|
||||
|
|
@ -86,10 +93,6 @@ impl SeedPeerSelector {
|
|||
|
||||
/// run starts the seed peer selector service.
|
||||
pub async fn run(&self) {
|
||||
// Receive shutdown signal.
|
||||
let mut shutdown = shutdown::Shutdown::default();
|
||||
|
||||
// Start the refresh loop.
|
||||
let mut interval = tokio::time::interval(self.health_check_interval);
|
||||
loop {
|
||||
tokio::select! {
|
||||
|
|
@ -99,7 +102,7 @@ impl SeedPeerSelector {
|
|||
Ok(_) => debug!("succeed to refresh seed peers"),
|
||||
}
|
||||
}
|
||||
_ = shutdown.recv() => {
|
||||
_ = shutdown::shutdown_signal() => {
|
||||
info!("seed peer selector service is shutting down");
|
||||
return;
|
||||
}
|
||||
|
|
@ -135,11 +138,13 @@ impl SeedPeerSelector {
|
|||
|
||||
let mut hosts = HashMap::with_capacity(seed_peers_length);
|
||||
let mut hashring = VNodeHashRing::new(DEFAULT_VNODES_PER_HOST);
|
||||
|
||||
while let Some(result) = join_set.join_next().await {
|
||||
match result {
|
||||
Ok(Ok(peer)) => {
|
||||
let addr: SocketAddr = format!("{}:{}", peer.ip, peer.port).parse().unwrap();
|
||||
let addr: SocketAddr = format!("{}:{}", peer.ip, peer.port)
|
||||
.parse()
|
||||
.or_err(ErrorType::ParseError)?;
|
||||
|
||||
hashring.add(addr);
|
||||
hosts.insert(addr.to_string(), peer);
|
||||
}
|
||||
|
|
@ -179,16 +184,16 @@ impl SeedPeerSelector {
|
|||
/// check_health checks the health of each seed peer.
|
||||
#[instrument(skip_all)]
|
||||
async fn check_health(addr: &str) -> Result<HealthCheckResponse> {
|
||||
let endpoint =
|
||||
Endpoint::from_shared(addr.to_string())?.timeout(SEED_PEERS_HEALTH_CHECK_TIMEOUT);
|
||||
let channel = Endpoint::from_shared(addr.to_string())?
|
||||
.connect_timeout(SEED_PEERS_HEALTH_CHECK_TIMEOUT)
|
||||
.connect()
|
||||
.await?;
|
||||
|
||||
let channel = endpoint.connect().await?;
|
||||
let mut client = HealthGRPCClient::new(channel);
|
||||
|
||||
let request = tonic::Request::new(HealthCheckRequest {
|
||||
let mut request = tonic::Request::new(HealthCheckRequest {
|
||||
service: "".to_string(),
|
||||
});
|
||||
|
||||
request.set_timeout(SEED_PEERS_HEALTH_CHECK_TIMEOUT);
|
||||
let response = client.check(request).await?;
|
||||
Ok(response.into_inner())
|
||||
}
|
||||
|
|
@ -200,21 +205,21 @@ impl Selector for SeedPeerSelector {
|
|||
async fn select(&self, task_id: String, replicas: u32) -> Result<Vec<Host>> {
|
||||
// Acquire a read lock and perform all logic within it.
|
||||
let seed_peers = self.seed_peers.read().await;
|
||||
|
||||
if seed_peers.hosts.is_empty() {
|
||||
return Err(Error::HostNotFound("seed peers".to_string()));
|
||||
}
|
||||
|
||||
// The number of replicas cannot exceed the total number of seed peers.
|
||||
let desired_replicas = std::cmp::min(replicas as usize, seed_peers.hashring.len());
|
||||
let expected_replicas = std::cmp::min(replicas as usize, seed_peers.hashring.len());
|
||||
debug!("expected replicas: {}", expected_replicas);
|
||||
|
||||
// Get replica nodes from the hash ring.
|
||||
let vnodes = seed_peers
|
||||
.hashring
|
||||
.get_with_replicas(&task_id, desired_replicas)
|
||||
.get_with_replicas(&task_id, expected_replicas)
|
||||
.unwrap_or_default();
|
||||
|
||||
let seed_peers = vnodes
|
||||
let seed_peers: Vec<Host> = vnodes
|
||||
.into_iter()
|
||||
.filter_map(|vnode| {
|
||||
seed_peers
|
||||
|
|
@ -224,6 +229,10 @@ impl Selector for SeedPeerSelector {
|
|||
})
|
||||
.collect();
|
||||
|
||||
if seed_peers.is_empty() {
|
||||
return Err(Error::HostNotFound("selected seed peers".to_string()));
|
||||
}
|
||||
|
||||
Ok(seed_peers)
|
||||
}
|
||||
}
|
||||
|
|
@ -98,9 +98,12 @@ struct DfdaemonUploadClientFactory {
|
|||
config: Arc<Config>,
|
||||
}
|
||||
|
||||
/// DfdaemonUploadClientFactory implements the Factory trait for creating DfdaemonUploadClient
|
||||
/// instances.
|
||||
#[tonic::async_trait]
|
||||
impl Factory<String, DfdaemonUploadClient> for DfdaemonUploadClientFactory {
|
||||
type Error = Error;
|
||||
/// Creates a new DfdaemonUploadClient for the given address.
|
||||
async fn make_client(&self, addr: &String) -> Result<DfdaemonUploadClient> {
|
||||
DfdaemonUploadClient::new(self.config.clone(), format!("http://{}", addr), true).await
|
||||
}
|
||||
|
|
@ -121,13 +124,9 @@ pub struct GRPCDownloader {
|
|||
impl GRPCDownloader {
|
||||
/// new returns a new GRPCDownloader.
|
||||
pub fn new(config: Arc<Config>, capacity: usize, idle_timeout: Duration) -> Self {
|
||||
let factory = DfdaemonUploadClientFactory {
|
||||
config: config.clone(),
|
||||
};
|
||||
|
||||
Self {
|
||||
config,
|
||||
client_pool: PoolBuilder::new(factory)
|
||||
config: config.clone(),
|
||||
client_pool: PoolBuilder::new(DfdaemonUploadClientFactory { config })
|
||||
.capacity(capacity)
|
||||
.idle_timeout(idle_timeout)
|
||||
.build(),
|
||||
|
|
@ -297,9 +296,12 @@ struct TCPClientFactory {
|
|||
config: Arc<Config>,
|
||||
}
|
||||
|
||||
/// TCPClientFactory implements the Factory trait for creating TCPClient instances.
|
||||
#[tonic::async_trait]
|
||||
impl Factory<String, TCPClient> for TCPClientFactory {
|
||||
type Error = Error;
|
||||
|
||||
/// Creates a new TCPClient for the given address.
|
||||
async fn make_client(&self, addr: &String) -> Result<TCPClient> {
|
||||
Ok(TCPClient::new(self.config.clone(), addr.clone()))
|
||||
}
|
||||
|
|
@ -309,12 +311,10 @@ impl Factory<String, TCPClient> for TCPClientFactory {
|
|||
impl TCPDownloader {
|
||||
/// new returns a new TCPDownloader.
|
||||
pub fn new(config: Arc<Config>, capacity: usize, idle_timeout: Duration) -> Self {
|
||||
let factory = TCPClientFactory {
|
||||
config: config.clone(),
|
||||
};
|
||||
|
||||
Self {
|
||||
client_pool: PoolBuilder::new(factory)
|
||||
client_pool: PoolBuilder::new(TCPClientFactory {
|
||||
config: config.clone(),
|
||||
})
|
||||
.capacity(capacity)
|
||||
.idle_timeout(idle_timeout)
|
||||
.build(),
|
||||
|
|
@ -346,6 +346,7 @@ impl Downloader for TCPDownloader {
|
|||
) -> Result<(Box<dyn AsyncRead + Send + Unpin>, u64, String)> {
|
||||
let entry = self.get_client_entry(addr).await?;
|
||||
let request_guard = entry.request_guard();
|
||||
|
||||
match entry.client.download_piece(number, task_id).await {
|
||||
Ok((reader, offset, digest)) => Ok((Box::new(reader), offset, digest)),
|
||||
Err(err) => {
|
||||
|
|
@ -370,6 +371,7 @@ impl Downloader for TCPDownloader {
|
|||
) -> Result<(Box<dyn AsyncRead + Send + Unpin>, u64, String)> {
|
||||
let entry = self.get_client_entry(addr).await?;
|
||||
let request_guard = entry.request_guard();
|
||||
|
||||
match entry
|
||||
.client
|
||||
.download_persistent_cache_piece(number, task_id)
|
||||
|
|
|
|||
Loading…
Reference in New Issue