diff --git a/src/rpc/client.rs b/src/rpc/client.rs index 02984ae..6b2df52 100644 --- a/src/rpc/client.rs +++ b/src/rpc/client.rs @@ -226,7 +226,7 @@ impl RpcClient { } pub fn get_timestamp(self: Arc) -> impl Future> { - Arc::clone(&self.pd).get_timestamp() + self.pd.clone().get_timestamp() } // Returns a Steam which iterates over the contexts for each region covered by range. diff --git a/src/rpc/pd/timestamp.rs b/src/rpc/pd/timestamp.rs index 5d4015d..2b38392 100644 --- a/src/rpc/pd/timestamp.rs +++ b/src/rpc/pd/timestamp.rs @@ -2,6 +2,14 @@ //! This module is the low-level mechanisms for getting timestamps from a PD //! cluster. It should be used via the `get_timestamp` API in `PdClient`. +//! +//! Once a `TimestampOracle` is created, there will be two futures running in a background working +//! thread created automatically. The `get_timestamp` method creates a oneshot channel whose +//! transmitter is served as a `TimestampRequest`. `TimestampRequest`s are sent to the working +//! thread through a bounded multi-producer, single-consumer channel. The first future tries to +//! exhaust the channel to get as many requests as possible and sends a single `TsoRequest` to the +//! PD server. The other future receives `TsoResponse`s from the PD server and allocates timestamps +//! for the requests. use super::Timestamp; use crate::{Error, Result}; @@ -12,123 +20,122 @@ use futures::{ executor::block_on, join, pin_mut, prelude::*, - task::{Context, Poll, Waker}, + task::{AtomicWaker, Context, Poll}, }; -use grpcio::{ClientDuplexReceiver, ClientDuplexSender, WriteFlags}; +use grpcio::WriteFlags; use kvproto::pdpb::{PdClient, *}; use std::{cell::RefCell, collections::VecDeque, pin::Pin, rc::Rc, thread}; -const MAX_PENDING_COUNT: usize = 64; +/// It is an empirical value. +const MAX_BATCH_SIZE: usize = 64; + +/// TODO: This value should be adjustable. +const MAX_PENDING_COUNT: usize = 1 << 16; + +type TimestampRequest = oneshot::Sender; /// The timestamp oracle (TSO) which provides monotonically increasing timestamps. #[derive(Clone)] pub struct TimestampOracle { - /// The transmitter of a bounded channel which transports the sender of a oneshot channel to - /// the TSO working thread. - /// In the working thread, the `oneshot::Sender` is used to send back timestamp results. - result_sender_tx: mpsc::Sender>, + /// The transmitter of a bounded channel which transports requests of getting a single + /// timestamp to the TSO working thread. A bounded channel is used to prevent using + /// too much memory unexpectedly. + /// In the working thread, the `TimestampRequest`, which is actually a one channel sender, + /// is used to send back the timestamp result. + request_tx: mpsc::Sender, } impl TimestampOracle { pub fn new(cluster_id: u64, pd_client: &PdClient) -> Result { - let (result_sender_tx, result_sender_rx) = mpsc::channel(MAX_PENDING_COUNT); + let (request_tx, request_rx) = mpsc::channel(MAX_BATCH_SIZE); + // FIXME: use tso_opt + let (rpc_sender, rpc_receiver) = pd_client.tso()?; // Start a background thread to handle TSO requests and responses - let worker = TsoWorker::new(cluster_id, pd_client, result_sender_rx)?; - thread::spawn(move || block_on(worker.run())); + thread::spawn(move || { + block_on(run_tso( + cluster_id, + rpc_sender.sink_compat().sink_err_into(), + rpc_receiver.compat().err_into(), + request_rx, + )) + }); - Ok(TimestampOracle { result_sender_tx }) + Ok(TimestampOracle { request_tx }) } pub async fn get_timestamp(mut self) -> Result { - let (result_sender, result_receiver) = oneshot::channel(); - self.result_sender_tx - .send(result_sender) + let (request, response) = oneshot::channel(); + self.request_tx + .send(request) .await - .map_err(|_| Error::internal_error("Result sender channel is closed"))?; - Ok(result_receiver.await?) + .map_err(|_| Error::internal_error("TimestampRequest channel is closed"))?; + Ok(response.await?) } } -struct TsoWorker { +async fn run_tso( cluster_id: u64, - result_sender_rx: mpsc::Receiver>, - rpc_sender: Compat01As03Sink, (TsoRequest, WriteFlags)>, - rpc_receiver: Compat01As03>, -} + mut rpc_sender: impl Sink<(TsoRequest, WriteFlags), Error = Error> + Unpin, + mut rpc_receiver: impl Stream> + Unpin, + request_rx: mpsc::Receiver, +) { + // The `TimestampRequest`s which are waiting for the responses from the PD server + let pending_requests = Rc::new(RefCell::new(VecDeque::with_capacity(MAX_PENDING_COUNT))); -impl TsoWorker { - fn new( - cluster_id: u64, - pd_client: &PdClient, - result_sender_rx: mpsc::Receiver>, - ) -> Result { - // FIXME: use tso_opt - let (rpc_sender, rpc_receiver) = pd_client.tso()?; - Ok(TsoWorker { - cluster_id, - result_sender_rx, - rpc_sender: rpc_sender.sink_compat(), - rpc_receiver: rpc_receiver.compat(), - }) - } + // When there are too many pending requests, the `send_request` future will refuse to fetch + // more requests from the bounded channel. This waker is used to wake up the sending future + // if the queue containing pending requests is no longer full. + let sending_future_waker = Rc::new(AtomicWaker::new()); - async fn run(mut self) { - let ctx = Rc::new(RefCell::new(TsoContext { - pending_queue: VecDeque::with_capacity(MAX_PENDING_COUNT), - waker: None, - })); + pin_mut!(request_rx); + let mut request_stream = TsoRequestStream { + cluster_id, + request_rx, + pending_requests: pending_requests.clone(), + self_waker: sending_future_waker.clone(), + }; - let result_sender_rx = self.result_sender_rx; - pin_mut!(result_sender_rx); - let mut request_stream = TsoRequestStream { - cluster_id: self.cluster_id, - result_sender_rx, - ctx: ctx.clone(), - }; + let send_requests = rpc_sender.send_all(&mut request_stream); - let send_requests = self.rpc_sender.send_all(&mut request_stream); + let receive_and_handle_responses = async move { + while let Some(Ok(resp)) = rpc_receiver.next().await { + let mut pending_requests = pending_requests.borrow_mut(); - let mut rpc_receiver = self.rpc_receiver; - let receive_and_handle_responses = async move { - while let Some(Ok(resp)) = rpc_receiver.next().await { - let mut ctx = ctx.borrow_mut(); - ctx.allocate_timestamps(&resp)?; - if let Some(waker) = &ctx.waker { - waker.wake_by_ref(); - } + // Wake up the sending future blocked by too many pending requests as we are consuming + // some of them here. + if pending_requests.len() == MAX_PENDING_COUNT { + sending_future_waker.wake(); } - Err(Error::internal_error("TSO stream terminated")) - }; - let _: (_, Result<()>) = join!(send_requests, receive_and_handle_responses); - } + allocate_timestamps(&resp, &mut pending_requests)?; + } + Err(Error::internal_error("TSO stream terminated")) + }; + + let (send_res, recv_res): (_, Result<()>) = join!(send_requests, receive_and_handle_responses); + error!("TSO send error: {:?}", send_res); + error!("TSO receive error: {:?}", recv_res); } struct TsoRequestStream<'a> { cluster_id: u64, - result_sender_rx: Pin<&'a mut mpsc::Receiver>>, - ctx: Rc>, + request_rx: Pin<&'a mut mpsc::Receiver>>, + pending_requests: Rc>>, + self_waker: Rc, } impl<'a> Stream for TsoRequestStream<'a> { type Item = (TsoRequest, WriteFlags); fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - let ctx = self.ctx.clone(); - let mut ctx = ctx.borrow_mut(); - - // Set the waker to the context, then the stream can be waked up after the pending queue - // is no longer full. - if ctx.waker.is_none() { - ctx.waker = Some(cx.waker().clone()); - } - + let pending_requests = self.pending_requests.clone(); + let mut pending_requests = pending_requests.borrow_mut(); let mut count = 0; - while ctx.pending_queue.len() < MAX_PENDING_COUNT { - match self.result_sender_rx.as_mut().poll_next(cx) { + while count < MAX_BATCH_SIZE && pending_requests.len() < MAX_PENDING_COUNT { + match self.request_rx.as_mut().poll_next(cx) { Poll::Ready(Some(sender)) => { - ctx.pending_queue.push_back(sender); + pending_requests.push_back(sender); count += 1; } Poll::Ready(None) => return Poll::Ready(None), @@ -141,47 +148,47 @@ impl<'a> Stream for TsoRequestStream<'a> { header: Some(RequestHeader { cluster_id: self.cluster_id, }), - count, + count: count as u32, }; let write_flags = WriteFlags::default().buffer_hint(false); Poll::Ready(Some((req, write_flags))) } else { + // Set the waker to the context, then the stream can be waked up after the pending queue + // is no longer full. + self.self_waker.register(cx.waker()); Poll::Pending } } } -struct TsoContext { - pending_queue: VecDeque>, - waker: Option, -} +fn allocate_timestamps( + resp: &TsoResponse, + pending_requests: &mut VecDeque, +) -> Result<()> { + // PD returns the timestamp with the biggest logical value. We can send back timestamps + // whose logical value is from `logical - count + 1` to `logical` using the senders + // in `pending`. + let tail_ts = resp + .timestamp + .as_ref() + .ok_or_else(|| Error::internal_error("No timestamp in TsoResponse"))?; -impl TsoContext { - fn allocate_timestamps(&mut self, resp: &TsoResponse) -> Result<()> { - // PD returns the timestamp with the biggest logical value. We can send back timestamps - // whose logical value is from `logical - count + 1` to `logical` using the senders - // in `pending`. - let tail_ts = resp - .timestamp - .as_ref() - .ok_or_else(|| Error::internal_error("No timestamp in TsoResponse"))?; - let mut offset = i64::from(resp.count); - while offset > 0 { - offset -= 1; - if let Some(sender) = self.pending_queue.pop_front() { - let ts = Timestamp { - physical: tail_ts.physical, - logical: tail_ts.logical - offset, - }; + let mut offset = i64::from(resp.count); + while offset > 0 { + offset -= 1; + if let Some(sender) = pending_requests.pop_front() { + let ts = Timestamp { + physical: tail_ts.physical, + logical: tail_ts.logical - offset, + }; - // It doesn't matter if the receiver of the channel is dropped. - let _ = sender.send(ts); - } else { - return Err(Error::internal_error( - "PD gives more timestamps than expected", - )); - } + // It doesn't matter if the receiver of the channel is dropped. + let _ = sender.send(ts); + } else { + return Err(Error::internal_error( + "PD gives more timestamps than expected", + )); } - Ok(()) } + Ok(()) }