diff --git a/src/main/java/org/tikv/RawKVClient.java b/src/main/java/org/tikv/RawKVClient.java index f2ea759a3c..dbaee4deec 100644 --- a/src/main/java/org/tikv/RawKVClient.java +++ b/src/main/java/org/tikv/RawKVClient.java @@ -1,9 +1,7 @@ package org.tikv; import com.google.protobuf.ByteString; -import java.util.ArrayList; -import java.util.Iterator; -import java.util.List; +import java.util.*; import org.tikv.kvproto.Kvrpcpb; import org.tikv.kvproto.Metapb; import org.tikv.operation.iterator.RawScanIterator; @@ -48,6 +46,34 @@ public class RawKVClient { client.rawPut(defaultBackOff(), key, value); } + /** + * Put a list of raw key-value pair to TiKV + * + * @param kvPairs kvPairs + */ + public void batchPut(List kvPairs) { + Map, List> regionMap = new HashMap<>(); + for (Kvrpcpb.KvPair kvPair : kvPairs) { + Pair pair = regionManager.getRegionStorePairByRawKey(kvPair.getKey()); + regionMap.computeIfAbsent(pair, t -> new ArrayList<>()).add(kvPair); + } + + List remainingPairs = new ArrayList<>(); + + for (Map.Entry, List> entry : + regionMap.entrySet()) { + RegionStoreClient client = + RegionStoreClient.create(entry.getKey().first, entry.getKey().second, session); + if (!client.rawBatchPut(defaultBackOff(), entry.getValue())) { + remainingPairs.addAll(entry.getValue()); + } + } + if (!remainingPairs.isEmpty()) { + // re-splitting ranges + batchPut(remainingPairs); + } + } + /** * Get a raw key-value pair from TiKV if key exists * diff --git a/src/main/java/org/tikv/region/RegionStoreClient.java b/src/main/java/org/tikv/region/RegionStoreClient.java index c215aaa2cd..4964d4904a 100644 --- a/src/main/java/org/tikv/region/RegionStoreClient.java +++ b/src/main/java/org/tikv/region/RegionStoreClient.java @@ -30,12 +30,15 @@ import org.apache.log4j.Logger; import org.tikv.AbstractGRPCClient; import org.tikv.TiSession; import org.tikv.exception.*; +import org.tikv.kvproto.Errorpb; import org.tikv.kvproto.Kvrpcpb.BatchGetRequest; import org.tikv.kvproto.Kvrpcpb.BatchGetResponse; import org.tikv.kvproto.Kvrpcpb.Context; import org.tikv.kvproto.Kvrpcpb.GetRequest; import org.tikv.kvproto.Kvrpcpb.GetResponse; import org.tikv.kvproto.Kvrpcpb.KvPair; +import org.tikv.kvproto.Kvrpcpb.RawBatchPutRequest; +import org.tikv.kvproto.Kvrpcpb.RawBatchPutResponse; import org.tikv.kvproto.Kvrpcpb.RawDeleteRequest; import org.tikv.kvproto.Kvrpcpb.RawDeleteResponse; import org.tikv.kvproto.Kvrpcpb.RawGetRequest; @@ -310,6 +313,37 @@ public class RegionStoreClient extends AbstractGRPCClient kvPairs) { + if (kvPairs.isEmpty()) { + return true; + } + Supplier factory = + () -> + RawBatchPutRequest.newBuilder() + .setContext(region.getContext()) + .addAllPairs(kvPairs) + .build(); + KVErrorHandler handler = + new KVErrorHandler<>( + regionManager, + this, + region, + resp -> resp.hasRegionError() ? resp.getRegionError() : null); + RawBatchPutResponse resp = + callWithRetry(backOffer, TikvGrpc.METHOD_RAW_BATCH_PUT, factory, handler); + return handleRawBatchPut(resp, backOffer); + } + + private boolean handleRawBatchPut(RawBatchPutResponse resp, BackOffer backOffer) { + if (resp.hasRegionError()) { + Errorpb.Error regionError = resp.getRegionError(); + logger.warn( + "Re-splitting RawBatchPutRequest due to region error:" + regionError.getMessage()); + return false; + } + return true; + } + /** * Return a batch KvPair list containing limited key-value pairs starting from `key`, which are in * the same region diff --git a/src/test/java/org/tikv/RawKVClientTest.java b/src/test/java/org/tikv/RawKVClientTest.java index 5e642511dd..552536acd6 100644 --- a/src/test/java/org/tikv/RawKVClientTest.java +++ b/src/test/java/org/tikv/RawKVClientTest.java @@ -102,17 +102,24 @@ public class RawKVClientTest { @Test public void validate() { if (!initialized) return; - baseTest(100, 100, 100, 100, false); + baseTest(100, 100, 100, 100, false, false); + baseTest(100, 100, 100, 100, false, true); } /** Example of benchmarking base test */ public void benchmark() { if (!initialized) return; - baseTest(TEST_CASES, TEST_CASES, 200, 5000, true); + baseTest(TEST_CASES, TEST_CASES, 200, 5000, true, false); + baseTest(TEST_CASES, TEST_CASES, 200, 5000, true, true); } private void baseTest( - int putCases, int getCases, int scanCases, int deleteCases, boolean benchmark) { + int putCases, + int getCases, + int scanCases, + int deleteCases, + boolean benchmark, + boolean batchPut) { if (putCases > KEY_POOL_SIZE) { System.out.println( "Number of distinct orderedKeys required exceeded pool size " + KEY_POOL_SIZE); @@ -126,7 +133,11 @@ public class RawKVClientTest { prepare(); - rawPutTest(putCases, benchmark); + if (batchPut) { + rawBatchPutTest(putCases, benchmark); + } else { + rawPutTest(putCases, benchmark); + } rawGetTest(getCases, benchmark); rawScanTest(scanCases, benchmark); rawDeleteTest(deleteCases, benchmark); @@ -209,6 +220,51 @@ public class RawKVClientTest { } } + private void rawBatchPutTest(int putCases, boolean benchmark) { + System.out.println("put testing"); + if (benchmark) { + for (int i = 0; i < putCases; i++) { + ByteString key = orderedKeys.get(i), value = values.get(i); + data.put(key, value); + } + + long start = System.currentTimeMillis(); + int base = putCases / WORKER_CNT; + for (int cnt = 0; cnt < WORKER_CNT; cnt++) { + int i = cnt; + completionService.submit( + () -> { + List list = new ArrayList<>(); + for (int j = 0; j < base; j++) { + int num = i * base + j; + ByteString key = orderedKeys.get(num), value = values.get(num); + list.add(Kvrpcpb.KvPair.newBuilder().setKey(key).setValue(value).build()); + } + client.batchPut(list); + return null; + }); + } + awaitTimeOut(100); + long end = System.currentTimeMillis(); + System.out.println( + putCases + + " put: " + + (end - start) / 1000.0 + + "s workers=" + + WORKER_CNT + + " put=" + + rawKeys().size()); + } else { + List list = new ArrayList<>(); + for (int i = 0; i < putCases; i++) { + ByteString key = randomKeys.get(i), value = values.get(r.nextInt(KEY_POOL_SIZE)); + data.put(key, value); + list.add(Kvrpcpb.KvPair.newBuilder().setKey(key).setValue(value).build()); + } + checkBatchPut(list); + } + } + private void rawGetTest(int getCases, boolean benchmark) { System.out.println("get testing"); if (benchmark) { @@ -317,6 +373,13 @@ public class RawKVClientTest { assert client.get(key).equals(value); } + private void checkBatchPut(List pairs) { + client.batchPut(pairs); + for (Kvrpcpb.KvPair pair : pairs) { + assert client.get(pair.getKey()).equals(pair.getValue()); + } + } + private void checkScan(ByteString startKey, ByteString endKey, List ans) { List result = client.scan(startKey, endKey); assert result.equals(ans); diff --git a/src/test/java/org/tikv/txn/LockResolverTest.java b/src/test/java/org/tikv/txn/LockResolverTest.java index 40af722aef..2029402456 100644 --- a/src/test/java/org/tikv/txn/LockResolverTest.java +++ b/src/test/java/org/tikv/txn/LockResolverTest.java @@ -205,8 +205,7 @@ public class LockResolverTest { resp -> resp.hasRegionError() ? resp.getRegionError() : null); CommitResponse resp = - client.callWithRetry( - backOffer, TikvGrpc.METHOD_KV_COMMIT, factory, handler); + client.callWithRetry(backOffer, TikvGrpc.METHOD_KV_COMMIT, factory, handler); if (resp.hasRegionError()) { throw new RegionException(resp.getRegionError());