diff --git a/src/main/java/org/tikv/common/apiversion/CodecUtils.java b/src/main/java/org/tikv/common/apiversion/CodecUtils.java index 1c6cfea9fa..a2b0725be5 100644 --- a/src/main/java/org/tikv/common/apiversion/CodecUtils.java +++ b/src/main/java/org/tikv/common/apiversion/CodecUtils.java @@ -23,7 +23,7 @@ import org.tikv.common.codec.CodecDataInput; import org.tikv.common.codec.CodecDataOutput; // TODO(iosmanthus): use ByteString.wrap to avoid once more copying. -class CodecUtils { +public class CodecUtils { public static ByteString encode(ByteString key) { CodecDataOutput cdo = new CodecDataOutput(); BytesCodec.writeBytes(cdo, key.toByteArray()); diff --git a/src/main/java/org/tikv/common/apiversion/RequestKeyV2Codec.java b/src/main/java/org/tikv/common/apiversion/RequestKeyV2Codec.java index 11db11c641..ab86fb5e02 100644 --- a/src/main/java/org/tikv/common/apiversion/RequestKeyV2Codec.java +++ b/src/main/java/org/tikv/common/apiversion/RequestKeyV2Codec.java @@ -81,22 +81,21 @@ public class RequestKeyV2Codec implements RequestKeyCodec { if (!start.isEmpty()) { start = CodecUtils.decode(start); - if (ByteString.unsignedLexicographicalComparator().compare(start, keyPrefix) < 0) { - start = ByteString.EMPTY; - } else { - start = decodeKey(start); - } } if (!end.isEmpty()) { end = CodecUtils.decode(end); - if (ByteString.unsignedLexicographicalComparator().compare(end, infiniteEndKey) >= 0) { - end = ByteString.EMPTY; - } else { - end = decodeKey(end); - } } + if (ByteString.unsignedLexicographicalComparator().compare(start, infiniteEndKey) >= 0 + || (!end.isEmpty() + && ByteString.unsignedLexicographicalComparator().compare(end, keyPrefix) <= 0)) { + throw new IllegalArgumentException("region out of keyspace" + region.toString()); + } + + start = start.startsWith(keyPrefix) ? start.substring(keyPrefix.size()) : ByteString.EMPTY; + end = end.startsWith(keyPrefix) ? end.substring(keyPrefix.size()) : ByteString.EMPTY; + return builder.setStartKey(start).setEndKey(end).build(); } } diff --git a/src/test/java/org/tikv/common/apiversion/RequestKeyCodecTest.java b/src/test/java/org/tikv/common/apiversion/RequestKeyCodecTest.java index 871a20cdf2..60ffa53e70 100644 --- a/src/test/java/org/tikv/common/apiversion/RequestKeyCodecTest.java +++ b/src/test/java/org/tikv/common/apiversion/RequestKeyCodecTest.java @@ -17,8 +17,7 @@ package org.tikv.common.apiversion; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; +import static org.junit.Assert.*; import com.google.common.collect.ImmutableList; import com.google.protobuf.ByteString; @@ -177,5 +176,79 @@ public class RequestKeyCodecTest { decoded = v2.decodeRegion(region); assertEquals(start, decoded.getStartKey()); assertEquals(ByteString.EMPTY, decoded.getEndKey()); + + // test region out of keyspace + region = + Region.newBuilder() + .setStartKey(ByteString.EMPTY) + .setEndKey(CodecUtils.encode(v2.keyPrefix)) + .build(); + + try { + decoded = v2.decodeRegion(region); + fail(); + } catch (Exception ignored) { + } + + region = + Region.newBuilder() + .setStartKey(CodecUtils.encode(v2.infiniteEndKey)) + .setEndKey(ByteString.EMPTY) + .build(); + try { + decoded = v2.decodeRegion(region); + fail(); + } catch (Exception ignored) { + } + + // case: regionStartKey == "" < keyPrefix < regionEndKey < infiniteEndKey + region = + Region.newBuilder() + .setStartKey(ByteString.EMPTY) + .setEndKey(CodecUtils.encode(v2.keyPrefix.concat(ByteString.copyFromUtf8("0")))) + .build(); + decoded = v2.decodeRegion(region); + assertTrue(decoded.getStartKey().isEmpty()); + assertEquals(ByteString.copyFromUtf8("0"), decoded.getEndKey()); + + // case: "" < regionStartKey < keyPrefix < regionEndKey < infiniteEndKey < "" + region = + Region.newBuilder() + .setStartKey(CodecUtils.encode(ByteString.copyFromUtf8("m_123"))) + .setEndKey(CodecUtils.encode(v2.keyPrefix.concat(ByteString.copyFromUtf8("0")))) + .build(); + decoded = v2.decodeRegion(region); + assertEquals(ByteString.EMPTY, decoded.getStartKey()); + assertEquals(ByteString.copyFromUtf8("0"), decoded.getEndKey()); + + // case: "" < regionStartKey < keyPrefix < infiniteEndKey < regionEndKey < "" + region = + Region.newBuilder() + .setStartKey(CodecUtils.encode(ByteString.copyFromUtf8("m_123"))) + .setEndKey(CodecUtils.encode(v2.infiniteEndKey.concat(ByteString.copyFromUtf8("0")))) + .build(); + decoded = v2.decodeRegion(region); + assertEquals(ByteString.EMPTY, decoded.getStartKey()); + assertEquals(ByteString.EMPTY, decoded.getEndKey()); + + // case: keyPrefix < regionStartKey < infiniteEndKey < regionEndKey < "" + region = + Region.newBuilder() + .setStartKey(CodecUtils.encode(v2.keyPrefix.concat(ByteString.copyFromUtf8("0")))) + .setEndKey(CodecUtils.encode(v2.infiniteEndKey.concat(ByteString.copyFromUtf8("0")))) + .build(); + decoded = v2.decodeRegion(region); + assertEquals(ByteString.copyFromUtf8("0"), decoded.getStartKey()); + assertTrue(decoded.getEndKey().isEmpty()); + + // case: keyPrefix < regionStartKey < infiniteEndKey < regionEndKey == "" + region = + Region.newBuilder() + .setStartKey(CodecUtils.encode(v2.keyPrefix.concat(ByteString.copyFromUtf8("0")))) + .setEndKey(ByteString.EMPTY) + .build(); + decoded = v2.decodeRegion(region); + assertEquals(ByteString.copyFromUtf8("0"), decoded.getStartKey()); + assertTrue(decoded.getEndKey().isEmpty()); } }