feat: optimize proxy request by dfdaemon (#249)

Signed-off-by: Gaius <gaius.qi@gmail.com>
This commit is contained in:
Gaius 2024-02-01 15:12:09 +08:00 committed by GitHub
parent 4e8a2e2dc0
commit f97fd1bff2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 288 additions and 284 deletions

View File

@ -14,7 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
use crate::config::dfdaemon::Config; use crate::config::dfdaemon::{Config, Rule};
use crate::grpc::dfdaemon_download::DfdaemonDownloadClient; use crate::grpc::dfdaemon_download::DfdaemonDownloadClient;
use crate::shutdown; use crate::shutdown;
use crate::task::Task; use crate::task::Task;
@ -24,7 +24,9 @@ use crate::utils::http::{
use crate::{Error as ClientError, Result as ClientResult}; use crate::{Error as ClientError, Result as ClientResult};
use bytes::Bytes; use bytes::Bytes;
use dragonfly_api::common::v2::{Download, TaskType}; use dragonfly_api::common::v2::{Download, TaskType};
use dragonfly_api::dfdaemon::v2::{download_task_response, DownloadTaskRequest}; use dragonfly_api::dfdaemon::v2::{
download_task_response, DownloadTaskRequest, DownloadTaskStartedResponse,
};
use futures_util::TryStreamExt; use futures_util::TryStreamExt;
use http_body_util::{combinators::BoxBody, BodyExt, Empty, Full, StreamBody}; use http_body_util::{combinators::BoxBody, BodyExt, Empty, Full, StreamBody};
use hyper::body::Frame; use hyper::body::Frame;
@ -46,6 +48,7 @@ use tracing::{error, info, instrument, Span};
pub mod header; pub mod header;
// Response is the response of the proxy server.
pub type Response = hyper::Response<BoxBody<Bytes, ClientError>>; pub type Response = hyper::Response<BoxBody<Bytes, ClientError>>;
// Proxy is the proxy server. // Proxy is the proxy server.
@ -88,8 +91,6 @@ impl Proxy {
#[instrument(skip_all)] #[instrument(skip_all)]
pub async fn run(&self) -> ClientResult<()> { pub async fn run(&self) -> ClientResult<()> {
let listener = TcpListener::bind(self.addr).await?; let listener = TcpListener::bind(self.addr).await?;
// Start the proxy server and wait for it to finish.
info!("proxy server listening on {}", self.addr); info!("proxy server listening on {}", self.addr);
loop { loop {
@ -161,125 +162,128 @@ pub async fn http_handler(
task: Arc<Task>, task: Arc<Task>,
request: Request<hyper::body::Incoming>, request: Request<hyper::body::Incoming>,
) -> ClientResult<Response> { ) -> ClientResult<Response> {
let Some(host) = request.uri().host() else { let request_uri = request.uri();
error!("CONNECT host is not socket addr: {:?}", request.uri()); if let Some(rule) =
let mut response = Response::new(full("CONNECT must be to a socket address")); find_matching_rule(config.proxy.rules.clone(), request_uri.to_string().as_str())
*response.status_mut() = http::StatusCode::BAD_REQUEST; {
return Ok(response); info!("proxy HTTP request via dfdaemon for URI: {}", request_uri);
}; return proxy_http_by_dfdaemon(config, task, rule.clone(), request).await;
let port = request.uri().port_u16().unwrap_or(80); }
info!(
"proxy HTTP request directly to remote server for URI: {}",
request_uri
);
proxy_http(request).await
}
// https_handler handles the https request.
#[instrument(skip_all)]
pub async fn https_handler(
config: Arc<Config>,
request: Request<hyper::body::Incoming>,
) -> ClientResult<Response> {
if let Some(rules) = config.proxy.rules.clone() { if let Some(rules) = config.proxy.rules.clone() {
for rule in rules.iter() { for rule in rules.iter() {
if rule.regex.is_match(request.uri().to_string().as_str()) { if rule.regex.is_match(request.uri().to_string().as_str()) {
// Convert the Reqwest header to the Hyper header. // TODO: handle https request.
let request_header = hyper_headermap_to_reqwest_headermap(request.headers()); let mut response = Response::new(full("CONNECT must be to a socket address"));
// Construct the download url.
let url =
match make_download_url(request.uri(), rule.use_tls, rule.redirect.clone()) {
Ok(url) => url,
Err(err) => {
let mut response = Response::new(full(
err.to_string().to_string().as_bytes().to_vec(),
));
*response.status_mut() = http::StatusCode::BAD_REQUEST; *response.status_mut() = http::StatusCode::BAD_REQUEST;
return Ok(response); return Ok(response);
} }
}
}
// Proxy the request directly to the remote server.
info!("proxy HTTPS request directly to remote server");
if let Some(addr) = host_addr(request.uri()) {
tokio::task::spawn(async move {
match hyper::upgrade::on(request).await {
Ok(upgraded) => {
if let Err(e) = tunnel(upgraded, addr).await {
error!("server io error: {}", e);
}; };
}
Err(e) => error!("upgrade error: {}", e),
}
});
// Get parameters from the header. Ok(Response::new(empty()))
let tag = header::get_tag(&request_header); } else {
let application = header::get_application(&request_header); return Ok(make_error_response(
let priority = header::get_priority(&request_header); http::StatusCode::BAD_REQUEST,
let piece_length = header::get_piece_length(&request_header); "CONNECT must be to a socket address",
let filtered_query_params = header::get_filtered_query_params( ));
&request_header, }
rule.filtered_query_params.clone(), }
);
let request_header = reqwest_headermap_to_hashmap(&request_header);
// proxy_http_by_dfdaemon proxies the HTTP request via the dfdaemon.
async fn proxy_http_by_dfdaemon(
config: Arc<Config>,
task: Arc<Task>,
rule: Rule,
request: Request<hyper::body::Incoming>,
) -> ClientResult<Response> {
// Initialize the dfdaemon download client. // Initialize the dfdaemon download client.
let dfdaemon_download_client = match DfdaemonDownloadClient::new_unix( let dfdaemon_download_client =
config.download.server.socket_path.clone(), match DfdaemonDownloadClient::new_unix(config.download.server.socket_path.clone()).await {
)
.await
{
Ok(client) => client, Ok(client) => client,
Err(err) => { Err(err) => {
let mut response = error!("create dfdaemon download client failed: {}", err);
Response::new(full(err.to_string().to_string().as_bytes().to_vec())); return Ok(make_error_response(
*response.status_mut() = http::StatusCode::INTERNAL_SERVER_ERROR; http::StatusCode::INTERNAL_SERVER_ERROR,
return Ok(response); err.to_string().as_str(),
));
}
};
// Make the download task request.
let download_task_request = match make_download_task_request(request, rule) {
Ok(download_task_request) => download_task_request,
Err(err) => {
error!("make download task request failed: {}", err);
return Ok(make_error_response(
http::StatusCode::INTERNAL_SERVER_ERROR,
err.to_string().as_str(),
));
} }
}; };
// Download the task by the dfdaemon download client. // Download the task by the dfdaemon download client.
let response = match dfdaemon_download_client let response = match dfdaemon_download_client
.download_task(DownloadTaskRequest { .download_task(download_task_request)
download: Some(Download {
url,
digest: None,
// Download range use header range in HTTP protocol.
range: None,
r#type: TaskType::Dfdaemon as i32,
tag,
application,
priority,
filtered_query_params,
request_header,
piece_length,
output_path: None,
timeout: None,
need_back_to_source: false,
}),
})
.await .await
{ {
Ok(response) => response, Ok(response) => response,
Err(err) => { Err(err) => {
let mut response = error!("initiate download task failed: {}", err);
Response::new(full(err.to_string().to_string().as_bytes().to_vec())); return Ok(make_error_response(
*response.status_mut() = http::StatusCode::BAD_REQUEST; http::StatusCode::INTERNAL_SERVER_ERROR,
return Ok(response); err.to_string().as_str(),
));
} }
}; };
// Handle the response from the download grpc server. // Handle the response from the download grpc server.
let mut out_stream = response.into_inner(); let mut out_stream = response.into_inner();
let Ok(Some(message)) = out_stream.message().await else { let Ok(Some(message)) = out_stream.message().await else {
let mut response = error!("response message failed");
Response::new(full("download task response failed".as_bytes().to_vec())); return Ok(make_error_response(
*response.status_mut() = http::StatusCode::INTERNAL_SERVER_ERROR; http::StatusCode::INTERNAL_SERVER_ERROR,
return Ok(response); "response message failed",
));
}; };
let download_task_started_response = match message.response { // Handle the download task started response.
Some(download_task_response::Response::DownloadTaskStartedResponse( let Some(download_task_response::Response::DownloadTaskStartedResponse(
mut response, download_task_started_response,
)) => { )) = message.response
// Insert the content range header to the resopnse header. else {
if let Some(range) = response.range.as_ref() { error!("response is not started");
response.response_header.insert( return Ok(make_error_response(
reqwest::header::CONTENT_RANGE.to_string(), http::StatusCode::INTERNAL_SERVER_ERROR,
format!( "response is not started",
"bytes {}-{}/{}",
range.start,
range.start + range.length - 1,
response.content_length
),
);
}
response
}
_ => {
let mut response = Response::new(full(
"download task response is not started".as_bytes().to_vec(),
)); ));
*response.status_mut() = http::StatusCode::INTERNAL_SERVER_ERROR;
return Ok(response);
}
}; };
// Write the task data to the reader. // Write the task data to the reader.
@ -287,74 +291,45 @@ pub async fn http_handler(
// Construct the response body. // Construct the response body.
let reader_stream = ReaderStream::new(reader); let reader_stream = ReaderStream::new(reader);
let stream_body = let stream_body = StreamBody::new(reader_stream.map_ok(Frame::data).map_err(ClientError::from));
StreamBody::new(reader_stream.map_ok(Frame::data).map_err(ClientError::from));
let boxed_body = stream_body.boxed(); let boxed_body = stream_body.boxed();
// Construct the response. // Construct the response.
let mut response = Response::new(boxed_body); let mut response = Response::new(boxed_body);
*response.headers_mut() = *response.headers_mut() = make_response_headers(download_task_started_response.clone())?;
hashmap_to_hyper_header_map(&download_task_started_response.response_header)?;
*response.status_mut() = http::StatusCode::OK; *response.status_mut() = http::StatusCode::OK;
// Write task data to pipe. If grpc received error message, // Write task data to pipe. If grpc received error message,
// shutdown the writer. // shutdown the writer.
tokio::spawn(async move { tokio::spawn(async move {
// Initialize the hashmap of the finished piece readers. // Initialize the hashmap of the finished piece readers and pieces.
let mut finished_piece_readers = HashMap::new(); let mut finished_piece_readers = HashMap::new();
while let Some(message) = match out_stream.message().await { // Get the first piece number from the started response.
Ok(message) => message, let Some(first_piece) = download_task_started_response.pieces.first() else {
Err(err) => { error!("reponse pieces is empty");
error!("download task response error: {}", err);
if let Err(err) = writer.shutdown().await { if let Err(err) = writer.shutdown().await {
error!("writer shutdown error: {}", err); error!("writer shutdown error: {}", err);
} }
return; return;
}
} {
match message.response {
Some(
download_task_response::Response::DownloadTaskStartedResponse(_),
) => {
error!("download task started response is duplicated");
if let Err(err) = writer.shutdown().await {
error!("writer shutdown error: {}", err);
}
return;
}
Some(
download_task_response::Response::DownloadPieceFinishedResponse(
response,
),
) => {
let piece = match response.piece {
Some(piece) => piece,
None => {
error!("download piece finished response piece is empty");
if let Err(err) = writer.shutdown().await {
error!("writer shutdown error: {}", err);
}
return;
}
}; };
let mut need_piece_number = first_piece.number;
let mut need_piece_number = match download_task_started_response // Read piece data from stream and write to pipe. If the piece data is
.pieces // not in order, store it in the hashmap, and write it to the pipe
.first() // when the previous piece data is written.
while let Ok(Some(message)) = out_stream.message().await {
if let Some(download_task_response::Response::DownloadPieceFinishedResponse(response)) =
message.response
{ {
Some(piece) => piece.number, let Some(piece) = response.piece else {
None => { error!("response piece is empty");
error!("download task started response pieces is empty");
if let Err(err) = writer.shutdown().await { if let Err(err) = writer.shutdown().await {
error!("writer shutdown error: {}", err); error!("writer shutdown error: {}", err);
} }
return; return;
}
}; };
let piece_reader = match task let piece_reader = match task
@ -379,15 +354,11 @@ pub async fn http_handler(
} }
}; };
// Sort by piece number and return to reader in order. // Write the piece data to the pipe in order.
finished_piece_readers.insert(piece.number, piece_reader); finished_piece_readers.insert(piece.number, piece_reader);
while let Some(piece_reader) = while let Some(piece_reader) = finished_piece_readers.get_mut(&need_piece_number) {
finished_piece_readers.get_mut(&need_piece_number)
{
info!("copy piece to stream: {}", need_piece_number); info!("copy piece to stream: {}", need_piece_number);
if let Err(err) = if let Err(err) = tokio::io::copy(piece_reader, &mut writer).await {
tokio::io::copy(piece_reader, &mut writer).await
{
error!("download piece reader error: {}", err); error!("download piece reader error: {}", err);
if let Err(err) = writer.shutdown().await { if let Err(err) = writer.shutdown().await {
error!("writer shutdown error: {}", err); error!("writer shutdown error: {}", err);
@ -395,11 +366,11 @@ pub async fn http_handler(
return; return;
} }
need_piece_number += 1; need_piece_number += 1;
} }
} } else {
None => { error!("response unknown message");
error!("download task response is empty");
if let Err(err) = writer.shutdown().await { if let Err(err) = writer.shutdown().await {
error!("writer shutdown error: {}", err); error!("writer shutdown error: {}", err);
} }
@ -407,19 +378,24 @@ pub async fn http_handler(
return; return;
} }
} }
}
info!("copy finished"); info!("copy finished");
}); });
// Send response Ok(response)
return Ok(response); }
}
} // proxy_http proxies the HTTP request directly to the remote server.
} async fn proxy_http(request: Request<hyper::body::Incoming>) -> ClientResult<Response> {
let Some(host) = request.uri().host() else {
error!("CONNECT host is not socket addr: {:?}", request.uri());
return Ok(make_error_response(
http::StatusCode::BAD_REQUEST,
"CONNECT must be to a socket address",
));
};
let port = request.uri().port_u16().unwrap_or(80);
// Proxy the request to the remote server directly.
info!("proxy http request to remote server directly");
let stream = TcpStream::connect((host, port)).await?; let stream = TcpStream::connect((host, port)).await?;
let io = TokioIo::new(stream); let io = TokioIo::new(stream);
let (mut sender, conn) = Builder::new() let (mut sender, conn) = Builder::new()
@ -438,45 +414,37 @@ pub async fn http_handler(
Ok(response.map(|b| b.map_err(ClientError::from).boxed())) Ok(response.map(|b| b.map_err(ClientError::from).boxed()))
} }
// https_handler handles the https request. // make_download_task_requet makes a download task request by the request.
#[instrument(skip_all)] #[instrument(skip_all)]
pub async fn https_handler( fn make_download_task_request(
config: Arc<Config>,
request: Request<hyper::body::Incoming>, request: Request<hyper::body::Incoming>,
) -> ClientResult<Response> { rule: Rule,
if let Some(rules) = config.proxy.rules.clone() { ) -> ClientResult<DownloadTaskRequest> {
for rule in rules.iter() { // Convert the Reqwest header to the Hyper header.
if rule.regex.is_match(request.uri().to_string().as_str()) { let reqwest_request_header = hyper_headermap_to_reqwest_headermap(request.headers());
// TODO: handle https request.
let mut response = Response::new(full("CONNECT must be to a socket address"));
*response.status_mut() = http::StatusCode::BAD_REQUEST;
return Ok(response);
}
}
}
// Proxy the request to the remote server directly. // Construct the download url.
info!("proxy https request to remote server directly"); Ok(DownloadTaskRequest {
if let Some(addr) = host_addr(request.uri()) { download: Some(Download {
tokio::task::spawn(async move { url: make_download_url(request.uri(), rule.use_tls, rule.redirect.clone())?,
match hyper::upgrade::on(request).await { digest: None,
Ok(upgraded) => { // Download range use header range in HTTP protocol.
if let Err(e) = tunnel(upgraded, addr).await { range: None,
error!("server io error: {}", e); r#type: TaskType::Dfdaemon as i32,
}; tag: header::get_tag(&reqwest_request_header),
} application: header::get_application(&reqwest_request_header),
Err(e) => error!("upgrade error: {}", e), priority: header::get_priority(&reqwest_request_header),
} filtered_query_params: header::get_filtered_query_params(
}); &reqwest_request_header,
rule.filtered_query_params.clone(),
Ok(Response::new(empty())) ),
} else { request_header: reqwest_headermap_to_hashmap(&reqwest_request_header),
error!("CONNECT host is not socket addr: {:?}", request.uri()); piece_length: header::get_piece_length(&reqwest_request_header),
let mut response = Response::new(full("CONNECT must be to a socket address")); output_path: None,
*response.status_mut() = http::StatusCode::BAD_REQUEST; timeout: None,
need_back_to_source: false,
Ok(response) }),
} })
} }
// make_download_url makes a download url by the given uri. // make_download_url makes a download url by the given uri.
@ -503,6 +471,42 @@ fn make_download_url(
Ok(http::Uri::from_parts(parts)?.to_string()) Ok(http::Uri::from_parts(parts)?.to_string())
} }
// make_response_headers makes the response headers.
#[instrument(skip_all)]
fn make_response_headers(
mut download_task_started_response: DownloadTaskStartedResponse,
) -> ClientResult<hyper::header::HeaderMap> {
// Insert the content range header to the resopnse header.
if let Some(range) = download_task_started_response.range.as_ref() {
download_task_started_response.response_header.insert(
reqwest::header::CONTENT_RANGE.to_string(),
format!(
"bytes {}-{}/{}",
range.start,
range.start + range.length - 1,
download_task_started_response.content_length
),
);
};
hashmap_to_hyper_header_map(&download_task_started_response.response_header)
}
// find_matching_rule returns whether the dfdaemon should be used to download the task.
// If the dfdaemon should be used, return the matched rule.
#[instrument(skip_all)]
fn find_matching_rule(rules: Option<Vec<Rule>>, url: &str) -> Option<Rule> {
rules?.iter().find(|rule| rule.regex.is_match(url)).cloned()
}
// make_error_response makes an error response with the given status and message.
#[instrument(skip_all)]
fn make_error_response(status: http::StatusCode, message: &str) -> Response {
let mut response = Response::new(full(message.as_bytes().to_vec()));
*response.status_mut() = status;
response
}
// empty returns an empty body. // empty returns an empty body.
#[instrument(skip_all)] #[instrument(skip_all)]
fn empty() -> BoxBody<Bytes, ClientError> { fn empty() -> BoxBody<Bytes, ClientError> {