diff --git a/Cargo.lock b/Cargo.lock index e8b00df8..c8968956 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1001,6 +1001,12 @@ dependencies = [ "const-random", ] +[[package]] +name = "downcast" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1435fa1053d8b2fbbe9be7e97eca7f33d37b28409959813daefc1446a14247f1" + [[package]] name = "dragonfly-api" version = "2.1.70" @@ -1222,6 +1228,7 @@ dependencies = [ "dragonfly-client-util", "fs2", "leaky-bucket", + "mockall", "nix 0.30.1", "num_cpus", "prost-wkt-types", @@ -1235,6 +1242,7 @@ dependencies = [ "tempfile", "tokio", "tokio-util", + "tonic", "tracing", "vortex-protocol", "walkdir", @@ -1451,6 +1459,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "fragile" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28dd6caf6059519a65843af8fe2a3ae298b14b80179855aeb4adc2c1934ee619" + [[package]] name = "fs2" version = "0.4.3" @@ -2738,6 +2752,32 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "mockall" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39a6bfcc6c8c7eed5ee98b9c3e33adc726054389233e201c95dab2d41a3839d2" +dependencies = [ + "cfg-if", + "downcast", + "fragile", + "mockall_derive", + "predicates", + "predicates-tree", +] + +[[package]] +name = "mockall_derive" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25ca3004c2efe9011bd4e461bd8256445052b9615405b4f7ea43fc8ca5c20898" +dependencies = [ + "cfg-if", + "proc-macro2", + "quote", + "syn 2.0.90", +] + [[package]] name = "mocktail" version = "0.3.0" @@ -3598,6 +3638,32 @@ version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" +[[package]] +name = "predicates" +version = "3.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5d19ee57562043d37e82899fade9a22ebab7be9cef5026b07fda9cdd4293573" +dependencies = [ + "anstyle", + "predicates-core", +] + +[[package]] +name = "predicates-core" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "727e462b119fe9c93fd0eb1429a5f7647394014cf3c04ab2c0350eeb09095ffa" + +[[package]] +name = "predicates-tree" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72dd2d6d381dfb73a193c7fca536518d7caee39fc8503f74e7dc0be0531b425c" +dependencies = [ + "predicates-core", + "termtree", +] + [[package]] name = "prettyplease" version = "0.2.17" @@ -5110,6 +5176,12 @@ dependencies = [ "redox_termios", ] +[[package]] +name = "termtree" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f50febec83f5ee1df3015341d8bd429f2d1cc62bcba7ea2076759d315084683" + [[package]] name = "testing_table" version = "0.3.0" diff --git a/dragonfly-client-storage/Cargo.toml b/dragonfly-client-storage/Cargo.toml index e66a9d39..5b54b860 100644 --- a/dragonfly-client-storage/Cargo.toml +++ b/dragonfly-client-storage/Cargo.toml @@ -31,11 +31,13 @@ bytesize.workspace = true leaky-bucket.workspace = true vortex-protocol.workspace = true rustls.workspace = true +tonic.workspace = true num_cpus = "1.17" bincode = "1.3.3" walkdir = "2.5.0" quinn = "0.11.9" socket2 = "0.6.0" +mockall = "0.13.1" [dev-dependencies] tempfile.workspace = true diff --git a/dragonfly-client-storage/src/client/mod.rs b/dragonfly-client-storage/src/client/mod.rs index c2d8bc5a..148688bb 100644 --- a/dragonfly-client-storage/src/client/mod.rs +++ b/dragonfly-client-storage/src/client/mod.rs @@ -17,7 +17,22 @@ pub mod quic; pub mod tcp; +use bytes::{Bytes, BytesMut}; +use dragonfly_client_config::dfdaemon::Config; +use dragonfly_client_core::{Error as ClientError, Result as ClientResult}; +use std::sync::Arc; use std::time::Duration; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite}; +use tokio::time; +use tracing::{error, instrument}; +use vortex_protocol::{ + tlv::{ + download_persistent_cache_piece::DownloadPersistentCachePiece, + download_piece::DownloadPiece, error::Error as VortexError, persistent_cache_piece_content, + piece_content, Tag, + }, + Header, Vortex, HEADER_SIZE, +}; /// DEFAULT_SEND_BUFFER_SIZE is the default size of the send buffer for network connections. const DEFAULT_SEND_BUFFER_SIZE: usize = 16 * 1024 * 1024; @@ -30,3 +45,604 @@ const DEFAULT_KEEPALIVE_INTERVAL: Duration = Duration::from_secs(5); /// DEFAULT_MAX_IDLE_TIMEOUT is the default maximum idle timeout for connections. const DEFAULT_MAX_IDLE_TIMEOUT: Duration = Duration::from_secs(300); + +/// Client defines a generic client interface for storage service protocols. +#[tonic::async_trait] +pub trait Client { + /// Downloads a piece from the server using the vortex protocol. + /// + /// This is the main entry point for downloading a piece. It applies + /// a timeout based on the configuration and handles connection timeouts gracefully. + #[instrument(skip_all)] + async fn download_piece( + &self, + number: u32, + task_id: &str, + ) -> ClientResult<(Box, u64, String)> { + time::timeout( + self.config().download.piece_timeout, + self.handle_download_piece(number, task_id), + ) + .await + .inspect_err(|err| { + error!("connect timeout to {}: {}", self.addr(), err); + })? + } + /// Internal handler for downloading a piece. + /// + /// This method performs the actual protocol communication: + /// 1. Creates a download piece request. + /// 2. Establishes connection and sends the request. + /// 3. Reads and validates the response header. + /// 4. Processes the piece content based on the response type. + #[instrument(skip_all)] + async fn handle_download_piece( + &self, + number: u32, + task_id: &str, + ) -> ClientResult<(Box, u64, String)> { + let request: Bytes = Vortex::DownloadPiece( + Header::new_download_piece(), + DownloadPiece::new(task_id.to_string(), number), + ) + .into(); + + let (mut reader, _writer) = self.connect_and_write_request(request).await?; + let header = self.read_header(&mut reader).await?; + match header.tag() { + Tag::PieceContent => { + let piece_content: piece_content::PieceContent = self + .read_piece_content(&mut reader, piece_content::METADATA_LENGTH_SIZE) + .await?; + + let metadata = piece_content.metadata(); + Ok((reader, metadata.offset, metadata.digest)) + } + Tag::Error => Err(self.read_error(&mut reader, header.length() as usize).await), + _ => Err(ClientError::Unknown(format!( + "unexpected tag: {:?}", + header.tag() + ))), + } + } + + /// Downloads a persistent cache piece from the server using the vortex protocol. + /// + /// Similar to `download_piece` but specifically for persistent cache piece. + #[instrument(skip_all)] + async fn download_persistent_cache_piece( + &self, + number: u32, + task_id: &str, + ) -> ClientResult<(Box, u64, String)> { + time::timeout( + self.config().download.piece_timeout, + self.handle_download_persistent_cache_piece(number, task_id), + ) + .await + .inspect_err(|err| { + error!("connect timeout to {}: {}", self.addr(), err); + })? + } + + /// Internal handler for downloading a persistent cache piece. + /// + /// Implements the same protocol flow as `handle_download_piece` but uses + /// persistent cache specific request/response types. + #[instrument(skip_all)] + async fn handle_download_persistent_cache_piece( + &self, + number: u32, + task_id: &str, + ) -> ClientResult<(Box, u64, String)> { + let request: Bytes = Vortex::DownloadPersistentCachePiece( + Header::new_download_persistent_cache_piece(), + DownloadPersistentCachePiece::new(task_id.to_string(), number), + ) + .into(); + + let (mut reader, _writer) = self.connect_and_write_request(request).await?; + let header = self.read_header(&mut reader).await?; + match header.tag() { + Tag::PersistentCachePieceContent => { + let persistent_cache_piece_content: persistent_cache_piece_content::PersistentCachePieceContent = + self.read_piece_content(&mut reader, piece_content::METADATA_LENGTH_SIZE) + .await?; + + let metadata = persistent_cache_piece_content.metadata(); + Ok((reader, metadata.offset, metadata.digest)) + } + Tag::Error => Err(self.read_error(&mut reader, header.length() as usize).await), + _ => Err(ClientError::Unknown(format!( + "unexpected tag: {:?}", + header.tag() + ))), + } + } + + /// Establishes connection and writes a vortex protocol request. + /// + /// This is a low-level utility function that handles the connection + /// lifecycle and request transmission. It ensures proper error handling + /// and connection cleanup. + async fn connect_and_write_request( + &self, + request: Bytes, + ) -> ClientResult<( + Box, + Box, + )>; + + /// Reads and parses a vortex protocol header. + /// + /// The header contains metadata about the following message, including + /// the message type (tag) and payload length. This is critical for + /// proper protocol message framing. + #[instrument(skip_all)] + async fn read_header( + &self, + reader: &mut Box, + ) -> ClientResult
{ + let mut header_bytes = BytesMut::with_capacity(HEADER_SIZE); + header_bytes.resize(HEADER_SIZE, 0); + reader + .read_exact(&mut header_bytes) + .await + .inspect_err(|err| { + error!("failed to receive header: {}", err); + })?; + + Header::try_from(header_bytes.freeze()).map_err(Into::into) + } + + /// Reads and parses piece content with variable-length metadata. + /// + /// This generic function handles the two-stage reading process for + /// piece content: first reading the metadata length, then reading + /// the actual metadata, and finally constructing the complete message. + #[instrument(skip_all)] + async fn read_piece_content( + &self, + reader: &mut Box, + metadata_length_size: usize, + ) -> ClientResult + where + T: TryFrom, + T::Error: Into, + T: 'static, + { + let mut metadata_length_bytes = BytesMut::with_capacity(metadata_length_size); + metadata_length_bytes.resize(metadata_length_size, 0); + reader + .read_exact(&mut metadata_length_bytes) + .await + .inspect_err(|err| { + error!("failed to receive metadata length: {}", err); + })?; + let metadata_length = u32::from_be_bytes(metadata_length_bytes[..].try_into()?) as usize; + + let mut metadata_bytes = BytesMut::with_capacity(metadata_length); + metadata_bytes.resize(metadata_length, 0); + reader + .read_exact(&mut metadata_bytes) + .await + .inspect_err(|err| { + error!("failed to receive metadata: {}", err); + })?; + + let mut content_bytes = BytesMut::with_capacity(metadata_length_size + metadata_length); + content_bytes.extend_from_slice(&metadata_length_bytes); + content_bytes.extend_from_slice(&metadata_bytes); + content_bytes.freeze().try_into().map_err(Into::into) + } + + /// Reads and processes error responses from the server. + /// + /// When the server responds with an error tag, this function reads + /// the error payload and converts it into an appropriate client error. + /// This provides structured error handling for protocol-level failures. + #[instrument(skip_all)] + async fn read_error( + &self, + reader: &mut Box, + header_length: usize, + ) -> ClientError { + let mut error_bytes = BytesMut::with_capacity(header_length); + error_bytes.resize(header_length, 0); + if let Err(err) = reader.read_exact(&mut error_bytes).await { + error!("failed to receive error: {}", err); + return ClientError::IO(err); + }; + + error_bytes + .freeze() + .try_into() + .map(|error: VortexError| { + ClientError::VortexProtocolStatus(error.code(), error.message().to_string()) + }) + .unwrap_or_else(|err| { + error!("failed to extract error: {}", err); + ClientError::Unknown(format!("failed to extract error: {}", err)) + }) + } + + /// Access to client configuration. + fn config(&self) -> &Arc; + + /// Access to client address. + fn addr(&self) -> &str; +} + +#[cfg(test)] +mod tests { + use super::*; + + use bytes::{BufMut, Bytes, BytesMut}; + use dragonfly_client_config::dfdaemon::Download; + use dragonfly_client_core::{Error as ClientError, Result as ClientResult}; + use std::io::Cursor; + use tokio::io::duplex; + use tokio::time::sleep; + + struct Mock { + config: Arc, + + addr: String, + + timeout: Duration, + } + + #[tonic::async_trait] + impl Client for Mock { + async fn handle_download_piece( + &self, + _number: u32, + _task_id: &str, + ) -> ClientResult<(Box, u64, String)> { + sleep(self.timeout).await; + + Ok((Box::new(Cursor::new(b"data")), 0, "hash".to_string())) + } + + async fn handle_download_persistent_cache_piece( + &self, + _number: u32, + _task_id: &str, + ) -> ClientResult<(Box, u64, String)> { + sleep(self.timeout).await; + + Ok((Box::new(Cursor::new(b"data")), 0, "hash".to_string())) + } + + async fn connect_and_write_request( + &self, + _request: Bytes, + ) -> ClientResult<( + Box, + Box, + )> { + let (reader, writer) = duplex(1); + Ok((Box::new(reader), Box::new(writer))) + } + + fn config(&self) -> &Arc { + &self.config + } + + fn addr(&self) -> &str { + &self.addr + } + } + + #[tokio::test] + async fn test_download_piece() { + let addr = "127.0.0.1:8080".to_string(); + let config = Arc::new(Config { + download: Download { + piece_timeout: Duration::from_secs(2), + ..Default::default() + }, + ..Default::default() + }); + + let mut mock = Mock { + config, + addr, + timeout: Duration::from_secs(1), + }; + let result = mock.download_piece(1, "task").await; + assert!(result.is_ok()); + if let Ok((mut reader, offset, digest)) = result { + let mut buf = Vec::new(); + reader.read_to_end(&mut buf).await.unwrap(); + assert_eq!(buf, b"data"); + assert_eq!(offset, 0); + assert_eq!(digest, "hash"); + } + + mock.timeout = Duration::from_secs(3); + let result = mock.download_piece(1, "task").await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_download_persistent_cache_piece() { + let addr = "127.0.0.1:8080".to_string(); + let config = Arc::new(Config { + download: Download { + piece_timeout: Duration::from_secs(2), + ..Default::default() + }, + ..Default::default() + }); + + let mut mock = Mock { + config, + addr, + timeout: Duration::from_secs(1), + }; + let result = mock.download_persistent_cache_piece(1, "task").await; + assert!(result.is_ok()); + if let Ok((mut reader, offset, digest)) = result { + let mut buf = Vec::new(); + reader.read_to_end(&mut buf).await.unwrap(); + assert_eq!(buf, b"data"); + assert_eq!(offset, 0); + assert_eq!(digest, "hash"); + } + + mock.timeout = Duration::from_secs(3); + let result = mock.download_persistent_cache_piece(1, "task").await; + assert!(result.is_err()); + } + + mockall::mock! { + Client {} + + #[tonic::async_trait] + impl Client for Client { + async fn connect_and_write_request( + &self, + request: Bytes, + ) -> ClientResult<(Box, Box)>; + + async fn read_header(&self, reader: &mut Box) -> ClientResult
; + + async fn read_piece_content( + &self, + reader: &mut Box, + metadata_length_size: usize, + ) -> ClientResult + where + T: TryFrom, + T::Error: Into, + T: 'static; + + async fn read_error(&self, reader: &mut Box, header_length: usize) -> ClientError; + + fn config(&self) -> &Arc; + + fn addr(&self) -> &str; + } + } + + #[tokio::test] + async fn test_handle_download_piece() { + let mut mock = MockClient::new(); + mock.expect_connect_and_write_request().returning(|_| { + let (reader, writer) = duplex(1); + Ok((Box::new(reader), Box::new(writer))) + }); + + mock.expect_read_header() + .times(1) + .returning(|_| Ok(Header::new_piece_content(1024))); + mock.expect_read_piece_content().returning(|_, _| { + Ok(piece_content::PieceContent::new( + 42, + 1024, + 2048, + "a".repeat(32), + "test_parent_id".to_string(), + 1, + Duration::from_secs(5), + chrono::DateTime::from_timestamp(1693152000, 0) + .unwrap() + .naive_utc(), + )) + }); + let result = mock.handle_download_piece(1, "task").await; + assert!(result.is_ok()); + if let Ok((_, offset, digest)) = result { + assert_eq!(offset, 1024); + assert_eq!(digest, "a".repeat(32)); + } + + mock.expect_read_header() + .times(1) + .returning(|_| Ok(Header::new_error(1024))); + mock.expect_read_error() + .returning(|_, _| ClientError::Unknown("test".to_string())); + let result = mock.handle_download_piece(1, "task").await; + assert!(result.is_err()); + if let Err(err) = result { + assert!(format!("{:?}", err).contains("test")); + } + + mock.expect_read_header() + .returning(|_| Ok(Header::new_close())); + let result = mock.handle_download_piece(1, "task").await; + assert!(result.is_err()); + if let Err(err) = result { + assert!(format!("{:?}", err).contains("unexpected tag")); + } + } + + #[tokio::test] + async fn test_handle_download_persistent_cache_piece() { + let mut mock = MockClient::new(); + mock.expect_connect_and_write_request().returning(|_| { + let (reader, writer) = duplex(1); + Ok((Box::new(reader), Box::new(writer))) + }); + + mock.expect_read_header() + .times(1) + .returning(|_| Ok(Header::new_persistent_cache_piece_content(1024))); + mock.expect_read_piece_content().returning(|_, _| { + Ok( + persistent_cache_piece_content::PersistentCachePieceContent::new( + 42, + 1024, + 2048, + "a".repeat(32), + "test_parent_id".to_string(), + 1, + Duration::from_secs(5), + chrono::DateTime::from_timestamp(1693152000, 0) + .unwrap() + .naive_utc(), + ), + ) + }); + let result = mock.handle_download_persistent_cache_piece(1, "task").await; + assert!(result.is_ok()); + if let Ok((_, offset, digest)) = result { + assert_eq!(offset, 1024); + assert_eq!(digest, "a".repeat(32)); + } + + mock.expect_read_header() + .times(1) + .returning(|_| Ok(Header::new_error(1024))); + mock.expect_read_error() + .returning(|_, _| ClientError::Unknown("test".to_string())); + let result = mock.handle_download_persistent_cache_piece(1, "task").await; + assert!(result.is_err()); + if let Err(err) = result { + assert!(format!("{:?}", err).contains("test")); + } + + mock.expect_read_header() + .returning(|_| Ok(Header::new_close())); + let result = mock.handle_download_persistent_cache_piece(1, "task").await; + assert!(result.is_err()); + if let Err(err) = result { + assert!(format!("{:?}", err).contains("unexpected tag")); + } + } + + #[tokio::test] + async fn test_read_header() { + let addr = "127.0.0.1:8080".to_string(); + let config = Arc::new(Config::default()); + let mock = Mock { + config, + addr, + timeout: Duration::from_secs(1), + }; + + let mut reader: Box = Box::new(Cursor::new(b"HEADER_SIZE")); + let result = mock.read_header(&mut reader).await; + assert!(result.is_ok()); + + let mut reader: Box = Box::new(Cursor::new(b"")); + let result = mock.read_header(&mut reader).await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_read_piece_content() { + let addr = "127.0.0.1:8080".to_string(); + let config = Arc::new(Config::default()); + let mock = Mock { + config, + addr, + timeout: Duration::from_secs(1), + }; + + let piece_content = piece_content::PieceContent::new( + 42, + 1024, + 2048, + "a".repeat(32), + "test_parent_id".to_string(), + 1, + Duration::from_secs(5), + chrono::DateTime::from_timestamp(1693152000, 0) + .unwrap() + .naive_utc(), + ); + let piece_content_bytes: Bytes = piece_content.into(); + let mut reader: Box = + Box::new(Cursor::new(piece_content_bytes)); + let result: ClientResult = mock + .read_piece_content(&mut reader, piece_content::METADATA_LENGTH_SIZE) + .await; + assert!(result.is_ok()); + if let Ok(content) = result { + assert_eq!(content.metadata().number, 42); + assert_eq!(content.metadata().offset, 1024); + assert_eq!(content.metadata().length, 2048); + assert_eq!(content.metadata().digest, "a".repeat(32)); + assert_eq!(content.metadata().parent_id, "test_parent_id".to_string()); + assert_eq!(content.metadata().traffic_type, 1); + assert_eq!(content.metadata().cost, Duration::from_secs(5)); + } + + let mut reader: Box = Box::new(Cursor::new(b"")); + let result: ClientResult = mock + .read_piece_content(&mut reader, piece_content::METADATA_LENGTH_SIZE) + .await; + assert!(result.is_err()); + + let mut reader: Box = Box::new(Cursor::new(b"METADATA")); + let result: ClientResult = mock + .read_piece_content(&mut reader, piece_content::METADATA_LENGTH_SIZE) + .await; + assert!(result.is_err()); + + let data = { + let mut bytes = BytesMut::new(); + bytes.put_u32(100); + bytes.put(&vec![0u8; 50][..]); + bytes.freeze() + }; + let mut reader: Box = Box::new(Cursor::new(data)); + let result: ClientResult = mock + .read_piece_content(&mut reader, piece_content::METADATA_LENGTH_SIZE) + .await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_read_error() { + let addr = "127.0.0.1:8080".to_string(); + let config = Arc::new(Config::default()); + let mock = Mock { + config, + addr, + timeout: Duration::from_secs(1), + }; + + let expected_code = vortex_protocol::tlv::error::Code::NotFound; + let expected_message = "Resource not found".to_string(); + let error = VortexError::new(expected_code, expected_message.clone()); + let bytes: Bytes = error.into(); + let bytes_length = bytes.len(); + let mut reader: Box = Box::new(Cursor::new(bytes)); + let result = mock.read_error(&mut reader, bytes_length).await; + assert!(matches!(result, + ClientError::VortexProtocolStatus(code, ref message) + if code == expected_code && message == &expected_message + )); + + let mut reader: Box = Box::new(Cursor::new(b"")); + let result = mock.read_error(&mut reader, 5).await; + assert!(matches!(result, ClientError::IO(_))); + + let mut reader: Box = Box::new(Cursor::new(b"")); + let result = mock.read_error(&mut reader, 0).await; + assert!(matches!(result, ClientError::Unknown(_))); + } +} diff --git a/dragonfly-client-storage/src/client/quic.rs b/dragonfly-client-storage/src/client/quic.rs index 30be13ef..32722ae5 100644 --- a/dragonfly-client-storage/src/client/quic.rs +++ b/dragonfly-client-storage/src/client/quic.rs @@ -14,28 +14,20 @@ * limitations under the License. */ -use bytes::{Bytes, BytesMut}; +use crate::client::Client; +use bytes::Bytes; use dragonfly_client_config::dfdaemon::Config; use dragonfly_client_core::{ error::{ErrorType, OrErr}, Error as ClientError, Result as ClientResult, }; use quinn::crypto::rustls::QuicClientConfig; -use quinn::{AckFrequencyConfig, ClientConfig, Endpoint, RecvStream, SendStream, TransportConfig}; +use quinn::{AckFrequencyConfig, ClientConfig, Endpoint, TransportConfig}; use rustls_pki_types::{CertificateDer, ServerName, UnixTime}; use std::net::SocketAddr; use std::sync::Arc; -use tokio::io::AsyncRead; -use tokio::time; +use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{error, instrument}; -use vortex_protocol::{ - tlv::{ - download_persistent_cache_piece::DownloadPersistentCachePiece, - download_piece::DownloadPiece, error::Error as VortexError, persistent_cache_piece_content, - piece_content, Tag, - }, - Header, Vortex, HEADER_SIZE, -}; /// QUICClient is a QUIC-based client for quic storage service. #[derive(Clone)] @@ -53,118 +45,10 @@ impl QUICClient { pub fn new(config: Arc, addr: String) -> Self { Self { config, addr } } +} - /// Downloads a piece from the server using the vortex protocol. - /// - /// This is the main entry point for downloading a piece. It applies - /// a timeout based on the configuration and handles connection timeouts gracefully. - #[instrument(skip_all)] - pub async fn download_piece( - &self, - number: u32, - task_id: &str, - ) -> ClientResult<(impl AsyncRead, u64, String)> { - time::timeout( - self.config.download.piece_timeout, - self.handle_download_piece(number, task_id), - ) - .await - .inspect_err(|err| { - error!("connect timeout to {}: {}", self.addr, err); - })? - } - /// Internal handler for downloading a piece. - /// - /// This method performs the actual protocol communication: - /// 1. Creates a download piece request. - /// 2. Establishes QUIC connection and sends the request. - /// 3. Reads and validates the response header. - /// 4. Processes the piece content based on the response type. - #[instrument(skip_all)] - async fn handle_download_piece( - &self, - number: u32, - task_id: &str, - ) -> ClientResult<(impl AsyncRead, u64, String)> { - let request: Bytes = Vortex::DownloadPiece( - Header::new_download_piece(), - DownloadPiece::new(task_id.to_string(), number), - ) - .into(); - - let (mut reader, _writer) = self.connect_and_write_request(request).await?; - let header = self.read_header(&mut reader).await?; - match header.tag() { - Tag::PieceContent => { - let piece_content: piece_content::PieceContent = self - .read_piece_content(&mut reader, piece_content::METADATA_LENGTH_SIZE) - .await?; - - let metadata = piece_content.metadata(); - Ok((reader, metadata.offset, metadata.digest)) - } - Tag::Error => Err(self.read_error(&mut reader, header.length() as usize).await), - _ => Err(ClientError::Unknown(format!( - "unexpected tag: {:?}", - header.tag() - ))), - } - } - - /// Downloads a persistent cache piece from the server using the vortex protocol. - /// - /// Similar to `download_piece` but specifically for persistent cache piece. - #[instrument(skip_all)] - pub async fn download_persistent_cache_piece( - &self, - number: u32, - task_id: &str, - ) -> ClientResult<(impl AsyncRead, u64, String)> { - time::timeout( - self.config.download.piece_timeout, - self.handle_download_persistent_cache_piece(number, task_id), - ) - .await - .inspect_err(|err| { - error!("connect timeout to {}: {}", self.addr, err); - })? - } - - /// Internal handler for downloading a persistent cache piece. - /// - /// Implements the same protocol flow as `handle_download_piece` but uses - /// persistent cache specific request/response types. - #[instrument(skip_all)] - async fn handle_download_persistent_cache_piece( - &self, - number: u32, - task_id: &str, - ) -> ClientResult<(impl AsyncRead, u64, String)> { - let request: Bytes = Vortex::DownloadPersistentCachePiece( - Header::new_download_persistent_cache_piece(), - DownloadPersistentCachePiece::new(task_id.to_string(), number), - ) - .into(); - - let (mut reader, _writer) = self.connect_and_write_request(request).await?; - let header = self.read_header(&mut reader).await?; - match header.tag() { - Tag::PersistentCachePieceContent => { - let persistent_cache_piece_content: persistent_cache_piece_content::PersistentCachePieceContent = - self.read_piece_content(&mut reader, persistent_cache_piece_content::METADATA_LENGTH_SIZE) - .await?; - - let metadata = persistent_cache_piece_content.metadata(); - Ok((reader, metadata.offset, metadata.digest)) - } - Tag::Error => Err(self.read_error(&mut reader, header.length() as usize).await), - _ => Err(ClientError::Unknown(format!( - "unexpected tag: {:?}", - header.tag() - ))), - } - } - +#[tonic::async_trait] +impl Client for QUICClient { /// Establishes QUIC connection and writes a vortex protocol request. /// /// This is a low-level utility function that handles the QUIC connection @@ -174,7 +58,10 @@ impl QUICClient { async fn connect_and_write_request( &self, request: Bytes, - ) -> ClientResult<(RecvStream, SendStream)> { + ) -> ClientResult<( + Box, + Box, + )> { let mut client_config = ClientConfig::new(Arc::new( QuicClientConfig::try_from( quinn::rustls::ClientConfig::builder() @@ -218,85 +105,17 @@ impl QUICClient { .await .inspect_err(|err| error!("failed to send request: {}", err))?; - Ok((reader, writer)) + Ok((Box::new(reader), Box::new(writer))) } - /// Reads and parses a vortex protocol header from the QUIC stream. - /// - /// The header contains metadata about the following message, including - /// the message type (tag) and payload length. This is critical for - /// proper protocol message framing. - #[instrument(skip_all)] - async fn read_header(&self, reader: &mut RecvStream) -> ClientResult
{ - let mut header_bytes = BytesMut::with_capacity(HEADER_SIZE); - header_bytes.resize(HEADER_SIZE, 0); - reader - .read_exact(&mut header_bytes) - .await - .inspect_err(|err| error!("failed to receive header: {}", err))?; - - Header::try_from(header_bytes.freeze()).map_err(Into::into) + /// Access to client configuration. + fn config(&self) -> &Arc { + &self.config } - /// Reads and parses piece content with variable-length metadata. - /// - /// This generic function handles the two-stage reading process for - /// piece content: first reading the metadata length, then reading - /// the actual metadata, and finally constructing the complete message. - #[instrument(skip_all)] - async fn read_piece_content( - &self, - reader: &mut RecvStream, - metadata_length_size: usize, - ) -> ClientResult - where - T: TryFrom>, - { - let mut metadata_length_bytes = BytesMut::with_capacity(metadata_length_size); - metadata_length_bytes.resize(metadata_length_size, 0); - reader - .read_exact(&mut metadata_length_bytes) - .await - .inspect_err(|err| error!("failed to receive metadata length: {}", err))?; - let metadata_length = u32::from_be_bytes(metadata_length_bytes[..].try_into()?) as usize; - - let mut metadata_bytes = BytesMut::with_capacity(metadata_length); - metadata_bytes.resize(metadata_length, 0); - reader - .read_exact(&mut metadata_bytes) - .await - .inspect_err(|err| error!("failed to receive metadata: {}", err))?; - - let mut content_bytes = BytesMut::with_capacity(metadata_length_size + metadata_length); - content_bytes.extend_from_slice(&metadata_length_bytes); - content_bytes.extend_from_slice(&metadata_bytes); - content_bytes.freeze().try_into().map_err(Into::into) - } - - /// Reads and processes error responses from the server. - /// - /// When the server responds with an error tag, this function reads - /// the error payload and converts it into an appropriate client error. - /// This provides structured error handling for protocol-level failures. - #[instrument(skip_all)] - async fn read_error(&self, reader: &mut RecvStream, header_length: usize) -> ClientError { - let mut error_bytes = BytesMut::with_capacity(header_length); - error_bytes.resize(header_length, 0); - if let Err(err) = reader.read_exact(&mut error_bytes).await { - error!("failed to receive error: {}", err); - return ClientError::Unknown(err.to_string()); - }; - - error_bytes - .freeze() - .try_into() - .map(|error: VortexError| { - ClientError::VortexProtocolStatus(error.code(), error.message().to_string()) - }) - .unwrap_or_else(|err| { - error!("failed to extract error: {}", err); - ClientError::Unknown(format!("failed to extract error: {}", err)) - }) + /// Access to client address. + fn addr(&self) -> &str { + &self.addr } } diff --git a/dragonfly-client-storage/src/client/tcp.rs b/dragonfly-client-storage/src/client/tcp.rs index 2a4b8da0..a840cf04 100644 --- a/dragonfly-client-storage/src/client/tcp.rs +++ b/dragonfly-client-storage/src/client/tcp.rs @@ -14,23 +14,14 @@ * limitations under the License. */ -use bytes::{Bytes, BytesMut}; +use crate::client::Client; +use bytes::Bytes; use dragonfly_client_config::dfdaemon::Config; -use dragonfly_client_core::{Error as ClientError, Result as ClientResult}; +use dragonfly_client_core::Result as ClientResult; use socket2::{SockRef, TcpKeepalive}; use std::sync::Arc; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt}; -use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; -use tokio::time; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tracing::{error, instrument}; -use vortex_protocol::{ - tlv::{ - download_persistent_cache_piece::DownloadPersistentCachePiece, - download_piece::DownloadPiece, error::Error as VortexError, persistent_cache_piece_content, - piece_content, Tag, - }, - Header, Vortex, HEADER_SIZE, -}; /// TCPClient is a TCP-based client for tcp storage service. #[derive(Clone)] @@ -48,118 +39,10 @@ impl TCPClient { pub fn new(config: Arc, addr: String) -> Self { Self { config, addr } } +} - /// Downloads a piece from the server using the vortex protocol. - /// - /// This is the main entry point for downloading a piece. It applies - /// a timeout based on the configuration and handles connection timeouts gracefully. - #[instrument(skip_all)] - pub async fn download_piece( - &self, - number: u32, - task_id: &str, - ) -> ClientResult<(impl AsyncRead, u64, String)> { - time::timeout( - self.config.download.piece_timeout, - self.handle_download_piece(number, task_id), - ) - .await - .inspect_err(|err| { - error!("connect timeout to {}: {}", self.addr, err); - })? - } - /// Internal handler for downloading a piece. - /// - /// This method performs the actual protocol communication: - /// 1. Creates a download piece request. - /// 2. Establishes TCP connection and sends the request. - /// 3. Reads and validates the response header. - /// 4. Processes the piece content based on the response type. - #[instrument(skip_all)] - async fn handle_download_piece( - &self, - number: u32, - task_id: &str, - ) -> ClientResult<(impl AsyncRead, u64, String)> { - let request: Bytes = Vortex::DownloadPiece( - Header::new_download_piece(), - DownloadPiece::new(task_id.to_string(), number), - ) - .into(); - - let (mut reader, _writer) = self.connect_and_write_request(request).await?; - let header = self.read_header(&mut reader).await?; - match header.tag() { - Tag::PieceContent => { - let piece_content: piece_content::PieceContent = self - .read_piece_content(&mut reader, piece_content::METADATA_LENGTH_SIZE) - .await?; - - let metadata = piece_content.metadata(); - Ok((reader, metadata.offset, metadata.digest)) - } - Tag::Error => Err(self.read_error(&mut reader, header.length() as usize).await), - _ => Err(ClientError::Unknown(format!( - "unexpected tag: {:?}", - header.tag() - ))), - } - } - - /// Downloads a persistent cache piece from the server using the vortex protocol. - /// - /// Similar to `download_piece` but specifically for persistent cache piece. - #[instrument(skip_all)] - pub async fn download_persistent_cache_piece( - &self, - number: u32, - task_id: &str, - ) -> ClientResult<(impl AsyncRead, u64, String)> { - time::timeout( - self.config.download.piece_timeout, - self.handle_download_persistent_cache_piece(number, task_id), - ) - .await - .inspect_err(|err| { - error!("connect timeout to {}: {}", self.addr, err); - })? - } - - /// Internal handler for downloading a persistent cache piece. - /// - /// Implements the same protocol flow as `handle_download_piece` but uses - /// persistent cache specific request/response types. - #[instrument(skip_all)] - async fn handle_download_persistent_cache_piece( - &self, - number: u32, - task_id: &str, - ) -> ClientResult<(impl AsyncRead, u64, String)> { - let request: Bytes = Vortex::DownloadPersistentCachePiece( - Header::new_download_persistent_cache_piece(), - DownloadPersistentCachePiece::new(task_id.to_string(), number), - ) - .into(); - - let (mut reader, _writer) = self.connect_and_write_request(request).await?; - let header = self.read_header(&mut reader).await?; - match header.tag() { - Tag::PersistentCachePieceContent => { - let persistent_cache_piece_content: persistent_cache_piece_content::PersistentCachePieceContent = - self.read_piece_content(&mut reader, piece_content::METADATA_LENGTH_SIZE) - .await?; - - let metadata = persistent_cache_piece_content.metadata(); - Ok((reader, metadata.offset, metadata.digest)) - } - Tag::Error => Err(self.read_error(&mut reader, header.length() as usize).await), - _ => Err(ClientError::Unknown(format!( - "unexpected tag: {:?}", - header.tag() - ))), - } - } - +#[tonic::async_trait] +impl Client for TCPClient { /// Establishes TCP connection and writes a vortex protocol request. /// /// This is a low-level utility function that handles the TCP connection @@ -169,7 +52,10 @@ impl TCPClient { async fn connect_and_write_request( &self, request: Bytes, - ) -> ClientResult<(OwnedReadHalf, OwnedWriteHalf)> { + ) -> ClientResult<( + Box, + Box, + )> { let stream = tokio::net::TcpStream::connect(self.addr.clone()).await?; let socket = SockRef::from(&stream); socket.set_tcp_nodelay(true)?; @@ -189,90 +75,16 @@ impl TCPClient { error!("failed to flush request: {}", err); })?; - Ok((reader, writer)) + Ok((Box::new(reader), Box::new(writer))) } - /// Reads and parses a vortex protocol header from the TCP stream. - /// - /// The header contains metadata about the following message, including - /// the message type (tag) and payload length. This is critical for - /// proper protocol message framing. - #[instrument(skip_all)] - async fn read_header(&self, reader: &mut OwnedReadHalf) -> ClientResult
{ - let mut header_bytes = BytesMut::with_capacity(HEADER_SIZE); - header_bytes.resize(HEADER_SIZE, 0); - reader - .read_exact(&mut header_bytes) - .await - .inspect_err(|err| { - error!("failed to receive header: {}", err); - })?; - - Header::try_from(header_bytes.freeze()).map_err(Into::into) + /// Access to client configuration. + fn config(&self) -> &Arc { + &self.config } - /// Reads and parses piece content with variable-length metadata. - /// - /// This generic function handles the two-stage reading process for - /// piece content: first reading the metadata length, then reading - /// the actual metadata, and finally constructing the complete message. - #[instrument(skip_all)] - async fn read_piece_content( - &self, - reader: &mut OwnedReadHalf, - metadata_length_size: usize, - ) -> ClientResult - where - T: TryFrom>, - { - let mut metadata_length_bytes = BytesMut::with_capacity(metadata_length_size); - metadata_length_bytes.resize(metadata_length_size, 0); - reader - .read_exact(&mut metadata_length_bytes) - .await - .inspect_err(|err| { - error!("failed to receive metadata length: {}", err); - })?; - let metadata_length = u32::from_be_bytes(metadata_length_bytes[..].try_into()?) as usize; - - let mut metadata_bytes = BytesMut::with_capacity(metadata_length); - metadata_bytes.resize(metadata_length, 0); - reader - .read_exact(&mut metadata_bytes) - .await - .inspect_err(|err| { - error!("failed to receive metadata: {}", err); - })?; - - let mut content_bytes = BytesMut::with_capacity(metadata_length_size + metadata_length); - content_bytes.extend_from_slice(&metadata_length_bytes); - content_bytes.extend_from_slice(&metadata_bytes); - content_bytes.freeze().try_into().map_err(Into::into) - } - - /// Reads and processes error responses from the server. - /// - /// When the server responds with an error tag, this function reads - /// the error payload and converts it into an appropriate client error. - /// This provides structured error handling for protocol-level failures. - #[instrument(skip_all)] - async fn read_error(&self, reader: &mut OwnedReadHalf, header_length: usize) -> ClientError { - let mut error_bytes = BytesMut::with_capacity(header_length); - error_bytes.resize(header_length, 0); - if let Err(err) = reader.read_exact(&mut error_bytes).await { - error!("failed to receive error: {}", err); - return ClientError::IO(err); - }; - - error_bytes - .freeze() - .try_into() - .map(|error: VortexError| { - ClientError::VortexProtocolStatus(error.code(), error.message().to_string()) - }) - .unwrap_or_else(|err| { - error!("failed to extract error: {}", err); - ClientError::Unknown(format!("failed to extract error: {}", err)) - }) + /// Access to client address. + fn addr(&self) -> &str { + &self.addr } } diff --git a/dragonfly-client-storage/src/server/tcp.rs b/dragonfly-client-storage/src/server/tcp.rs index 2f6d8192..0fd47eea 100644 --- a/dragonfly-client-storage/src/server/tcp.rs +++ b/dragonfly-client-storage/src/server/tcp.rs @@ -503,3 +503,415 @@ impl TCPServerHandler { Ok(()) } } + +#[cfg(test)] +mod tests { + use super::*; + use dragonfly_client_config::dfdaemon::Config; + use std::time::Duration; + use tempfile::tempdir; + + async fn create_test_tcp_server_handler() -> TCPServerHandler { + let config = Arc::new(Config::default()); + let dir = tempdir().unwrap(); + let log_dir = dir.path().join("log"); + + let storage = Storage::new(config.clone(), dir.path(), log_dir) + .await + .unwrap(); + let storage = Arc::new(storage); + + let id_generator = IDGenerator::new( + "127.0.0.1".to_string(), + config.host.hostname.clone(), + config.seed_peer.enable, + ); + let id_generator = Arc::new(id_generator); + + let upload_rate_limiter = Arc::new( + RateLimiter::builder() + .initial(config.upload.rate_limit.as_u64() as usize) + .refill(config.upload.rate_limit.as_u64() as usize) + .max(config.upload.rate_limit.as_u64() as usize) + .interval(Duration::from_secs(1)) + .fair(false) + .build(), + ); + + TCPServerHandler { + id_generator, + storage, + upload_rate_limiter, + } + } + + #[tokio::test] + async fn test_read_header_success() { + let handler = create_test_tcp_server_handler().await; + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let server_handle = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + let (mut reader, _writer) = stream.into_split(); + handler.read_header(&mut reader).await + }); + + let client_stream = TcpStream::connect(addr).await.unwrap(); + let (_reader, mut writer) = client_stream.into_split(); + + let header = Header::new(Tag::DownloadPiece, 100); + let header_bytes: Bytes = header.into(); + writer.write_all(&header_bytes).await.unwrap(); + writer.flush().await.unwrap(); + + let result = server_handle.await.unwrap(); + assert!(result.is_ok()); + if let Ok(header) = result { + assert_eq!(header.tag(), Tag::DownloadPiece); + assert_eq!(header.length(), 100); + } + } + + #[tokio::test] + async fn test_read_header_insufficient_data() { + let handler = create_test_tcp_server_handler().await; + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let server_handle = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + let (mut reader, _writer) = stream.into_split(); + handler.read_header(&mut reader).await + }); + + let client_stream = TcpStream::connect(addr).await.unwrap(); + let (_reader, mut writer) = client_stream.into_split(); + + let partial_data = vec![0u8; HEADER_SIZE - 1]; + writer.write_all(&partial_data).await.unwrap(); + writer.flush().await.unwrap(); + drop(writer); + + let result = server_handle.await.unwrap(); + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_read_header_connection_closed() { + let handler = create_test_tcp_server_handler().await; + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let server_handle = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + let (mut reader, _writer) = stream.into_split(); + handler.read_header(&mut reader).await + }); + + let client_stream = TcpStream::connect(addr).await.unwrap(); + drop(client_stream); + + let result = server_handle.await.unwrap(); + assert!(result.is_err()); + } + + const HEADER_LENGTH: usize = 68; + + #[tokio::test] + async fn test_read_download_piece_success() { + let handler = create_test_tcp_server_handler().await; + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let server_handle = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + let (mut reader, _writer) = stream.into_split(); + handler + .read_download_piece(&mut reader, HEADER_LENGTH) + .await + }); + + let client_stream = TcpStream::connect(addr).await.unwrap(); + let (_reader, mut writer) = client_stream.into_split(); + + let task_id = "a".repeat(64); + let piece_number = 42; + let download_piece = DownloadPiece::new(task_id.clone(), piece_number); + let bytes: Bytes = download_piece.into(); + writer.write_all(&bytes).await.unwrap(); + writer.flush().await.unwrap(); + + let result: ClientResult = server_handle.await.unwrap(); + assert!(result.is_ok()); + if let Ok(download_piece) = result { + assert_eq!(download_piece.task_id(), task_id); + assert_eq!(download_piece.piece_number(), piece_number); + } + } + + #[tokio::test] + async fn test_read_download_piece_insufficient_data() { + let handler = create_test_tcp_server_handler().await; + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let server_handle = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + let (mut reader, _writer) = stream.into_split(); + handler + .read_download_piece(&mut reader, HEADER_LENGTH) + .await + }); + + let client_stream = TcpStream::connect(addr).await.unwrap(); + let (_reader, mut writer) = client_stream.into_split(); + + let partial_data = vec![0u8; HEADER_LENGTH - 1]; + writer.write_all(&partial_data).await.unwrap(); + writer.flush().await.unwrap(); + drop(writer); + + let result: ClientResult = server_handle.await.unwrap(); + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_read_download_piece_connection_closed() { + let handler = create_test_tcp_server_handler().await; + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let server_handle = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + let (mut reader, _writer) = stream.into_split(); + handler + .read_download_piece(&mut reader, HEADER_LENGTH) + .await + }); + + let client_stream = TcpStream::connect(addr).await.unwrap(); + drop(client_stream); + + let result: ClientResult = server_handle.await.unwrap(); + assert!(result.is_err()); + } + + async fn create_test_tcp_pair() -> (TcpStream, TcpStream) { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let client_task = tokio::spawn(async move { TcpStream::connect(addr).await.unwrap() }); + + let (server_stream, _) = listener.accept().await.unwrap(); + let client_stream = client_task.await.unwrap(); + + (server_stream, client_stream) + } + + #[tokio::test] + async fn test_write_response_success() { + let handler = create_test_tcp_server_handler().await; + let (server_stream, mut client_stream) = create_test_tcp_pair().await; + let (_server_reader, mut server_writer) = server_stream.into_split(); + + let test_data = Bytes::from("Hello from server!"); + + let write_task = tokio::spawn(async move { + handler + .write_response(test_data.clone(), &mut server_writer) + .await + }); + + let mut response_buffer = vec![0u8; 1024]; + let bytes_read = client_stream.read(&mut response_buffer).await.unwrap(); + + assert_eq!(&response_buffer[..bytes_read], b"Hello from server!"); + assert!(write_task.await.unwrap().is_ok()); + } + + #[tokio::test] + async fn test_write_response_large_data() { + let handler = create_test_tcp_server_handler().await; + let (server_stream, mut client_stream) = create_test_tcp_pair().await; + let (_server_reader, mut server_writer) = server_stream.into_split(); + + let large_data = Bytes::from(vec![42u8; 1024 * 1024]); + + let write_task = tokio::spawn(async move { + handler + .write_response(large_data.clone(), &mut server_writer) + .await + }); + + let mut total_received = 0; + let mut buffer = vec![0u8; 8192]; + + while total_received < 1024 * 1024 { + match client_stream.read(&mut buffer).await { + Ok(0) => break, + Ok(n) => { + total_received += n; + assert!(buffer[..n].iter().all(|&b| b == 42)); + } + Err(e) => panic!("Read error: {}", e), + } + } + + assert_eq!(total_received, 1024 * 1024); + assert!(write_task.await.unwrap().is_ok()); + } + + #[tokio::test] + async fn test_write_response_client_slow_read() { + let handler = create_test_tcp_server_handler().await; + let (server_stream, mut client_stream) = create_test_tcp_pair().await; + let (_server_reader, mut server_writer) = server_stream.into_split(); + + let test_data = Bytes::from(vec![1u8; 64 * 1024]); + + let write_task = + tokio::spawn( + async move { handler.write_response(test_data, &mut server_writer).await }, + ); + + tokio::spawn(async move { + let mut buffer = vec![0u8; 1024]; + let mut total_read = 0; + + while total_read < 64 * 1024 { + tokio::time::sleep(Duration::from_millis(10)).await; + match client_stream.read(&mut buffer).await { + Ok(0) => break, + Ok(n) => total_read += n, + Err(_) => break, + } + } + }); + + let result = tokio::time::timeout(Duration::from_secs(5), write_task).await; + assert!(result.is_ok()); + assert!(result.unwrap().unwrap().is_ok()); + } + + #[tokio::test] + async fn test_write_stream_success() { + let handler = create_test_tcp_server_handler().await; + let (server_stream, mut client_stream) = create_test_tcp_pair().await; + let (_server_reader, mut server_writer) = server_stream.into_split(); + + let test_data = b"Stream content for testing".to_vec(); + let mut mock_stream = std::io::Cursor::new(test_data.clone()); + + let write_task = tokio::spawn(async move { + handler + .write_stream(&mut mock_stream, &mut server_writer) + .await + }); + + let mut response_buffer = vec![0u8; 1024]; + let bytes_read = client_stream.read(&mut response_buffer).await.unwrap(); + + assert_eq!(&response_buffer[..bytes_read], test_data.as_slice()); + assert!(write_task.await.unwrap().is_ok()); + } + + #[tokio::test] + async fn test_write_stream_large_stream() { + let handler = create_test_tcp_server_handler().await; + let (server_stream, mut client_stream) = create_test_tcp_pair().await; + let (_server_reader, mut server_writer) = server_stream.into_split(); + + let large_data = vec![123u8; 10 * 1024 * 1024]; + let mut mock_stream = std::io::Cursor::new(large_data.clone()); + + let write_task = tokio::spawn(async move { + handler + .write_stream(&mut mock_stream, &mut server_writer) + .await + }); + + let mut total_received = 0; + let mut buffer = vec![0u8; 64 * 1024]; + + while total_received < 10 * 1024 * 1024 { + match tokio::time::timeout(Duration::from_secs(1), client_stream.read(&mut buffer)) + .await + { + Ok(Ok(0)) => break, + Ok(Ok(n)) => { + total_received += n; + assert!(buffer[..n].iter().all(|&b| b == 123)); + } + Ok(Err(e)) => panic!("Read error: {}", e), + Err(_) => panic!("Read timeout"), + } + } + + assert_eq!(total_received, 10 * 1024 * 1024); + assert!(write_task.await.unwrap().is_ok()); + } + + #[tokio::test] + async fn test_write_stream_with_slow_source() { + let handler = create_test_tcp_server_handler().await; + let (server_stream, mut client_stream) = create_test_tcp_pair().await; + let (_server_reader, mut server_writer) = server_stream.into_split(); + + struct SlowReader { + data: std::io::Cursor>, + delay: Duration, + } + + impl AsyncRead for SlowReader { + fn poll_read( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + cx.waker().wake_by_ref(); + std::thread::sleep(self.delay); + + let mut temp_buf = vec![0u8; std::cmp::min(buf.remaining(), 10)]; + match std::io::Read::read(&mut self.data, &mut temp_buf) { + Ok(0) => std::task::Poll::Ready(Ok(())), + Ok(n) => { + buf.put_slice(&temp_buf[..n]); + std::task::Poll::Ready(Ok(())) + } + Err(e) => std::task::Poll::Ready(Err(e)), + } + } + } + + let test_data = b"Slow stream data".to_vec(); + let mut slow_stream = SlowReader { + data: std::io::Cursor::new(test_data.clone()), + delay: Duration::from_millis(5), + }; + + let write_task = tokio::spawn(async move { + handler + .write_stream(&mut slow_stream, &mut server_writer) + .await + }); + + let mut response_buffer = vec![0u8; 1024]; + let bytes_read = tokio::time::timeout( + Duration::from_secs(2), + client_stream.read(&mut response_buffer), + ) + .await + .unwrap() + .unwrap(); + + assert_eq!(&response_buffer[..bytes_read], test_data.as_slice()); + assert!(write_task.await.unwrap().is_ok()); + } +} diff --git a/dragonfly-client/src/resource/piece_downloader.rs b/dragonfly-client/src/resource/piece_downloader.rs index 6947ec1e..df8a08ff 100644 --- a/dragonfly-client/src/resource/piece_downloader.rs +++ b/dragonfly-client/src/resource/piece_downloader.rs @@ -18,7 +18,9 @@ use crate::grpc::dfdaemon_upload::DfdaemonUploadClient; use dragonfly_api::dfdaemon::v2::{DownloadPersistentCachePieceRequest, DownloadPieceRequest}; use dragonfly_client_config::dfdaemon::Config; use dragonfly_client_core::{Error, Result}; -use dragonfly_client_storage::{client::quic::QUICClient, client::tcp::TCPClient, metadata}; +use dragonfly_client_storage::{ + client::quic::QUICClient, client::tcp::TCPClient, client::Client, metadata, +}; use dragonfly_client_util::pool::{Builder as PoolBuilder, Entry, Factory, Pool}; use std::io::Cursor; use std::sync::Arc;