diff --git a/examples/pessimistic.rs b/examples/pessimistic.rs index f4e8c44..ce14394 100644 --- a/examples/pessimistic.rs +++ b/examples/pessimistic.rs @@ -71,4 +71,5 @@ async fn main() { .await .expect("Committing read-only transaction should not fail"); println!("{:?}", (key1, result)); + txn3.commit().await.unwrap(); } diff --git a/src/transaction/buffer.rs b/src/transaction/buffer.rs index 265b2a9..390a6d4 100644 --- a/src/transaction/buffer.rs +++ b/src/transaction/buffer.rs @@ -8,13 +8,38 @@ use std::{ use tikv_client_proto::kvrpcpb; use tokio::sync::Mutex; +#[derive(Default)] +struct Mutations { + primary_key: Option, + key_mutation_map: BTreeMap, +} + +impl Mutations { + fn insert(&mut self, key: impl Into, mutation: Mutation) { + let key = key.into(); + if self.primary_key.is_none() { + self.primary_key = Some(key.clone()); + } + self.key_mutation_map.insert(key, mutation); + } + + pub fn get_primary_key_or(&mut self, key: Key) -> Key { + self.primary_key.get_or_insert(key).clone() + } +} + /// A caching layer which buffers reads and writes in a transaction. #[derive(Default)] pub struct Buffer { - mutations: Mutex>, + mutations: Mutex, } impl Buffer { + /// Get the primary key of the buffer, if not exists, use `key` as the primary key. + pub async fn get_primary_key_or(&self, key: Key) -> Key { + self.mutations.lock().await.get_primary_key_or(key) + } + /// Get a value from the buffer. If the value is not present, run `f` to get /// the value. pub async fn get_or_else(&self, key: Key, f: F) -> Result> @@ -54,6 +79,7 @@ impl Buffer { ) = keys .map(|key| { let value = mutations + .key_mutation_map .get(&key) .map(Mutation::get_value) .unwrap_or(MutationValue::Undetermined); @@ -92,7 +118,7 @@ impl Buffer { { // read from local buffer let mut mutations = self.mutations.lock().await; - let mutation_range = mutations.range(range.clone()); + let mutation_range = mutations.key_mutation_map.range(range.clone()); // fetch from TiKV // fetch more entries because some of them may be deleted. @@ -137,8 +163,10 @@ impl Buffer { /// Lock the given key if necessary. pub async fn lock(&self, key: Key) { - let mut mutations = self.mutations.lock().await; + let mutations = &mut self.mutations.lock().await; + mutations.primary_key.get_or_insert(key.clone()); let value = mutations + .key_mutation_map .entry(key) // Mutated keys don't need a lock. .or_insert(Mutation::Lock); @@ -163,10 +191,14 @@ impl Buffer { /// Converts the buffered mutations to the proto buffer version pub async fn to_proto_mutations(&self) -> Vec { - self.mutations - .lock() - .await + let mutations = self.mutations.lock().await; + let (primary, other) = mutations + .key_mutation_map .iter() + .partition::, _>(|(key, _)| *key == mutations.primary_key.as_ref().unwrap()); + primary + .into_iter() + .chain(other.into_iter()) .filter_map(|(key, mutation)| mutation.to_proto_with_key(key)) .collect() } @@ -175,6 +207,7 @@ impl Buffer { self.mutations .lock() .await + .key_mutation_map .get(&key) .map(Mutation::get_value) .unwrap_or(MutationValue::Undetermined) @@ -269,7 +302,10 @@ mod tests { )) .unwrap() .collect::>(), - vec![KvPair(Key::from(b"key1".to_vec()), b"value".to_vec()),] + vec![KvPair( + Key::from(b"key1".to_vec()), + Value::from(b"value".to_vec()), + ),] ); } diff --git a/src/transaction/transaction.rs b/src/transaction/transaction.rs index bd448db..75eb599 100644 --- a/src/transaction/transaction.rs +++ b/src/transaction/transaction.rs @@ -468,7 +468,7 @@ impl Transaction { /// Pessimistically lock the keys. /// - /// Once resovled it acquires a lock on the key in TiKV. + /// Once resolved it acquires a lock on the key in TiKV. /// The lock prevents other transactions from mutating the entry until it is released. /// /// Only valid for pessimistic transactions, panics if called on an optimistic transaction. @@ -481,19 +481,19 @@ impl Transaction { "`pessimistic_lock` is only valid to use with pessimistic transactions" ); - let mut keys: Vec> = keys + let keys: Vec> = keys .into_iter() .map(|it| it.into()) .map(|it: Key| it.into()) .collect(); - keys.sort(); - let primary_lock = keys[0].clone(); + let first_key = keys[0].clone(); + let primary_lock = self.buffer.get_primary_key_or(first_key.into()).await; let lock_ttl = DEFAULT_LOCK_TTL; let for_update_ts = self.rpc.clone().get_timestamp().await.unwrap().version(); self.options.push_for_update_ts(for_update_ts); new_pessimistic_lock_request( keys, - primary_lock.into(), + primary_lock, self.timestamp.version(), lock_ttl, for_update_ts,