diff --git a/src/transaction/buffer.rs b/src/transaction/buffer.rs index 1d7f61b..1287a28 100644 --- a/src/transaction/buffer.rs +++ b/src/transaction/buffer.rs @@ -6,7 +6,6 @@ use std::{ future::Future, }; use tikv_client_proto::kvrpcpb; -use tokio::sync::{Mutex, MutexGuard}; #[derive(Default)] struct InnerBuffer { @@ -31,17 +30,17 @@ impl InnerBuffer { /// A caching layer which buffers reads and writes in a transaction. #[derive(Default)] pub struct Buffer { - mutations: Mutex, + inner: InnerBuffer, } impl Buffer { /// Get the primary key of the buffer. pub async fn get_primary_key(&self) -> Option { - self.mutations.lock().await.primary_key.clone() + self.inner.primary_key.clone() } /// Get the primary key of the buffer, if not exists, use `key` as the primary key. - pub async fn get_primary_key_or(&self, key: &Key) -> Key { - self.mutations.lock().await.get_primary_key_or(key).clone() + pub async fn get_primary_key_or(&mut self, key: &Key) -> Key { + self.inner.get_primary_key_or(key).clone() } /// Get a value from the buffer. @@ -55,7 +54,7 @@ impl Buffer { /// Get a value from the buffer. If the value is not present, run `f` to get /// the value. - pub async fn get_or_else(&self, key: Key, f: F) -> Result> + pub async fn get_or_else(&mut self, key: Key, f: F) -> Result> where F: FnOnce(Key) -> Fut, Fut: Future>>, @@ -64,8 +63,7 @@ impl Buffer { MutationValue::Determined(value) => Ok(value), MutationValue::Undetermined => { let value = f(key.clone()).await?; - let mut mutations = self.mutations.lock().await; - Self::update_cache(&mut mutations, key, value.clone()); + Self::update_cache(&mut self.inner, key, value.clone()); Ok(value) } } @@ -76,7 +74,7 @@ impl Buffer { /// /// only used for snapshot read (i.e. not for `batch_get_for_update`) pub async fn batch_get_or_else( - &self, + &mut self, keys: impl Iterator, f: F, ) -> Result> @@ -85,7 +83,6 @@ impl Buffer { Fut: Future>>, { let (cached_results, undetermined_keys) = { - let mutations = self.mutations.lock().await; // Partition the keys into those we have buffered and those we have to // get from the store. let (undetermined_keys, cached_results): ( @@ -93,7 +90,8 @@ impl Buffer { Vec<(Key, MutationValue)>, ) = keys .map(|key| { - let value = mutations + let value = self + .inner .entry_map .get(&key) .map(BufferEntry::get_value) @@ -111,11 +109,10 @@ impl Buffer { }; let fetched_results = f(Box::new(undetermined_keys)).await?; - let mut mutations = self.mutations.lock().await; for kvpair in &fetched_results { let key = kvpair.0.clone(); let value = Some(kvpair.1.clone()); - Self::update_cache(&mut mutations, key, value); + Self::update_cache(&mut self.inner, key, value); } let results = cached_results.chain(fetched_results.into_iter()); @@ -124,7 +121,7 @@ impl Buffer { /// Run `f` to fetch entries in `range` from TiKV. Combine them with mutations in local buffer. Returns the results. pub async fn scan_and_fetch( - &self, + &mut self, range: BoundRange, limit: u32, f: F, @@ -134,8 +131,7 @@ impl Buffer { Fut: Future>>, { // read from local buffer - let mut mutations = self.mutations.lock().await; - let mutation_range = mutations.entry_map.range(range.clone()); + let mutation_range = self.inner.entry_map.range(range.clone()); // fetch from TiKV // fetch more entries because some of them may be deleted. @@ -166,7 +162,7 @@ impl Buffer { // update local buffer for (k, v) in &results { - Self::update_cache(&mut mutations, k.clone(), Some(v.clone())); + Self::update_cache(&mut self.inner, k.clone(), Some(v.clone())); } let mut res = results @@ -179,10 +175,10 @@ impl Buffer { } /// Lock the given key if necessary. - pub async fn lock(&self, key: Key) { - let mutations = &mut self.mutations.lock().await; - mutations.primary_key.get_or_insert(key.clone()); - let value = mutations + pub async fn lock(&mut self, key: Key) { + self.inner.primary_key.get_or_insert(key.clone()); + let value = self + .inner .entry_map .entry(key) // Mutated keys don't need a lock. @@ -194,25 +190,19 @@ impl Buffer { } /// Insert a value into the buffer (does not write through). - pub async fn put(&self, key: Key, value: Value) { - self.mutations - .lock() - .await - .insert(key, BufferEntry::Put(value)); + pub async fn put(&mut self, key: Key, value: Value) { + self.inner.insert(key, BufferEntry::Put(value)); } /// Mark a value as Insert mutation into the buffer (does not write through). - pub async fn insert(&self, key: Key, value: Value) { - self.mutations - .lock() - .await - .insert(key, BufferEntry::Insert(value)); + pub async fn insert(&mut self, key: Key, value: Value) { + self.inner.insert(key, BufferEntry::Insert(value)); } /// Mark a value as deleted. - pub async fn delete(&self, key: Key) { - let mut mutations = self.mutations.lock().await; - let value = mutations + pub async fn delete(&mut self, key: Key) { + let value = self + .inner .entry_map .entry(key.clone()) .or_insert(BufferEntry::Del); @@ -224,14 +214,12 @@ impl Buffer { new_value = BufferEntry::Del } - mutations.insert(key, new_value); + self.inner.insert(key, new_value); } /// Converts the buffered mutations to the proto buffer version pub async fn to_proto_mutations(&self) -> Vec { - self.mutations - .lock() - .await + self.inner .entry_map .iter() .filter_map(|(key, mutation)| mutation.to_proto_with_key(key)) @@ -239,16 +227,14 @@ impl Buffer { } async fn get_from_mutations(&self, key: &Key) -> MutationValue { - self.mutations - .lock() - .await + self.inner .entry_map .get(&key) .map(BufferEntry::get_value) .unwrap_or(MutationValue::Undetermined) } - fn update_cache(buffer: &mut MutexGuard, key: Key, value: Option) { + fn update_cache(buffer: &mut InnerBuffer, key: Key, value: Option) { match buffer.entry_map.get(&key) { Some(BufferEntry::Locked(None)) => { buffer @@ -378,7 +364,7 @@ mod tests { #[tokio::test] #[allow(unreachable_code)] async fn set_and_get_from_buffer() { - let buffer = Buffer::default(); + let mut buffer = Buffer::default(); buffer .put(b"key1".to_vec().into(), b"value1".to_vec()) .await; @@ -411,7 +397,7 @@ mod tests { #[tokio::test] #[allow(unreachable_code)] async fn insert_and_get_from_buffer() { - let buffer = Buffer::default(); + let mut buffer = Buffer::default(); buffer .insert(b"key1".to_vec().into(), b"value1".to_vec()) .await; @@ -453,13 +439,13 @@ mod tests { let v2: Value = b"value2".to_vec(); let v2_ = v2.clone(); - let buffer = Buffer::default(); + let mut buffer = Buffer::default(); let r1 = block_on(buffer.get_or_else(k1.clone(), move |_| ready(Ok(Some(v1_))))); let r2 = block_on(buffer.get_or_else(k1.clone(), move |_| ready(panic!()))); assert_eq!(r1.unwrap().unwrap(), v1); assert_eq!(r2.unwrap().unwrap(), v1); - let buffer = Buffer::default(); + let mut buffer = Buffer::default(); let r1 = block_on( buffer.batch_get_or_else(vec![k1.clone(), k2.clone()].into_iter(), move |_| { ready(Ok(vec![(k1_, v1__).into(), (k2_, v2_).into()])) diff --git a/src/transaction/snapshot.rs b/src/transaction/snapshot.rs index 28711c4..3267ccd 100644 --- a/src/transaction/snapshot.rs +++ b/src/transaction/snapshot.rs @@ -19,18 +19,18 @@ pub struct Snapshot { impl Snapshot { /// Get the value associated with the given key. - pub async fn get(&self, key: impl Into) -> Result> { + pub async fn get(&mut self, key: impl Into) -> Result> { self.transaction.get(key).await } /// Check whether the key exists. - pub async fn key_exists(&self, key: impl Into) -> Result { + pub async fn key_exists(&mut self, key: impl Into) -> Result { self.transaction.key_exists(key).await } /// Get the values associated with the given keys. pub async fn batch_get( - &self, + &mut self, keys: impl IntoIterator>, ) -> Result> { self.transaction.batch_get(keys).await @@ -38,7 +38,7 @@ impl Snapshot { /// Scan a range, return at most `limit` key-value pairs that lying in the range. pub async fn scan( - &self, + &mut self, range: impl Into, limit: u32, ) -> Result> { @@ -47,7 +47,7 @@ impl Snapshot { /// Scan a range, return at most `limit` keys that lying in the range. pub async fn scan_keys( - &self, + &mut self, range: impl Into, limit: u32, ) -> Result> { @@ -56,7 +56,7 @@ impl Snapshot { /// Unimplemented. Similar to scan, but in the reverse direction. #[allow(dead_code)] - fn scan_reverse(&self, range: impl RangeBounds) -> BoxStream> { + fn scan_reverse(&mut self, range: impl RangeBounds) -> BoxStream> { self.transaction.scan_reverse(range) } } diff --git a/src/transaction/transaction.rs b/src/transaction/transaction.rs index 1fe8be4..3b5fb97 100644 --- a/src/transaction/transaction.rs +++ b/src/transaction/transaction.rs @@ -93,7 +93,7 @@ impl Transaction { /// txn.commit().await.unwrap(); /// # }); /// ``` - pub async fn get(&self, key: impl Into) -> Result> { + pub async fn get(&mut self, key: impl Into) -> Result> { self.check_allow_operation().await?; let timestamp = self.timestamp.clone(); let rpc = self.rpc.clone(); @@ -170,7 +170,7 @@ impl Transaction { /// txn.commit().await.unwrap(); /// # }); /// ``` - pub async fn key_exists(&self, key: impl Into) -> Result { + pub async fn key_exists(&mut self, key: impl Into) -> Result { let key = key.into(); Ok(self.scan_keys(key.clone()..=key, 1).await?.next().is_some()) } @@ -202,7 +202,7 @@ impl Transaction { /// # }); /// ``` pub async fn batch_get( - &self, + &mut self, keys: impl IntoIterator>, ) -> Result> { self.check_allow_operation().await?; @@ -299,7 +299,7 @@ impl Transaction { /// # }); /// ``` pub async fn scan( - &self, + &mut self, range: impl Into, limit: u32, ) -> Result> { @@ -333,7 +333,7 @@ impl Transaction { /// # }); /// ``` pub async fn scan_keys( - &self, + &mut self, range: impl Into, limit: u32, ) -> Result> { @@ -589,7 +589,7 @@ impl Transaction { } async fn scan_inner( - &self, + &mut self, range: impl Into, limit: u32, key_only: bool, diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs index 74e7c9b..7f37f9f 100644 --- a/tests/integration_tests.rs +++ b/tests/integration_tests.rs @@ -102,7 +102,7 @@ async fn crud() -> Result<()> { txn.commit().await?; // Read again from TiKV - let snapshot = client.snapshot( + let mut snapshot = client.snapshot( client.current_timestamp().await?, // TODO needed because pessimistic does not check locks (#235) TransactionOptions::new_optimistic(), @@ -341,9 +341,9 @@ async fn txn_bank_transfer() -> Result<()> { .await?; let chosen_people = people.iter().choose_multiple(&mut rng, 2); let alice = chosen_people[0]; - let mut alice_balance = get_txn_u32(&txn, alice.clone()).await?; + let mut alice_balance = get_txn_u32(&mut txn, alice.clone()).await?; let bob = chosen_people[1]; - let mut bob_balance = get_txn_u32(&txn, bob.clone()).await?; + let mut bob_balance = get_txn_u32(&mut txn, bob.clone()).await?; if alice_balance == 0 { txn.rollback().await?; continue; @@ -362,7 +362,7 @@ async fn txn_bank_transfer() -> Result<()> { let mut new_sum = 0; let mut txn = client.begin_optimistic().await?; for person in people.iter() { - new_sum += get_txn_u32(&txn, person.clone()).await?; + new_sum += get_txn_u32(&mut txn, person.clone()).await?; } assert_eq!(sum, new_sum); txn.commit().await?; @@ -716,7 +716,7 @@ async fn get_u32(client: &RawClient, key: Vec) -> Result { } // helper function -async fn get_txn_u32(txn: &Transaction, key: Vec) -> Result { +async fn get_txn_u32(txn: &mut Transaction, key: Vec) -> Result { let x = txn.get(key).await?.unwrap(); let boxed_slice = x.into_boxed_slice(); let array: Box<[u8; 4]> = boxed_slice