diff --git a/src/kv/key.rs b/src/kv/key.rs index 1b4f060..ae6f364 100644 --- a/src/kv/key.rs +++ b/src/kv/key.rs @@ -17,6 +17,7 @@ use super::HexRepr; use crate::kv::codec::BytesEncoder; use crate::kv::codec::{self}; use crate::proto::kvrpcpb; +use crate::proto::kvrpcpb::KvPair; const _PROPTEST_KEY_MAX: usize = 1024 * 2; // 2 KB @@ -80,6 +81,20 @@ impl AsRef for kvrpcpb::Mutation { } } +pub struct KvPairTTL(pub KvPair, pub u64); + +impl AsRef for KvPairTTL { + fn as_ref(&self) -> &Key { + self.0.key.as_ref() + } +} + +impl From for (KvPair, u64) { + fn from(value: KvPairTTL) -> Self { + (value.0.clone(), value.1) + } +} + impl Key { /// The empty key. pub const EMPTY: Self = Key(Vec::new()); diff --git a/src/kv/mod.rs b/src/kv/mod.rs index 489110e..c48f476 100644 --- a/src/kv/mod.rs +++ b/src/kv/mod.rs @@ -11,6 +11,7 @@ mod value; pub use bound_range::BoundRange; pub use bound_range::IntoOwnedRange; pub use key::Key; +pub use key::KvPairTTL; pub use kvpair::KvPair; pub use value::Value; diff --git a/src/raw/client.rs b/src/raw/client.rs index e885b4e..da84eb9 100644 --- a/src/raw/client.rs +++ b/src/raw/client.rs @@ -880,7 +880,7 @@ mod tests { async fn test_batch_put_with_ttl() -> Result<()> { let pd_client = Arc::new(MockPdClient::new(MockKvClient::with_dispatch_hook( move |req: &dyn Any| { - if let Some(_) = req.downcast_ref::() { + if req.downcast_ref::().is_some() { let resp = kvrpcpb::RawBatchPutResponse { ..Default::default() }; @@ -898,8 +898,8 @@ mod tests { keyspace: Keyspace::Enable { keyspace_id: 0 }, }; let pairs = vec![ - KvPair(vec![11].into(), vec![12].into()), - KvPair(vec![11].into(), vec![12].into()), + KvPair(vec![11].into(), vec![12]), + KvPair(vec![11].into(), vec![12]), ]; let ttls = vec![0, 0]; assert!(client.batch_put_with_ttl(pairs, ttls).await.is_ok()); diff --git a/src/raw/requests.rs b/src/raw/requests.rs index f6e876d..df99e93 100644 --- a/src/raw/requests.rs +++ b/src/raw/requests.rs @@ -1,7 +1,6 @@ // Copyright 2019 TiKV Project Authors. Licensed under Apache-2.0. use std::any::Any; -use std::collections::HashMap; use std::ops::Range; use std::sync::Arc; use std::time::Duration; @@ -9,10 +8,12 @@ use std::time::Duration; use async_trait::async_trait; use futures::stream::BoxStream; use futures::StreamExt; + use tonic::transport::Channel; use super::RawRpcRequest; use crate::collect_single; +use crate::kv::KvPairTTL; use crate::pd::PdClient; use crate::proto::kvrpcpb; use crate::proto::metapb; @@ -198,31 +199,19 @@ impl Shardable for kvrpcpb::RawBatchPutRequest { &self, pd_client: &Arc, ) -> BoxStream<'static, Result<(Self::Shard, RegionStore)>> { - // Maintain a map of the pair and its associated ttl let kvs = self.pairs.clone(); - let kv_pair = kvs.into_iter().map(Into::::into); - let kv_ttl = kv_pair.zip(self.ttls.clone()).collect::>(); - let mut pairs = self.pairs.clone(); - pairs.sort_by(|a, b| a.key.cmp(&b.key)); - store_stream_for_keys( - pairs.into_iter().map(Into::::into), - pd_client.clone(), - ) - .map(move |r| { - let s = r.map(|(kv, store)| { - let kv_ttls = kv - .into_iter() - .map(|k: KvPair| { - let kv: kvrpcpb::KvPair = k.clone().into(); - let ttl = *kv_ttl.get(&k).unwrap(); - (kv, ttl) - }) - .collect::>(); - (kv_ttls, store) - }); - s - }) - .boxed() + let ttls = self.ttls.clone(); + let mut kv_ttl: Vec = kvs + .iter() + .zip(ttls) + .map(|(kv, ttl)| KvPairTTL(kv.clone(), ttl)) + .collect(); + kv_ttl.sort_by(|a, b| a.0.key.clone().cmp(&b.0.key)); + store_stream_for_keys(kv_ttl.into_iter(), pd_client.clone()) + .map(move |r| { + r.map(|(kv, store): (Vec<(kvrpcpb::KvPair, u64)>, RegionStore)| (kv, store)) + }) + .boxed() } fn apply_shard(&mut self, shard: Self::Shard, store: &RegionStore) -> Result<()> { @@ -580,6 +569,7 @@ impl HasLocks for kvrpcpb::RawCoprocessorResponse {} #[cfg(test)] mod test { use std::any::Any; + use std::collections::HashMap; use std::ops::Deref; use std::sync::Mutex; @@ -592,7 +582,6 @@ mod test { use crate::request::Keyspace; use crate::request::Plan; - #[rstest::rstest] #[case(Keyspace::Disable)] #[case(Keyspace::Enable { keyspace_id: 0 })] @@ -639,14 +628,11 @@ mod test { #[tokio::test] async fn test_raw_batch_put() -> Result<()> { - let region1_kvs = vec![KvPair(vec![9].into(), vec![12].into())]; + let region1_kvs = vec![KvPair(vec![9].into(), vec![12])]; let region1_ttls = vec![0]; let region2_kvs = vec![ - KvPair(vec![11].into(), vec![12].into()), - KvPair( - "FFF".to_string().as_bytes().to_vec().into(), - vec![12].into(), - ), + KvPair(vec![11].into(), vec![12]), + KvPair("FFF".to_string().as_bytes().to_vec().into(), vec![12]), ]; let region2_ttls = vec![0, 1];