// Copyright 2020 TiKV Project Authors. Licensed under Apache-2.0. //! A utility module for managing and retrying PD requests. use crate::{ region::{RegionId, RegionWithLeader, StoreId}, stats::pd_stats, Error, Result, SecurityManager, }; use async_trait::async_trait; use futures_timer::Delay; use grpcio::Environment; use std::{ fmt, sync::Arc, time::{Duration, Instant}, }; use tikv_client_pd::{Cluster, Connection}; use tikv_client_proto::{ metapb, pdpb::{self, Timestamp}, }; use tokio::sync::RwLock; // FIXME: these numbers and how they are used are all just cargo-culted in, there // may be more optimal values. const RECONNECT_INTERVAL_SEC: u64 = 1; const MAX_REQUEST_COUNT: usize = 5; const LEADER_CHANGE_RETRY: usize = 10; #[async_trait] pub trait RetryClientTrait { // These get_* functions will try multiple times to make a request, reconnecting as necessary. // It does not know about encoding. Caller should take care of it. async fn get_region(self: Arc, key: Vec) -> Result; async fn get_region_by_id(self: Arc, region_id: RegionId) -> Result; async fn get_store(self: Arc, id: StoreId) -> Result; async fn get_all_stores(self: Arc) -> Result>; async fn get_timestamp(self: Arc) -> Result; async fn update_safepoint(self: Arc, safepoint: u64) -> Result; } /// Client for communication with a PD cluster. Has the facility to reconnect to the cluster. pub struct RetryClient { // Tuple is the cluster and the time of the cluster's last reconnect. cluster: RwLock<(Cl, Instant)>, connection: Connection, timeout: Duration, } #[cfg(test)] impl RetryClient { pub fn new_with_cluster( env: Arc, security_mgr: Arc, timeout: Duration, cluster: Cl, ) -> RetryClient { let connection = Connection::new(env, security_mgr); RetryClient { cluster: RwLock::new((cluster, Instant::now())), connection, timeout, } } } macro_rules! retry { ($self: ident, $tag: literal, |$cluster: ident| $call: expr) => {{ let stats = pd_stats($tag); let mut last_err = Ok(()); for _ in 0..LEADER_CHANGE_RETRY { // use the block here to drop the guard of the read lock, // otherwise `reconnect` will try to acquire the write lock and results in a deadlock let res = { let $cluster = &$self.cluster.read().await.0; let res = $call.await; res }; match stats.done(res) { Ok(r) => return Ok(r), Err(e) => last_err = Err(e), } let mut reconnect_count = MAX_REQUEST_COUNT; while let Err(e) = $self.reconnect(RECONNECT_INTERVAL_SEC).await { reconnect_count -= 1; if reconnect_count == 0 { return Err(e); } Delay::new(Duration::from_secs(RECONNECT_INTERVAL_SEC)).await; } } last_err?; unreachable!(); }}; } impl RetryClient { pub async fn connect( env: Arc, endpoints: &[String], security_mgr: Arc, timeout: Duration, ) -> Result { let connection = Connection::new(env, security_mgr); let cluster = RwLock::new(( connection.connect_cluster(endpoints, timeout).await?, Instant::now(), )); Ok(RetryClient { cluster, connection, timeout, }) } } #[async_trait] impl RetryClientTrait for RetryClient { // These get_* functions will try multiple times to make a request, reconnecting as necessary. // It does not know about encoding. Caller should take care of it. async fn get_region(self: Arc, key: Vec) -> Result { retry!(self, "get_region", |cluster| { let key = key.clone(); async { cluster .get_region(key.clone(), self.timeout) .await .and_then(|resp| { region_from_response(resp, || Error::RegionForKeyNotFound { key }) }) } }) } async fn get_region_by_id(self: Arc, region_id: RegionId) -> Result { retry!(self, "get_region_by_id", |cluster| async { cluster .get_region_by_id(region_id, self.timeout) .await .and_then(|resp| { region_from_response(resp, || Error::RegionNotFoundInResponse { region_id }) }) }) } async fn get_store(self: Arc, id: StoreId) -> Result { retry!(self, "get_store", |cluster| async { cluster .get_store(id, self.timeout) .await .map(|mut resp| resp.take_store()) }) } #[allow(dead_code)] async fn get_all_stores(self: Arc) -> Result> { retry!(self, "get_all_stores", |cluster| async { cluster .get_all_stores(self.timeout) .await .map(|mut resp| resp.take_stores().into_iter().map(Into::into).collect()) }) } async fn get_timestamp(self: Arc) -> Result { retry!(self, "get_timestamp", |cluster| cluster.get_timestamp()) } async fn update_safepoint(self: Arc, safepoint: u64) -> Result { retry!(self, "update_gc_safepoint", |cluster| async { cluster .update_safepoint(safepoint, self.timeout) .await .map(|resp| resp.get_new_safe_point() == safepoint) }) } } impl fmt::Debug for RetryClient { fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { fmt.debug_struct("pd::RetryClient") .field("timeout", &self.timeout) .finish() } } fn region_from_response( mut resp: pdpb::GetRegionResponse, err: impl FnOnce() -> Error, ) -> Result { let region = resp.region.take().ok_or_else(err)?; Ok(RegionWithLeader::new(region, resp.leader.take())) } // A node-like thing that can be connected to. #[async_trait] trait Reconnect { type Cl; async fn reconnect(&self, interval_sec: u64) -> Result<()>; } #[async_trait] impl Reconnect for RetryClient { type Cl = Cluster; async fn reconnect(&self, interval_sec: u64) -> Result<()> { let reconnect_begin = Instant::now(); let mut lock = self.cluster.write().await; let (cluster, last_connected) = &mut *lock; // If `last_connected + interval_sec` is larger or equal than reconnect_begin, // a concurrent reconnect is just succeed when this thread trying to get write lock let should_connect = reconnect_begin > *last_connected + Duration::from_secs(interval_sec); if should_connect { self.connection.reconnect(cluster, self.timeout).await?; *last_connected = Instant::now(); } Ok(()) } } #[cfg(test)] mod test { use super::*; use futures::{executor, future::ready}; use std::sync::{ atomic::{AtomicUsize, Ordering}, Mutex, }; use tikv_client_common::internal_err; #[test] fn test_reconnect() { struct MockClient { reconnect_count: AtomicUsize, cluster: RwLock<((), Instant)>, } #[async_trait] impl Reconnect for MockClient { type Cl = (); async fn reconnect(&self, _: u64) -> Result<()> { self.reconnect_count .fetch_add(1, std::sync::atomic::Ordering::SeqCst); // Not actually unimplemented, we just don't care about the error. Err(Error::Unimplemented) } } async fn retry_err(client: Arc) -> Result<()> { retry!(client, "test", |_c| ready(Err(internal_err!("whoops")))) } async fn retry_ok(client: Arc) -> Result<()> { retry!(client, "test", |_c| ready(Ok::<_, Error>(()))) } executor::block_on(async { let client = Arc::new(MockClient { reconnect_count: AtomicUsize::new(0), cluster: RwLock::new(((), Instant::now())), }); assert!(retry_err(client.clone()).await.is_err()); assert_eq!( client .reconnect_count .load(std::sync::atomic::Ordering::SeqCst), MAX_REQUEST_COUNT ); client .reconnect_count .store(0, std::sync::atomic::Ordering::SeqCst); assert!(retry_ok(client.clone()).await.is_ok()); assert_eq!( client .reconnect_count .load(std::sync::atomic::Ordering::SeqCst), 0 ); }) } #[test] fn test_retry() { struct MockClient { cluster: RwLock<(AtomicUsize, Instant)>, } #[async_trait] impl Reconnect for MockClient { type Cl = Mutex; async fn reconnect(&self, _: u64) -> Result<()> { Ok(()) } } async fn retry_max_err( client: Arc, max_retries: Arc, ) -> Result<()> { retry!(client, "test", |c| { c.fetch_add(1, std::sync::atomic::Ordering::SeqCst); let max_retries = max_retries.fetch_sub(1, Ordering::SeqCst) - 1; if max_retries == 0 { ready(Ok(())) } else { ready(Err(internal_err!("whoops"))) } }) } async fn retry_max_ok( client: Arc, max_retries: Arc, ) -> Result<()> { retry!(client, "test", |c| { c.fetch_add(1, std::sync::atomic::Ordering::SeqCst); let max_retries = max_retries.fetch_sub(1, Ordering::SeqCst) - 1; if max_retries == 0 { ready(Ok(())) } else { ready(Err(internal_err!("whoops"))) } }) } executor::block_on(async { let client = Arc::new(MockClient { cluster: RwLock::new((AtomicUsize::new(0), Instant::now())), }); let max_retries = Arc::new(AtomicUsize::new(1000)); assert!(retry_max_err(client.clone(), max_retries).await.is_err()); assert_eq!( client.cluster.read().await.0.load(Ordering::SeqCst), LEADER_CHANGE_RETRY ); let client = Arc::new(MockClient { cluster: RwLock::new((AtomicUsize::new(0), Instant::now())), }); let max_retries = Arc::new(AtomicUsize::new(2)); assert!(retry_max_ok(client.clone(), max_retries).await.is_ok()); assert_eq!(client.cluster.read().await.0.load(Ordering::SeqCst), 2); }) } }