diff --git a/src/transaction/mod.rs b/src/transaction/mod.rs index 423ac3b..568aa00 100644 --- a/src/transaction/mod.rs +++ b/src/transaction/mod.rs @@ -23,11 +23,11 @@ mod requests; #[allow(clippy::module_inception)] mod transaction; +#[derive(Debug, Clone)] pub enum Mutation { Put(Value), Del, Lock, - Rollback, } impl Mutation { @@ -43,8 +43,16 @@ impl Mutation { } Mutation::Del => pb.set_op(kvrpcpb::Op::Del), Mutation::Lock => pb.set_op(kvrpcpb::Op::Lock), - Mutation::Rollback => pb.set_op(kvrpcpb::Op::Rollback), }; pb } + + /// Returns a `Some` if the value can be determined by this mutation. Otherwise, returns `None`. + fn get_value(&self) -> Option { + match self { + Mutation::Put(value) => Some(value.clone()), + Mutation::Del => Some(Value::default()), + Mutation::Lock => None, + } + } } diff --git a/src/transaction/transaction.rs b/src/transaction/transaction.rs index cf12630..0e93f86 100644 --- a/src/transaction/transaction.rs +++ b/src/transaction/transaction.rs @@ -51,8 +51,13 @@ impl Transaction { /// txn.commit().await.unwrap(); /// # }); /// ``` - pub async fn get(&self, _key: impl Into) -> Result { - unimplemented!() + pub async fn get(&self, key: impl Into) -> Result { + let key = key.into(); + if let Some(value) = self.get_from_mutations(&key) { + Ok(value) + } else { + self.snapshot.get(key).await + } } /// Gets the values associated with the given keys. @@ -74,9 +79,38 @@ impl Transaction { /// ``` pub async fn batch_get( &self, - _keys: impl IntoIterator>, + keys: impl IntoIterator>, ) -> Result> { - unimplemented!() + let mut result = Vec::new(); + let mut keys_from_snapshot = Vec::new(); + let mut result_indices_from_snapshot = Vec::new(); + + // Try to fill the result vector from mutations + for key in keys { + let key = key.into(); + if let Some(value) = self.get_from_mutations(&key) { + result.push((key, value).into()); + } else { + keys_from_snapshot.push(key); + result_indices_from_snapshot.push(result.len()); + // Push a placeholder + result.push(KvPair::default()); + } + } + + // Get others from snapshot + let kv_pairs_from_snapshot = self + .snapshot + .batch_get(keys_from_snapshot.into_iter()) + .await?; + for (kv_pair, index) in kv_pairs_from_snapshot + .into_iter() + .zip(result_indices_from_snapshot) + { + result[index] = kv_pair; + } + + Ok(result) } pub fn scan(&self, _range: impl RangeBounds) -> Scanner { @@ -145,8 +179,11 @@ impl Transaction { /// # }); /// ``` pub fn lock_keys(&mut self, keys: impl IntoIterator>) { - self.mutations - .extend(keys.into_iter().map(|key| (key.into(), Mutation::Lock))); + for key in keys { + let key = key.into(); + // Mutated keys don't need a lock. + self.mutations.entry(key).or_insert(Mutation::Lock); + } } /// Commits the actions of the transaction. @@ -209,6 +246,12 @@ impl Transaction { } async fn prewrite(&mut self) -> Result<()> { + // TODO: Too many clones. Consider using bytes::Byte. + let _rpc_mutations: Vec<_> = self + .mutations + .iter() + .map(|(k, v)| v.clone().with_key(k.clone())) + .collect(); unimplemented!() } @@ -219,6 +262,10 @@ impl Transaction { async fn commit_secondary(&mut self) -> Result<()> { unimplemented!() } + + fn get_from_mutations(&self, key: &Key) -> Option { + self.mutations.get(key).and_then(Mutation::get_value) + } } pub struct TxnInfo {