diff --git a/src/transaction/buffer.rs b/src/transaction/buffer.rs index 5c8f897..513cd1a 100644 --- a/src/transaction/buffer.rs +++ b/src/transaction/buffer.rs @@ -1,58 +1,36 @@ // Copyright 2019 TiKV Project Authors. Licensed under Apache-2.0. use crate::{BoundRange, Key, KvPair, Result, Value}; -use derive_new::new; use std::{ collections::{btree_map::Entry, BTreeMap, HashMap}, future::Future, }; use tikv_client_proto::kvrpcpb; -use tokio::sync::{Mutex, MutexGuard}; - -#[derive(new)] -struct InnerBuffer { - #[new(default)] - primary_key: Option, - #[new(default)] - entry_map: BTreeMap, - is_pessimistic: bool, -} - -impl InnerBuffer { - fn insert(&mut self, key: impl Into, entry: BufferEntry) { - let key = key.into(); - if !matches!(entry, BufferEntry::Cached(_) | BufferEntry::CheckNotExist) { - self.primary_key.get_or_insert_with(|| key.clone()); - } - self.entry_map.insert(key, entry); - } - - /// Set the primary key if it is not set - pub fn primary_key_or(&mut self, key: &Key) { - self.primary_key.get_or_insert(key.clone()); - } -} /// A caching layer which buffers reads and writes in a transaction. pub struct Buffer { - inner: Mutex, + primary_key: Option, + entry_map: BTreeMap, + is_pessimistic: bool, } impl Buffer { pub fn new(is_pessimistic: bool) -> Buffer { Buffer { - inner: Mutex::new(InnerBuffer::new(is_pessimistic)), + primary_key: None, + entry_map: BTreeMap::new(), + is_pessimistic, } } /// Get the primary key of the buffer. pub async fn get_primary_key(&self) -> Option { - self.inner.lock().await.primary_key.clone() + self.primary_key.clone() } /// Set the primary key if it is not set - pub async fn primary_key_or(&self, key: &Key) { - self.inner.lock().await.primary_key_or(key); + pub async fn primary_key_or(&mut self, key: &Key) { + self.primary_key.get_or_insert_with(|| key.clone()); } /// Get a value from the buffer. @@ -66,7 +44,7 @@ impl Buffer { /// Get a value from the buffer. If the value is not present, run `f` to get /// the value. - pub async fn get_or_else(&self, key: Key, f: F) -> Result> + pub async fn get_or_else(&mut self, key: Key, f: F) -> Result> where F: FnOnce(Key) -> Fut, Fut: Future>>, @@ -75,8 +53,7 @@ impl Buffer { MutationValue::Determined(value) => Ok(value), MutationValue::Undetermined => { let value = f(key.clone()).await?; - let mut mutations = self.inner.lock().await; - Self::update_cache(&mut mutations, key, value.clone()); + self.update_cache(key, value.clone()); Ok(value) } } @@ -87,7 +64,7 @@ impl Buffer { /// /// only used for snapshot read (i.e. not for `batch_get_for_update`) pub async fn batch_get_or_else( - &self, + &mut self, keys: impl Iterator, f: F, ) -> Result> @@ -96,7 +73,6 @@ impl Buffer { Fut: Future>>, { let (cached_results, undetermined_keys) = { - let mutations = self.inner.lock().await; // Partition the keys into those we have buffered and those we have to // get from the store. let (undetermined_keys, cached_results): ( @@ -104,7 +80,7 @@ impl Buffer { Vec<(Key, MutationValue)>, ) = keys .map(|key| { - let value = mutations + let value = self .entry_map .get(&key) .map(BufferEntry::get_value) @@ -122,11 +98,10 @@ impl Buffer { }; let fetched_results = f(Box::new(undetermined_keys)).await?; - let mut mutations = self.inner.lock().await; for kvpair in &fetched_results { let key = kvpair.0.clone(); let value = Some(kvpair.1.clone()); - Self::update_cache(&mut mutations, key, value); + self.update_cache(key, value); } let results = cached_results.chain(fetched_results.into_iter()); @@ -135,7 +110,7 @@ impl Buffer { /// Run `f` to fetch entries in `range` from TiKV. Combine them with mutations in local buffer. Returns the results. pub async fn scan_and_fetch( - &self, + &mut self, range: BoundRange, limit: u32, f: F, @@ -145,8 +120,7 @@ impl Buffer { Fut: Future>>, { // read from local buffer - let mut mutations = self.inner.lock().await; - let mutation_range = mutations.entry_map.range(range.clone()); + let mutation_range = self.entry_map.range(range.clone()); // fetch from TiKV // fetch more entries because some of them may be deleted. @@ -177,7 +151,7 @@ impl Buffer { // update local buffer for (k, v) in &results { - Self::update_cache(&mut mutations, k.clone(), Some(v.clone())); + self.update_cache(k.clone(), Some(v.clone())); } let mut res = results @@ -190,10 +164,9 @@ impl Buffer { } /// Lock the given key if necessary. - pub async fn lock(&self, key: Key) { - let mutations = &mut self.inner.lock().await; - mutations.primary_key.get_or_insert_with(|| key.clone()); - let value = mutations + pub async fn lock(&mut self, key: Key) { + self.primary_key.get_or_insert_with(|| key.clone()); + let value = self .entry_map .entry(key) // Mutated keys don't need a lock. @@ -205,27 +178,25 @@ impl Buffer { } /// Insert a value into the buffer (does not write through). - pub async fn put(&self, key: Key, value: Value) { - self.inner.lock().await.insert(key, BufferEntry::Put(value)); + pub async fn put(&mut self, key: Key, value: Value) { + self.insert_entry(key, BufferEntry::Put(value)); } /// Mark a value as Insert mutation into the buffer (does not write through). - pub async fn insert(&self, key: Key, value: Value) { - let mut mutations = self.inner.lock().await; - let mut entry = mutations.entry_map.entry(key.clone()); + pub async fn insert(&mut self, key: Key, value: Value) { + let mut entry = self.entry_map.entry(key.clone()); match entry { Entry::Occupied(ref mut o) if matches!(o.get(), BufferEntry::Del) => { o.insert(BufferEntry::Put(value)); } - _ => mutations.insert(key, BufferEntry::Insert(value)), + _ => self.insert_entry(key, BufferEntry::Insert(value)), } } /// Mark a value as deleted. - pub async fn delete(&self, key: Key) { - let mut mutations = self.inner.lock().await; - let is_pessimistic = mutations.is_pessimistic; - let mut entry = mutations.entry_map.entry(key.clone()); + pub async fn delete(&mut self, key: Key) { + let is_pessimistic = self.is_pessimistic; + let mut entry = self.entry_map.entry(key.clone()); match entry { Entry::Occupied(ref mut o) @@ -233,40 +204,32 @@ impl Buffer { { o.insert(BufferEntry::CheckNotExist); } - _ => mutations.insert(key, BufferEntry::Del), + _ => self.insert_entry(key, BufferEntry::Del), } } /// Converts the buffered mutations to the proto buffer version pub async fn to_proto_mutations(&self) -> Vec { - self.inner - .lock() - .await - .entry_map + self.entry_map .iter() .filter_map(|(key, mutation)| mutation.to_proto_with_key(key)) .collect() } async fn get_from_mutations(&self, key: &Key) -> MutationValue { - self.inner - .lock() - .await - .entry_map + self.entry_map .get(&key) .map(BufferEntry::get_value) .unwrap_or(MutationValue::Undetermined) } - fn update_cache(buffer: &mut MutexGuard, key: Key, value: Option) { - match buffer.entry_map.get(&key) { + fn update_cache(&mut self, key: Key, value: Option) { + match self.entry_map.get(&key) { Some(BufferEntry::Locked(None)) => { - buffer - .entry_map - .insert(key, BufferEntry::Locked(Some(value))); + self.entry_map.insert(key, BufferEntry::Locked(Some(value))); } None => { - buffer.entry_map.insert(key, BufferEntry::Cached(value)); + self.entry_map.insert(key, BufferEntry::Cached(value)); } Some(BufferEntry::Cached(v)) | Some(BufferEntry::Locked(Some(v))) => { assert!(&value == v); @@ -285,6 +248,14 @@ impl Buffer { } } } + + fn insert_entry(&mut self, key: impl Into, entry: BufferEntry) { + let key = key.into(); + if !matches!(entry, BufferEntry::Cached(_) | BufferEntry::CheckNotExist) { + self.primary_key.get_or_insert_with(|| key.clone()); + } + self.entry_map.insert(key, entry); + } } // The state of a key-value pair in the buffer. @@ -388,7 +359,7 @@ mod tests { #[tokio::test] #[allow(unreachable_code)] async fn set_and_get_from_buffer() { - let buffer = Buffer::new(false); + let mut buffer = Buffer::new(false); buffer .put(b"key1".to_vec().into(), b"value1".to_vec()) .await; @@ -421,7 +392,7 @@ mod tests { #[tokio::test] #[allow(unreachable_code)] async fn insert_and_get_from_buffer() { - let buffer = Buffer::new(false); + let mut buffer = Buffer::new(false); buffer .insert(b"key1".to_vec().into(), b"value1".to_vec()) .await; @@ -463,13 +434,13 @@ mod tests { let v2: Value = b"value2".to_vec(); let v2_ = v2.clone(); - let buffer = Buffer::new(false); + let mut buffer = Buffer::new(false); let r1 = block_on(buffer.get_or_else(k1.clone(), move |_| ready(Ok(Some(v1_))))); let r2 = block_on(buffer.get_or_else(k1.clone(), move |_| ready(panic!()))); assert_eq!(r1.unwrap().unwrap(), v1); assert_eq!(r2.unwrap().unwrap(), v1); - let buffer = Buffer::new(false); + let mut buffer = Buffer::new(false); let r1 = block_on( buffer.batch_get_or_else(vec![k1.clone(), k2.clone()].into_iter(), move |_| { ready(Ok(vec![(k1_, v1__).into(), (k2_, v2_).into()])) diff --git a/src/transaction/snapshot.rs b/src/transaction/snapshot.rs index 28711c4..3267ccd 100644 --- a/src/transaction/snapshot.rs +++ b/src/transaction/snapshot.rs @@ -19,18 +19,18 @@ pub struct Snapshot { impl Snapshot { /// Get the value associated with the given key. - pub async fn get(&self, key: impl Into) -> Result> { + pub async fn get(&mut self, key: impl Into) -> Result> { self.transaction.get(key).await } /// Check whether the key exists. - pub async fn key_exists(&self, key: impl Into) -> Result { + pub async fn key_exists(&mut self, key: impl Into) -> Result { self.transaction.key_exists(key).await } /// Get the values associated with the given keys. pub async fn batch_get( - &self, + &mut self, keys: impl IntoIterator>, ) -> Result> { self.transaction.batch_get(keys).await @@ -38,7 +38,7 @@ impl Snapshot { /// Scan a range, return at most `limit` key-value pairs that lying in the range. pub async fn scan( - &self, + &mut self, range: impl Into, limit: u32, ) -> Result> { @@ -47,7 +47,7 @@ impl Snapshot { /// Scan a range, return at most `limit` keys that lying in the range. pub async fn scan_keys( - &self, + &mut self, range: impl Into, limit: u32, ) -> Result> { @@ -56,7 +56,7 @@ impl Snapshot { /// Unimplemented. Similar to scan, but in the reverse direction. #[allow(dead_code)] - fn scan_reverse(&self, range: impl RangeBounds) -> BoxStream> { + fn scan_reverse(&mut self, range: impl RangeBounds) -> BoxStream> { self.transaction.scan_reverse(range) } } diff --git a/src/transaction/transaction.rs b/src/transaction/transaction.rs index 71801fa..496bed4 100644 --- a/src/transaction/transaction.rs +++ b/src/transaction/transaction.rs @@ -93,7 +93,7 @@ impl Transaction { /// txn.commit().await.unwrap(); /// # }); /// ``` - pub async fn get(&self, key: impl Into) -> Result> { + pub async fn get(&mut self, key: impl Into) -> Result> { self.check_allow_operation().await?; let timestamp = self.timestamp.clone(); let rpc = self.rpc.clone(); @@ -184,7 +184,7 @@ impl Transaction { /// txn.commit().await.unwrap(); /// # }); /// ``` - pub async fn key_exists(&self, key: impl Into) -> Result { + pub async fn key_exists(&mut self, key: impl Into) -> Result { let key = key.into(); Ok(self.scan_keys(key.clone()..=key, 1).await?.next().is_some()) } @@ -216,7 +216,7 @@ impl Transaction { /// # }); /// ``` pub async fn batch_get( - &self, + &mut self, keys: impl IntoIterator>, ) -> Result> { self.check_allow_operation().await?; @@ -307,7 +307,7 @@ impl Transaction { /// # }); /// ``` pub async fn scan( - &self, + &mut self, range: impl Into, limit: u32, ) -> Result> { @@ -341,7 +341,7 @@ impl Transaction { /// # }); /// ``` pub async fn scan_keys( - &self, + &mut self, range: impl Into, limit: u32, ) -> Result> { @@ -602,7 +602,7 @@ impl Transaction { } async fn scan_inner( - &self, + &mut self, range: impl Into, limit: u32, key_only: bool, diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs index 5888aef..0b1d757 100644 --- a/tests/integration_tests.rs +++ b/tests/integration_tests.rs @@ -103,7 +103,7 @@ async fn crud() -> Result<()> { txn.commit().await?; // Read again from TiKV - let snapshot = client.snapshot( + let mut snapshot = client.snapshot( client.current_timestamp().await?, // TODO needed because pessimistic does not check locks (#235) TransactionOptions::new_optimistic(), @@ -339,9 +339,9 @@ async fn txn_bank_transfer() -> Result<()> { .await?; let chosen_people = people.iter().choose_multiple(&mut rng, 2); let alice = chosen_people[0]; - let mut alice_balance = get_txn_u32(&txn, alice.clone()).await?; + let mut alice_balance = get_txn_u32(&mut txn, alice.clone()).await?; let bob = chosen_people[1]; - let mut bob_balance = get_txn_u32(&txn, bob.clone()).await?; + let mut bob_balance = get_txn_u32(&mut txn, bob.clone()).await?; if alice_balance == 0 { txn.rollback().await?; continue; @@ -360,7 +360,7 @@ async fn txn_bank_transfer() -> Result<()> { let mut new_sum = 0; let mut txn = client.begin_optimistic().await?; for person in people.iter() { - new_sum += get_txn_u32(&txn, person.clone()).await?; + new_sum += get_txn_u32(&mut txn, person.clone()).await?; } assert_eq!(sum, new_sum); txn.commit().await?; @@ -827,7 +827,7 @@ async fn get_u32(client: &RawClient, key: Vec) -> Result { } // helper function -async fn get_txn_u32(txn: &Transaction, key: Vec) -> Result { +async fn get_txn_u32(txn: &mut Transaction, key: Vec) -> Result { let x = txn.get(key).await?.unwrap(); let boxed_slice = x.into_boxed_slice(); let array: Box<[u8; 4]> = boxed_slice