This commit is contained in:
Fan Tianlan 2025-09-29 22:41:38 +08:00 committed by GitHub
commit 15648484d0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 1141 additions and 406 deletions

72
Cargo.lock generated
View File

@ -1001,6 +1001,12 @@ dependencies = [
"const-random",
]
[[package]]
name = "downcast"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1435fa1053d8b2fbbe9be7e97eca7f33d37b28409959813daefc1446a14247f1"
[[package]]
name = "dragonfly-api"
version = "2.1.70"
@ -1222,6 +1228,7 @@ dependencies = [
"dragonfly-client-util",
"fs2",
"leaky-bucket",
"mockall",
"nix 0.30.1",
"num_cpus",
"prost-wkt-types",
@ -1235,6 +1242,7 @@ dependencies = [
"tempfile",
"tokio",
"tokio-util",
"tonic",
"tracing",
"vortex-protocol",
"walkdir",
@ -1451,6 +1459,12 @@ dependencies = [
"percent-encoding",
]
[[package]]
name = "fragile"
version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "28dd6caf6059519a65843af8fe2a3ae298b14b80179855aeb4adc2c1934ee619"
[[package]]
name = "fs2"
version = "0.4.3"
@ -2738,6 +2752,32 @@ dependencies = [
"windows-sys 0.52.0",
]
[[package]]
name = "mockall"
version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39a6bfcc6c8c7eed5ee98b9c3e33adc726054389233e201c95dab2d41a3839d2"
dependencies = [
"cfg-if",
"downcast",
"fragile",
"mockall_derive",
"predicates",
"predicates-tree",
]
[[package]]
name = "mockall_derive"
version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "25ca3004c2efe9011bd4e461bd8256445052b9615405b4f7ea43fc8ca5c20898"
dependencies = [
"cfg-if",
"proc-macro2",
"quote",
"syn 2.0.90",
]
[[package]]
name = "mocktail"
version = "0.3.0"
@ -3598,6 +3638,32 @@ version = "0.2.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
[[package]]
name = "predicates"
version = "3.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a5d19ee57562043d37e82899fade9a22ebab7be9cef5026b07fda9cdd4293573"
dependencies = [
"anstyle",
"predicates-core",
]
[[package]]
name = "predicates-core"
version = "1.0.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "727e462b119fe9c93fd0eb1429a5f7647394014cf3c04ab2c0350eeb09095ffa"
[[package]]
name = "predicates-tree"
version = "1.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72dd2d6d381dfb73a193c7fca536518d7caee39fc8503f74e7dc0be0531b425c"
dependencies = [
"predicates-core",
"termtree",
]
[[package]]
name = "prettyplease"
version = "0.2.17"
@ -5110,6 +5176,12 @@ dependencies = [
"redox_termios",
]
[[package]]
name = "termtree"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f50febec83f5ee1df3015341d8bd429f2d1cc62bcba7ea2076759d315084683"
[[package]]
name = "testing_table"
version = "0.3.0"

View File

@ -31,11 +31,13 @@ bytesize.workspace = true
leaky-bucket.workspace = true
vortex-protocol.workspace = true
rustls.workspace = true
tonic.workspace = true
num_cpus = "1.17"
bincode = "1.3.3"
walkdir = "2.5.0"
quinn = "0.11.9"
socket2 = "0.6.0"
mockall = "0.13.1"
[dev-dependencies]
tempfile.workspace = true

View File

@ -17,7 +17,22 @@
pub mod quic;
pub mod tcp;
use bytes::{Bytes, BytesMut};
use dragonfly_client_config::dfdaemon::Config;
use dragonfly_client_core::{Error as ClientError, Result as ClientResult};
use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
use tokio::time;
use tracing::{error, instrument};
use vortex_protocol::{
tlv::{
download_persistent_cache_piece::DownloadPersistentCachePiece,
download_piece::DownloadPiece, error::Error as VortexError, persistent_cache_piece_content,
piece_content, Tag,
},
Header, Vortex, HEADER_SIZE,
};
/// DEFAULT_SEND_BUFFER_SIZE is the default size of the send buffer for network connections.
const DEFAULT_SEND_BUFFER_SIZE: usize = 16 * 1024 * 1024;
@ -30,3 +45,604 @@ const DEFAULT_KEEPALIVE_INTERVAL: Duration = Duration::from_secs(5);
/// DEFAULT_MAX_IDLE_TIMEOUT is the default maximum idle timeout for connections.
const DEFAULT_MAX_IDLE_TIMEOUT: Duration = Duration::from_secs(300);
/// Client defines a generic client interface for storage service protocols.
#[tonic::async_trait]
pub trait Client {
/// Downloads a piece from the server using the vortex protocol.
///
/// This is the main entry point for downloading a piece. It applies
/// a timeout based on the configuration and handles connection timeouts gracefully.
#[instrument(skip_all)]
async fn download_piece(
&self,
number: u32,
task_id: &str,
) -> ClientResult<(Box<dyn AsyncRead + Send + Unpin>, u64, String)> {
time::timeout(
self.config().download.piece_timeout,
self.handle_download_piece(number, task_id),
)
.await
.inspect_err(|err| {
error!("connect timeout to {}: {}", self.addr(), err);
})?
}
/// Internal handler for downloading a piece.
///
/// This method performs the actual protocol communication:
/// 1. Creates a download piece request.
/// 2. Establishes connection and sends the request.
/// 3. Reads and validates the response header.
/// 4. Processes the piece content based on the response type.
#[instrument(skip_all)]
async fn handle_download_piece(
&self,
number: u32,
task_id: &str,
) -> ClientResult<(Box<dyn AsyncRead + Send + Unpin>, u64, String)> {
let request: Bytes = Vortex::DownloadPiece(
Header::new_download_piece(),
DownloadPiece::new(task_id.to_string(), number),
)
.into();
let (mut reader, _writer) = self.connect_and_write_request(request).await?;
let header = self.read_header(&mut reader).await?;
match header.tag() {
Tag::PieceContent => {
let piece_content: piece_content::PieceContent = self
.read_piece_content(&mut reader, piece_content::METADATA_LENGTH_SIZE)
.await?;
let metadata = piece_content.metadata();
Ok((reader, metadata.offset, metadata.digest))
}
Tag::Error => Err(self.read_error(&mut reader, header.length() as usize).await),
_ => Err(ClientError::Unknown(format!(
"unexpected tag: {:?}",
header.tag()
))),
}
}
/// Downloads a persistent cache piece from the server using the vortex protocol.
///
/// Similar to `download_piece` but specifically for persistent cache piece.
#[instrument(skip_all)]
async fn download_persistent_cache_piece(
&self,
number: u32,
task_id: &str,
) -> ClientResult<(Box<dyn AsyncRead + Send + Unpin>, u64, String)> {
time::timeout(
self.config().download.piece_timeout,
self.handle_download_persistent_cache_piece(number, task_id),
)
.await
.inspect_err(|err| {
error!("connect timeout to {}: {}", self.addr(), err);
})?
}
/// Internal handler for downloading a persistent cache piece.
///
/// Implements the same protocol flow as `handle_download_piece` but uses
/// persistent cache specific request/response types.
#[instrument(skip_all)]
async fn handle_download_persistent_cache_piece(
&self,
number: u32,
task_id: &str,
) -> ClientResult<(Box<dyn AsyncRead + Send + Unpin>, u64, String)> {
let request: Bytes = Vortex::DownloadPersistentCachePiece(
Header::new_download_persistent_cache_piece(),
DownloadPersistentCachePiece::new(task_id.to_string(), number),
)
.into();
let (mut reader, _writer) = self.connect_and_write_request(request).await?;
let header = self.read_header(&mut reader).await?;
match header.tag() {
Tag::PersistentCachePieceContent => {
let persistent_cache_piece_content: persistent_cache_piece_content::PersistentCachePieceContent =
self.read_piece_content(&mut reader, piece_content::METADATA_LENGTH_SIZE)
.await?;
let metadata = persistent_cache_piece_content.metadata();
Ok((reader, metadata.offset, metadata.digest))
}
Tag::Error => Err(self.read_error(&mut reader, header.length() as usize).await),
_ => Err(ClientError::Unknown(format!(
"unexpected tag: {:?}",
header.tag()
))),
}
}
/// Establishes connection and writes a vortex protocol request.
///
/// This is a low-level utility function that handles the connection
/// lifecycle and request transmission. It ensures proper error handling
/// and connection cleanup.
async fn connect_and_write_request(
&self,
request: Bytes,
) -> ClientResult<(
Box<dyn AsyncRead + Send + Unpin>,
Box<dyn AsyncWrite + Send + Unpin>,
)>;
/// Reads and parses a vortex protocol header.
///
/// The header contains metadata about the following message, including
/// the message type (tag) and payload length. This is critical for
/// proper protocol message framing.
#[instrument(skip_all)]
async fn read_header(
&self,
reader: &mut Box<dyn AsyncRead + Send + Unpin>,
) -> ClientResult<Header> {
let mut header_bytes = BytesMut::with_capacity(HEADER_SIZE);
header_bytes.resize(HEADER_SIZE, 0);
reader
.read_exact(&mut header_bytes)
.await
.inspect_err(|err| {
error!("failed to receive header: {}", err);
})?;
Header::try_from(header_bytes.freeze()).map_err(Into::into)
}
/// Reads and parses piece content with variable-length metadata.
///
/// This generic function handles the two-stage reading process for
/// piece content: first reading the metadata length, then reading
/// the actual metadata, and finally constructing the complete message.
#[instrument(skip_all)]
async fn read_piece_content<T>(
&self,
reader: &mut Box<dyn AsyncRead + Send + Unpin>,
metadata_length_size: usize,
) -> ClientResult<T>
where
T: TryFrom<Bytes>,
T::Error: Into<ClientError>,
T: 'static,
{
let mut metadata_length_bytes = BytesMut::with_capacity(metadata_length_size);
metadata_length_bytes.resize(metadata_length_size, 0);
reader
.read_exact(&mut metadata_length_bytes)
.await
.inspect_err(|err| {
error!("failed to receive metadata length: {}", err);
})?;
let metadata_length = u32::from_be_bytes(metadata_length_bytes[..].try_into()?) as usize;
let mut metadata_bytes = BytesMut::with_capacity(metadata_length);
metadata_bytes.resize(metadata_length, 0);
reader
.read_exact(&mut metadata_bytes)
.await
.inspect_err(|err| {
error!("failed to receive metadata: {}", err);
})?;
let mut content_bytes = BytesMut::with_capacity(metadata_length_size + metadata_length);
content_bytes.extend_from_slice(&metadata_length_bytes);
content_bytes.extend_from_slice(&metadata_bytes);
content_bytes.freeze().try_into().map_err(Into::into)
}
/// Reads and processes error responses from the server.
///
/// When the server responds with an error tag, this function reads
/// the error payload and converts it into an appropriate client error.
/// This provides structured error handling for protocol-level failures.
#[instrument(skip_all)]
async fn read_error(
&self,
reader: &mut Box<dyn AsyncRead + Send + Unpin>,
header_length: usize,
) -> ClientError {
let mut error_bytes = BytesMut::with_capacity(header_length);
error_bytes.resize(header_length, 0);
if let Err(err) = reader.read_exact(&mut error_bytes).await {
error!("failed to receive error: {}", err);
return ClientError::IO(err);
};
error_bytes
.freeze()
.try_into()
.map(|error: VortexError| {
ClientError::VortexProtocolStatus(error.code(), error.message().to_string())
})
.unwrap_or_else(|err| {
error!("failed to extract error: {}", err);
ClientError::Unknown(format!("failed to extract error: {}", err))
})
}
/// Access to client configuration.
fn config(&self) -> &Arc<Config>;
/// Access to client address.
fn addr(&self) -> &str;
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::{BufMut, Bytes, BytesMut};
use dragonfly_client_config::dfdaemon::Download;
use dragonfly_client_core::{Error as ClientError, Result as ClientResult};
use std::io::Cursor;
use tokio::io::duplex;
use tokio::time::sleep;
struct Mock {
config: Arc<Config>,
addr: String,
timeout: Duration,
}
#[tonic::async_trait]
impl Client for Mock {
async fn handle_download_piece(
&self,
_number: u32,
_task_id: &str,
) -> ClientResult<(Box<dyn AsyncRead + Send + Unpin>, u64, String)> {
sleep(self.timeout).await;
Ok((Box::new(Cursor::new(b"data")), 0, "hash".to_string()))
}
async fn handle_download_persistent_cache_piece(
&self,
_number: u32,
_task_id: &str,
) -> ClientResult<(Box<dyn AsyncRead + Send + Unpin>, u64, String)> {
sleep(self.timeout).await;
Ok((Box::new(Cursor::new(b"data")), 0, "hash".to_string()))
}
async fn connect_and_write_request(
&self,
_request: Bytes,
) -> ClientResult<(
Box<dyn AsyncRead + Send + Unpin>,
Box<dyn AsyncWrite + Send + Unpin>,
)> {
let (reader, writer) = duplex(1);
Ok((Box::new(reader), Box::new(writer)))
}
fn config(&self) -> &Arc<Config> {
&self.config
}
fn addr(&self) -> &str {
&self.addr
}
}
#[tokio::test]
async fn test_download_piece() {
let addr = "127.0.0.1:8080".to_string();
let config = Arc::new(Config {
download: Download {
piece_timeout: Duration::from_secs(2),
..Default::default()
},
..Default::default()
});
let mut mock = Mock {
config,
addr,
timeout: Duration::from_secs(1),
};
let result = mock.download_piece(1, "task").await;
assert!(result.is_ok());
if let Ok((mut reader, offset, digest)) = result {
let mut buf = Vec::new();
reader.read_to_end(&mut buf).await.unwrap();
assert_eq!(buf, b"data");
assert_eq!(offset, 0);
assert_eq!(digest, "hash");
}
mock.timeout = Duration::from_secs(3);
let result = mock.download_piece(1, "task").await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_download_persistent_cache_piece() {
let addr = "127.0.0.1:8080".to_string();
let config = Arc::new(Config {
download: Download {
piece_timeout: Duration::from_secs(2),
..Default::default()
},
..Default::default()
});
let mut mock = Mock {
config,
addr,
timeout: Duration::from_secs(1),
};
let result = mock.download_persistent_cache_piece(1, "task").await;
assert!(result.is_ok());
if let Ok((mut reader, offset, digest)) = result {
let mut buf = Vec::new();
reader.read_to_end(&mut buf).await.unwrap();
assert_eq!(buf, b"data");
assert_eq!(offset, 0);
assert_eq!(digest, "hash");
}
mock.timeout = Duration::from_secs(3);
let result = mock.download_persistent_cache_piece(1, "task").await;
assert!(result.is_err());
}
mockall::mock! {
Client {}
#[tonic::async_trait]
impl Client for Client {
async fn connect_and_write_request(
&self,
request: Bytes,
) -> ClientResult<(Box<dyn AsyncRead + Send + Unpin>, Box<dyn AsyncWrite + Send + Unpin>)>;
async fn read_header(&self, reader: &mut Box<dyn AsyncRead + Send + Unpin>) -> ClientResult<Header>;
async fn read_piece_content<T>(
&self,
reader: &mut Box<dyn AsyncRead + Send + Unpin>,
metadata_length_size: usize,
) -> ClientResult<T>
where
T: TryFrom<Bytes>,
T::Error: Into<ClientError>,
T: 'static;
async fn read_error(&self, reader: &mut Box<dyn AsyncRead + Send + Unpin>, header_length: usize) -> ClientError;
fn config(&self) -> &Arc<Config>;
fn addr(&self) -> &str;
}
}
#[tokio::test]
async fn test_handle_download_piece() {
let mut mock = MockClient::new();
mock.expect_connect_and_write_request().returning(|_| {
let (reader, writer) = duplex(1);
Ok((Box::new(reader), Box::new(writer)))
});
mock.expect_read_header()
.times(1)
.returning(|_| Ok(Header::new_piece_content(1024)));
mock.expect_read_piece_content().returning(|_, _| {
Ok(piece_content::PieceContent::new(
42,
1024,
2048,
"a".repeat(32),
"test_parent_id".to_string(),
1,
Duration::from_secs(5),
chrono::DateTime::from_timestamp(1693152000, 0)
.unwrap()
.naive_utc(),
))
});
let result = mock.handle_download_piece(1, "task").await;
assert!(result.is_ok());
if let Ok((_, offset, digest)) = result {
assert_eq!(offset, 1024);
assert_eq!(digest, "a".repeat(32));
}
mock.expect_read_header()
.times(1)
.returning(|_| Ok(Header::new_error(1024)));
mock.expect_read_error()
.returning(|_, _| ClientError::Unknown("test".to_string()));
let result = mock.handle_download_piece(1, "task").await;
assert!(result.is_err());
if let Err(err) = result {
assert!(format!("{:?}", err).contains("test"));
}
mock.expect_read_header()
.returning(|_| Ok(Header::new_close()));
let result = mock.handle_download_piece(1, "task").await;
assert!(result.is_err());
if let Err(err) = result {
assert!(format!("{:?}", err).contains("unexpected tag"));
}
}
#[tokio::test]
async fn test_handle_download_persistent_cache_piece() {
let mut mock = MockClient::new();
mock.expect_connect_and_write_request().returning(|_| {
let (reader, writer) = duplex(1);
Ok((Box::new(reader), Box::new(writer)))
});
mock.expect_read_header()
.times(1)
.returning(|_| Ok(Header::new_persistent_cache_piece_content(1024)));
mock.expect_read_piece_content().returning(|_, _| {
Ok(
persistent_cache_piece_content::PersistentCachePieceContent::new(
42,
1024,
2048,
"a".repeat(32),
"test_parent_id".to_string(),
1,
Duration::from_secs(5),
chrono::DateTime::from_timestamp(1693152000, 0)
.unwrap()
.naive_utc(),
),
)
});
let result = mock.handle_download_persistent_cache_piece(1, "task").await;
assert!(result.is_ok());
if let Ok((_, offset, digest)) = result {
assert_eq!(offset, 1024);
assert_eq!(digest, "a".repeat(32));
}
mock.expect_read_header()
.times(1)
.returning(|_| Ok(Header::new_error(1024)));
mock.expect_read_error()
.returning(|_, _| ClientError::Unknown("test".to_string()));
let result = mock.handle_download_persistent_cache_piece(1, "task").await;
assert!(result.is_err());
if let Err(err) = result {
assert!(format!("{:?}", err).contains("test"));
}
mock.expect_read_header()
.returning(|_| Ok(Header::new_close()));
let result = mock.handle_download_persistent_cache_piece(1, "task").await;
assert!(result.is_err());
if let Err(err) = result {
assert!(format!("{:?}", err).contains("unexpected tag"));
}
}
#[tokio::test]
async fn test_read_header() {
let addr = "127.0.0.1:8080".to_string();
let config = Arc::new(Config::default());
let mock = Mock {
config,
addr,
timeout: Duration::from_secs(1),
};
let mut reader: Box<dyn AsyncRead + Send + Unpin> = Box::new(Cursor::new(b"HEADER_SIZE"));
let result = mock.read_header(&mut reader).await;
assert!(result.is_ok());
let mut reader: Box<dyn AsyncRead + Send + Unpin> = Box::new(Cursor::new(b""));
let result = mock.read_header(&mut reader).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_read_piece_content() {
let addr = "127.0.0.1:8080".to_string();
let config = Arc::new(Config::default());
let mock = Mock {
config,
addr,
timeout: Duration::from_secs(1),
};
let piece_content = piece_content::PieceContent::new(
42,
1024,
2048,
"a".repeat(32),
"test_parent_id".to_string(),
1,
Duration::from_secs(5),
chrono::DateTime::from_timestamp(1693152000, 0)
.unwrap()
.naive_utc(),
);
let piece_content_bytes: Bytes = piece_content.into();
let mut reader: Box<dyn AsyncRead + Send + Unpin> =
Box::new(Cursor::new(piece_content_bytes));
let result: ClientResult<piece_content::PieceContent> = mock
.read_piece_content(&mut reader, piece_content::METADATA_LENGTH_SIZE)
.await;
assert!(result.is_ok());
if let Ok(content) = result {
assert_eq!(content.metadata().number, 42);
assert_eq!(content.metadata().offset, 1024);
assert_eq!(content.metadata().length, 2048);
assert_eq!(content.metadata().digest, "a".repeat(32));
assert_eq!(content.metadata().parent_id, "test_parent_id".to_string());
assert_eq!(content.metadata().traffic_type, 1);
assert_eq!(content.metadata().cost, Duration::from_secs(5));
}
let mut reader: Box<dyn AsyncRead + Send + Unpin> = Box::new(Cursor::new(b""));
let result: ClientResult<piece_content::PieceContent> = mock
.read_piece_content(&mut reader, piece_content::METADATA_LENGTH_SIZE)
.await;
assert!(result.is_err());
let mut reader: Box<dyn AsyncRead + Send + Unpin> = Box::new(Cursor::new(b"METADATA"));
let result: ClientResult<piece_content::PieceContent> = mock
.read_piece_content(&mut reader, piece_content::METADATA_LENGTH_SIZE)
.await;
assert!(result.is_err());
let data = {
let mut bytes = BytesMut::new();
bytes.put_u32(100);
bytes.put(&vec![0u8; 50][..]);
bytes.freeze()
};
let mut reader: Box<dyn AsyncRead + Send + Unpin> = Box::new(Cursor::new(data));
let result: ClientResult<piece_content::PieceContent> = mock
.read_piece_content(&mut reader, piece_content::METADATA_LENGTH_SIZE)
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_read_error() {
let addr = "127.0.0.1:8080".to_string();
let config = Arc::new(Config::default());
let mock = Mock {
config,
addr,
timeout: Duration::from_secs(1),
};
let expected_code = vortex_protocol::tlv::error::Code::NotFound;
let expected_message = "Resource not found".to_string();
let error = VortexError::new(expected_code, expected_message.clone());
let bytes: Bytes = error.into();
let bytes_length = bytes.len();
let mut reader: Box<dyn AsyncRead + Send + Unpin> = Box::new(Cursor::new(bytes));
let result = mock.read_error(&mut reader, bytes_length).await;
assert!(matches!(result,
ClientError::VortexProtocolStatus(code, ref message)
if code == expected_code && message == &expected_message
));
let mut reader: Box<dyn AsyncRead + Send + Unpin> = Box::new(Cursor::new(b""));
let result = mock.read_error(&mut reader, 5).await;
assert!(matches!(result, ClientError::IO(_)));
let mut reader: Box<dyn AsyncRead + Send + Unpin> = Box::new(Cursor::new(b""));
let result = mock.read_error(&mut reader, 0).await;
assert!(matches!(result, ClientError::Unknown(_)));
}
}

View File

@ -14,28 +14,20 @@
* limitations under the License.
*/
use bytes::{Bytes, BytesMut};
use crate::client::Client;
use bytes::Bytes;
use dragonfly_client_config::dfdaemon::Config;
use dragonfly_client_core::{
error::{ErrorType, OrErr},
Error as ClientError, Result as ClientResult,
};
use quinn::crypto::rustls::QuicClientConfig;
use quinn::{AckFrequencyConfig, ClientConfig, Endpoint, RecvStream, SendStream, TransportConfig};
use quinn::{AckFrequencyConfig, ClientConfig, Endpoint, TransportConfig};
use rustls_pki_types::{CertificateDer, ServerName, UnixTime};
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::io::AsyncRead;
use tokio::time;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{error, instrument};
use vortex_protocol::{
tlv::{
download_persistent_cache_piece::DownloadPersistentCachePiece,
download_piece::DownloadPiece, error::Error as VortexError, persistent_cache_piece_content,
piece_content, Tag,
},
Header, Vortex, HEADER_SIZE,
};
/// QUICClient is a QUIC-based client for quic storage service.
#[derive(Clone)]
@ -53,118 +45,10 @@ impl QUICClient {
pub fn new(config: Arc<Config>, addr: String) -> Self {
Self { config, addr }
}
/// Downloads a piece from the server using the vortex protocol.
///
/// This is the main entry point for downloading a piece. It applies
/// a timeout based on the configuration and handles connection timeouts gracefully.
#[instrument(skip_all)]
pub async fn download_piece(
&self,
number: u32,
task_id: &str,
) -> ClientResult<(impl AsyncRead, u64, String)> {
time::timeout(
self.config.download.piece_timeout,
self.handle_download_piece(number, task_id),
)
.await
.inspect_err(|err| {
error!("connect timeout to {}: {}", self.addr, err);
})?
}
/// Internal handler for downloading a piece.
///
/// This method performs the actual protocol communication:
/// 1. Creates a download piece request.
/// 2. Establishes QUIC connection and sends the request.
/// 3. Reads and validates the response header.
/// 4. Processes the piece content based on the response type.
#[instrument(skip_all)]
async fn handle_download_piece(
&self,
number: u32,
task_id: &str,
) -> ClientResult<(impl AsyncRead, u64, String)> {
let request: Bytes = Vortex::DownloadPiece(
Header::new_download_piece(),
DownloadPiece::new(task_id.to_string(), number),
)
.into();
let (mut reader, _writer) = self.connect_and_write_request(request).await?;
let header = self.read_header(&mut reader).await?;
match header.tag() {
Tag::PieceContent => {
let piece_content: piece_content::PieceContent = self
.read_piece_content(&mut reader, piece_content::METADATA_LENGTH_SIZE)
.await?;
let metadata = piece_content.metadata();
Ok((reader, metadata.offset, metadata.digest))
}
Tag::Error => Err(self.read_error(&mut reader, header.length() as usize).await),
_ => Err(ClientError::Unknown(format!(
"unexpected tag: {:?}",
header.tag()
))),
}
}
/// Downloads a persistent cache piece from the server using the vortex protocol.
///
/// Similar to `download_piece` but specifically for persistent cache piece.
#[instrument(skip_all)]
pub async fn download_persistent_cache_piece(
&self,
number: u32,
task_id: &str,
) -> ClientResult<(impl AsyncRead, u64, String)> {
time::timeout(
self.config.download.piece_timeout,
self.handle_download_persistent_cache_piece(number, task_id),
)
.await
.inspect_err(|err| {
error!("connect timeout to {}: {}", self.addr, err);
})?
}
/// Internal handler for downloading a persistent cache piece.
///
/// Implements the same protocol flow as `handle_download_piece` but uses
/// persistent cache specific request/response types.
#[instrument(skip_all)]
async fn handle_download_persistent_cache_piece(
&self,
number: u32,
task_id: &str,
) -> ClientResult<(impl AsyncRead, u64, String)> {
let request: Bytes = Vortex::DownloadPersistentCachePiece(
Header::new_download_persistent_cache_piece(),
DownloadPersistentCachePiece::new(task_id.to_string(), number),
)
.into();
let (mut reader, _writer) = self.connect_and_write_request(request).await?;
let header = self.read_header(&mut reader).await?;
match header.tag() {
Tag::PersistentCachePieceContent => {
let persistent_cache_piece_content: persistent_cache_piece_content::PersistentCachePieceContent =
self.read_piece_content(&mut reader, persistent_cache_piece_content::METADATA_LENGTH_SIZE)
.await?;
let metadata = persistent_cache_piece_content.metadata();
Ok((reader, metadata.offset, metadata.digest))
}
Tag::Error => Err(self.read_error(&mut reader, header.length() as usize).await),
_ => Err(ClientError::Unknown(format!(
"unexpected tag: {:?}",
header.tag()
))),
}
}
#[tonic::async_trait]
impl Client for QUICClient {
/// Establishes QUIC connection and writes a vortex protocol request.
///
/// This is a low-level utility function that handles the QUIC connection
@ -174,7 +58,10 @@ impl QUICClient {
async fn connect_and_write_request(
&self,
request: Bytes,
) -> ClientResult<(RecvStream, SendStream)> {
) -> ClientResult<(
Box<dyn AsyncRead + Send + Unpin>,
Box<dyn AsyncWrite + Send + Unpin>,
)> {
let mut client_config = ClientConfig::new(Arc::new(
QuicClientConfig::try_from(
quinn::rustls::ClientConfig::builder()
@ -218,85 +105,17 @@ impl QUICClient {
.await
.inspect_err(|err| error!("failed to send request: {}", err))?;
Ok((reader, writer))
Ok((Box::new(reader), Box::new(writer)))
}
/// Reads and parses a vortex protocol header from the QUIC stream.
///
/// The header contains metadata about the following message, including
/// the message type (tag) and payload length. This is critical for
/// proper protocol message framing.
#[instrument(skip_all)]
async fn read_header(&self, reader: &mut RecvStream) -> ClientResult<Header> {
let mut header_bytes = BytesMut::with_capacity(HEADER_SIZE);
header_bytes.resize(HEADER_SIZE, 0);
reader
.read_exact(&mut header_bytes)
.await
.inspect_err(|err| error!("failed to receive header: {}", err))?;
Header::try_from(header_bytes.freeze()).map_err(Into::into)
/// Access to client configuration.
fn config(&self) -> &Arc<Config> {
&self.config
}
/// Reads and parses piece content with variable-length metadata.
///
/// This generic function handles the two-stage reading process for
/// piece content: first reading the metadata length, then reading
/// the actual metadata, and finally constructing the complete message.
#[instrument(skip_all)]
async fn read_piece_content<T>(
&self,
reader: &mut RecvStream,
metadata_length_size: usize,
) -> ClientResult<T>
where
T: TryFrom<Bytes, Error: Into<ClientError>>,
{
let mut metadata_length_bytes = BytesMut::with_capacity(metadata_length_size);
metadata_length_bytes.resize(metadata_length_size, 0);
reader
.read_exact(&mut metadata_length_bytes)
.await
.inspect_err(|err| error!("failed to receive metadata length: {}", err))?;
let metadata_length = u32::from_be_bytes(metadata_length_bytes[..].try_into()?) as usize;
let mut metadata_bytes = BytesMut::with_capacity(metadata_length);
metadata_bytes.resize(metadata_length, 0);
reader
.read_exact(&mut metadata_bytes)
.await
.inspect_err(|err| error!("failed to receive metadata: {}", err))?;
let mut content_bytes = BytesMut::with_capacity(metadata_length_size + metadata_length);
content_bytes.extend_from_slice(&metadata_length_bytes);
content_bytes.extend_from_slice(&metadata_bytes);
content_bytes.freeze().try_into().map_err(Into::into)
}
/// Reads and processes error responses from the server.
///
/// When the server responds with an error tag, this function reads
/// the error payload and converts it into an appropriate client error.
/// This provides structured error handling for protocol-level failures.
#[instrument(skip_all)]
async fn read_error(&self, reader: &mut RecvStream, header_length: usize) -> ClientError {
let mut error_bytes = BytesMut::with_capacity(header_length);
error_bytes.resize(header_length, 0);
if let Err(err) = reader.read_exact(&mut error_bytes).await {
error!("failed to receive error: {}", err);
return ClientError::Unknown(err.to_string());
};
error_bytes
.freeze()
.try_into()
.map(|error: VortexError| {
ClientError::VortexProtocolStatus(error.code(), error.message().to_string())
})
.unwrap_or_else(|err| {
error!("failed to extract error: {}", err);
ClientError::Unknown(format!("failed to extract error: {}", err))
})
/// Access to client address.
fn addr(&self) -> &str {
&self.addr
}
}

View File

@ -14,23 +14,14 @@
* limitations under the License.
*/
use bytes::{Bytes, BytesMut};
use crate::client::Client;
use bytes::Bytes;
use dragonfly_client_config::dfdaemon::Config;
use dragonfly_client_core::{Error as ClientError, Result as ClientResult};
use dragonfly_client_core::Result as ClientResult;
use socket2::{SockRef, TcpKeepalive};
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
use tokio::time;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tracing::{error, instrument};
use vortex_protocol::{
tlv::{
download_persistent_cache_piece::DownloadPersistentCachePiece,
download_piece::DownloadPiece, error::Error as VortexError, persistent_cache_piece_content,
piece_content, Tag,
},
Header, Vortex, HEADER_SIZE,
};
/// TCPClient is a TCP-based client for tcp storage service.
#[derive(Clone)]
@ -48,118 +39,10 @@ impl TCPClient {
pub fn new(config: Arc<Config>, addr: String) -> Self {
Self { config, addr }
}
/// Downloads a piece from the server using the vortex protocol.
///
/// This is the main entry point for downloading a piece. It applies
/// a timeout based on the configuration and handles connection timeouts gracefully.
#[instrument(skip_all)]
pub async fn download_piece(
&self,
number: u32,
task_id: &str,
) -> ClientResult<(impl AsyncRead, u64, String)> {
time::timeout(
self.config.download.piece_timeout,
self.handle_download_piece(number, task_id),
)
.await
.inspect_err(|err| {
error!("connect timeout to {}: {}", self.addr, err);
})?
}
/// Internal handler for downloading a piece.
///
/// This method performs the actual protocol communication:
/// 1. Creates a download piece request.
/// 2. Establishes TCP connection and sends the request.
/// 3. Reads and validates the response header.
/// 4. Processes the piece content based on the response type.
#[instrument(skip_all)]
async fn handle_download_piece(
&self,
number: u32,
task_id: &str,
) -> ClientResult<(impl AsyncRead, u64, String)> {
let request: Bytes = Vortex::DownloadPiece(
Header::new_download_piece(),
DownloadPiece::new(task_id.to_string(), number),
)
.into();
let (mut reader, _writer) = self.connect_and_write_request(request).await?;
let header = self.read_header(&mut reader).await?;
match header.tag() {
Tag::PieceContent => {
let piece_content: piece_content::PieceContent = self
.read_piece_content(&mut reader, piece_content::METADATA_LENGTH_SIZE)
.await?;
let metadata = piece_content.metadata();
Ok((reader, metadata.offset, metadata.digest))
}
Tag::Error => Err(self.read_error(&mut reader, header.length() as usize).await),
_ => Err(ClientError::Unknown(format!(
"unexpected tag: {:?}",
header.tag()
))),
}
}
/// Downloads a persistent cache piece from the server using the vortex protocol.
///
/// Similar to `download_piece` but specifically for persistent cache piece.
#[instrument(skip_all)]
pub async fn download_persistent_cache_piece(
&self,
number: u32,
task_id: &str,
) -> ClientResult<(impl AsyncRead, u64, String)> {
time::timeout(
self.config.download.piece_timeout,
self.handle_download_persistent_cache_piece(number, task_id),
)
.await
.inspect_err(|err| {
error!("connect timeout to {}: {}", self.addr, err);
})?
}
/// Internal handler for downloading a persistent cache piece.
///
/// Implements the same protocol flow as `handle_download_piece` but uses
/// persistent cache specific request/response types.
#[instrument(skip_all)]
async fn handle_download_persistent_cache_piece(
&self,
number: u32,
task_id: &str,
) -> ClientResult<(impl AsyncRead, u64, String)> {
let request: Bytes = Vortex::DownloadPersistentCachePiece(
Header::new_download_persistent_cache_piece(),
DownloadPersistentCachePiece::new(task_id.to_string(), number),
)
.into();
let (mut reader, _writer) = self.connect_and_write_request(request).await?;
let header = self.read_header(&mut reader).await?;
match header.tag() {
Tag::PersistentCachePieceContent => {
let persistent_cache_piece_content: persistent_cache_piece_content::PersistentCachePieceContent =
self.read_piece_content(&mut reader, piece_content::METADATA_LENGTH_SIZE)
.await?;
let metadata = persistent_cache_piece_content.metadata();
Ok((reader, metadata.offset, metadata.digest))
}
Tag::Error => Err(self.read_error(&mut reader, header.length() as usize).await),
_ => Err(ClientError::Unknown(format!(
"unexpected tag: {:?}",
header.tag()
))),
}
}
#[tonic::async_trait]
impl Client for TCPClient {
/// Establishes TCP connection and writes a vortex protocol request.
///
/// This is a low-level utility function that handles the TCP connection
@ -169,7 +52,10 @@ impl TCPClient {
async fn connect_and_write_request(
&self,
request: Bytes,
) -> ClientResult<(OwnedReadHalf, OwnedWriteHalf)> {
) -> ClientResult<(
Box<dyn AsyncRead + Send + Unpin>,
Box<dyn AsyncWrite + Send + Unpin>,
)> {
let stream = tokio::net::TcpStream::connect(self.addr.clone()).await?;
let socket = SockRef::from(&stream);
socket.set_tcp_nodelay(true)?;
@ -189,90 +75,16 @@ impl TCPClient {
error!("failed to flush request: {}", err);
})?;
Ok((reader, writer))
Ok((Box::new(reader), Box::new(writer)))
}
/// Reads and parses a vortex protocol header from the TCP stream.
///
/// The header contains metadata about the following message, including
/// the message type (tag) and payload length. This is critical for
/// proper protocol message framing.
#[instrument(skip_all)]
async fn read_header(&self, reader: &mut OwnedReadHalf) -> ClientResult<Header> {
let mut header_bytes = BytesMut::with_capacity(HEADER_SIZE);
header_bytes.resize(HEADER_SIZE, 0);
reader
.read_exact(&mut header_bytes)
.await
.inspect_err(|err| {
error!("failed to receive header: {}", err);
})?;
Header::try_from(header_bytes.freeze()).map_err(Into::into)
/// Access to client configuration.
fn config(&self) -> &Arc<Config> {
&self.config
}
/// Reads and parses piece content with variable-length metadata.
///
/// This generic function handles the two-stage reading process for
/// piece content: first reading the metadata length, then reading
/// the actual metadata, and finally constructing the complete message.
#[instrument(skip_all)]
async fn read_piece_content<T>(
&self,
reader: &mut OwnedReadHalf,
metadata_length_size: usize,
) -> ClientResult<T>
where
T: TryFrom<Bytes, Error: Into<ClientError>>,
{
let mut metadata_length_bytes = BytesMut::with_capacity(metadata_length_size);
metadata_length_bytes.resize(metadata_length_size, 0);
reader
.read_exact(&mut metadata_length_bytes)
.await
.inspect_err(|err| {
error!("failed to receive metadata length: {}", err);
})?;
let metadata_length = u32::from_be_bytes(metadata_length_bytes[..].try_into()?) as usize;
let mut metadata_bytes = BytesMut::with_capacity(metadata_length);
metadata_bytes.resize(metadata_length, 0);
reader
.read_exact(&mut metadata_bytes)
.await
.inspect_err(|err| {
error!("failed to receive metadata: {}", err);
})?;
let mut content_bytes = BytesMut::with_capacity(metadata_length_size + metadata_length);
content_bytes.extend_from_slice(&metadata_length_bytes);
content_bytes.extend_from_slice(&metadata_bytes);
content_bytes.freeze().try_into().map_err(Into::into)
}
/// Reads and processes error responses from the server.
///
/// When the server responds with an error tag, this function reads
/// the error payload and converts it into an appropriate client error.
/// This provides structured error handling for protocol-level failures.
#[instrument(skip_all)]
async fn read_error(&self, reader: &mut OwnedReadHalf, header_length: usize) -> ClientError {
let mut error_bytes = BytesMut::with_capacity(header_length);
error_bytes.resize(header_length, 0);
if let Err(err) = reader.read_exact(&mut error_bytes).await {
error!("failed to receive error: {}", err);
return ClientError::IO(err);
};
error_bytes
.freeze()
.try_into()
.map(|error: VortexError| {
ClientError::VortexProtocolStatus(error.code(), error.message().to_string())
})
.unwrap_or_else(|err| {
error!("failed to extract error: {}", err);
ClientError::Unknown(format!("failed to extract error: {}", err))
})
/// Access to client address.
fn addr(&self) -> &str {
&self.addr
}
}

View File

@ -503,3 +503,415 @@ impl TCPServerHandler {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use dragonfly_client_config::dfdaemon::Config;
use std::time::Duration;
use tempfile::tempdir;
async fn create_test_tcp_server_handler() -> TCPServerHandler {
let config = Arc::new(Config::default());
let dir = tempdir().unwrap();
let log_dir = dir.path().join("log");
let storage = Storage::new(config.clone(), dir.path(), log_dir)
.await
.unwrap();
let storage = Arc::new(storage);
let id_generator = IDGenerator::new(
"127.0.0.1".to_string(),
config.host.hostname.clone(),
config.seed_peer.enable,
);
let id_generator = Arc::new(id_generator);
let upload_rate_limiter = Arc::new(
RateLimiter::builder()
.initial(config.upload.rate_limit.as_u64() as usize)
.refill(config.upload.rate_limit.as_u64() as usize)
.max(config.upload.rate_limit.as_u64() as usize)
.interval(Duration::from_secs(1))
.fair(false)
.build(),
);
TCPServerHandler {
id_generator,
storage,
upload_rate_limiter,
}
}
#[tokio::test]
async fn test_read_header_success() {
let handler = create_test_tcp_server_handler().await;
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let server_handle = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let (mut reader, _writer) = stream.into_split();
handler.read_header(&mut reader).await
});
let client_stream = TcpStream::connect(addr).await.unwrap();
let (_reader, mut writer) = client_stream.into_split();
let header = Header::new(Tag::DownloadPiece, 100);
let header_bytes: Bytes = header.into();
writer.write_all(&header_bytes).await.unwrap();
writer.flush().await.unwrap();
let result = server_handle.await.unwrap();
assert!(result.is_ok());
if let Ok(header) = result {
assert_eq!(header.tag(), Tag::DownloadPiece);
assert_eq!(header.length(), 100);
}
}
#[tokio::test]
async fn test_read_header_insufficient_data() {
let handler = create_test_tcp_server_handler().await;
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let server_handle = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let (mut reader, _writer) = stream.into_split();
handler.read_header(&mut reader).await
});
let client_stream = TcpStream::connect(addr).await.unwrap();
let (_reader, mut writer) = client_stream.into_split();
let partial_data = vec![0u8; HEADER_SIZE - 1];
writer.write_all(&partial_data).await.unwrap();
writer.flush().await.unwrap();
drop(writer);
let result = server_handle.await.unwrap();
assert!(result.is_err());
}
#[tokio::test]
async fn test_read_header_connection_closed() {
let handler = create_test_tcp_server_handler().await;
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let server_handle = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let (mut reader, _writer) = stream.into_split();
handler.read_header(&mut reader).await
});
let client_stream = TcpStream::connect(addr).await.unwrap();
drop(client_stream);
let result = server_handle.await.unwrap();
assert!(result.is_err());
}
const HEADER_LENGTH: usize = 68;
#[tokio::test]
async fn test_read_download_piece_success() {
let handler = create_test_tcp_server_handler().await;
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let server_handle = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let (mut reader, _writer) = stream.into_split();
handler
.read_download_piece(&mut reader, HEADER_LENGTH)
.await
});
let client_stream = TcpStream::connect(addr).await.unwrap();
let (_reader, mut writer) = client_stream.into_split();
let task_id = "a".repeat(64);
let piece_number = 42;
let download_piece = DownloadPiece::new(task_id.clone(), piece_number);
let bytes: Bytes = download_piece.into();
writer.write_all(&bytes).await.unwrap();
writer.flush().await.unwrap();
let result: ClientResult<DownloadPiece> = server_handle.await.unwrap();
assert!(result.is_ok());
if let Ok(download_piece) = result {
assert_eq!(download_piece.task_id(), task_id);
assert_eq!(download_piece.piece_number(), piece_number);
}
}
#[tokio::test]
async fn test_read_download_piece_insufficient_data() {
let handler = create_test_tcp_server_handler().await;
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let server_handle = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let (mut reader, _writer) = stream.into_split();
handler
.read_download_piece(&mut reader, HEADER_LENGTH)
.await
});
let client_stream = TcpStream::connect(addr).await.unwrap();
let (_reader, mut writer) = client_stream.into_split();
let partial_data = vec![0u8; HEADER_LENGTH - 1];
writer.write_all(&partial_data).await.unwrap();
writer.flush().await.unwrap();
drop(writer);
let result: ClientResult<DownloadPiece> = server_handle.await.unwrap();
assert!(result.is_err());
}
#[tokio::test]
async fn test_read_download_piece_connection_closed() {
let handler = create_test_tcp_server_handler().await;
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let server_handle = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let (mut reader, _writer) = stream.into_split();
handler
.read_download_piece(&mut reader, HEADER_LENGTH)
.await
});
let client_stream = TcpStream::connect(addr).await.unwrap();
drop(client_stream);
let result: ClientResult<DownloadPiece> = server_handle.await.unwrap();
assert!(result.is_err());
}
async fn create_test_tcp_pair() -> (TcpStream, TcpStream) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let client_task = tokio::spawn(async move { TcpStream::connect(addr).await.unwrap() });
let (server_stream, _) = listener.accept().await.unwrap();
let client_stream = client_task.await.unwrap();
(server_stream, client_stream)
}
#[tokio::test]
async fn test_write_response_success() {
let handler = create_test_tcp_server_handler().await;
let (server_stream, mut client_stream) = create_test_tcp_pair().await;
let (_server_reader, mut server_writer) = server_stream.into_split();
let test_data = Bytes::from("Hello from server!");
let write_task = tokio::spawn(async move {
handler
.write_response(test_data.clone(), &mut server_writer)
.await
});
let mut response_buffer = vec![0u8; 1024];
let bytes_read = client_stream.read(&mut response_buffer).await.unwrap();
assert_eq!(&response_buffer[..bytes_read], b"Hello from server!");
assert!(write_task.await.unwrap().is_ok());
}
#[tokio::test]
async fn test_write_response_large_data() {
let handler = create_test_tcp_server_handler().await;
let (server_stream, mut client_stream) = create_test_tcp_pair().await;
let (_server_reader, mut server_writer) = server_stream.into_split();
let large_data = Bytes::from(vec![42u8; 1024 * 1024]);
let write_task = tokio::spawn(async move {
handler
.write_response(large_data.clone(), &mut server_writer)
.await
});
let mut total_received = 0;
let mut buffer = vec![0u8; 8192];
while total_received < 1024 * 1024 {
match client_stream.read(&mut buffer).await {
Ok(0) => break,
Ok(n) => {
total_received += n;
assert!(buffer[..n].iter().all(|&b| b == 42));
}
Err(e) => panic!("Read error: {}", e),
}
}
assert_eq!(total_received, 1024 * 1024);
assert!(write_task.await.unwrap().is_ok());
}
#[tokio::test]
async fn test_write_response_client_slow_read() {
let handler = create_test_tcp_server_handler().await;
let (server_stream, mut client_stream) = create_test_tcp_pair().await;
let (_server_reader, mut server_writer) = server_stream.into_split();
let test_data = Bytes::from(vec![1u8; 64 * 1024]);
let write_task =
tokio::spawn(
async move { handler.write_response(test_data, &mut server_writer).await },
);
tokio::spawn(async move {
let mut buffer = vec![0u8; 1024];
let mut total_read = 0;
while total_read < 64 * 1024 {
tokio::time::sleep(Duration::from_millis(10)).await;
match client_stream.read(&mut buffer).await {
Ok(0) => break,
Ok(n) => total_read += n,
Err(_) => break,
}
}
});
let result = tokio::time::timeout(Duration::from_secs(5), write_task).await;
assert!(result.is_ok());
assert!(result.unwrap().unwrap().is_ok());
}
#[tokio::test]
async fn test_write_stream_success() {
let handler = create_test_tcp_server_handler().await;
let (server_stream, mut client_stream) = create_test_tcp_pair().await;
let (_server_reader, mut server_writer) = server_stream.into_split();
let test_data = b"Stream content for testing".to_vec();
let mut mock_stream = std::io::Cursor::new(test_data.clone());
let write_task = tokio::spawn(async move {
handler
.write_stream(&mut mock_stream, &mut server_writer)
.await
});
let mut response_buffer = vec![0u8; 1024];
let bytes_read = client_stream.read(&mut response_buffer).await.unwrap();
assert_eq!(&response_buffer[..bytes_read], test_data.as_slice());
assert!(write_task.await.unwrap().is_ok());
}
#[tokio::test]
async fn test_write_stream_large_stream() {
let handler = create_test_tcp_server_handler().await;
let (server_stream, mut client_stream) = create_test_tcp_pair().await;
let (_server_reader, mut server_writer) = server_stream.into_split();
let large_data = vec![123u8; 10 * 1024 * 1024];
let mut mock_stream = std::io::Cursor::new(large_data.clone());
let write_task = tokio::spawn(async move {
handler
.write_stream(&mut mock_stream, &mut server_writer)
.await
});
let mut total_received = 0;
let mut buffer = vec![0u8; 64 * 1024];
while total_received < 10 * 1024 * 1024 {
match tokio::time::timeout(Duration::from_secs(1), client_stream.read(&mut buffer))
.await
{
Ok(Ok(0)) => break,
Ok(Ok(n)) => {
total_received += n;
assert!(buffer[..n].iter().all(|&b| b == 123));
}
Ok(Err(e)) => panic!("Read error: {}", e),
Err(_) => panic!("Read timeout"),
}
}
assert_eq!(total_received, 10 * 1024 * 1024);
assert!(write_task.await.unwrap().is_ok());
}
#[tokio::test]
async fn test_write_stream_with_slow_source() {
let handler = create_test_tcp_server_handler().await;
let (server_stream, mut client_stream) = create_test_tcp_pair().await;
let (_server_reader, mut server_writer) = server_stream.into_split();
struct SlowReader {
data: std::io::Cursor<Vec<u8>>,
delay: Duration,
}
impl AsyncRead for SlowReader {
fn poll_read(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
cx.waker().wake_by_ref();
std::thread::sleep(self.delay);
let mut temp_buf = vec![0u8; std::cmp::min(buf.remaining(), 10)];
match std::io::Read::read(&mut self.data, &mut temp_buf) {
Ok(0) => std::task::Poll::Ready(Ok(())),
Ok(n) => {
buf.put_slice(&temp_buf[..n]);
std::task::Poll::Ready(Ok(()))
}
Err(e) => std::task::Poll::Ready(Err(e)),
}
}
}
let test_data = b"Slow stream data".to_vec();
let mut slow_stream = SlowReader {
data: std::io::Cursor::new(test_data.clone()),
delay: Duration::from_millis(5),
};
let write_task = tokio::spawn(async move {
handler
.write_stream(&mut slow_stream, &mut server_writer)
.await
});
let mut response_buffer = vec![0u8; 1024];
let bytes_read = tokio::time::timeout(
Duration::from_secs(2),
client_stream.read(&mut response_buffer),
)
.await
.unwrap()
.unwrap();
assert_eq!(&response_buffer[..bytes_read], test_data.as_slice());
assert!(write_task.await.unwrap().is_ok());
}
}

View File

@ -18,7 +18,9 @@ use crate::grpc::dfdaemon_upload::DfdaemonUploadClient;
use dragonfly_api::dfdaemon::v2::{DownloadPersistentCachePieceRequest, DownloadPieceRequest};
use dragonfly_client_config::dfdaemon::Config;
use dragonfly_client_core::{Error, Result};
use dragonfly_client_storage::{client::quic::QUICClient, client::tcp::TCPClient, metadata};
use dragonfly_client_storage::{
client::quic::QUICClient, client::tcp::TCPClient, client::Client, metadata,
};
use dragonfly_client_util::pool::{Builder as PoolBuilder, Entry, Factory, Pool};
use std::io::Cursor;
use std::sync::Arc;