From f97fd1bff26403c5725a503af0065a2deb5e991a Mon Sep 17 00:00:00 2001 From: Gaius Date: Thu, 1 Feb 2024 15:12:09 +0800 Subject: [PATCH] feat: optimize proxy request by dfdaemon (#249) Signed-off-by: Gaius --- src/proxy/mod.rs | 572 ++++++++++++++++++++++++----------------------- 1 file changed, 288 insertions(+), 284 deletions(-) diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs index 428d2903..121faf5a 100644 --- a/src/proxy/mod.rs +++ b/src/proxy/mod.rs @@ -14,7 +14,7 @@ * limitations under the License. */ -use crate::config::dfdaemon::Config; +use crate::config::dfdaemon::{Config, Rule}; use crate::grpc::dfdaemon_download::DfdaemonDownloadClient; use crate::shutdown; use crate::task::Task; @@ -24,7 +24,9 @@ use crate::utils::http::{ use crate::{Error as ClientError, Result as ClientResult}; use bytes::Bytes; use dragonfly_api::common::v2::{Download, TaskType}; -use dragonfly_api::dfdaemon::v2::{download_task_response, DownloadTaskRequest}; +use dragonfly_api::dfdaemon::v2::{ + download_task_response, DownloadTaskRequest, DownloadTaskStartedResponse, +}; use futures_util::TryStreamExt; use http_body_util::{combinators::BoxBody, BodyExt, Empty, Full, StreamBody}; use hyper::body::Frame; @@ -46,6 +48,7 @@ use tracing::{error, info, instrument, Span}; pub mod header; +// Response is the response of the proxy server. pub type Response = hyper::Response>; // Proxy is the proxy server. @@ -88,8 +91,6 @@ impl Proxy { #[instrument(skip_all)] pub async fn run(&self) -> ClientResult<()> { let listener = TcpListener::bind(self.addr).await?; - - // Start the proxy server and wait for it to finish. info!("proxy server listening on {}", self.addr); loop { @@ -161,281 +162,19 @@ pub async fn http_handler( task: Arc, request: Request, ) -> ClientResult { - let Some(host) = request.uri().host() else { - error!("CONNECT host is not socket addr: {:?}", request.uri()); - let mut response = Response::new(full("CONNECT must be to a socket address")); - *response.status_mut() = http::StatusCode::BAD_REQUEST; - return Ok(response); - }; - let port = request.uri().port_u16().unwrap_or(80); - - if let Some(rules) = config.proxy.rules.clone() { - for rule in rules.iter() { - if rule.regex.is_match(request.uri().to_string().as_str()) { - // Convert the Reqwest header to the Hyper header. - let request_header = hyper_headermap_to_reqwest_headermap(request.headers()); - - // Construct the download url. - let url = - match make_download_url(request.uri(), rule.use_tls, rule.redirect.clone()) { - Ok(url) => url, - Err(err) => { - let mut response = Response::new(full( - err.to_string().to_string().as_bytes().to_vec(), - )); - *response.status_mut() = http::StatusCode::BAD_REQUEST; - return Ok(response); - } - }; - - // Get parameters from the header. - let tag = header::get_tag(&request_header); - let application = header::get_application(&request_header); - let priority = header::get_priority(&request_header); - let piece_length = header::get_piece_length(&request_header); - let filtered_query_params = header::get_filtered_query_params( - &request_header, - rule.filtered_query_params.clone(), - ); - let request_header = reqwest_headermap_to_hashmap(&request_header); - - // Initialize the dfdaemon download client. - let dfdaemon_download_client = match DfdaemonDownloadClient::new_unix( - config.download.server.socket_path.clone(), - ) - .await - { - Ok(client) => client, - Err(err) => { - let mut response = - Response::new(full(err.to_string().to_string().as_bytes().to_vec())); - *response.status_mut() = http::StatusCode::INTERNAL_SERVER_ERROR; - return Ok(response); - } - }; - - // Download the task by the dfdaemon download client. - let response = match dfdaemon_download_client - .download_task(DownloadTaskRequest { - download: Some(Download { - url, - digest: None, - // Download range use header range in HTTP protocol. - range: None, - r#type: TaskType::Dfdaemon as i32, - tag, - application, - priority, - filtered_query_params, - request_header, - piece_length, - output_path: None, - timeout: None, - need_back_to_source: false, - }), - }) - .await - { - Ok(response) => response, - Err(err) => { - let mut response = - Response::new(full(err.to_string().to_string().as_bytes().to_vec())); - *response.status_mut() = http::StatusCode::BAD_REQUEST; - return Ok(response); - } - }; - - // Handle the response from the download grpc server. - let mut out_stream = response.into_inner(); - let Ok(Some(message)) = out_stream.message().await else { - let mut response = - Response::new(full("download task response failed".as_bytes().to_vec())); - *response.status_mut() = http::StatusCode::INTERNAL_SERVER_ERROR; - return Ok(response); - }; - - let download_task_started_response = match message.response { - Some(download_task_response::Response::DownloadTaskStartedResponse( - mut response, - )) => { - // Insert the content range header to the resopnse header. - if let Some(range) = response.range.as_ref() { - response.response_header.insert( - reqwest::header::CONTENT_RANGE.to_string(), - format!( - "bytes {}-{}/{}", - range.start, - range.start + range.length - 1, - response.content_length - ), - ); - } - - response - } - _ => { - let mut response = Response::new(full( - "download task response is not started".as_bytes().to_vec(), - )); - *response.status_mut() = http::StatusCode::INTERNAL_SERVER_ERROR; - return Ok(response); - } - }; - - // Write the task data to the reader. - let (reader, mut writer) = tokio::io::duplex(1024); - - // Construct the response body. - let reader_stream = ReaderStream::new(reader); - let stream_body = - StreamBody::new(reader_stream.map_ok(Frame::data).map_err(ClientError::from)); - let boxed_body = stream_body.boxed(); - - // Construct the response. - let mut response = Response::new(boxed_body); - *response.headers_mut() = - hashmap_to_hyper_header_map(&download_task_started_response.response_header)?; - *response.status_mut() = http::StatusCode::OK; - - // Write task data to pipe. If grpc received error message, - // shutdown the writer. - tokio::spawn(async move { - // Initialize the hashmap of the finished piece readers. - let mut finished_piece_readers = HashMap::new(); - - while let Some(message) = match out_stream.message().await { - Ok(message) => message, - Err(err) => { - error!("download task response error: {}", err); - if let Err(err) = writer.shutdown().await { - error!("writer shutdown error: {}", err); - } - - return; - } - } { - match message.response { - Some( - download_task_response::Response::DownloadTaskStartedResponse(_), - ) => { - error!("download task started response is duplicated"); - if let Err(err) = writer.shutdown().await { - error!("writer shutdown error: {}", err); - } - - return; - } - Some( - download_task_response::Response::DownloadPieceFinishedResponse( - response, - ), - ) => { - let piece = match response.piece { - Some(piece) => piece, - None => { - error!("download piece finished response piece is empty"); - if let Err(err) = writer.shutdown().await { - error!("writer shutdown error: {}", err); - } - - return; - } - }; - - let mut need_piece_number = match download_task_started_response - .pieces - .first() - { - Some(piece) => piece.number, - None => { - error!("download task started response pieces is empty"); - if let Err(err) = writer.shutdown().await { - error!("writer shutdown error: {}", err); - } - - return; - } - }; - - let piece_reader = match task - .piece - .download_from_local_peer_into_async_read( - message.task_id.as_str(), - piece.number, - piece.length, - download_task_started_response.range.clone(), - true, - ) - .await - { - Ok(piece_reader) => piece_reader, - Err(err) => { - error!("download piece reader error: {}", err); - if let Err(err) = writer.shutdown().await { - error!("writer shutdown error: {}", err); - } - - return; - } - }; - - // Sort by piece number and return to reader in order. - finished_piece_readers.insert(piece.number, piece_reader); - while let Some(piece_reader) = - finished_piece_readers.get_mut(&need_piece_number) - { - info!("copy piece to stream: {}", need_piece_number); - if let Err(err) = - tokio::io::copy(piece_reader, &mut writer).await - { - error!("download piece reader error: {}", err); - if let Err(err) = writer.shutdown().await { - error!("writer shutdown error: {}", err); - } - - return; - } - need_piece_number += 1; - } - } - None => { - error!("download task response is empty"); - if let Err(err) = writer.shutdown().await { - error!("writer shutdown error: {}", err); - } - - return; - } - } - } - - info!("copy finished"); - }); - - // Send response - return Ok(response); - } - } + let request_uri = request.uri(); + if let Some(rule) = + find_matching_rule(config.proxy.rules.clone(), request_uri.to_string().as_str()) + { + info!("proxy HTTP request via dfdaemon for URI: {}", request_uri); + return proxy_http_by_dfdaemon(config, task, rule.clone(), request).await; } - // Proxy the request to the remote server directly. - info!("proxy http request to remote server directly"); - let stream = TcpStream::connect((host, port)).await?; - let io = TokioIo::new(stream); - let (mut sender, conn) = Builder::new() - .preserve_header_case(true) - .title_case_headers(true) - .handshake(io) - .await?; - - tokio::task::spawn(async move { - if let Err(err) = conn.await { - error!("connection failed: {:?}", err); - } - }); - - let response = sender.send_request(request).await?; - Ok(response.map(|b| b.map_err(ClientError::from).boxed())) + info!( + "proxy HTTP request directly to remote server for URI: {}", + request_uri + ); + proxy_http(request).await } // https_handler handles the https request. @@ -455,8 +194,8 @@ pub async fn https_handler( } } - // Proxy the request to the remote server directly. - info!("proxy https request to remote server directly"); + // Proxy the request directly to the remote server. + info!("proxy HTTPS request directly to remote server"); if let Some(addr) = host_addr(request.uri()) { tokio::task::spawn(async move { match hyper::upgrade::on(request).await { @@ -471,14 +210,243 @@ pub async fn https_handler( Ok(Response::new(empty())) } else { - error!("CONNECT host is not socket addr: {:?}", request.uri()); - let mut response = Response::new(full("CONNECT must be to a socket address")); - *response.status_mut() = http::StatusCode::BAD_REQUEST; - - Ok(response) + return Ok(make_error_response( + http::StatusCode::BAD_REQUEST, + "CONNECT must be to a socket address", + )); } } +// proxy_http_by_dfdaemon proxies the HTTP request via the dfdaemon. +async fn proxy_http_by_dfdaemon( + config: Arc, + task: Arc, + rule: Rule, + request: Request, +) -> ClientResult { + // Initialize the dfdaemon download client. + let dfdaemon_download_client = + match DfdaemonDownloadClient::new_unix(config.download.server.socket_path.clone()).await { + Ok(client) => client, + Err(err) => { + error!("create dfdaemon download client failed: {}", err); + return Ok(make_error_response( + http::StatusCode::INTERNAL_SERVER_ERROR, + err.to_string().as_str(), + )); + } + }; + + // Make the download task request. + let download_task_request = match make_download_task_request(request, rule) { + Ok(download_task_request) => download_task_request, + Err(err) => { + error!("make download task request failed: {}", err); + return Ok(make_error_response( + http::StatusCode::INTERNAL_SERVER_ERROR, + err.to_string().as_str(), + )); + } + }; + + // Download the task by the dfdaemon download client. + let response = match dfdaemon_download_client + .download_task(download_task_request) + .await + { + Ok(response) => response, + Err(err) => { + error!("initiate download task failed: {}", err); + return Ok(make_error_response( + http::StatusCode::INTERNAL_SERVER_ERROR, + err.to_string().as_str(), + )); + } + }; + + // Handle the response from the download grpc server. + let mut out_stream = response.into_inner(); + let Ok(Some(message)) = out_stream.message().await else { + error!("response message failed"); + return Ok(make_error_response( + http::StatusCode::INTERNAL_SERVER_ERROR, + "response message failed", + )); + }; + + // Handle the download task started response. + let Some(download_task_response::Response::DownloadTaskStartedResponse( + download_task_started_response, + )) = message.response + else { + error!("response is not started"); + return Ok(make_error_response( + http::StatusCode::INTERNAL_SERVER_ERROR, + "response is not started", + )); + }; + + // Write the task data to the reader. + let (reader, mut writer) = tokio::io::duplex(1024); + + // Construct the response body. + let reader_stream = ReaderStream::new(reader); + let stream_body = StreamBody::new(reader_stream.map_ok(Frame::data).map_err(ClientError::from)); + let boxed_body = stream_body.boxed(); + + // Construct the response. + let mut response = Response::new(boxed_body); + *response.headers_mut() = make_response_headers(download_task_started_response.clone())?; + *response.status_mut() = http::StatusCode::OK; + + // Write task data to pipe. If grpc received error message, + // shutdown the writer. + tokio::spawn(async move { + // Initialize the hashmap of the finished piece readers and pieces. + let mut finished_piece_readers = HashMap::new(); + + // Get the first piece number from the started response. + let Some(first_piece) = download_task_started_response.pieces.first() else { + error!("reponse pieces is empty"); + if let Err(err) = writer.shutdown().await { + error!("writer shutdown error: {}", err); + } + + return; + }; + let mut need_piece_number = first_piece.number; + + // Read piece data from stream and write to pipe. If the piece data is + // not in order, store it in the hashmap, and write it to the pipe + // when the previous piece data is written. + while let Ok(Some(message)) = out_stream.message().await { + if let Some(download_task_response::Response::DownloadPieceFinishedResponse(response)) = + message.response + { + let Some(piece) = response.piece else { + error!("response piece is empty"); + if let Err(err) = writer.shutdown().await { + error!("writer shutdown error: {}", err); + } + + return; + }; + + let piece_reader = match task + .piece + .download_from_local_peer_into_async_read( + message.task_id.as_str(), + piece.number, + piece.length, + download_task_started_response.range.clone(), + true, + ) + .await + { + Ok(piece_reader) => piece_reader, + Err(err) => { + error!("download piece reader error: {}", err); + if let Err(err) = writer.shutdown().await { + error!("writer shutdown error: {}", err); + } + + return; + } + }; + + // Write the piece data to the pipe in order. + finished_piece_readers.insert(piece.number, piece_reader); + while let Some(piece_reader) = finished_piece_readers.get_mut(&need_piece_number) { + info!("copy piece to stream: {}", need_piece_number); + if let Err(err) = tokio::io::copy(piece_reader, &mut writer).await { + error!("download piece reader error: {}", err); + if let Err(err) = writer.shutdown().await { + error!("writer shutdown error: {}", err); + } + + return; + } + + need_piece_number += 1; + } + } else { + error!("response unknown message"); + if let Err(err) = writer.shutdown().await { + error!("writer shutdown error: {}", err); + } + + return; + } + } + + info!("copy finished"); + }); + + Ok(response) +} + +// proxy_http proxies the HTTP request directly to the remote server. +async fn proxy_http(request: Request) -> ClientResult { + let Some(host) = request.uri().host() else { + error!("CONNECT host is not socket addr: {:?}", request.uri()); + return Ok(make_error_response( + http::StatusCode::BAD_REQUEST, + "CONNECT must be to a socket address", + )); + }; + let port = request.uri().port_u16().unwrap_or(80); + + let stream = TcpStream::connect((host, port)).await?; + let io = TokioIo::new(stream); + let (mut sender, conn) = Builder::new() + .preserve_header_case(true) + .title_case_headers(true) + .handshake(io) + .await?; + + tokio::task::spawn(async move { + if let Err(err) = conn.await { + error!("connection failed: {:?}", err); + } + }); + + let response = sender.send_request(request).await?; + Ok(response.map(|b| b.map_err(ClientError::from).boxed())) +} + +// make_download_task_requet makes a download task request by the request. +#[instrument(skip_all)] +fn make_download_task_request( + request: Request, + rule: Rule, +) -> ClientResult { + // Convert the Reqwest header to the Hyper header. + let reqwest_request_header = hyper_headermap_to_reqwest_headermap(request.headers()); + + // Construct the download url. + Ok(DownloadTaskRequest { + download: Some(Download { + url: make_download_url(request.uri(), rule.use_tls, rule.redirect.clone())?, + digest: None, + // Download range use header range in HTTP protocol. + range: None, + r#type: TaskType::Dfdaemon as i32, + tag: header::get_tag(&reqwest_request_header), + application: header::get_application(&reqwest_request_header), + priority: header::get_priority(&reqwest_request_header), + filtered_query_params: header::get_filtered_query_params( + &reqwest_request_header, + rule.filtered_query_params.clone(), + ), + request_header: reqwest_headermap_to_hashmap(&reqwest_request_header), + piece_length: header::get_piece_length(&reqwest_request_header), + output_path: None, + timeout: None, + need_back_to_source: false, + }), + }) +} + // make_download_url makes a download url by the given uri. #[instrument(skip_all)] fn make_download_url( @@ -503,6 +471,42 @@ fn make_download_url( Ok(http::Uri::from_parts(parts)?.to_string()) } +// make_response_headers makes the response headers. +#[instrument(skip_all)] +fn make_response_headers( + mut download_task_started_response: DownloadTaskStartedResponse, +) -> ClientResult { + // Insert the content range header to the resopnse header. + if let Some(range) = download_task_started_response.range.as_ref() { + download_task_started_response.response_header.insert( + reqwest::header::CONTENT_RANGE.to_string(), + format!( + "bytes {}-{}/{}", + range.start, + range.start + range.length - 1, + download_task_started_response.content_length + ), + ); + }; + + hashmap_to_hyper_header_map(&download_task_started_response.response_header) +} + +// find_matching_rule returns whether the dfdaemon should be used to download the task. +// If the dfdaemon should be used, return the matched rule. +#[instrument(skip_all)] +fn find_matching_rule(rules: Option>, url: &str) -> Option { + rules?.iter().find(|rule| rule.regex.is_match(url)).cloned() +} + +// make_error_response makes an error response with the given status and message. +#[instrument(skip_all)] +fn make_error_response(status: http::StatusCode, message: &str) -> Response { + let mut response = Response::new(full(message.as_bytes().to_vec())); + *response.status_mut() = status; + response +} + // empty returns an empty body. #[instrument(skip_all)] fn empty() -> BoxBody {