memdb: introduce snapshot interface (#1623)

Signed-off-by: you06 <you1474600@gmail.com>

Co-authored-by: ekexium <eke@fastmail.com>
This commit is contained in:
you06 2025-04-15 23:17:37 +09:00 committed by GitHub
parent 2b8c6a7761
commit 183817ac81
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 517 additions and 312 deletions

View File

@ -21,32 +21,35 @@ import (
"github.com/tikv/client-go/v2/internal/unionstore/arena"
)
// getSnapshot returns the "snapshot" for snapshotGetter or snapshotIterator, which is usually the snapshot
// of stage[0]
func (t *ART) getSnapshot() arena.MemDBCheckpoint {
type Snapshot struct {
tree *ART
cp arena.MemDBCheckpoint
}
func (t *ART) getSnapshotCheckpoint() arena.MemDBCheckpoint {
if len(t.stages) > 0 {
return t.stages[0]
}
return t.checkpoint()
}
// SnapshotGetter returns a Getter for a snapshot of MemBuffer.
func (t *ART) SnapshotGetter() *SnapGetter {
return &SnapGetter{
// GetSnapshot returns a Getter for a snapshot of MemBuffer's stage[0]
func (t *ART) GetSnapshot() *Snapshot {
return &Snapshot{
tree: t,
cp: t.getSnapshot(),
cp: t.getSnapshotCheckpoint(),
}
}
func (t *ART) newSnapshotIterator(start, end []byte, desc bool) *SnapIter {
func (s *Snapshot) NewSnapshotIterator(start, end []byte, desc bool) *SnapIter {
var (
inner *Iterator
err error
)
if desc {
inner, err = t.IterReverse(start, end)
inner, err = s.tree.IterReverse(start, end)
} else {
inner, err = t.Iter(start, end)
inner, err = s.tree.Iter(start, end)
}
if err != nil {
panic(err)
@ -54,7 +57,7 @@ func (t *ART) newSnapshotIterator(start, end []byte, desc bool) *SnapIter {
inner.ignoreSeqNo = true
it := &SnapIter{
Iterator: inner,
cp: t.getSnapshot(),
cp: s.cp,
}
it.tree.allocator.snapshotInc()
for !it.setValue() && it.Valid() {
@ -63,37 +66,6 @@ func (t *ART) newSnapshotIterator(start, end []byte, desc bool) *SnapIter {
return it
}
// SnapshotIter returns an Iterator for a snapshot of MemBuffer.
func (t *ART) SnapshotIter(start, end []byte) *SnapIter {
return t.newSnapshotIterator(start, end, false)
}
// SnapshotIterReverse returns a reverse Iterator for a snapshot of MemBuffer.
func (t *ART) SnapshotIterReverse(k, lowerBound []byte) *SnapIter {
return t.newSnapshotIterator(k, lowerBound, true)
}
type SnapGetter struct {
tree *ART
cp arena.MemDBCheckpoint
}
func (snap *SnapGetter) Get(ctx context.Context, key []byte) ([]byte, error) {
addr, lf := snap.tree.traverse(key, false)
if addr.IsNull() {
return nil, tikverr.ErrNotExist
}
if lf.vLogAddr.IsNull() {
// A flags only key, act as value not exists
return nil, tikverr.ErrNotExist
}
v, ok := snap.tree.allocator.vlogAllocator.GetSnapshotValue(lf.vLogAddr, &snap.cp)
if !ok {
return nil, tikverr.ErrNotExist
}
return v, nil
}
type SnapIter struct {
*Iterator
value []byte
@ -134,3 +106,21 @@ func (i *SnapIter) setValue() bool {
}
return false
}
func (snap *Snapshot) Get(ctx context.Context, key []byte) ([]byte, error) {
addr, lf := snap.tree.traverse(key, false)
if addr.IsNull() {
return nil, tikverr.ErrNotExist
}
if lf.vLogAddr.IsNull() {
// A flags only key, act as value not exists
return nil, tikverr.ErrNotExist
}
v, ok := snap.tree.allocator.vlogAllocator.GetSnapshotValue(lf.vLogAddr, &snap.cp)
if !ok {
return nil, tikverr.ErrNotExist
}
return v, nil
}
func (snap *Snapshot) Close() {}

View File

@ -39,7 +39,7 @@ func TestSnapshotIteratorPreventFreeNode(t *testing.T) {
default:
panic("unsupported num")
}
it := tree.SnapshotIter(nil, nil)
it := tree.GetSnapshot().NewSnapshotIterator(nil, nil, false)
require.Equal(t, 0, len(*unusedNodeSlice))
tree.Set([]byte{0, byte(num)}, []byte{0, byte(num)})
require.Equal(t, 1, len(*unusedNodeSlice))
@ -60,7 +60,7 @@ func TestConcurrentSnapshotIterNoRace(t *testing.T) {
}
const concurrency = 100
it := tree.SnapshotIter(nil, nil)
it := tree.GetSnapshot().NewSnapshotIterator(nil, nil, false)
tree.Set([]byte{0, byte(num)}, []byte{0, byte(num)})
@ -72,7 +72,7 @@ func TestConcurrentSnapshotIterNoRace(t *testing.T) {
}()
for i := 1; i < concurrency; i++ {
go func(it *SnapIter) {
concurrentIt := tree.SnapshotIter(nil, nil)
concurrentIt := tree.GetSnapshot().NewSnapshotIterator(nil, nil, false)
concurrentIt.Close()
wg.Done()
}(it)

View File

@ -0,0 +1,220 @@
// Copyright 2025 TiKV Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package unionstore
import (
"context"
"sync"
)
// SnapshotWithMutex wraps a MemBuffer's snapshot with a mutex to ensure thread-safe access.
// A MemBuffer's snapshot can always read the already-committed data and by pass the staging data through the multi-version system provided by the memdb.
// It uses an RWMutex to prevent concurrent writes to the MemBuffer during read operations, since the underlying memdb (ART or RBT) is not thread-safe.
// And this wrap also avoid the user to call RLock and RUnlock manually.
// A sequence number check is also implemented to ensure the snapshot remains valid during access.
// While the MemBuffer doesn't support interleaving iterators and writes, the SnapshotWithMutex wrapper makes this possible.
type SnapshotWithMutex[S memdbSnapshot] struct {
mu *sync.RWMutex
seqCheck func() error
snapshot S
}
var _ MemBufferSnapshot = (*SnapshotWithMutex[memdbSnapshot])(nil)
func (s *SnapshotWithMutex[_]) Get(ctx context.Context, k []byte) ([]byte, error) {
if err := s.seqCheck(); err != nil {
return nil, err
}
s.mu.RLock()
defer s.mu.RUnlock()
return s.snapshot.Get(ctx, k)
}
type snapshotBatchedIter[S memdbSnapshot] struct {
mu *sync.RWMutex
seqCheck func() error
snapshot S
lower []byte
upper []byte
reverse bool
err error
// current batch
keys [][]byte
values [][]byte
pos int
batchSize int
nextKey []byte
}
func (s *SnapshotWithMutex[S]) BatchedSnapshotIter(lower, upper []byte, reverse bool) Iterator {
iter := &snapshotBatchedIter[S]{
mu: s.mu,
seqCheck: s.seqCheck,
snapshot: s.snapshot,
lower: lower,
upper: upper,
reverse: reverse,
batchSize: 32,
}
iter.err = iter.fillBatch()
return iter
}
func (it *snapshotBatchedIter[_]) fillBatch() error {
// The check of sequence numbers don't have to be protected by the rwlock, as the invariant is that
// there cannot be concurrent writes to the seqNo variables.
if err := it.seqCheck(); err != nil {
return err
}
it.mu.RLock()
defer it.mu.RUnlock()
if it.keys == nil || it.values == nil || cap(it.keys) < it.batchSize || cap(it.values) < it.batchSize {
it.keys = make([][]byte, 0, it.batchSize)
it.values = make([][]byte, 0, it.batchSize)
} else {
it.keys = it.keys[:0]
it.values = it.values[:0]
}
var snapshotIter Iterator
if it.reverse {
searchUpper := it.upper
if it.nextKey != nil {
searchUpper = it.nextKey
}
snapshotIter = it.snapshot.NewSnapshotIterator(searchUpper, it.lower, true)
} else {
searchLower := it.lower
if it.nextKey != nil {
searchLower = it.nextKey
}
snapshotIter = it.snapshot.NewSnapshotIterator(searchLower, it.upper, false)
}
defer snapshotIter.Close()
// fill current batch
// Further optimization: let the underlying memdb support batch iter.
for i := 0; i < it.batchSize && snapshotIter.Valid(); i++ {
it.keys = it.keys[:i+1]
it.values = it.values[:i+1]
it.keys[i] = snapshotIter.Key()
it.values[i] = snapshotIter.Value()
if err := snapshotIter.Next(); err != nil {
return err
}
}
// update state
it.pos = 0
if len(it.keys) > 0 {
lastKey := it.keys[len(it.keys)-1]
keyLen := len(lastKey)
if it.reverse {
if cap(it.nextKey) >= keyLen {
it.nextKey = it.nextKey[:keyLen]
} else {
it.nextKey = make([]byte, keyLen)
}
copy(it.nextKey, lastKey)
} else {
if cap(it.nextKey) >= keyLen+1 {
it.nextKey = it.nextKey[:keyLen+1]
} else {
it.nextKey = make([]byte, keyLen+1)
}
copy(it.nextKey, lastKey)
it.nextKey[keyLen] = 0
}
} else {
it.nextKey = nil
}
it.batchSize = min(it.batchSize*2, 4096)
return nil
}
func (it *snapshotBatchedIter[_]) Valid() bool {
return it.seqCheck() == nil &&
it.pos < len(it.keys) &&
it.err == nil
}
func (it *snapshotBatchedIter[_]) Next() error {
if it.err != nil {
return it.err
}
if err := it.seqCheck(); err != nil {
return err
}
it.pos++
if it.pos >= len(it.keys) {
return it.fillBatch()
}
return nil
}
func (it *snapshotBatchedIter[_]) Key() []byte {
if !it.Valid() {
return nil
}
return it.keys[it.pos]
}
func (it *snapshotBatchedIter[_]) Value() []byte {
if !it.Valid() {
return nil
}
return it.values[it.pos]
}
func (it *snapshotBatchedIter[_]) Close() {
it.keys = nil
it.values = nil
it.nextKey = nil
}
func (s *SnapshotWithMutex[_]) ForEachInSnapshotRange(lower []byte, upper []byte, f func(k, v []byte) (stop bool, err error), reverse bool) error {
s.mu.RLock()
defer s.mu.RUnlock()
var iter Iterator
if reverse {
iter = s.snapshot.NewSnapshotIterator(upper, lower, true)
} else {
iter = s.snapshot.NewSnapshotIterator(lower, upper, false)
}
defer iter.Close()
for iter.Valid() {
stop, err := f(iter.Key(), iter.Value())
if err != nil {
return err
}
err = iter.Next()
if err != nil {
return err
}
if stop {
break
}
}
return nil
}
func (s *SnapshotWithMutex[_]) Close() {}

View File

@ -16,7 +16,6 @@ package unionstore
import (
"context"
"fmt"
"sync"
"github.com/pingcap/errors"
@ -154,203 +153,49 @@ func (db *artDBWithContext) IterReverse(upper, lower []byte) (Iterator, error) {
return db.ART.IterReverse(upper, lower)
}
func (db *artDBWithContext) ForEachInSnapshotRange(lower []byte, upper []byte, f func(k, v []byte) (stop bool, err error), reverse bool) error {
db.RLock()
defer db.RUnlock()
var iter Iterator
if reverse {
iter = db.SnapshotIterReverse(upper, lower)
} else {
iter = db.SnapshotIter(lower, upper)
}
defer iter.Close()
for iter.Valid() {
stop, err := f(iter.Key(), iter.Value())
if err != nil {
return err
}
err = iter.Next()
if err != nil {
return err
}
if stop {
break
}
}
return nil
// SnapshotGetter implements the Getter interface, by wrapping GetSnapshot.
func (db *artDBWithContext) SnapshotGetter() Getter {
return db.ART.GetSnapshot()
}
// SnapshotIter returns an Iterator for a snapshot of MemBuffer.
func (db *artDBWithContext) SnapshotIter(lower, upper []byte) Iterator {
return db.ART.SnapshotIter(lower, upper)
return db.ART.GetSnapshot().NewSnapshotIterator(lower, upper, false)
}
// SnapshotIterReverse returns a reversed Iterator for a snapshot of MemBuffer.
// SnapshotIter returns an Iterator for a snapshot of MemBuffer.
func (db *artDBWithContext) SnapshotIterReverse(upper, lower []byte) Iterator {
return db.ART.SnapshotIterReverse(upper, lower)
return db.ART.GetSnapshot().NewSnapshotIterator(upper, lower, true)
}
// SnapshotGetter returns a Getter for a snapshot of MemBuffer.
func (db *artDBWithContext) SnapshotGetter() Getter {
return db.ART.SnapshotGetter()
type artSnapshot struct {
*art.Snapshot
}
type snapshotBatchedIter struct {
db *artDBWithContext
snapshotSeqNo int
lower []byte
upper []byte
reverse bool
err error
// current batch
keys [][]byte
values [][]byte
pos int
batchSize int
nextKey []byte
// NewSnapshotIterator wraps `ART.NewSnapshotIterator` and cast the result into an `Iterator`.
func (a *artSnapshot) NewSnapshotIterator(start, end []byte, reverse bool) Iterator {
return a.Snapshot.NewSnapshotIterator(start, end, reverse)
}
func (db *artDBWithContext) BatchedSnapshotIter(lower, upper []byte, reverse bool) Iterator {
// GetSnapshot returns a snapshot of the ART.
func (db *artDBWithContext) GetSnapshot() MemBufferSnapshot {
if len(db.Stages()) == 0 {
logutil.BgLogger().Error("should not use BatchedSnapshotIter for a memdb without any staging buffer")
}
iter := &snapshotBatchedIter{
db: db,
snapshotSeqNo: db.SnapshotSeqNo,
lower: lower,
upper: upper,
reverse: reverse,
batchSize: 32,
}
iter.err = iter.fillBatch()
return iter
}
func (it *snapshotBatchedIter) fillBatch() error {
// The check of sequence numbers don't have to be protected by the rwlock, as the invariant is that
// there cannot be concurrent writes to the seqNo variables.
if it.snapshotSeqNo != it.db.SnapshotSeqNo {
return errors.Errorf(
"invalid iter: snapshotSeqNo changed, iter's=%d, db's=%d",
it.snapshotSeqNo,
it.db.SnapshotSeqNo,
)
}
it.db.RLock()
defer it.db.RUnlock()
if it.keys == nil || it.values == nil || cap(it.keys) < it.batchSize || cap(it.values) < it.batchSize {
it.keys = make([][]byte, 0, it.batchSize)
it.values = make([][]byte, 0, it.batchSize)
} else {
it.keys = it.keys[:0]
it.values = it.values[:0]
}
var snapshotIter Iterator
if it.reverse {
searchUpper := it.upper
if it.nextKey != nil {
searchUpper = it.nextKey
snapshotSeqNo := db.SnapshotSeqNo
seqCheck := func() error {
if snapshotSeqNo != db.SnapshotSeqNo {
return errors.Errorf(
"invalid iter: snapshotSeqNo changed, iter's=%d, db's=%d",
snapshotSeqNo,
db.SnapshotSeqNo,
)
}
snapshotIter = it.db.SnapshotIterReverse(searchUpper, it.lower)
} else {
searchLower := it.lower
if it.nextKey != nil {
searchLower = it.nextKey
}
snapshotIter = it.db.SnapshotIter(searchLower, it.upper)
}
defer snapshotIter.Close()
// fill current batch
// Further optimization: let the underlying memdb support batch iter.
for i := 0; i < it.batchSize && snapshotIter.Valid(); i++ {
it.keys = it.keys[:i+1]
it.values = it.values[:i+1]
it.keys[i] = snapshotIter.Key()
it.values[i] = snapshotIter.Value()
if err := snapshotIter.Next(); err != nil {
return err
}
}
// update state
it.pos = 0
if len(it.keys) > 0 {
lastKey := it.keys[len(it.keys)-1]
keyLen := len(lastKey)
if it.reverse {
if cap(it.nextKey) >= keyLen {
it.nextKey = it.nextKey[:keyLen]
} else {
it.nextKey = make([]byte, keyLen)
}
copy(it.nextKey, lastKey)
} else {
if cap(it.nextKey) >= keyLen+1 {
it.nextKey = it.nextKey[:keyLen+1]
} else {
it.nextKey = make([]byte, keyLen+1)
}
copy(it.nextKey, lastKey)
it.nextKey[keyLen] = 0
}
} else {
it.nextKey = nil
}
it.batchSize = min(it.batchSize*2, 4096)
return nil
}
func (it *snapshotBatchedIter) Valid() bool {
return it.snapshotSeqNo == it.db.SnapshotSeqNo &&
it.pos < len(it.keys) &&
it.err == nil
}
func (it *snapshotBatchedIter) Next() error {
if it.err != nil {
return it.err
}
if it.snapshotSeqNo != it.db.SnapshotSeqNo {
return errors.New(
fmt.Sprintf(
"invalid snapshotBatchedIter: snapshotSeqNo changed, iter's=%d, db's=%d",
it.snapshotSeqNo,
it.db.SnapshotSeqNo,
),
)
}
it.pos++
if it.pos >= len(it.keys) {
return it.fillBatch()
}
return nil
}
func (it *snapshotBatchedIter) Key() []byte {
if !it.Valid() {
return nil
}
return it.keys[it.pos]
}
func (it *snapshotBatchedIter) Value() []byte {
if !it.Valid() {
return nil
return &SnapshotWithMutex[*artSnapshot]{
mu: &db.RWMutex,
seqCheck: seqCheck,
snapshot: &artSnapshot{db.ART.GetSnapshot()},
}
return it.values[it.pos]
}
func (it *snapshotBatchedIter) Close() {
it.keys = nil
it.values = nil
it.nextKey = nil
}

View File

@ -198,7 +198,8 @@ func BenchmarkSnapshotIter(b *testing.B) {
}
b.Run("RBT-SnapshotIter", func(b *testing.B) { f(b, newRbtDBWithContext()) })
// unimplemented for RBT
b.Run("RBT-BatchedSnapshotIter", func(b *testing.B) { fBatched(b, newRbtDBWithContext()) })
b.Run("ART-ForEachInSnapshot", func(b *testing.B) { fForEach(b, newRbtDBWithContext()) })
b.Run("ART-SnapshotIter", func(b *testing.B) { f(b, newArtDBWithContext()) })
b.Run("ART-BatchedSnapshotIter", func(b *testing.B) { fBatched(b, newArtDBWithContext()) })
b.Run("ART-ForEachInSnapshot", func(b *testing.B) { fForEach(b, newArtDBWithContext()) })
@ -279,8 +280,9 @@ func benchBatchedSnapshotIter(b *testing.B, buffer MemBuffer) {
}
buffer.Staging()
b.ResetTimer()
snapshot := buffer.GetSnapshot()
for i := 0; i < b.N; i++ {
iter := buffer.BatchedSnapshotIter(nil, nil, false)
iter := snapshot.BatchedSnapshotIter(nil, nil, false)
for iter.Valid() {
iter.Next()
}
@ -297,8 +299,9 @@ func benchForEachInSnapshot(b *testing.B, buffer MemBuffer) {
f := func(key, value []byte) (bool, error) {
return false, nil
}
snapshot := buffer.GetSnapshot()
for i := 0; i < b.N; i++ {
err := buffer.ForEachInSnapshotRange(nil, nil, f, false)
err := snapshot.ForEachInSnapshotRange(nil, nil, f, false)
if err != nil {
b.Error(err)
}

View File

@ -189,17 +189,17 @@ func (db *rbtDBWithContext) ForEachInSnapshotRange(lower []byte, upper []byte, f
// SnapshotIter returns an Iterator for a snapshot of MemBuffer.
func (db *rbtDBWithContext) SnapshotIter(lower, upper []byte) Iterator {
return db.RBT.SnapshotIter(lower, upper)
return db.RBT.GetSnapshot().SnapshotIter(lower, upper)
}
// SnapshotIterReverse returns a reversed Iterator for a snapshot of MemBuffer.
func (db *rbtDBWithContext) SnapshotIterReverse(upper, lower []byte) Iterator {
return db.RBT.SnapshotIterReverse(upper, lower)
return db.RBT.GetSnapshot().SnapshotIterReverse(upper, lower)
}
// SnapshotGetter returns a Getter for a snapshot of MemBuffer.
func (db *rbtDBWithContext) SnapshotGetter() Getter {
return db.RBT.SnapshotGetter()
return db.RBT.GetSnapshot()
}
func (db *rbtDBWithContext) BatchedSnapshotIter(lower, upper []byte, reverse bool) Iterator {
@ -210,3 +210,29 @@ func (db *rbtDBWithContext) BatchedSnapshotIter(lower, upper []byte, reverse boo
return db.SnapshotIter(lower, upper)
}
}
type rbtSnapshot struct {
*rbt.Snapshot
}
// NewSnapshotIterator wraps `RBT.SnapshotIterReverse` and `RBT.SnapshotIter` and cast the result into an `Iterator`.
func (a *rbtSnapshot) NewSnapshotIterator(start, end []byte, reverse bool) Iterator {
if reverse {
return a.Snapshot.SnapshotIterReverse(start, end)
} else {
return a.Snapshot.SnapshotIter(start, end)
}
}
// GetSnapshot returns a snapshot of the RBT.
func (db *rbtDBWithContext) GetSnapshot() MemBufferSnapshot {
// The RBT doesn't maintain the sequence number, so the seqCheck is a noop function.
seqCheck := func() error {
return nil
}
return &SnapshotWithMutex[*rbtSnapshot]{
mu: &db.RWMutex,
seqCheck: seqCheck,
snapshot: &rbtSnapshot{db.RBT.GetSnapshot()},
}
}

View File

@ -1373,9 +1373,11 @@ func TestBatchedSnapshotIter(t *testing.T) {
}
h := db.Staging()
defer db.Release(h)
snapshot := db.GetSnapshot()
defer snapshot.Close()
// Create iterator - should be positioned at first key
iter := db.BatchedSnapshotIter([]byte{0, 0}, []byte{0, 255}, false)
iter := snapshot.BatchedSnapshotIter([]byte{0, 0}, []byte{0, 255}, false)
defer iter.Close()
// Should be able to read first key immediately
@ -1405,8 +1407,10 @@ func TestBatchedSnapshotIter(t *testing.T) {
}
h := db.Staging()
defer db.Release(h)
snapshot := db.GetSnapshot()
defer snapshot.Close()
iter := db.BatchedSnapshotIter([]byte{0, 0}, []byte{0, 255}, true)
iter := snapshot.BatchedSnapshotIter([]byte{0, 0}, []byte{0, 255}, true)
defer iter.Close()
// Should be positioned at last key
@ -1442,21 +1446,24 @@ func TestBatchedSnapshotIterEdgeCase(t *testing.T) {
t.Run("EdgeCases", func(t *testing.T) {
db := newArtDBWithContext()
h := db.Staging()
snapshot := db.GetSnapshot()
// invalid range - should be invalid immediately
iter := db.BatchedSnapshotIter([]byte{1}, []byte{1}, false)
iter := snapshot.BatchedSnapshotIter([]byte{1}, []byte{1}, false)
require.False(t, iter.Valid())
iter.Close()
// empty range - should be invalid immediately
iter = db.BatchedSnapshotIter([]byte{0}, []byte{1}, false)
iter = snapshot.BatchedSnapshotIter([]byte{0}, []byte{1}, false)
require.False(t, iter.Valid())
iter.Close()
snapshot.Close()
// Single element range
_ = db.Set([]byte{1}, []byte{1})
db.Release(h)
h = db.Staging()
iter = db.BatchedSnapshotIter([]byte{1}, []byte{2}, false)
snapshot = db.GetSnapshot()
iter = snapshot.BatchedSnapshotIter([]byte{1}, []byte{2}, false)
require.True(t, iter.Valid())
require.Equal(t, []byte{1}, iter.Key())
require.NoError(t, iter.Next())
@ -1467,11 +1474,13 @@ func TestBatchedSnapshotIterEdgeCase(t *testing.T) {
_ = db.Set([]byte{2}, []byte{2})
_ = db.Set([]byte{3}, []byte{3})
_ = db.Set([]byte{4}, []byte{4})
snapshot.Close()
db.Release(h)
_ = db.Staging()
// Forward iteration [2,4)
iter = db.BatchedSnapshotIter([]byte{2}, []byte{4}, false)
snapshot = db.GetSnapshot()
iter = snapshot.BatchedSnapshotIter([]byte{2}, []byte{4}, false)
vals := []byte{}
for iter.Valid() {
vals = append(vals, iter.Key()[0])
@ -1481,7 +1490,7 @@ func TestBatchedSnapshotIterEdgeCase(t *testing.T) {
iter.Close()
// Reverse iteration [2,4)
iter = db.BatchedSnapshotIter([]byte{2}, []byte{4}, true)
iter = snapshot.BatchedSnapshotIter([]byte{2}, []byte{4}, true)
vals = []byte{}
for iter.Valid() {
vals = append(vals, iter.Key()[0])
@ -1503,7 +1512,9 @@ func TestBatchedSnapshotIterEdgeCase(t *testing.T) {
// lower bound included
h := db.Staging()
defer db.Release(h)
iter := db.BatchedSnapshotIter([]byte{1, 2}, []byte{1, 9}, false)
snapshot := db.GetSnapshot()
defer snapshot.Close()
iter := snapshot.BatchedSnapshotIter([]byte{1, 2}, []byte{1, 9}, false)
vals := []byte{}
for iter.Valid() {
vals = append(vals, iter.Key()[1])
@ -1513,7 +1524,7 @@ func TestBatchedSnapshotIterEdgeCase(t *testing.T) {
iter.Close()
// upper bound excluded
iter = db.BatchedSnapshotIter([]byte{1, 0}, []byte{1, 6}, false)
iter = snapshot.BatchedSnapshotIter([]byte{1, 0}, []byte{1, 6}, false)
vals = []byte{}
for iter.Valid() {
vals = append(vals, iter.Key()[1])
@ -1523,7 +1534,7 @@ func TestBatchedSnapshotIterEdgeCase(t *testing.T) {
iter.Close()
// reverse
iter = db.BatchedSnapshotIter([]byte{1, 0}, []byte{1, 6}, true)
iter = snapshot.BatchedSnapshotIter([]byte{1, 0}, []byte{1, 6}, true)
vals = []byte{}
for iter.Valid() {
vals = append(vals, iter.Key()[1])
@ -1546,8 +1557,10 @@ func TestBatchedSnapshotIterEdgeCase(t *testing.T) {
}
h := db.Staging()
defer db.Release(h)
snapshot := db.GetSnapshot()
defer snapshot.Close()
// forward
iter := db.BatchedSnapshotIter([]byte{2}, []byte{3}, false)
iter := snapshot.BatchedSnapshotIter([]byte{2}, []byte{3}, false)
count := 0
for iter.Valid() {
require.Equal(t, keys[count], iter.Key())
@ -1558,7 +1571,7 @@ func TestBatchedSnapshotIterEdgeCase(t *testing.T) {
iter.Close()
// reverse
iter = db.BatchedSnapshotIter([]byte{2}, []byte{3}, true)
iter = snapshot.BatchedSnapshotIter([]byte{2}, []byte{3}, true)
count = len(keys) - 1
for iter.Valid() {
require.Equal(t, keys[count], iter.Key())
@ -1577,8 +1590,10 @@ func TestBatchedSnapshotIterEdgeCase(t *testing.T) {
}
h := db.Staging()
defer db.Release(h)
snapshot := db.GetSnapshot()
defer snapshot.Close()
// forward
iter := db.BatchedSnapshotIter([]byte{3, 0}, []byte{3, 255}, false)
iter := snapshot.BatchedSnapshotIter([]byte{3, 0}, []byte{3, 255}, false)
count := 0
for iter.Valid() {
require.Equal(t, []byte{3, byte(count)}, iter.Key())
@ -1589,7 +1604,7 @@ func TestBatchedSnapshotIterEdgeCase(t *testing.T) {
iter.Close()
// reverse
iter = db.BatchedSnapshotIter([]byte{3, 0}, []byte{3, 255}, true)
iter = snapshot.BatchedSnapshotIter([]byte{3, 0}, []byte{3, 255}, true)
count = 99
for iter.Valid() {
require.Equal(t, []byte{3, byte(count)}, iter.Key())
@ -1604,8 +1619,10 @@ func TestBatchedSnapshotIterEdgeCase(t *testing.T) {
db := newArtDBWithContext()
_ = db.Set([]byte{0}, []byte{0})
h := db.Staging()
snapshot := db.GetSnapshot()
defer snapshot.Close()
_ = db.Set([]byte{byte(1)}, []byte{byte(1)})
iter := db.BatchedSnapshotIter([]byte{0}, []byte{255}, false)
iter := snapshot.BatchedSnapshotIter([]byte{0}, []byte{255}, false)
require.True(t, iter.Valid())
require.NoError(t, iter.Next())
db.Release(h)

View File

@ -558,3 +558,7 @@ func (p *PipelinedMemDB) MemHookSet() bool {
func (p *PipelinedMemDB) BatchedSnapshotIter(lower, upper []byte, reverse bool) Iterator {
panic("BatchedSnapshotIter is not supported for PipelinedMemDB")
}
func (*PipelinedMemDB) GetSnapshot() MemBufferSnapshot {
panic("GetSnapshot is not supported for PipelinedMemDB")
}

View File

@ -41,56 +41,56 @@ import (
"github.com/tikv/client-go/v2/internal/unionstore/arena"
)
// SnapshotGetter returns a Getter for a snapshot of MemBuffer.
func (db *RBT) SnapshotGetter() *rbtSnapGetter {
return &rbtSnapGetter{
db: db,
cp: db.getSnapshot(),
}
type Snapshot struct {
db *RBT
cp arena.MemDBCheckpoint
}
// SnapshotIter returns an Iterator for a snapshot of MemBuffer.
func (db *RBT) SnapshotIter(start, end []byte) *rbtSnapIter {
it := &rbtSnapIter{
RBTIterator: &RBTIterator{
db: db,
start: start,
end: end,
},
cp: db.getSnapshot(),
}
it.init()
return it
}
// SnapshotIterReverse returns a reverse Iterator for a snapshot of MemBuffer.
func (db *RBT) SnapshotIterReverse(k, lowerBound []byte) *rbtSnapIter {
it := &rbtSnapIter{
RBTIterator: &RBTIterator{
db: db,
start: lowerBound,
end: k,
reverse: true,
},
cp: db.getSnapshot(),
}
it.init()
return it
}
func (db *RBT) getSnapshot() arena.MemDBCheckpoint {
func (db *RBT) getSnapshotCheckpoint() arena.MemDBCheckpoint {
if len(db.stages) > 0 {
return db.stages[0]
}
return db.vlog.Checkpoint()
}
type rbtSnapGetter struct {
db *RBT
cp arena.MemDBCheckpoint
// GetSnapshot returns a snapshot of MemBuffer.
func (db *RBT) GetSnapshot() *Snapshot {
return &Snapshot{
db: db,
cp: db.getSnapshotCheckpoint(),
}
}
func (snap *rbtSnapGetter) Get(ctx context.Context, key []byte) ([]byte, error) {
// SnapshotIter returns an Iterator for a snapshot of MemBuffer.
func (snap *Snapshot) SnapshotIter(start, end []byte) *rbtSnapIter {
it := &rbtSnapIter{
RBTIterator: &RBTIterator{
db: snap.db,
start: start,
end: end,
},
cp: snap.cp,
}
it.init()
return it
}
// SnapshotIterReverse returns a reverse Iterator for a snapshot of MemBuffer.
func (snap *Snapshot) SnapshotIterReverse(k, lowerBound []byte) *rbtSnapIter {
it := &rbtSnapIter{
RBTIterator: &RBTIterator{
db: snap.db,
start: lowerBound,
end: k,
reverse: true,
},
cp: snap.cp,
}
it.init()
return it
}
func (snap *Snapshot) Get(ctx context.Context, key []byte) ([]byte, error) {
x := snap.db.traverse(key, false)
if x.isNull() {
return nil, tikverr.ErrNotExist

View File

@ -202,37 +202,17 @@ type MemBuffer interface {
// Any write operation to the memdb invalidates this iterator immediately after its creation.
// Attempting to use such an invalidated iterator will result in a panic.
IterReverse([]byte, []byte) (Iterator, error)
// SnapshotIter returns an Iterator for a snapshot of MemBuffer.
// Deprecated: use ForEachInSnapshotRange or BatchedSnapshotIter instead.
// Deprecated: use GetSnapshot instead.
SnapshotIter([]byte, []byte) Iterator
// SnapshotIterReverse returns a reversed Iterator for a snapshot of MemBuffer.
// Deprecated: use ForEachInSnapshotRange or BatchedSnapshotIter instead.
// Deprecated: use GetSnapshot instead.
SnapshotIterReverse([]byte, []byte) Iterator
// ForEachInSnapshotRange scans the key-value pairs in the state[0] snapshot if it exists,
// otherwise it uses the current checkpoint as snapshot.
//
// NOTE: returned kv-pairs are only valid during the iteration. If you want to use them after the iteration,
// you need to make a copy.
//
// The method is protected by a RWLock to prevent potential iterator invalidation, i.e.
// You cannot modify the MemBuffer during the iteration.
//
// Use it when you need to scan the whole range, otherwise consider using BatchedSnapshotIter.
ForEachInSnapshotRange(lower []byte, upper []byte, f func(k, v []byte) (stop bool, err error), reverse bool) error
// BatchedSnapshotIter returns an iterator of the "snapshot", namely stage[0].
// It iterates in batches and prevents iterator invalidation.
//
// Use it when you need on-demand "next", otherwise consider using ForEachInSnapshotRange.
// NOTE: you should never use it when there are no stages.
//
// The iterator becomes invalid when any operation that may modify the "snapshot",
// e.g. RevertToCheckpoint or releasing stage[0].
BatchedSnapshotIter(lower, upper []byte, reverse bool) Iterator
//SnapshotGetter returns a Getter for a snapshot of MemBuffer.
// SnapshotGetter returns a Getter for a snapshot of MemBuffer.
// Deprecated: use GetSnapshot instead.
SnapshotGetter() Getter
// InspectStage iterates all buffered keys and values in MemBuffer.
InspectStage(handle int, f func([]byte, kv.KeyFlags, []byte))
// SetEntrySizeLimit sets the size limit for each entry and total buffer.
@ -270,6 +250,41 @@ type MemBuffer interface {
FlushWait() error
// GetMetrics returns the metrics related to flushing
GetMetrics() Metrics
// GetSnapshot returns a snapshot of the MemBuffer.
// The snapshot acquired using this function represents a version without any staging data written.
// The snapshot is valid until all the exist stagings are cleaned up or released.
// e.g.
// ┌───────────┐
// │ MemBuffer │
// │ (k, v1) │
// └─────┬─────┘
// staging
// │
// │ GetSnapshot ┌────────┐
// ├─────────────►│snapshot│
// │ └────┬───┘
// set(k, v2) │
// │ │
// ┌───────▼───────┐ │
// │ MemBuffer │ │
// │ (k, v1) │ │
// │staging(k, v2) │ ┌────▼───┐ get(k) ┌──────┐
// └───────┬───────┘ │snapshot├──────────►│ v1 │
// │ └────┬───┘ └──────┘
// release │
// staging │
// │ │
// ┌─────▼─────┐ ┌────▼───┐ get(k) ┌──────┐
// │ MemBuffer │ │invalid │──────────►│error │
// │ (k, v2) │ │snapshot│ └──────┘
// └───────────┘ └────────┘
// Snapshot returned by this function is protected by an `RWLock` to ensure thread safety.
// And this snapshot can fully replace `SnapshotGetter`, `SnapshotIter`, and `SnapshotIterReverse`.
// Additionally, it provides two iteration methods: `ForEachInSnapshotRange` and `BatchedSnapshotIter`,
// which tolerate interleaving reads and writes for using them simply.
// The snapshot also verifies the snapshot sequence number to prevent reading from an invalid snapshot.
GetSnapshot() MemBufferSnapshot
}
type Metrics struct {
@ -284,3 +299,37 @@ var (
_ MemBuffer = &rbtDBWithContext{}
_ MemBuffer = &artDBWithContext{}
)
type memdbSnapshot interface {
Getter
NewSnapshotIterator(start, end []byte, desc bool) Iterator
}
type MemBufferSnapshot interface {
Getter
// ForEachInSnapshotRange scans the key-value pairs in the state[0] snapshot if it exists,
// otherwise it uses the current checkpoint as snapshot.
//
// NOTE: returned kv-pairs are only valid during the iteration. If you want to use them after the iteration,
// you need to make a copy.
//
// The method is protected by a RWLock to prevent potential iterator invalidation, i.e.
// You cannot modify the MemBuffer during the iteration.
//
// Use it when you need to scan the whole range, otherwise consider using BatchedSnapshotIter.
ForEachInSnapshotRange(lower []byte, upper []byte, f func(k, v []byte) (stop bool, err error), reverse bool) error
// BatchedSnapshotIter returns an iterator of the "snapshot", namely stage[0].
// It iterates in batches and prevents iterator invalidation.
//
// Use it when you need on-demand "next", otherwise consider using ForEachInSnapshotRange.
// NOTE: you should never use it when there are no stages.
//
// The iterator becomes invalid when any operation that may modify the "snapshot",
// e.g. RevertToCheckpoint or releasing stage[0].
BatchedSnapshotIter(lower, upper []byte, reverse bool) Iterator
// Close releases the snapshot.
Close()
}

View File

@ -60,3 +60,6 @@ type MemDBCheckpoint = unionstore.MemDBCheckpoint
// Metrics is the metrics of unionstore.
type Metrics = unionstore.Metrics
// MemBufferSnapshot is the snapshot of MemBuffer.
type MemBufferSnapshot = unionstore.MemBufferSnapshot

View File

@ -94,3 +94,51 @@ func (b *BufferBatchGetter) BatchGet(ctx context.Context, keys [][]byte) (map[st
}
return bufferValues, nil
}
// BufferSnapshotBatchGetter is the type for BatchGet with MemBuffer.
type BufferSnapshotBatchGetter struct {
buffer BatchSnapshotBufferGetter
snapshot BatchGetter
}
// BatchSnapshotBufferGetter is the interface for BatchGet.
type BatchSnapshotBufferGetter interface {
unionstore.Getter
BatchGetter
}
// NewBufferSnapshotBatchGetter creates a new BufferBatchGetter.
func NewBufferSnapshotBatchGetter(buffer BatchSnapshotBufferGetter, snapshot BatchGetter) *BufferSnapshotBatchGetter {
return &BufferSnapshotBatchGetter{buffer: buffer, snapshot: snapshot}
}
// BatchGet gets a batch of values.
func (b *BufferSnapshotBatchGetter) BatchGet(ctx context.Context, keys [][]byte) (map[string][]byte, error) {
bufferValues, err := b.buffer.BatchGet(ctx, keys)
if err != nil {
return nil, err
}
if len(bufferValues) == 0 {
return b.snapshot.BatchGet(ctx, keys)
}
shrinkKeys := make([][]byte, 0, len(keys)-len(bufferValues))
for _, key := range keys {
val, ok := bufferValues[string(key)]
if !ok {
shrinkKeys = append(shrinkKeys, key)
continue
}
// the deleted key should be removed from the result, and also no need to snapshot read it again.
if len(val) == 0 {
delete(bufferValues, string(key))
}
}
storageValues, err := b.snapshot.BatchGet(ctx, shrinkKeys)
if err != nil {
return nil, err
}
for key, val := range storageValues {
bufferValues[key] = val
}
return bufferValues, nil
}