From a601c39c1fc1e0dc68439c6146977a6645a3f81b Mon Sep 17 00:00:00 2001 From: Yilin Chen Date: Wed, 14 Aug 2019 13:14:29 +0800 Subject: [PATCH] batch_get returns an iterator Signed-off-by: Yilin Chen --- src/transaction/mod.rs | 73 ++++++++++++++++++++++++++---------------- 1 file changed, 45 insertions(+), 28 deletions(-) diff --git a/src/transaction/mod.rs b/src/transaction/mod.rs index 26a58eb..dc1f851 100644 --- a/src/transaction/mod.rs +++ b/src/transaction/mod.rs @@ -15,10 +15,7 @@ pub use self::requests::Scanner; use crate::{Key, Result, Value}; use derive_new::new; use kvproto::kvrpcpb; -use std::{ - collections::{BTreeMap, HashMap}, - ops::RangeBounds, -}; +use std::{collections::BTreeMap, ops::RangeBounds}; mod client; pub(crate) mod requests; @@ -119,8 +116,8 @@ impl Transaction { } } - /// Gets the values associated with the given keys. Non-existent keys are not included in the - /// result. + /// Gets the values associated with the given keys. The returned iterator is in the same order + /// as the given keys. /// /// ```rust,no_run /// # #![feature(async_await)] @@ -132,7 +129,11 @@ impl Transaction { /// # let connected_client = connecting_client.await.unwrap(); /// let mut txn = connected_client.begin().await.unwrap(); /// let keys = vec!["TiKV".to_owned(), "TiDB".to_owned()]; - /// let result: HashMap = txn.batch_get(keys).await.unwrap(); + /// let result: HashMap = txn + /// .batch_get(keys) + /// .await + /// .unwrap() + /// .filter_map(|(k, v)| v.map(move |v| (k, v))).collect(); /// // Finish the transaction... /// txn.commit().await.unwrap(); /// # }); @@ -140,25 +141,37 @@ impl Transaction { pub async fn batch_get( &self, keys: impl IntoIterator>, - ) -> Result> { - let mut result = HashMap::new(); + ) -> Result)>> { let mut keys_from_snapshot = Vec::new(); - - // Try to get the result from buffered mutations first - for key in keys { - let key = key.into(); - match self.get_from_mutations(&key) { - MutationValue::Determined(Some(value)) => { - result.insert(key, value); + let mut results_in_buffer = keys + .into_iter() + .map(|key| { + let key = key.into(); + let mutation_value = self.get_from_mutations(&key); + if let MutationValue::Undetermined = mutation_value { + keys_from_snapshot.push(key.clone()); } - MutationValue::Determined(None) => {} - MutationValue::Undetermined => keys_from_snapshot.push(key), + (key, mutation_value) + }) + .collect::>() + .into_iter(); + let mut results_from_snapshot = self + .snapshot + .batch_get(keys_from_snapshot) + .await? + .peekable(); + Ok(std::iter::from_fn(move || { + let (key, mutation_value) = results_in_buffer.next()?; + match mutation_value { + MutationValue::Determined(value) => Some((key, value)), + MutationValue::Undetermined => match results_from_snapshot.peek() { + Some((key_from_snapshot, _)) if &key == key_from_snapshot => { + results_from_snapshot.next() + } + _ => Some((key, None)), + }, } - } - - // Get others from snapshot - result.extend(self.snapshot.batch_get(keys_from_snapshot).await?); - Ok(result) + })) } pub fn scan(&self, _range: impl RangeBounds) -> Scanner { @@ -349,8 +362,8 @@ impl Snapshot { unimplemented!() } - /// Gets the values associated with the given keys. Non-existent keys are not included in the - /// result. + /// Gets the values associated with the given keys. The returned iterator is in the same order + /// as the given keys. /// /// ```rust,no_run /// # #![feature(async_await)] @@ -362,14 +375,18 @@ impl Snapshot { /// # let connected_client = connecting_client.await.unwrap(); /// let snapshot = connected_client.snapshot().await.unwrap(); /// let keys = vec!["TiKV".to_owned(), "TiDB".to_owned()]; - /// let result: HashMap = snapshot.batch_get(keys).await.unwrap(); + /// let result: HashMap = snapshot + /// .batch_get(keys) + /// .await + /// .unwrap() + /// .filter_map(|(k, v)| v.map(move |v| (k, v))).collect(); /// # }); /// ``` pub async fn batch_get( &self, _keys: impl IntoIterator>, - ) -> Result> { - unimplemented!() + ) -> Result)>> { + Ok(std::iter::repeat_with(|| unimplemented!())) } pub fn scan(&self, range: impl RangeBounds) -> Scanner {