feat: optimize proxy request by dfdaemon (#249)

Signed-off-by: Gaius <gaius.qi@gmail.com>
This commit is contained in:
Gaius 2024-02-01 15:12:09 +08:00 committed by GitHub
parent 4e8a2e2dc0
commit f97fd1bff2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 288 additions and 284 deletions

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
use crate::config::dfdaemon::Config;
use crate::config::dfdaemon::{Config, Rule};
use crate::grpc::dfdaemon_download::DfdaemonDownloadClient;
use crate::shutdown;
use crate::task::Task;
@ -24,7 +24,9 @@ use crate::utils::http::{
use crate::{Error as ClientError, Result as ClientResult};
use bytes::Bytes;
use dragonfly_api::common::v2::{Download, TaskType};
use dragonfly_api::dfdaemon::v2::{download_task_response, DownloadTaskRequest};
use dragonfly_api::dfdaemon::v2::{
download_task_response, DownloadTaskRequest, DownloadTaskStartedResponse,
};
use futures_util::TryStreamExt;
use http_body_util::{combinators::BoxBody, BodyExt, Empty, Full, StreamBody};
use hyper::body::Frame;
@ -46,6 +48,7 @@ use tracing::{error, info, instrument, Span};
pub mod header;
// Response is the response of the proxy server.
pub type Response = hyper::Response<BoxBody<Bytes, ClientError>>;
// Proxy is the proxy server.
@ -88,8 +91,6 @@ impl Proxy {
#[instrument(skip_all)]
pub async fn run(&self) -> ClientResult<()> {
let listener = TcpListener::bind(self.addr).await?;
// Start the proxy server and wait for it to finish.
info!("proxy server listening on {}", self.addr);
loop {
@ -161,281 +162,19 @@ pub async fn http_handler(
task: Arc<Task>,
request: Request<hyper::body::Incoming>,
) -> ClientResult<Response> {
let Some(host) = request.uri().host() else {
error!("CONNECT host is not socket addr: {:?}", request.uri());
let mut response = Response::new(full("CONNECT must be to a socket address"));
*response.status_mut() = http::StatusCode::BAD_REQUEST;
return Ok(response);
};
let port = request.uri().port_u16().unwrap_or(80);
if let Some(rules) = config.proxy.rules.clone() {
for rule in rules.iter() {
if rule.regex.is_match(request.uri().to_string().as_str()) {
// Convert the Reqwest header to the Hyper header.
let request_header = hyper_headermap_to_reqwest_headermap(request.headers());
// Construct the download url.
let url =
match make_download_url(request.uri(), rule.use_tls, rule.redirect.clone()) {
Ok(url) => url,
Err(err) => {
let mut response = Response::new(full(
err.to_string().to_string().as_bytes().to_vec(),
));
*response.status_mut() = http::StatusCode::BAD_REQUEST;
return Ok(response);
}
};
// Get parameters from the header.
let tag = header::get_tag(&request_header);
let application = header::get_application(&request_header);
let priority = header::get_priority(&request_header);
let piece_length = header::get_piece_length(&request_header);
let filtered_query_params = header::get_filtered_query_params(
&request_header,
rule.filtered_query_params.clone(),
);
let request_header = reqwest_headermap_to_hashmap(&request_header);
// Initialize the dfdaemon download client.
let dfdaemon_download_client = match DfdaemonDownloadClient::new_unix(
config.download.server.socket_path.clone(),
)
.await
{
Ok(client) => client,
Err(err) => {
let mut response =
Response::new(full(err.to_string().to_string().as_bytes().to_vec()));
*response.status_mut() = http::StatusCode::INTERNAL_SERVER_ERROR;
return Ok(response);
}
};
// Download the task by the dfdaemon download client.
let response = match dfdaemon_download_client
.download_task(DownloadTaskRequest {
download: Some(Download {
url,
digest: None,
// Download range use header range in HTTP protocol.
range: None,
r#type: TaskType::Dfdaemon as i32,
tag,
application,
priority,
filtered_query_params,
request_header,
piece_length,
output_path: None,
timeout: None,
need_back_to_source: false,
}),
})
.await
{
Ok(response) => response,
Err(err) => {
let mut response =
Response::new(full(err.to_string().to_string().as_bytes().to_vec()));
*response.status_mut() = http::StatusCode::BAD_REQUEST;
return Ok(response);
}
};
// Handle the response from the download grpc server.
let mut out_stream = response.into_inner();
let Ok(Some(message)) = out_stream.message().await else {
let mut response =
Response::new(full("download task response failed".as_bytes().to_vec()));
*response.status_mut() = http::StatusCode::INTERNAL_SERVER_ERROR;
return Ok(response);
};
let download_task_started_response = match message.response {
Some(download_task_response::Response::DownloadTaskStartedResponse(
mut response,
)) => {
// Insert the content range header to the resopnse header.
if let Some(range) = response.range.as_ref() {
response.response_header.insert(
reqwest::header::CONTENT_RANGE.to_string(),
format!(
"bytes {}-{}/{}",
range.start,
range.start + range.length - 1,
response.content_length
),
);
}
response
}
_ => {
let mut response = Response::new(full(
"download task response is not started".as_bytes().to_vec(),
));
*response.status_mut() = http::StatusCode::INTERNAL_SERVER_ERROR;
return Ok(response);
}
};
// Write the task data to the reader.
let (reader, mut writer) = tokio::io::duplex(1024);
// Construct the response body.
let reader_stream = ReaderStream::new(reader);
let stream_body =
StreamBody::new(reader_stream.map_ok(Frame::data).map_err(ClientError::from));
let boxed_body = stream_body.boxed();
// Construct the response.
let mut response = Response::new(boxed_body);
*response.headers_mut() =
hashmap_to_hyper_header_map(&download_task_started_response.response_header)?;
*response.status_mut() = http::StatusCode::OK;
// Write task data to pipe. If grpc received error message,
// shutdown the writer.
tokio::spawn(async move {
// Initialize the hashmap of the finished piece readers.
let mut finished_piece_readers = HashMap::new();
while let Some(message) = match out_stream.message().await {
Ok(message) => message,
Err(err) => {
error!("download task response error: {}", err);
if let Err(err) = writer.shutdown().await {
error!("writer shutdown error: {}", err);
}
return;
}
} {
match message.response {
Some(
download_task_response::Response::DownloadTaskStartedResponse(_),
) => {
error!("download task started response is duplicated");
if let Err(err) = writer.shutdown().await {
error!("writer shutdown error: {}", err);
}
return;
}
Some(
download_task_response::Response::DownloadPieceFinishedResponse(
response,
),
) => {
let piece = match response.piece {
Some(piece) => piece,
None => {
error!("download piece finished response piece is empty");
if let Err(err) = writer.shutdown().await {
error!("writer shutdown error: {}", err);
}
return;
}
};
let mut need_piece_number = match download_task_started_response
.pieces
.first()
{
Some(piece) => piece.number,
None => {
error!("download task started response pieces is empty");
if let Err(err) = writer.shutdown().await {
error!("writer shutdown error: {}", err);
}
return;
}
};
let piece_reader = match task
.piece
.download_from_local_peer_into_async_read(
message.task_id.as_str(),
piece.number,
piece.length,
download_task_started_response.range.clone(),
true,
)
.await
{
Ok(piece_reader) => piece_reader,
Err(err) => {
error!("download piece reader error: {}", err);
if let Err(err) = writer.shutdown().await {
error!("writer shutdown error: {}", err);
}
return;
}
};
// Sort by piece number and return to reader in order.
finished_piece_readers.insert(piece.number, piece_reader);
while let Some(piece_reader) =
finished_piece_readers.get_mut(&need_piece_number)
{
info!("copy piece to stream: {}", need_piece_number);
if let Err(err) =
tokio::io::copy(piece_reader, &mut writer).await
{
error!("download piece reader error: {}", err);
if let Err(err) = writer.shutdown().await {
error!("writer shutdown error: {}", err);
}
return;
}
need_piece_number += 1;
}
}
None => {
error!("download task response is empty");
if let Err(err) = writer.shutdown().await {
error!("writer shutdown error: {}", err);
}
return;
}
}
}
info!("copy finished");
});
// Send response
return Ok(response);
}
}
let request_uri = request.uri();
if let Some(rule) =
find_matching_rule(config.proxy.rules.clone(), request_uri.to_string().as_str())
{
info!("proxy HTTP request via dfdaemon for URI: {}", request_uri);
return proxy_http_by_dfdaemon(config, task, rule.clone(), request).await;
}
// Proxy the request to the remote server directly.
info!("proxy http request to remote server directly");
let stream = TcpStream::connect((host, port)).await?;
let io = TokioIo::new(stream);
let (mut sender, conn) = Builder::new()
.preserve_header_case(true)
.title_case_headers(true)
.handshake(io)
.await?;
tokio::task::spawn(async move {
if let Err(err) = conn.await {
error!("connection failed: {:?}", err);
}
});
let response = sender.send_request(request).await?;
Ok(response.map(|b| b.map_err(ClientError::from).boxed()))
info!(
"proxy HTTP request directly to remote server for URI: {}",
request_uri
);
proxy_http(request).await
}
// https_handler handles the https request.
@ -455,8 +194,8 @@ pub async fn https_handler(
}
}
// Proxy the request to the remote server directly.
info!("proxy https request to remote server directly");
// Proxy the request directly to the remote server.
info!("proxy HTTPS request directly to remote server");
if let Some(addr) = host_addr(request.uri()) {
tokio::task::spawn(async move {
match hyper::upgrade::on(request).await {
@ -471,14 +210,243 @@ pub async fn https_handler(
Ok(Response::new(empty()))
} else {
error!("CONNECT host is not socket addr: {:?}", request.uri());
let mut response = Response::new(full("CONNECT must be to a socket address"));
*response.status_mut() = http::StatusCode::BAD_REQUEST;
Ok(response)
return Ok(make_error_response(
http::StatusCode::BAD_REQUEST,
"CONNECT must be to a socket address",
));
}
}
// proxy_http_by_dfdaemon proxies the HTTP request via the dfdaemon.
async fn proxy_http_by_dfdaemon(
config: Arc<Config>,
task: Arc<Task>,
rule: Rule,
request: Request<hyper::body::Incoming>,
) -> ClientResult<Response> {
// Initialize the dfdaemon download client.
let dfdaemon_download_client =
match DfdaemonDownloadClient::new_unix(config.download.server.socket_path.clone()).await {
Ok(client) => client,
Err(err) => {
error!("create dfdaemon download client failed: {}", err);
return Ok(make_error_response(
http::StatusCode::INTERNAL_SERVER_ERROR,
err.to_string().as_str(),
));
}
};
// Make the download task request.
let download_task_request = match make_download_task_request(request, rule) {
Ok(download_task_request) => download_task_request,
Err(err) => {
error!("make download task request failed: {}", err);
return Ok(make_error_response(
http::StatusCode::INTERNAL_SERVER_ERROR,
err.to_string().as_str(),
));
}
};
// Download the task by the dfdaemon download client.
let response = match dfdaemon_download_client
.download_task(download_task_request)
.await
{
Ok(response) => response,
Err(err) => {
error!("initiate download task failed: {}", err);
return Ok(make_error_response(
http::StatusCode::INTERNAL_SERVER_ERROR,
err.to_string().as_str(),
));
}
};
// Handle the response from the download grpc server.
let mut out_stream = response.into_inner();
let Ok(Some(message)) = out_stream.message().await else {
error!("response message failed");
return Ok(make_error_response(
http::StatusCode::INTERNAL_SERVER_ERROR,
"response message failed",
));
};
// Handle the download task started response.
let Some(download_task_response::Response::DownloadTaskStartedResponse(
download_task_started_response,
)) = message.response
else {
error!("response is not started");
return Ok(make_error_response(
http::StatusCode::INTERNAL_SERVER_ERROR,
"response is not started",
));
};
// Write the task data to the reader.
let (reader, mut writer) = tokio::io::duplex(1024);
// Construct the response body.
let reader_stream = ReaderStream::new(reader);
let stream_body = StreamBody::new(reader_stream.map_ok(Frame::data).map_err(ClientError::from));
let boxed_body = stream_body.boxed();
// Construct the response.
let mut response = Response::new(boxed_body);
*response.headers_mut() = make_response_headers(download_task_started_response.clone())?;
*response.status_mut() = http::StatusCode::OK;
// Write task data to pipe. If grpc received error message,
// shutdown the writer.
tokio::spawn(async move {
// Initialize the hashmap of the finished piece readers and pieces.
let mut finished_piece_readers = HashMap::new();
// Get the first piece number from the started response.
let Some(first_piece) = download_task_started_response.pieces.first() else {
error!("reponse pieces is empty");
if let Err(err) = writer.shutdown().await {
error!("writer shutdown error: {}", err);
}
return;
};
let mut need_piece_number = first_piece.number;
// Read piece data from stream and write to pipe. If the piece data is
// not in order, store it in the hashmap, and write it to the pipe
// when the previous piece data is written.
while let Ok(Some(message)) = out_stream.message().await {
if let Some(download_task_response::Response::DownloadPieceFinishedResponse(response)) =
message.response
{
let Some(piece) = response.piece else {
error!("response piece is empty");
if let Err(err) = writer.shutdown().await {
error!("writer shutdown error: {}", err);
}
return;
};
let piece_reader = match task
.piece
.download_from_local_peer_into_async_read(
message.task_id.as_str(),
piece.number,
piece.length,
download_task_started_response.range.clone(),
true,
)
.await
{
Ok(piece_reader) => piece_reader,
Err(err) => {
error!("download piece reader error: {}", err);
if let Err(err) = writer.shutdown().await {
error!("writer shutdown error: {}", err);
}
return;
}
};
// Write the piece data to the pipe in order.
finished_piece_readers.insert(piece.number, piece_reader);
while let Some(piece_reader) = finished_piece_readers.get_mut(&need_piece_number) {
info!("copy piece to stream: {}", need_piece_number);
if let Err(err) = tokio::io::copy(piece_reader, &mut writer).await {
error!("download piece reader error: {}", err);
if let Err(err) = writer.shutdown().await {
error!("writer shutdown error: {}", err);
}
return;
}
need_piece_number += 1;
}
} else {
error!("response unknown message");
if let Err(err) = writer.shutdown().await {
error!("writer shutdown error: {}", err);
}
return;
}
}
info!("copy finished");
});
Ok(response)
}
// proxy_http proxies the HTTP request directly to the remote server.
async fn proxy_http(request: Request<hyper::body::Incoming>) -> ClientResult<Response> {
let Some(host) = request.uri().host() else {
error!("CONNECT host is not socket addr: {:?}", request.uri());
return Ok(make_error_response(
http::StatusCode::BAD_REQUEST,
"CONNECT must be to a socket address",
));
};
let port = request.uri().port_u16().unwrap_or(80);
let stream = TcpStream::connect((host, port)).await?;
let io = TokioIo::new(stream);
let (mut sender, conn) = Builder::new()
.preserve_header_case(true)
.title_case_headers(true)
.handshake(io)
.await?;
tokio::task::spawn(async move {
if let Err(err) = conn.await {
error!("connection failed: {:?}", err);
}
});
let response = sender.send_request(request).await?;
Ok(response.map(|b| b.map_err(ClientError::from).boxed()))
}
// make_download_task_requet makes a download task request by the request.
#[instrument(skip_all)]
fn make_download_task_request(
request: Request<hyper::body::Incoming>,
rule: Rule,
) -> ClientResult<DownloadTaskRequest> {
// Convert the Reqwest header to the Hyper header.
let reqwest_request_header = hyper_headermap_to_reqwest_headermap(request.headers());
// Construct the download url.
Ok(DownloadTaskRequest {
download: Some(Download {
url: make_download_url(request.uri(), rule.use_tls, rule.redirect.clone())?,
digest: None,
// Download range use header range in HTTP protocol.
range: None,
r#type: TaskType::Dfdaemon as i32,
tag: header::get_tag(&reqwest_request_header),
application: header::get_application(&reqwest_request_header),
priority: header::get_priority(&reqwest_request_header),
filtered_query_params: header::get_filtered_query_params(
&reqwest_request_header,
rule.filtered_query_params.clone(),
),
request_header: reqwest_headermap_to_hashmap(&reqwest_request_header),
piece_length: header::get_piece_length(&reqwest_request_header),
output_path: None,
timeout: None,
need_back_to_source: false,
}),
})
}
// make_download_url makes a download url by the given uri.
#[instrument(skip_all)]
fn make_download_url(
@ -503,6 +471,42 @@ fn make_download_url(
Ok(http::Uri::from_parts(parts)?.to_string())
}
// make_response_headers makes the response headers.
#[instrument(skip_all)]
fn make_response_headers(
mut download_task_started_response: DownloadTaskStartedResponse,
) -> ClientResult<hyper::header::HeaderMap> {
// Insert the content range header to the resopnse header.
if let Some(range) = download_task_started_response.range.as_ref() {
download_task_started_response.response_header.insert(
reqwest::header::CONTENT_RANGE.to_string(),
format!(
"bytes {}-{}/{}",
range.start,
range.start + range.length - 1,
download_task_started_response.content_length
),
);
};
hashmap_to_hyper_header_map(&download_task_started_response.response_header)
}
// find_matching_rule returns whether the dfdaemon should be used to download the task.
// If the dfdaemon should be used, return the matched rule.
#[instrument(skip_all)]
fn find_matching_rule(rules: Option<Vec<Rule>>, url: &str) -> Option<Rule> {
rules?.iter().find(|rule| rule.regex.is_match(url)).cloned()
}
// make_error_response makes an error response with the given status and message.
#[instrument(skip_all)]
fn make_error_response(status: http::StatusCode, message: &str) -> Response {
let mut response = Response::new(full(message.as_bytes().to_vec()));
*response.status_mut() = status;
response
}
// empty returns an empty body.
#[instrument(skip_all)]
fn empty() -> BoxBody<Bytes, ClientError> {