diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs index 17a3452..b9c178c 100644 --- a/tests/integration_tests.rs +++ b/tests/integration_tests.rs @@ -9,7 +9,7 @@ use std::{ convert::TryInto, env, }; -use tikv_client::{Config, Key, RawClient, Result, TransactionClient, Value}; +use tikv_client::{Config, Key, RawClient, Result, Transaction, TransactionClient, Value}; /// The limit of scan in each iteration in `clear_tikv`. const SCAN_BATCH_SIZE: u32 = 1000; @@ -177,17 +177,15 @@ async fn raw_bank_transfer() -> Fallible<()> { let mut alice_balance = get_u32(&client, alice.clone()).await?; let bob = chosen_people[1]; let mut bob_balance = get_u32(&client, bob.clone()).await?; - if alice_balance > bob_balance { - let transfer = rng.gen_range(0, alice_balance); - alice_balance -= transfer; - bob_balance += transfer; - client - .put(alice.clone(), alice_balance.to_be_bytes().to_vec()) - .await?; - client - .put(bob.clone(), bob_balance.to_be_bytes().to_vec()) - .await?; - } + let transfer = rng.gen_range(0, alice_balance); + alice_balance -= transfer; + bob_balance += transfer; + client + .put(alice.clone(), alice_balance.to_be_bytes().to_vec()) + .await?; + client + .put(bob.clone(), bob_balance.to_be_bytes().to_vec()) + .await?; } // check @@ -199,6 +197,53 @@ async fn raw_bank_transfer() -> Fallible<()> { Fallible::Ok(()) } +#[tokio::test] +#[serial] +async fn txn_bank_transfer() -> Fallible<()> { + clear_tikv().await?; + let config = Config::new(pd_addrs()); + let client = TransactionClient::new(config).await?; + let mut rng = thread_rng(); + + let people = gen_u32_keys(NUM_PEOPLE, &mut rng); + let mut txn = client.begin().await?; + let mut sum: u32 = 0; + for person in &people { + let init = rng.gen::() as u32; + sum += init as u32; + txn.set(person.clone(), init.to_be_bytes().to_vec()).await?; + } + txn.commit().await?; + + // transfer + for _ in 0..NUM_TRNASFER { + let mut txn = client.begin().await?; + let chosen_people = people.iter().choose_multiple(&mut rng, 2); + let alice = chosen_people[0]; + let mut alice_balance = get_txn_u32(&txn, alice.clone()).await?; + let bob = chosen_people[1]; + let mut bob_balance = get_txn_u32(&txn, bob.clone()).await?; + let transfer = rng.gen_range(0, alice_balance); + alice_balance -= transfer; + bob_balance += transfer; + txn.set(alice.clone(), alice_balance.to_be_bytes().to_vec()) + .await?; + txn.set(bob.clone(), bob_balance.to_be_bytes().to_vec()) + .await?; + txn.commit().await?; + } + + // check + let mut new_sum = 0; + let mut txn = client.begin().await?; + for person in people.iter() { + new_sum += get_txn_u32(&txn, person.clone()).await?; + } + assert_eq!(sum, new_sum); + txn.commit().await?; + Fallible::Ok(()) +} + #[tokio::test] #[serial] async fn raw() -> Fallible<()> { @@ -339,6 +384,16 @@ async fn get_u32(client: &RawClient, key: Vec) -> Fallible { Fallible::Ok(u32::from_be_bytes(*array)) } +// helper function +async fn get_txn_u32(txn: &Transaction, key: Vec) -> Fallible { + let x = txn.get(key).await?.unwrap(); + let boxed_slice = x.into_boxed_slice(); + let array: Box<[u8; 4]> = boxed_slice + .try_into() + .expect("Value should not exceed u32 (4 * u8)"); + Fallible::Ok(u32::from_be_bytes(*array)) +} + // helper function fn gen_u32_keys(num: u32, rng: &mut impl Rng) -> HashSet> { let mut set = HashSet::new();