diff --git a/src/main/java/org/tikv/common/region/RegionStoreClient.java b/src/main/java/org/tikv/common/region/RegionStoreClient.java index ba742c872b..22607b2bdb 100644 --- a/src/main/java/org/tikv/common/region/RegionStoreClient.java +++ b/src/main/java/org/tikv/common/region/RegionStoreClient.java @@ -357,7 +357,7 @@ public class RegionStoreClient extends AbstractRegionStoreClient { this, lockResolverClient, resp -> resp.hasRegionError() ? resp.getRegionError() : null, - resp -> null, + resp -> resp.hasError() ? resp.getError() : null, resolveLockResult -> addResolvedLocks(version, resolveLockResult.getResolvedLocks()), version, forWrite); @@ -366,13 +366,14 @@ public class RegionStoreClient extends AbstractRegionStoreClient { // we need to update region after retry region = regionManager.getRegionByKey(startKey, backOffer); - if (isScanSuccess(backOffer, resp)) { - return doScan(resp); + if (handleScanResponse(backOffer, resp, version, forWrite)) { + return resp.getPairsList(); } } } - private boolean isScanSuccess(BackOffer backOffer, ScanResponse resp) { + private boolean handleScanResponse( + BackOffer backOffer, ScanResponse resp, long version, boolean forWrite) { if (resp == null) { this.regionManager.onRequestFail(region); throw new TiClientInternalException("ScanResponse failed without a cause"); @@ -381,28 +382,35 @@ public class RegionStoreClient extends AbstractRegionStoreClient { backOffer.doBackOff(BoRegionMiss, new RegionException(resp.getRegionError())); return false; } - return true; - } - // TODO: resolve locks after scan - private List doScan(ScanResponse resp) { - // Check if kvPair contains error, it should be a Lock if hasError is true. - List kvPairs = resp.getPairsList(); - List newKvPairs = new ArrayList<>(); - for (KvPair kvPair : kvPairs) { + // Resolve locks + // Note: Memory lock conflict is returned by both `ScanResponse.error` & + // `ScanResponse.pairs[0].error`, while other key errors are returned by + // `ScanResponse.pairs.error` + // See https://github.com/pingcap/kvproto/pull/697 + List locks = new ArrayList<>(); + for (KvPair kvPair : resp.getPairsList()) { if (kvPair.hasError()) { Lock lock = AbstractLockResolverClient.extractLockFromKeyErr(kvPair.getError(), codec); - newKvPairs.add( - KvPair.newBuilder() - .setError(kvPair.getError()) - .setValue(kvPair.getValue()) - .setKey(lock.getKey()) - .build()); - } else { - newKvPairs.add(codec.decodeKvPair(kvPair)); + locks.add(lock); } } - return Collections.unmodifiableList(newKvPairs); + if (!locks.isEmpty()) { + ResolveLockResult resolveLockResult = + lockResolverClient.resolveLocks(backOffer, version, locks, forWrite); + addResolvedLocks(version, resolveLockResult.getResolvedLocks()); + + long msBeforeExpired = resolveLockResult.getMsBeforeTxnExpired(); + if (msBeforeExpired > 0) { + // if not resolve all locks, we wait and retry + backOffer.doBackOffWithMaxSleep( + BoTxnLockFast, msBeforeExpired, new KeyException(locks.toString())); + } + + return false; + } + + return true; } public List scan(BackOffer backOffer, ByteString startKey, long version) { diff --git a/src/test/java/org/tikv/common/KVMockServer.java b/src/test/java/org/tikv/common/KVMockServer.java index 69d8a55ee0..ea09270cfc 100644 --- a/src/test/java/org/tikv/common/KVMockServer.java +++ b/src/test/java/org/tikv/common/KVMockServer.java @@ -46,6 +46,7 @@ import java.util.stream.Collectors; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.tikv.common.key.Key; +import org.tikv.common.meta.TiTimestamp; import org.tikv.common.region.TiRegion; import org.tikv.kvproto.Coprocessor; import org.tikv.kvproto.Errorpb; @@ -67,6 +68,10 @@ public class KVMockServer extends TikvGrpc.TikvImplBase { private final Map> keyErrMap = new HashMap<>(); + private final Map> lockMap = new HashMap<>(); + private final Map> txnStatusMap = + new HashMap<>(); + // for KV error public static final int ABORT = 1; public static final int RETRY = 2; @@ -117,9 +122,68 @@ public class KVMockServer extends TikvGrpc.TikvImplBase { regionErrMap.put(toRawKey(key.getBytes(StandardCharsets.UTF_8)), builder); } + public void removeError(String key) { + regionErrMap.remove(toRawKey(key.getBytes(StandardCharsets.UTF_8))); + } + + // putWithLock is used to "prewrite" key-value without "commit" + public void putWithLock( + ByteString key, ByteString value, ByteString primaryKey, Long startTs, Long ttl) { + put(key, value); + + Kvrpcpb.LockInfo.Builder lock = + Kvrpcpb.LockInfo.newBuilder() + .setPrimaryLock(primaryKey) + .setLockVersion(startTs) + .setKey(key) + .setLockTtl(ttl); + lockMap.put(toRawKey(key), () -> lock); + } + + public void removeLock(ByteString key) { + lockMap.remove(toRawKey(key)); + } + + public boolean hasLock(ByteString key) { + return lockMap.containsKey(toRawKey(key)); + } + + // putTxnStatus is used to save transaction status + // commitTs > 0: committed + // commitTs == 0 && key is empty: rollback + // commitTs == 0 && key not empty: locked by key + public void putTxnStatus(Long startTs, Long commitTs, ByteString key) { + if (commitTs > 0 || (commitTs == 0 && key.isEmpty())) { // committed || rollback + Kvrpcpb.CheckTxnStatusResponse.Builder txnStatus = + Kvrpcpb.CheckTxnStatusResponse.newBuilder() + .setCommitVersion(commitTs) + .setLockTtl(0) + .setAction(Kvrpcpb.Action.NoAction); + txnStatusMap.put(startTs, () -> txnStatus); + } else { // locked + Kvrpcpb.LockInfo.Builder lock = lockMap.get(toRawKey(key)).get(); + Kvrpcpb.CheckTxnStatusResponse.Builder txnStatus = + Kvrpcpb.CheckTxnStatusResponse.newBuilder() + .setCommitVersion(commitTs) + .setLockTtl(lock.getLockTtl()) + .setAction(Kvrpcpb.Action.NoAction) + .setLockInfo(lock); + txnStatusMap.put(startTs, () -> txnStatus); + } + } + + // putTxnStatus is used to save transaction status + // commitTs > 0: committed + // commitTs == 0: rollback + public void putTxnStatus(Long startTs, Long commitTs) { + putTxnStatus(startTs, commitTs, ByteString.EMPTY); + } + public void clearAllMap() { dataMap.clear(); regionErrMap.clear(); + lockMap.clear(); + txnStatusMap.clear(); } private Errorpb.Error verifyContext(Context context) throws Exception { @@ -255,9 +319,12 @@ public class KVMockServer extends TikvGrpc.TikvImplBase { return; } + Supplier lock = lockMap.get(key); Supplier errProvider = keyErrMap.remove(key); if (errProvider != null) { builder.setError(errProvider.get().build()); + } else if (lock != null) { + builder.setError(Kvrpcpb.KeyError.newBuilder().setLocked(lock.get())); } else { ByteString value = dataMap.get(key); builder.setValue(value); @@ -299,11 +366,17 @@ public class KVMockServer extends TikvGrpc.TikvImplBase { kvs.entrySet() .stream() .map( - kv -> - Kvrpcpb.KvPair.newBuilder() - .setKey(kv.getKey().toByteString()) - .setValue(kv.getValue()) - .build()) + kv -> { + Kvrpcpb.KvPair.Builder kvBuilder = + Kvrpcpb.KvPair.newBuilder() + .setKey(kv.getKey().toByteString()) + .setValue(kv.getValue()); + Supplier lock = lockMap.get(kv.getKey()); + if (lock != null) { + kvBuilder.setError(Kvrpcpb.KeyError.newBuilder().setLocked(lock.get())); + } + return kvBuilder.build(); + }) .collect(Collectors.toList())); } responseObserver.onNext(builder.build()); @@ -354,6 +427,96 @@ public class KVMockServer extends TikvGrpc.TikvImplBase { } } + @Override + public void kvCheckTxnStatus( + org.tikv.kvproto.Kvrpcpb.CheckTxnStatusRequest request, + io.grpc.stub.StreamObserver + responseObserver) { + logger.info("KVMockServer.kvCheckTxnStatus"); + try { + Long startTs = request.getLockTs(); + Long currentTs = request.getCurrentTs(); + logger.info("kvCheckTxnStatus for txn: " + startTs); + Kvrpcpb.CheckTxnStatusResponse.Builder builder = Kvrpcpb.CheckTxnStatusResponse.newBuilder(); + + Error e = verifyContext(request.getContext()); + if (e != null) { + responseObserver.onNext(builder.setRegionError(e).build()); + responseObserver.onCompleted(); + return; + } + + Supplier txnStatus = txnStatusMap.get(startTs); + if (txnStatus != null) { + Kvrpcpb.CheckTxnStatusResponse resp = txnStatus.get().build(); + if (resp.getCommitVersion() == 0 + && resp.getLockTtl() > 0 + && TiTimestamp.extractPhysical(startTs) + resp.getLockInfo().getLockTtl() + < TiTimestamp.extractPhysical(currentTs)) { + ByteString key = resp.getLockInfo().getKey(); + logger.info( + String.format( + "kvCheckTxnStatus rollback expired txn: %d, remove lock: %s", + startTs, key.toStringUtf8())); + removeLock(key); + putTxnStatus(startTs, 0L, ByteString.EMPTY); + resp = txnStatusMap.get(startTs).get().build(); + } + logger.info("kvCheckTxnStatus resp: " + resp); + responseObserver.onNext(resp); + } else { + builder.setError( + Kvrpcpb.KeyError.newBuilder() + .setTxnNotFound( + Kvrpcpb.TxnNotFound.newBuilder() + .setPrimaryKey(request.getPrimaryKey()) + .setStartTs(startTs))); + logger.info("kvCheckTxnStatus, TxnNotFound"); + responseObserver.onNext(builder.build()); + } + responseObserver.onCompleted(); + } catch (Exception e) { + logger.error("kvCheckTxnStatus error: " + e); + responseObserver.onError(Status.INTERNAL.asRuntimeException()); + } + } + + @Override + public void kvResolveLock( + org.tikv.kvproto.Kvrpcpb.ResolveLockRequest request, + io.grpc.stub.StreamObserver responseObserver) { + logger.info("KVMockServer.kvResolveLock"); + try { + Long startTs = request.getStartVersion(); + Long commitTs = request.getCommitVersion(); + logger.info( + String.format( + "kvResolveLock for txn: %d, commitTs: %d, keys: %d", + startTs, commitTs, request.getKeysCount())); + Kvrpcpb.ResolveLockResponse.Builder builder = Kvrpcpb.ResolveLockResponse.newBuilder(); + + Error e = verifyContext(request.getContext()); + if (e != null) { + responseObserver.onNext(builder.setRegionError(e).build()); + responseObserver.onCompleted(); + return; + } + + if (request.getKeysCount() == 0) { + lockMap.entrySet().removeIf(entry -> entry.getValue().get().getLockVersion() == startTs); + } else { + for (int i = 0; i < request.getKeysCount(); i++) { + removeLock(request.getKeys(i)); + } + } + + responseObserver.onNext(builder.build()); + responseObserver.onCompleted(); + } catch (Exception e) { + responseObserver.onError(Status.INTERNAL.asRuntimeException()); + } + } + @Override public void coprocessor( org.tikv.kvproto.Coprocessor.Request requestWrap, diff --git a/src/test/java/org/tikv/common/MockServerTest.java b/src/test/java/org/tikv/common/MockServerTest.java index 02cab4c46f..db9ae5694b 100644 --- a/src/test/java/org/tikv/common/MockServerTest.java +++ b/src/test/java/org/tikv/common/MockServerTest.java @@ -39,6 +39,8 @@ public class MockServerTest extends PDMockServerTest { public void setup() throws IOException { super.setup(); + port = GrpcUtils.getFreePort(); + Metapb.Region r = Metapb.Region.newBuilder() .setRegionEpoch(Metapb.RegionEpoch.newBuilder().setConfVer(1).setVersion(2)) @@ -51,7 +53,7 @@ public class MockServerTest extends PDMockServerTest { List s = ImmutableList.of( Metapb.Store.newBuilder() - .setAddress("localhost:1234") + .setAddress(LOCAL_ADDR + ":" + port) .setVersion("5.0.0") .setId(13) .build()); @@ -70,6 +72,6 @@ public class MockServerTest extends PDMockServerTest { (request) -> Pdpb.GetStoreResponse.newBuilder().setStore(store).build()); } server = new KVMockServer(); - port = server.start(region); + server.start(region, port); } } diff --git a/src/test/java/org/tikv/common/PDClientMockTest.java b/src/test/java/org/tikv/common/PDClientMockTest.java index a8074d9457..6837334fee 100644 --- a/src/test/java/org/tikv/common/PDClientMockTest.java +++ b/src/test/java/org/tikv/common/PDClientMockTest.java @@ -74,9 +74,12 @@ public class PDClientMockTest extends PDMockServerTest { @Test public void testTso() throws Exception { try (PDClient client = session.getPDClient()) { + Long current = System.currentTimeMillis(); TiTimestamp ts = client.getTimestamp(defaultBackOff()); - // Test pdServer is set to generate physical == logical + 1 - assertEquals(ts.getPhysical(), ts.getLogical() + 1); + // Test pdServer is set to generate physical to current, logical to 1 + assertTrue(ts.getPhysical() >= current); + assertTrue(ts.getPhysical() < current + 100); + assertEquals(ts.getLogical(), 1); } } diff --git a/src/test/java/org/tikv/common/PDMockServer.java b/src/test/java/org/tikv/common/PDMockServer.java index 723034f1e3..99ccb66bbb 100644 --- a/src/test/java/org/tikv/common/PDMockServer.java +++ b/src/test/java/org/tikv/common/PDMockServer.java @@ -75,8 +75,17 @@ public class PDMockServer extends PDGrpc.PDImplBase { @Override public StreamObserver tso(StreamObserver resp) { return new StreamObserver() { - private int physical = 1; - private int logical = 0; + private long physical = System.currentTimeMillis(); + private long logical = 0; + + private void updateTso() { + logical++; + if (logical >= (1 << 18)) { + logical = 0; + physical++; + } + physical = Math.max(physical, System.currentTimeMillis()); + } @Override public void onNext(TsoRequest value) {} @@ -86,7 +95,8 @@ public class PDMockServer extends PDGrpc.PDImplBase { @Override public void onCompleted() { - resp.onNext(GrpcUtils.makeTsoResponse(clusterId, physical++, logical++)); + updateTso(); + resp.onNext(GrpcUtils.makeTsoResponse(clusterId, physical, logical)); resp.onCompleted(); } }; diff --git a/src/test/java/org/tikv/common/RegionStoreClientTest.java b/src/test/java/org/tikv/common/RegionStoreClientTest.java index 1a03ad80e2..bb288c48ae 100644 --- a/src/test/java/org/tikv/common/RegionStoreClientTest.java +++ b/src/test/java/org/tikv/common/RegionStoreClientTest.java @@ -17,15 +17,16 @@ package org.tikv.common; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; +import static org.junit.Assert.*; import com.google.common.collect.ImmutableList; import com.google.protobuf.ByteString; import java.util.List; import java.util.Optional; import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.tikv.common.exception.KeyException; import org.tikv.common.region.RegionManager; import org.tikv.common.region.RegionStoreClient; import org.tikv.common.region.RegionStoreClient.RegionStoreClientBuilder; @@ -40,6 +41,7 @@ import org.tikv.kvproto.Kvrpcpb; import org.tikv.kvproto.Metapb; public class RegionStoreClientTest extends MockServerTest { + private static final Logger logger = LoggerFactory.getLogger(MockServerTest.class); private RegionStoreClient createClientV2() { return createClient("2.1.19"); @@ -49,6 +51,10 @@ public class RegionStoreClientTest extends MockServerTest { return createClient("3.0.12"); } + private RegionStoreClient createClientV4() { + return createClient("6.1.0"); + } + private RegionStoreClient createClient(String version) { Metapb.Store meta = Metapb.Store.newBuilder() @@ -161,30 +167,130 @@ public class RegionStoreClientTest extends MockServerTest { @Test public void scanTest() { - doScanTest(createClientV3()); + doScanTest(createClientV4()); } public void doScanTest(RegionStoreClient client) { + Long startTs = session.getTimestamp().getVersion(); + server.put("key1", "value1"); server.put("key2", "value2"); server.put("key4", "value4"); server.put("key5", "value5"); - List kvs = client.scan(defaultBackOff(), ByteString.copyFromUtf8("key2"), 1); - assertEquals(3, kvs.size()); + + // put lock will expire in 1s + ByteString key6 = ByteString.copyFromUtf8("key6"); + server.putWithLock(key6, ByteString.copyFromUtf8("value6"), key6, startTs, 100L); + server.putTxnStatus(startTs, 0L, key6); + assertTrue(server.hasLock(key6)); + + List kvs = + client.scan( + defaultBackOff(), ByteString.copyFromUtf8("key2"), session.getTimestamp().getVersion()); + assertEquals(4, kvs.size()); kvs.forEach( kv -> assertEquals( kv.getKey().toStringUtf8().replace("key", "value"), kv.getValue().toStringUtf8())); + assertFalse(server.hasLock(key6)); + // put region error server.putError( "error1", () -> Errorpb.Error.newBuilder().setServerIsBusy(ServerIsBusy.getDefaultInstance())); try { - client.scan(defaultBackOff(), ByteString.copyFromUtf8("error1"), 1); + client.scan( + defaultBackOff(), ByteString.copyFromUtf8("error1"), session.getTimestamp().getVersion()); fail(); } catch (Exception e) { assertTrue(true); } + server.removeError("error1"); + + // put lock + Long startTs7 = session.getTimestamp().getVersion(); + ByteString key7 = ByteString.copyFromUtf8("key7"); + server.putWithLock(key7, ByteString.copyFromUtf8("value7"), key7, startTs7, 3000L); + server.putTxnStatus(startTs7, 0L, key7); + assertTrue(server.hasLock(key7)); + try { + client.scan( + defaultBackOff(), ByteString.copyFromUtf8("key2"), session.getTimestamp().getVersion()); + fail(); + } catch (Exception e) { + KeyException keyException = (KeyException) e.getCause(); + assertTrue(keyException.getMessage().contains("org.tikv.txn.Lock")); + } + assertTrue(server.hasLock(key7)); + + server.clearAllMap(); + client.close(); + } + + @Test + public void resolveLocksTest() { + doResolveLocksTest(createClientV4()); + } + + public void doResolveLocksTest(RegionStoreClient client) { + ByteString primaryKey = ByteString.copyFromUtf8("primary"); + server.put(primaryKey, ByteString.copyFromUtf8("value0")); + + // get with committed lock + { + Long startTs = session.getTimestamp().getVersion(); + Long commitTs = session.getTimestamp().getVersion(); + logger.info("startTs: " + startTs); + + ByteString key1 = ByteString.copyFromUtf8("key1"); + ByteString value1 = ByteString.copyFromUtf8("value1"); + server.putWithLock(key1, value1, primaryKey, startTs, 1L); + server.putTxnStatus(startTs, commitTs); + assertTrue(server.hasLock(key1)); + + ByteString expected1 = client.get(defaultBackOff(), key1, 200); + assertEquals(value1, expected1); + assertFalse(server.hasLock(key1)); + } + + // get with not expired lock. + { + Long startTs = session.getTimestamp().getVersion(); + logger.info("startTs: " + startTs); + + ByteString key2 = ByteString.copyFromUtf8("key2"); + ByteString value2 = ByteString.copyFromUtf8("value2"); + server.putWithLock(key2, value2, key2, startTs, 3000L); + server.putTxnStatus(startTs, 0L, key2); + assertTrue(server.hasLock(key2)); + + try { + client.get(defaultBackOff(), key2, session.getTimestamp().getVersion()); + fail(); + } catch (Exception e) { + KeyException keyException = (KeyException) e.getCause(); + assertTrue(keyException.getMessage().contains("org.tikv.txn.Lock")); + } + assertTrue(server.hasLock(key2)); + } + + // get with expired lock. + { + Long startTs = session.getTimestamp().getVersion(); + logger.info("startTs: " + startTs); + + ByteString key3 = ByteString.copyFromUtf8("key3"); + ByteString value3 = ByteString.copyFromUtf8("value3"); + server.putWithLock(key3, value3, key3, startTs, 100L); + server.putTxnStatus(startTs, 0L, key3); + assertTrue(server.hasLock(key3)); + + ByteString expected3 = + client.get(defaultBackOff(), key3, session.getTimestamp().getVersion()); + assertEquals(expected3, value3); + assertFalse(server.hasLock(key3)); + } + server.clearAllMap(); client.close(); }