membuffer: refactor the memdb to support multi implementations (#1426)

ref pingcap/tidb#55287

Signed-off-by: you06 <you1474600@gmail.com>
This commit is contained in:
you06 2024-08-23 12:46:43 +09:00 committed by GitHub
parent 75e3705e58
commit 41d133b6b6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 2308 additions and 1331 deletions

View File

@ -32,19 +32,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package unionstore
package arena
import (
"encoding/binary"
"math"
"unsafe"
"github.com/tikv/client-go/v2/kv"
"go.uber.org/atomic"
)
const (
alignMask = 1<<32 - 8 // 29 bit 1 and 3 bit 0.
alignMask = 0xFFFFFFF8 // 29 bits of 1 and 3 bits of 0
nullBlockOffset = math.MaxUint32
maxBlockSize = 128 << 20
@ -52,34 +51,39 @@ const (
)
var (
nullAddr = memdbArenaAddr{math.MaxUint32, math.MaxUint32}
nullNodeAddr = memdbNodeAddr{nil, nullAddr}
endian = binary.LittleEndian
Tombstone = []byte{}
NullAddr = MemdbArenaAddr{math.MaxUint32, math.MaxUint32}
BadAddr = MemdbArenaAddr{math.MaxUint32 - 1, math.MaxUint32}
endian = binary.LittleEndian
)
type memdbArenaAddr struct {
type MemdbArenaAddr struct {
idx uint32
off uint32
}
func (addr memdbArenaAddr) isNull() bool {
func (addr MemdbArenaAddr) IsNull() bool {
// Combine all checks into a single condition
return addr == nullAddr || addr.idx == math.MaxUint32 || addr.off == math.MaxUint32
return addr == NullAddr || addr.idx == math.MaxUint32 || addr.off == math.MaxUint32
}
func (addr MemdbArenaAddr) ToHandle() MemKeyHandle {
return MemKeyHandle{idx: uint16(addr.idx), off: addr.off}
}
// store and load is used by vlog, due to pointer in vlog is not aligned.
func (addr memdbArenaAddr) store(dst []byte) {
func (addr MemdbArenaAddr) store(dst []byte) {
endian.PutUint32(dst, addr.idx)
endian.PutUint32(dst[4:], addr.off)
}
func (addr *memdbArenaAddr) load(src []byte) {
func (addr *MemdbArenaAddr) load(src []byte) {
addr.idx = endian.Uint32(src)
addr.off = endian.Uint32(src[4:])
}
type memdbArena struct {
type MemdbArena struct {
blockSize int
blocks []memdbArenaBlock
// the total size of all blocks, also the approximate memory footprint of the arena.
@ -88,7 +92,7 @@ type memdbArena struct {
memChangeHook atomic.Pointer[func()]
}
func (a *memdbArena) alloc(size int, align bool) (memdbArenaAddr, []byte) {
func (a *MemdbArena) Alloc(size int, align bool) (MemdbArenaAddr, []byte) {
if size > maxBlockSize {
panic("alloc size is larger than max block size")
}
@ -98,7 +102,7 @@ func (a *memdbArena) alloc(size int, align bool) (memdbArenaAddr, []byte) {
}
addr, data := a.allocInLastBlock(size, align)
if !addr.isNull() {
if !addr.IsNull() {
return addr, data
}
@ -106,7 +110,7 @@ func (a *memdbArena) alloc(size int, align bool) (memdbArenaAddr, []byte) {
return a.allocInLastBlock(size, align)
}
func (a *memdbArena) enlarge(allocSize, blockSize int) {
func (a *MemdbArena) enlarge(allocSize, blockSize int) {
a.blockSize = blockSize
for a.blockSize <= allocSize {
a.blockSize <<= 1
@ -119,36 +123,59 @@ func (a *memdbArena) enlarge(allocSize, blockSize int) {
buf: make([]byte, a.blockSize),
})
a.capacity += uint64(a.blockSize)
// We shall not call a.onMemChange() here, since it will make the latest block empty, which breaks a precondition
// for some operations (e.g. revertToCheckpoint)
// We shall not call a.OnMemChange() here, since it will make the latest block empty, which breaks a precondition
// for some operations (e.g. RevertToCheckpoint)
}
// onMemChange should only be called right before exiting memdb.
func (a *MemdbArena) Blocks() int {
return len(a.blocks)
}
func (a *MemdbArena) Capacity() uint64 {
return a.capacity
}
// SetMemChangeHook sets the hook function that will be called when the memory footprint of the arena changes.
func (a *MemdbArena) SetMemChangeHook(hook func()) {
a.memChangeHook.Store(&hook)
}
// MemHookSet returns whether the memory change hook is set.
func (a *MemdbArena) MemHookSet() bool {
return a.memChangeHook.Load() != nil
}
// OnMemChange should only be called right before exiting memdb.
// This is because the hook can lead to a panic, and leave memdb in an inconsistent state.
func (a *memdbArena) onMemChange() {
func (a *MemdbArena) OnMemChange() {
hook := a.memChangeHook.Load()
if hook != nil {
(*hook)()
}
}
func (a *memdbArena) allocInLastBlock(size int, align bool) (memdbArenaAddr, []byte) {
func (a *MemdbArena) allocInLastBlock(size int, align bool) (MemdbArenaAddr, []byte) {
idx := len(a.blocks) - 1
offset, data := a.blocks[idx].alloc(size, align)
if offset == nullBlockOffset {
return nullAddr, nil
return NullAddr, nil
}
return memdbArenaAddr{uint32(idx), offset}, data
return MemdbArenaAddr{uint32(idx), offset}, data
}
func (a *memdbArena) reset() {
// GetData gets data slice of given addr, DO NOT access others data.
func (a *MemdbArena) GetData(addr MemdbArenaAddr) []byte {
return a.blocks[addr.idx].buf[addr.off:]
}
func (a *MemdbArena) Reset() {
for i := range a.blocks {
a.blocks[i].reset()
}
a.blocks = a.blocks[:0]
a.blockSize = 0
a.capacity = 0
a.onMemChange()
a.OnMemChange()
}
type memdbArenaBlock struct {
@ -176,18 +203,18 @@ func (a *memdbArenaBlock) reset() {
a.length = 0
}
// MemDBCheckpoint is the checkpoint of memory DB.
// MemDBCheckpoint is the Checkpoint of memory DB.
type MemDBCheckpoint struct {
blockSize int
blocks int
offsetInBlock int
}
func (cp *MemDBCheckpoint) isSamePosition(other *MemDBCheckpoint) bool {
func (cp *MemDBCheckpoint) IsSamePosition(other *MemDBCheckpoint) bool {
return cp.blocks == other.blocks && cp.offsetInBlock == other.offsetInBlock
}
func (a *memdbArena) checkpoint() MemDBCheckpoint {
func (a *MemdbArena) Checkpoint() MemDBCheckpoint {
snap := MemDBCheckpoint{
blockSize: a.blockSize,
blocks: len(a.blocks),
@ -198,7 +225,7 @@ func (a *memdbArena) checkpoint() MemDBCheckpoint {
return snap
}
func (a *memdbArena) truncate(snap *MemDBCheckpoint) {
func (a *MemdbArena) Truncate(snap *MemDBCheckpoint) {
for i := snap.blocks; i < len(a.blocks); i++ {
a.blocks[i] = memdbArenaBlock{}
}
@ -212,203 +239,137 @@ func (a *memdbArena) truncate(snap *MemDBCheckpoint) {
for _, block := range a.blocks {
a.capacity += uint64(block.length)
}
// We shall not call a.onMemChange() here, since it may cause a panic and leave memdb in an inconsistent state
// We shall not call a.OnMemChange() here, since it may cause a panic and leave memdb in an inconsistent state
}
type nodeAllocator struct {
memdbArena
// Dummy node, so that we can make X.left.up = X.
// We then use this instead of NULL to mean the top or bottom
// end of the rb tree. It is a black node.
nullNode memdbNode
// KeyFlagsGetter is an interface to get key and key flags, usually a leaf or node.
type KeyFlagsGetter interface {
GetKey() []byte
GetKeyFlags() kv.KeyFlags
}
func (a *nodeAllocator) init() {
a.nullNode = memdbNode{
up: nullAddr,
left: nullAddr,
right: nullAddr,
vptr: nullAddr,
}
// VlogMemDB is the interface of the memory buffer which supports vlog to revert node and inspect node.
type VlogMemDB[G KeyFlagsGetter] interface {
RevertNode(hdr *MemdbVlogHdr)
InspectNode(addr MemdbArenaAddr) (G, MemdbArenaAddr)
}
func (a *nodeAllocator) getNode(addr memdbArenaAddr) *memdbNode {
if addr.isNull() {
return &a.nullNode
}
return (*memdbNode)(unsafe.Pointer(&a.blocks[addr.idx].buf[addr.off]))
}
func (a *nodeAllocator) allocNode(key []byte) (memdbArenaAddr, *memdbNode) {
nodeSize := 8*4 + 2 + kv.FlagBytes + len(key)
prevBlocks := len(a.blocks)
addr, mem := a.alloc(nodeSize, true)
n := (*memdbNode)(unsafe.Pointer(&mem[0]))
n.vptr = nullAddr
n.klen = uint16(len(key))
copy(n.getKey(), key)
if prevBlocks != len(a.blocks) {
a.onMemChange()
}
return addr, n
}
var testMode = false
func (a *nodeAllocator) freeNode(addr memdbArenaAddr) {
if testMode {
// Make it easier for debug.
n := a.getNode(addr)
badAddr := nullAddr
badAddr.idx--
n.left = badAddr
n.right = badAddr
n.up = badAddr
n.vptr = badAddr
return
}
// TODO: reuse freed nodes. Need to fix lastTraversedNode when implementing this.
}
func (a *nodeAllocator) reset() {
a.memdbArena.reset()
a.init()
}
type memdbVlog struct {
memdbArena
memdb *MemDB
type MemdbVlog[G KeyFlagsGetter, M VlogMemDB[G]] struct {
MemdbArena
}
const memdbVlogHdrSize = 8 + 8 + 4
type memdbVlogHdr struct {
nodeAddr memdbArenaAddr
oldValue memdbArenaAddr
valueLen uint32
type MemdbVlogHdr struct {
NodeAddr MemdbArenaAddr
OldValue MemdbArenaAddr
ValueLen uint32
}
func (hdr *memdbVlogHdr) store(dst []byte) {
func (hdr *MemdbVlogHdr) store(dst []byte) {
cursor := 0
endian.PutUint32(dst[cursor:], hdr.valueLen)
endian.PutUint32(dst[cursor:], hdr.ValueLen)
cursor += 4
hdr.oldValue.store(dst[cursor:])
hdr.OldValue.store(dst[cursor:])
cursor += 8
hdr.nodeAddr.store(dst[cursor:])
hdr.NodeAddr.store(dst[cursor:])
}
func (hdr *memdbVlogHdr) load(src []byte) {
func (hdr *MemdbVlogHdr) load(src []byte) {
cursor := 0
hdr.valueLen = endian.Uint32(src[cursor:])
hdr.ValueLen = endian.Uint32(src[cursor:])
cursor += 4
hdr.oldValue.load(src[cursor:])
hdr.OldValue.load(src[cursor:])
cursor += 8
hdr.nodeAddr.load(src[cursor:])
hdr.NodeAddr.load(src[cursor:])
}
func (l *memdbVlog) appendValue(nodeAddr memdbArenaAddr, oldValue memdbArenaAddr, value []byte) memdbArenaAddr {
// AppendValue appends a value and it's vlog header to the vlog.
func (l *MemdbVlog[G, M]) AppendValue(nodeAddr MemdbArenaAddr, oldValue MemdbArenaAddr, value []byte) MemdbArenaAddr {
size := memdbVlogHdrSize + len(value)
prevBlocks := len(l.blocks)
addr, mem := l.alloc(size, false)
addr, mem := l.Alloc(size, false)
copy(mem, value)
hdr := memdbVlogHdr{nodeAddr, oldValue, uint32(len(value))}
hdr := MemdbVlogHdr{nodeAddr, oldValue, uint32(len(value))}
hdr.store(mem[len(value):])
addr.off += uint32(size)
if prevBlocks != len(l.blocks) {
l.onMemChange()
l.OnMemChange()
}
return addr
}
// A pure function that gets a value.
func (l *memdbVlog) getValue(addr memdbArenaAddr) []byte {
// GetValue is a pure function that gets a value.
func (l *MemdbVlog[G, M]) GetValue(addr MemdbArenaAddr) []byte {
lenOff := addr.off - memdbVlogHdrSize
block := l.blocks[addr.idx].buf
valueLen := endian.Uint32(block[lenOff:])
if valueLen == 0 {
return tombstone
return Tombstone
}
valueOff := lenOff - valueLen
return block[valueOff:lenOff:lenOff]
}
func (l *memdbVlog) getSnapshotValue(addr memdbArenaAddr, snap *MemDBCheckpoint) ([]byte, bool) {
result := l.selectValueHistory(addr, func(addr memdbArenaAddr) bool {
return !l.canModify(snap, addr)
func (l *MemdbVlog[G, M]) GetSnapshotValue(addr MemdbArenaAddr, snap *MemDBCheckpoint) ([]byte, bool) {
result := l.SelectValueHistory(addr, func(addr MemdbArenaAddr) bool {
return !l.CanModify(snap, addr)
})
if result.isNull() {
if result.IsNull() {
return nil, false
}
return l.getValue(result), true
return l.GetValue(result), true
}
func (l *memdbVlog) selectValueHistory(addr memdbArenaAddr, predicate func(memdbArenaAddr) bool) memdbArenaAddr {
for !addr.isNull() {
func (l *MemdbVlog[G, M]) SelectValueHistory(addr MemdbArenaAddr, predicate func(MemdbArenaAddr) bool) MemdbArenaAddr {
for !addr.IsNull() {
if predicate(addr) {
return addr
}
var hdr memdbVlogHdr
var hdr MemdbVlogHdr
hdr.load(l.blocks[addr.idx].buf[addr.off-memdbVlogHdrSize:])
addr = hdr.oldValue
addr = hdr.OldValue
}
return nullAddr
return NullAddr
}
func (l *memdbVlog) revertToCheckpoint(db *MemDB, cp *MemDBCheckpoint) {
cursor := l.checkpoint()
for !cp.isSamePosition(&cursor) {
func (l *MemdbVlog[G, M]) RevertToCheckpoint(m M, cp *MemDBCheckpoint) {
cursor := l.Checkpoint()
for !cp.IsSamePosition(&cursor) {
hdrOff := cursor.offsetInBlock - memdbVlogHdrSize
block := l.blocks[cursor.blocks-1].buf
var hdr memdbVlogHdr
var hdr MemdbVlogHdr
hdr.load(block[hdrOff:])
node := db.getNode(hdr.nodeAddr)
node.vptr = hdr.oldValue
db.size -= int(hdr.valueLen)
// oldValue.isNull() == true means this is a newly added value.
if hdr.oldValue.isNull() {
// If there are no flags associated with this key, we need to delete this node.
keptFlags := node.getKeyFlags().AndPersistent()
if keptFlags == 0 {
db.deleteNode(node)
} else {
node.setKeyFlags(keptFlags)
db.dirty = true
}
} else {
db.size += len(l.getValue(hdr.oldValue))
}
m.RevertNode(&hdr)
l.moveBackCursor(&cursor, &hdr)
}
}
func (l *memdbVlog) inspectKVInLog(db *MemDB, head, tail *MemDBCheckpoint, f func([]byte, kv.KeyFlags, []byte)) {
func (l *MemdbVlog[G, M]) InspectKVInLog(m M, head, tail *MemDBCheckpoint, f func([]byte, kv.KeyFlags, []byte)) {
cursor := *tail
for !head.isSamePosition(&cursor) {
cursorAddr := memdbArenaAddr{idx: uint32(cursor.blocks - 1), off: uint32(cursor.offsetInBlock)}
for !head.IsSamePosition(&cursor) {
cursorAddr := MemdbArenaAddr{idx: uint32(cursor.blocks - 1), off: uint32(cursor.offsetInBlock)}
hdrOff := cursorAddr.off - memdbVlogHdrSize
block := l.blocks[cursorAddr.idx].buf
var hdr memdbVlogHdr
var hdr MemdbVlogHdr
hdr.load(block[hdrOff:])
node := db.allocator.getNode(hdr.nodeAddr)
node, vptr := m.InspectNode(hdr.NodeAddr)
// Skip older versions.
if node.vptr == cursorAddr {
value := block[hdrOff-hdr.valueLen : hdrOff]
f(node.getKey(), node.getKeyFlags(), value)
if vptr == cursorAddr {
value := block[hdrOff-hdr.ValueLen : hdrOff]
f(node.GetKey(), node.GetKeyFlags(), value)
}
l.moveBackCursor(&cursor, &hdr)
}
}
func (l *memdbVlog) moveBackCursor(cursor *MemDBCheckpoint, hdr *memdbVlogHdr) {
cursor.offsetInBlock -= (memdbVlogHdrSize + int(hdr.valueLen))
func (l *MemdbVlog[G, M]) moveBackCursor(cursor *MemDBCheckpoint, hdr *MemdbVlogHdr) {
cursor.offsetInBlock -= (memdbVlogHdrSize + int(hdr.ValueLen))
if cursor.offsetInBlock == 0 {
cursor.blocks--
if cursor.blocks > 0 {
@ -417,7 +378,7 @@ func (l *memdbVlog) moveBackCursor(cursor *MemDBCheckpoint, hdr *memdbVlogHdr) {
}
}
func (l *memdbVlog) canModify(cp *MemDBCheckpoint, addr memdbArenaAddr) bool {
func (l *MemdbVlog[G, M]) CanModify(cp *MemDBCheckpoint, addr MemdbArenaAddr) bool {
if cp == nil {
return true
}
@ -429,3 +390,15 @@ func (l *memdbVlog) canModify(cp *MemDBCheckpoint, addr memdbArenaAddr) bool {
}
return false
}
// MemKeyHandle represents a pointer for key in MemBuffer.
type MemKeyHandle struct {
// Opaque user data
UserData uint16
idx uint16
off uint32
}
func (h MemKeyHandle) ToAddr() MemdbArenaAddr {
return MemdbArenaAddr{idx: uint32(h.idx), off: h.off}
}

View File

@ -0,0 +1,79 @@
// Copyright 2021 TiKV Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// NOTE: The code in this file is based on code from the
// TiDB project, licensed under the Apache License v 2.0
//
// https://github.com/pingcap/tidb/tree/cc5e161ac06827589c4966674597c137cc9e809c/store/tikv/unionstore/memdb_arena.go
//
// Copyright 2020 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package arena
import (
"testing"
"github.com/stretchr/testify/assert"
)
type dummyMemDB struct{}
func (m *dummyMemDB) RevertNode(hdr *MemdbVlogHdr) {}
func (m *dummyMemDB) InspectNode(addr MemdbArenaAddr) (KeyFlagsGetter, MemdbArenaAddr) {
return nil, NullAddr
}
func TestBigValue(t *testing.T) {
assert := assert.New(t)
var vlog MemdbVlog[KeyFlagsGetter, *dummyMemDB]
vlog.AppendValue(MemdbArenaAddr{0, 0}, NullAddr, make([]byte, 80<<20))
assert.Equal(vlog.blockSize, maxBlockSize)
assert.Equal(len(vlog.blocks), 1)
cp := vlog.Checkpoint()
vlog.AppendValue(MemdbArenaAddr{0, 1}, NullAddr, make([]byte, 127<<20))
vlog.RevertToCheckpoint(&dummyMemDB{}, &cp)
assert.Equal(vlog.blockSize, maxBlockSize)
assert.Equal(len(vlog.blocks), 2)
assert.PanicsWithValue("alloc size is larger than max block size", func() {
vlog.AppendValue(MemdbArenaAddr{0, 2}, NullAddr, make([]byte, maxBlockSize+1))
})
}
func TestValueLargeThanBlock(t *testing.T) {
assert := assert.New(t)
var vlog MemdbVlog[KeyFlagsGetter, *dummyMemDB]
vlog.AppendValue(MemdbArenaAddr{0, 0}, NullAddr, make([]byte, 1))
vlog.AppendValue(MemdbArenaAddr{0, 1}, NullAddr, make([]byte, 4096))
assert.Equal(len(vlog.blocks), 2)
vAddr := vlog.AppendValue(MemdbArenaAddr{0, 2}, NullAddr, make([]byte, 3000))
assert.Equal(len(vlog.blocks), 2)
val := vlog.GetValue(vAddr)
assert.Equal(len(val), 3000)
}

View File

@ -0,0 +1,176 @@
// Copyright 2024 TiKV Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//nolint:unused
package art
import (
"math"
"github.com/tikv/client-go/v2/internal/unionstore/arena"
"github.com/tikv/client-go/v2/kv"
)
type ART struct {
allocator artAllocator
root artNode
stages []arena.MemDBCheckpoint
vlogInvalid bool
dirty bool
entrySizeLimit uint64
bufferSizeLimit uint64
len int
size int
}
func New() *ART {
var t ART
t.root = nullArtNode
t.stages = make([]arena.MemDBCheckpoint, 0, 2)
t.entrySizeLimit = math.MaxUint64
t.bufferSizeLimit = math.MaxUint64
t.allocator.nodeAllocator.freeNode4 = make([]arena.MemdbArenaAddr, 0, 1<<4)
t.allocator.nodeAllocator.freeNode16 = make([]arena.MemdbArenaAddr, 0, 1<<3)
t.allocator.nodeAllocator.freeNode48 = make([]arena.MemdbArenaAddr, 0, 1<<2)
return &t
}
func (t *ART) Get(key []byte) ([]byte, error) {
panic("unimplemented")
}
// GetFlags returns the latest flags associated with key.
func (t *ART) GetFlags(key []byte) (kv.KeyFlags, error) {
panic("unimplemented")
}
func (t *ART) Set(key artKey, value []byte, ops []kv.FlagsOp) error {
panic("unimplemented")
}
func (t *ART) search(key artKey) (arena.MemdbArenaAddr, *artLeaf) {
panic("unimplemented")
}
func (t *ART) Dirty() bool {
panic("unimplemented")
}
// Mem returns the memory usage of MemBuffer.
func (t *ART) Mem() uint64 {
panic("unimplemented")
}
// Len returns the count of entries in the MemBuffer.
func (t *ART) Len() int {
panic("unimplemented")
}
// Size returns the size of the MemBuffer.
func (t *ART) Size() int {
panic("unimplemented")
}
func (t *ART) checkpoint() arena.MemDBCheckpoint {
panic("unimplemented")
}
func (t *ART) RevertNode(hdr *arena.MemdbVlogHdr) {
panic("unimplemented")
}
func (t *ART) InspectNode(addr arena.MemdbArenaAddr) (*artLeaf, arena.MemdbArenaAddr) {
panic("unimplemented")
}
// Checkpoint returns a checkpoint of ART.
func (t *ART) Checkpoint() *arena.MemDBCheckpoint {
panic("unimplemented")
}
// RevertToCheckpoint reverts the ART to the checkpoint.
func (t *ART) RevertToCheckpoint(cp *arena.MemDBCheckpoint) {
panic("unimplemented")
}
func (t *ART) Stages() []arena.MemDBCheckpoint {
panic("unimplemented")
}
func (t *ART) Staging() int {
panic("unimplemented")
}
func (t *ART) Release(h int) {
panic("unimplemented")
}
func (t *ART) Cleanup(h int) {
panic("unimplemented")
}
func (t *ART) revertToCheckpoint(cp *arena.MemDBCheckpoint) {
panic("unimplemented")
}
func (t *ART) moveBackCursor(cursor *arena.MemDBCheckpoint, hdr *arena.MemdbVlogHdr) {
panic("unimplemented")
}
func (t *ART) truncate(snap *arena.MemDBCheckpoint) {
panic("unimplemented")
}
// DiscardValues releases the memory used by all values.
// NOTE: any operation need value will panic after this function.
func (t *ART) DiscardValues() {
panic("unimplemented")
}
// InspectStage used to inspect the value updates in the given stage.
func (t *ART) InspectStage(handle int, f func([]byte, kv.KeyFlags, []byte)) {
panic("unimplemented")
}
// SelectValueHistory select the latest value which makes `predicate` returns true from the modification history.
func (t *ART) SelectValueHistory(key []byte, predicate func(value []byte) bool) ([]byte, error) {
panic("unimplemented")
}
func (t *ART) SetMemoryFootprintChangeHook(fn func(uint64)) {
panic("unimplemented")
}
// MemHookSet implements the MemBuffer interface.
func (t *ART) MemHookSet() bool {
panic("unimplemented")
}
// GetKeyByHandle returns key by handle.
func (t *ART) GetKeyByHandle(handle arena.MemKeyHandle) []byte {
panic("unimplemented")
}
// GetValueByHandle returns value by handle.
func (t *ART) GetValueByHandle(handle arena.MemKeyHandle) ([]byte, bool) {
panic("unimplemented")
}
func (t *ART) SetEntrySizeLimit(entryLimit, bufferLimit uint64) {
panic("unimplemented")
}
func (t *ART) RemoveFromBuffer(key []byte) {
panic("unimplemented")
}

View File

@ -0,0 +1,35 @@
// Copyright 2024 TiKV Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//nolint:unused
package art
import (
"github.com/tikv/client-go/v2/internal/unionstore/arena"
)
// fixedSizeArena is a fixed size arena allocator.
// because the size of each type of node is fixed, the discarded nodes can be reused.
// reusing blocks reduces the memory pieces.
type nodeArena struct {
arena.MemdbArena
freeNode4 []arena.MemdbArenaAddr
freeNode16 []arena.MemdbArenaAddr
freeNode48 []arena.MemdbArenaAddr
}
type artAllocator struct {
vlogAllocator arena.MemdbVlog[*artLeaf, *ART]
nodeAllocator nodeArena
}

View File

@ -0,0 +1,31 @@
// Copyright 2024 TiKV Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package art
func (*ART) Iter([]byte, []byte) (*Iterator, error) {
panic("unimplemented")
}
func (*ART) IterReverse([]byte, []byte) (*Iterator, error) {
panic("unimplemented")
}
type Iterator struct{}
func (i *Iterator) Valid() bool { panic("unimplemented") }
func (i *Iterator) Key() []byte { panic("unimplemented") }
func (i *Iterator) Value() []byte { panic("unimplemented") }
func (i *Iterator) Next() error { panic("unimplemented") }
func (i *Iterator) Close() { panic("unimplemented") }

View File

@ -0,0 +1,58 @@
// Copyright 2024 TiKV Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//nolint:unused
package art
import (
"github.com/tikv/client-go/v2/internal/unionstore/arena"
"github.com/tikv/client-go/v2/kv"
)
type artNodeKind uint16
const (
typeARTInvalid artNodeKind = 0
//nolint:unused
typeARTNode4 artNodeKind = 1
typeARTNode16 artNodeKind = 2
typeARTNode48 artNodeKind = 3
typeARTNode256 artNodeKind = 4
typeARTLeaf artNodeKind = 5
)
var nullArtNode = artNode{kind: typeARTInvalid, addr: arena.NullAddr}
type artKey []byte
type artNode struct {
kind artNodeKind
addr arena.MemdbArenaAddr
}
type artLeaf struct {
vAddr arena.MemdbArenaAddr
klen uint16
flags uint16
}
// GetKey gets the full key of the leaf
func (l *artLeaf) GetKey() []byte {
panic("unimplemented")
}
// GetKeyFlags gets the flags of the leaf
func (l *artLeaf) GetKeyFlags() kv.KeyFlags {
panic("unimplemented")
}

View File

@ -0,0 +1,43 @@
// Copyright 2024 TiKV Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package art
import "context"
func (*ART) SnapshotGetter() *SnapshotGetter {
panic("unimplemented")
}
func (*ART) SnapshotIter([]byte, []byte) *SnapshotIter {
panic("unimplemented")
}
func (*ART) SnapshotIterReverse([]byte, []byte) *SnapshotIter {
panic("unimplemented")
}
type SnapshotGetter struct{}
func (s *SnapshotGetter) Get(context.Context, []byte) ([]byte, error) {
panic("unimplemented")
}
type SnapshotIter struct{}
func (i *SnapshotIter) Valid() bool { panic("unimplemented") }
func (i *SnapshotIter) Key() []byte { panic("unimplemented") }
func (i *SnapshotIter) Value() []byte { panic("unimplemented") }
func (i *SnapshotIter) Next() error { panic("unimplemented") }
func (i *SnapshotIter) Close() { panic("unimplemented") }

View File

@ -35,919 +35,21 @@
package unionstore
import (
"bytes"
"fmt"
"math"
"sync"
"sync/atomic"
"unsafe"
tikverr "github.com/tikv/client-go/v2/error"
"github.com/tikv/client-go/v2/kv"
"github.com/tikv/client-go/v2/internal/unionstore/arena"
)
var tombstone = []byte{}
const unlimitedSize = math.MaxUint64
// IsTombstone returns whether the value is a tombstone.
func IsTombstone(val []byte) bool { return len(val) == 0 }
// MemKeyHandle represents a pointer for key in MemBuffer.
type MemKeyHandle struct {
// Opaque user data
UserData uint16
idx uint16
off uint32
}
type MemDBCheckpoint = arena.MemDBCheckpoint
func (h MemKeyHandle) toAddr() memdbArenaAddr {
return memdbArenaAddr{idx: uint32(h.idx), off: h.off}
}
type MemKeyHandle = arena.MemKeyHandle
// MemDB is rollbackable Red-Black Tree optimized for TiDB's transaction states buffer use scenario.
// You can think MemDB is a combination of two separate tree map, one for key => value and another for key => keyFlags.
//
// The value map is rollbackable, that means you can use the `Staging`, `Release` and `Cleanup` API to safely modify KVs.
//
// The flags map is not rollbackable. There are two types of flag, persistent and non-persistent.
// When discarding a newly added KV in `Cleanup`, the non-persistent flags will be cleared.
// If there are persistent flags associated with key, we will keep this key in node without value.
type MemDB struct {
// This RWMutex only used to ensure memdbSnapGetter.Get will not race with
// concurrent memdb.Set, memdb.SetWithFlags, memdb.Delete and memdb.UpdateFlags.
sync.RWMutex
root memdbArenaAddr
allocator nodeAllocator
vlog memdbVlog
type MemDB = rbtDBWithContext
entrySizeLimit uint64
bufferSizeLimit uint64
count int
size int
vlogInvalid bool
dirty bool
stages []MemDBCheckpoint
// when the MemDB is wrapper by upper RWMutex, we can skip the internal mutex.
skipMutex bool
// The lastTraversedNode must exist
lastTraversedNode atomic.Pointer[memdbNodeAddr]
hitCount atomic.Uint64
missCount atomic.Uint64
}
const unlimitedSize = math.MaxUint64
func newMemDB() *MemDB {
db := new(MemDB)
db.allocator.init()
db.root = nullAddr
db.stages = make([]MemDBCheckpoint, 0, 2)
db.entrySizeLimit = unlimitedSize
db.bufferSizeLimit = unlimitedSize
db.vlog.memdb = db
db.skipMutex = false
db.lastTraversedNode.Store(&nullNodeAddr)
return db
}
// updateLastTraversed updates the last traversed node atomically
func (db *MemDB) updateLastTraversed(node memdbNodeAddr) {
db.lastTraversedNode.Store(&node)
}
// checkKeyInCache retrieves the last traversed node if the key matches
func (db *MemDB) checkKeyInCache(key []byte) (memdbNodeAddr, bool) {
nodePtr := db.lastTraversedNode.Load()
if nodePtr == nil || nodePtr.isNull() {
return nullNodeAddr, false
}
if bytes.Equal(key, nodePtr.memdbNode.getKey()) {
return *nodePtr, true
}
return nullNodeAddr, false
}
// Staging create a new staging buffer inside the MemBuffer.
// Subsequent writes will be temporarily stored in this new staging buffer.
// When you think all modifications looks good, you can call `Release` to public all of them to the upper level buffer.
func (db *MemDB) Staging() int {
if !db.skipMutex {
db.Lock()
defer db.Unlock()
}
db.stages = append(db.stages, db.vlog.checkpoint())
return len(db.stages)
}
// Release publish all modifications in the latest staging buffer to upper level.
func (db *MemDB) Release(h int) {
if !db.skipMutex {
db.Lock()
defer db.Unlock()
}
if h != len(db.stages) {
// This should never happens in production environment.
// Use panic to make debug easier.
panic("cannot release staging buffer")
}
if h == 1 {
tail := db.vlog.checkpoint()
if !db.stages[0].isSamePosition(&tail) {
db.dirty = true
}
}
db.stages = db.stages[:h-1]
}
// Cleanup cleanup the resources referenced by the StagingHandle.
// If the changes are not published by `Release`, they will be discarded.
func (db *MemDB) Cleanup(h int) {
if !db.skipMutex {
db.Lock()
defer db.Unlock()
}
if h > len(db.stages) {
return
}
if h < len(db.stages) {
// This should never happens in production environment.
// Use panic to make debug easier.
panic(fmt.Sprintf("cannot cleanup staging buffer, h=%v, len(db.stages)=%v", h, len(db.stages)))
}
cp := &db.stages[h-1]
if !db.vlogInvalid {
curr := db.vlog.checkpoint()
if !curr.isSamePosition(cp) {
db.vlog.revertToCheckpoint(db, cp)
db.vlog.truncate(cp)
}
}
db.stages = db.stages[:h-1]
db.vlog.onMemChange()
}
// Checkpoint returns a checkpoint of MemDB.
func (db *MemDB) Checkpoint() *MemDBCheckpoint {
cp := db.vlog.checkpoint()
return &cp
}
// RevertToCheckpoint reverts the MemDB to the checkpoint.
func (db *MemDB) RevertToCheckpoint(cp *MemDBCheckpoint) {
db.vlog.revertToCheckpoint(db, cp)
db.vlog.truncate(cp)
db.vlog.onMemChange()
}
// Reset resets the MemBuffer to initial states.
func (db *MemDB) Reset() {
db.root = nullAddr
db.stages = db.stages[:0]
db.dirty = false
db.vlogInvalid = false
db.size = 0
db.count = 0
db.vlog.reset()
db.allocator.reset()
}
// DiscardValues releases the memory used by all values.
// NOTE: any operation need value will panic after this function.
func (db *MemDB) DiscardValues() {
db.vlogInvalid = true
db.vlog.reset()
}
// InspectStage used to inspect the value updates in the given stage.
func (db *MemDB) InspectStage(handle int, f func([]byte, kv.KeyFlags, []byte)) {
idx := handle - 1
tail := db.vlog.checkpoint()
head := db.stages[idx]
db.vlog.inspectKVInLog(db, &head, &tail, f)
}
// Get gets the value for key k from kv store.
// If corresponding kv pair does not exist, it returns nil and ErrNotExist.
func (db *MemDB) Get(key []byte) ([]byte, error) {
if db.vlogInvalid {
// panic for easier debugging.
panic("vlog is resetted")
}
x := db.traverse(key, false)
if x.isNull() {
return nil, tikverr.ErrNotExist
}
if x.vptr.isNull() {
// A flag only key, act as value not exists
return nil, tikverr.ErrNotExist
}
return db.vlog.getValue(x.vptr), nil
}
// SelectValueHistory select the latest value which makes `predicate` returns true from the modification history.
func (db *MemDB) SelectValueHistory(key []byte, predicate func(value []byte) bool) ([]byte, error) {
x := db.traverse(key, false)
if x.isNull() {
return nil, tikverr.ErrNotExist
}
if x.vptr.isNull() {
// A flag only key, act as value not exists
return nil, tikverr.ErrNotExist
}
result := db.vlog.selectValueHistory(x.vptr, func(addr memdbArenaAddr) bool {
return predicate(db.vlog.getValue(addr))
})
if result.isNull() {
return nil, nil
}
return db.vlog.getValue(result), nil
}
// GetFlags returns the latest flags associated with key.
func (db *MemDB) GetFlags(key []byte) (kv.KeyFlags, error) {
x := db.traverse(key, false)
if x.isNull() {
return 0, tikverr.ErrNotExist
}
return x.getKeyFlags(), nil
}
// UpdateFlags update the flags associated with key.
func (db *MemDB) UpdateFlags(key []byte, ops ...kv.FlagsOp) {
err := db.set(key, nil, ops...)
_ = err // set without value will never fail
}
// Set sets the value for key k as v into kv store.
// v must NOT be nil or empty, otherwise it returns ErrCannotSetNilValue.
func (db *MemDB) Set(key []byte, value []byte) error {
if len(value) == 0 {
return tikverr.ErrCannotSetNilValue
}
return db.set(key, value)
}
// SetWithFlags put key-value into the last active staging buffer with the given KeyFlags.
func (db *MemDB) SetWithFlags(key []byte, value []byte, ops ...kv.FlagsOp) error {
if len(value) == 0 {
return tikverr.ErrCannotSetNilValue
}
return db.set(key, value, ops...)
}
// Delete removes the entry for key k from kv store.
func (db *MemDB) Delete(key []byte) error {
return db.set(key, tombstone)
}
// DeleteWithFlags delete key with the given KeyFlags
func (db *MemDB) DeleteWithFlags(key []byte, ops ...kv.FlagsOp) error {
return db.set(key, tombstone, ops...)
}
// GetKeyByHandle returns key by handle.
func (db *MemDB) GetKeyByHandle(handle MemKeyHandle) []byte {
x := db.getNode(handle.toAddr())
return x.getKey()
}
// GetValueByHandle returns value by handle.
func (db *MemDB) GetValueByHandle(handle MemKeyHandle) ([]byte, bool) {
if db.vlogInvalid {
return nil, false
}
x := db.getNode(handle.toAddr())
if x.vptr.isNull() {
return nil, false
}
return db.vlog.getValue(x.vptr), true
}
// Len returns the number of entries in the DB.
func (db *MemDB) Len() int {
return db.count
}
// Size returns sum of keys and values length.
func (db *MemDB) Size() int {
return db.size
}
// Dirty returns whether the root staging buffer is updated.
func (db *MemDB) Dirty() bool {
return db.dirty
}
func (db *MemDB) set(key []byte, value []byte, ops ...kv.FlagsOp) error {
if !db.skipMutex {
db.Lock()
defer db.Unlock()
}
if db.vlogInvalid {
// panic for easier debugging.
panic("vlog is reset")
}
if value != nil {
if size := uint64(len(key) + len(value)); size > db.entrySizeLimit {
return &tikverr.ErrEntryTooLarge{
Limit: db.entrySizeLimit,
Size: size,
}
}
}
if len(db.stages) == 0 {
db.dirty = true
}
x := db.traverse(key, true)
// the NeedConstraintCheckInPrewrite flag is temporary,
// every write to the node removes the flag unless it's explicitly set.
// This set must be in the latest stage so no special processing is needed.
var flags kv.KeyFlags
if value != nil {
flags = kv.ApplyFlagsOps(x.getKeyFlags(), append([]kv.FlagsOp{kv.DelNeedConstraintCheckInPrewrite}, ops...)...)
} else {
// an UpdateFlag operation, do not delete the NeedConstraintCheckInPrewrite flag.
flags = kv.ApplyFlagsOps(x.getKeyFlags(), ops...)
}
if flags.AndPersistent() != 0 {
db.dirty = true
}
x.setKeyFlags(flags)
if value == nil {
return nil
}
db.setValue(x, value)
if uint64(db.Size()) > db.bufferSizeLimit {
return &tikverr.ErrTxnTooLarge{Size: db.Size()}
}
return nil
}
func (db *MemDB) setValue(x memdbNodeAddr, value []byte) {
var activeCp *MemDBCheckpoint
if len(db.stages) > 0 {
activeCp = &db.stages[len(db.stages)-1]
}
var oldVal []byte
if !x.vptr.isNull() {
oldVal = db.vlog.getValue(x.vptr)
}
if len(oldVal) > 0 && db.vlog.canModify(activeCp, x.vptr) {
// For easier to implement, we only consider this case.
// It is the most common usage in TiDB's transaction buffers.
if len(oldVal) == len(value) {
copy(oldVal, value)
return
}
}
x.vptr = db.vlog.appendValue(x.addr, x.vptr, value)
db.size = db.size - len(oldVal) + len(value)
}
// traverse search for and if not found and insert is true, will add a new node in.
// Returns a pointer to the new node, or the node found.
func (db *MemDB) traverse(key []byte, insert bool) memdbNodeAddr {
if node, found := db.checkKeyInCache(key); found {
db.hitCount.Add(1)
return node
}
db.missCount.Add(1)
x := db.getRoot()
y := memdbNodeAddr{nil, nullAddr}
found := false
// walk x down the tree
for !x.isNull() && !found {
cmp := bytes.Compare(key, x.getKey())
if cmp < 0 {
if insert && x.left.isNull() {
y = x
}
x = x.getLeft(db)
} else if cmp > 0 {
if insert && x.right.isNull() {
y = x
}
x = x.getRight(db)
} else {
found = true
}
}
if found {
db.updateLastTraversed(x)
}
if found || !insert {
return x
}
z := db.allocNode(key)
z.up = y.addr
if y.isNull() {
db.root = z.addr
} else {
cmp := bytes.Compare(z.getKey(), y.getKey())
if cmp < 0 {
y.left = z.addr
} else {
y.right = z.addr
}
}
z.left = nullAddr
z.right = nullAddr
// colour this new node red
z.setRed()
// Having added a red node, we must now walk back up the tree balancing it,
// by a series of rotations and changing of colours
x = z
// While we are not at the top and our parent node is red
// NOTE: Since the root node is guaranteed black, then we
// are also going to stop if we are the child of the root
for x.addr != db.root {
xUp := x.getUp(db)
if xUp.isBlack() {
break
}
xUpUp := xUp.getUp(db)
// if our parent is on the left side of our grandparent
if x.up == xUpUp.left {
// get the right side of our grandparent (uncle?)
y = xUpUp.getRight(db)
if y.isRed() {
// make our parent black
xUp.setBlack()
// make our uncle black
y.setBlack()
// make our grandparent red
xUpUp.setRed()
// now consider our grandparent
x = xUp.getUp(db)
} else {
// if we are on the right side of our parent
if x.addr == xUp.right {
// Move up to our parent
x = x.getUp(db)
db.leftRotate(x)
xUp = x.getUp(db)
xUpUp = xUp.getUp(db)
}
xUp.setBlack()
xUpUp.setRed()
db.rightRotate(xUpUp)
}
} else {
// everything here is the same as above, but exchanging left for right
y = xUpUp.getLeft(db)
if y.isRed() {
xUp.setBlack()
y.setBlack()
xUpUp.setRed()
x = xUp.getUp(db)
} else {
if x.addr == xUp.left {
x = x.getUp(db)
db.rightRotate(x)
xUp = x.getUp(db)
xUpUp = xUp.getUp(db)
}
xUp.setBlack()
xUpUp.setRed()
db.leftRotate(xUpUp)
}
}
}
// Set the root node black
db.getRoot().setBlack()
db.updateLastTraversed(z)
return z
}
//
// Rotate our tree thus:-
//
// X leftRotate(X)---> Y
// / \ / \
// A Y <---rightRotate(Y) X C
// / \ / \
// B C A B
//
// NOTE: This does not change the ordering.
//
// We assume that neither X nor Y is NULL
//
func (db *MemDB) leftRotate(x memdbNodeAddr) {
y := x.getRight(db)
// Turn Y's left subtree into X's right subtree (move B)
x.right = y.left
// If B is not null, set it's parent to be X
if !y.left.isNull() {
left := y.getLeft(db)
left.up = x.addr
}
// Set Y's parent to be what X's parent was
y.up = x.up
// if X was the root
if x.up.isNull() {
db.root = y.addr
} else {
xUp := x.getUp(db)
// Set X's parent's left or right pointer to be Y
if x.addr == xUp.left {
xUp.left = y.addr
} else {
xUp.right = y.addr
}
}
// Put X on Y's left
y.left = x.addr
// Set X's parent to be Y
x.up = y.addr
}
func (db *MemDB) rightRotate(y memdbNodeAddr) {
x := y.getLeft(db)
// Turn X's right subtree into Y's left subtree (move B)
y.left = x.right
// If B is not null, set it's parent to be Y
if !x.right.isNull() {
right := x.getRight(db)
right.up = y.addr
}
// Set X's parent to be what Y's parent was
x.up = y.up
// if Y was the root
if y.up.isNull() {
db.root = x.addr
} else {
yUp := y.getUp(db)
// Set Y's parent's left or right pointer to be X
if y.addr == yUp.left {
yUp.left = x.addr
} else {
yUp.right = x.addr
}
}
// Put Y on X's right
x.right = y.addr
// Set Y's parent to be X
y.up = x.addr
}
func (db *MemDB) deleteNode(z memdbNodeAddr) {
var x, y memdbNodeAddr
if db.lastTraversedNode.Load().addr == z.addr {
db.lastTraversedNode.Store(&nullNodeAddr)
}
db.count--
db.size -= int(z.klen)
if z.left.isNull() || z.right.isNull() {
y = z
} else {
y = db.successor(z)
}
if !y.left.isNull() {
x = y.getLeft(db)
} else {
x = y.getRight(db)
}
x.up = y.up
if y.up.isNull() {
db.root = x.addr
} else {
yUp := y.getUp(db)
if y.addr == yUp.left {
yUp.left = x.addr
} else {
yUp.right = x.addr
}
}
needFix := y.isBlack()
// NOTE: traditional red-black tree will copy key from Y to Z and free Y.
// We cannot do the same thing here, due to Y's pointer is stored in vlog and the space in Z may not suitable for Y.
// So we need to copy states from Z to Y, and relink all nodes formerly connected to Z.
if y != z {
db.replaceNode(z, y)
}
if needFix {
db.deleteNodeFix(x)
}
db.allocator.freeNode(z.addr)
}
func (db *MemDB) replaceNode(old memdbNodeAddr, new memdbNodeAddr) {
if !old.up.isNull() {
oldUp := old.getUp(db)
if old.addr == oldUp.left {
oldUp.left = new.addr
} else {
oldUp.right = new.addr
}
} else {
db.root = new.addr
}
new.up = old.up
left := old.getLeft(db)
left.up = new.addr
new.left = old.left
right := old.getRight(db)
right.up = new.addr
new.right = old.right
if old.isBlack() {
new.setBlack()
} else {
new.setRed()
}
}
func (db *MemDB) deleteNodeFix(x memdbNodeAddr) {
for x.addr != db.root && x.isBlack() {
xUp := x.getUp(db)
if x.addr == xUp.left {
w := xUp.getRight(db)
if w.isRed() {
w.setBlack()
xUp.setRed()
db.leftRotate(xUp)
w = x.getUp(db).getRight(db)
}
if w.getLeft(db).isBlack() && w.getRight(db).isBlack() {
w.setRed()
x = x.getUp(db)
} else {
if w.getRight(db).isBlack() {
w.getLeft(db).setBlack()
w.setRed()
db.rightRotate(w)
w = x.getUp(db).getRight(db)
}
xUp := x.getUp(db)
if xUp.isBlack() {
w.setBlack()
} else {
w.setRed()
}
xUp.setBlack()
w.getRight(db).setBlack()
db.leftRotate(xUp)
x = db.getRoot()
}
} else {
w := xUp.getLeft(db)
if w.isRed() {
w.setBlack()
xUp.setRed()
db.rightRotate(xUp)
w = x.getUp(db).getLeft(db)
}
if w.getRight(db).isBlack() && w.getLeft(db).isBlack() {
w.setRed()
x = x.getUp(db)
} else {
if w.getLeft(db).isBlack() {
w.getRight(db).setBlack()
w.setRed()
db.leftRotate(w)
w = x.getUp(db).getLeft(db)
}
xUp := x.getUp(db)
if xUp.isBlack() {
w.setBlack()
} else {
w.setRed()
}
xUp.setBlack()
w.getLeft(db).setBlack()
db.rightRotate(xUp)
x = db.getRoot()
}
}
}
x.setBlack()
}
func (db *MemDB) successor(x memdbNodeAddr) (y memdbNodeAddr) {
if !x.right.isNull() {
// If right is not NULL then go right one and
// then keep going left until we find a node with
// no left pointer.
y = x.getRight(db)
for !y.left.isNull() {
y = y.getLeft(db)
}
return
}
// Go up the tree until we get to a node that is on the
// left of its parent (or the root) and then return the
// parent.
y = x.getUp(db)
for !y.isNull() && x.addr == y.right {
x = y
y = y.getUp(db)
}
return y
}
func (db *MemDB) predecessor(x memdbNodeAddr) (y memdbNodeAddr) {
if !x.left.isNull() {
// If left is not NULL then go left one and
// then keep going right until we find a node with
// no right pointer.
y = x.getLeft(db)
for !y.right.isNull() {
y = y.getRight(db)
}
return
}
// Go up the tree until we get to a node that is on the
// right of its parent (or the root) and then return the
// parent.
y = x.getUp(db)
for !y.isNull() && x.addr == y.left {
x = y
y = y.getUp(db)
}
return y
}
func (db *MemDB) getNode(x memdbArenaAddr) memdbNodeAddr {
return memdbNodeAddr{db.allocator.getNode(x), x}
}
func (db *MemDB) getRoot() memdbNodeAddr {
return db.getNode(db.root)
}
func (db *MemDB) allocNode(key []byte) memdbNodeAddr {
db.size += len(key)
db.count++
x, xn := db.allocator.allocNode(key)
return memdbNodeAddr{xn, x}
}
type memdbNodeAddr struct {
*memdbNode
addr memdbArenaAddr
}
func (a *memdbNodeAddr) isNull() bool {
return a.addr.isNull()
}
func (a memdbNodeAddr) getUp(db *MemDB) memdbNodeAddr {
return db.getNode(a.up)
}
func (a memdbNodeAddr) getLeft(db *MemDB) memdbNodeAddr {
return db.getNode(a.left)
}
func (a memdbNodeAddr) getRight(db *MemDB) memdbNodeAddr {
return db.getNode(a.right)
}
type memdbNode struct {
up memdbArenaAddr
left memdbArenaAddr
right memdbArenaAddr
vptr memdbArenaAddr
klen uint16
flags uint16
}
func (n *memdbNode) isRed() bool {
return n.flags&nodeColorBit != 0
}
func (n *memdbNode) isBlack() bool {
return !n.isRed()
}
func (n *memdbNode) setRed() {
n.flags |= nodeColorBit
}
func (n *memdbNode) setBlack() {
n.flags &= ^nodeColorBit
}
func (n *memdbNode) getKey() []byte {
base := unsafe.Add(unsafe.Pointer(&n.flags), kv.FlagBytes)
return unsafe.Slice((*byte)(base), int(n.klen))
}
const (
// bit 1 => red, bit 0 => black
nodeColorBit uint16 = 0x8000
nodeFlagsMask = ^nodeColorBit
)
func (n *memdbNode) getKeyFlags() kv.KeyFlags {
return kv.KeyFlags(n.flags & nodeFlagsMask)
}
func (n *memdbNode) setKeyFlags(f kv.KeyFlags) {
n.flags = (^nodeFlagsMask & n.flags) | uint16(f)
}
// RemoveFromBuffer removes a record from the mem buffer. It should be only used for test.
func (db *MemDB) RemoveFromBuffer(key []byte) {
x := db.traverse(key, false)
if x.isNull() {
return
}
db.size -= len(db.vlog.getValue(x.vptr))
db.deleteNode(x)
}
// SetMemoryFootprintChangeHook sets the hook function that is triggered when memdb grows.
func (db *MemDB) SetMemoryFootprintChangeHook(hook func(uint64)) {
innerHook := func() {
hook(db.allocator.capacity + db.vlog.capacity)
}
db.allocator.memChangeHook.Store(&innerHook)
db.vlog.memChangeHook.Store(&innerHook)
}
// Mem returns the current memory footprint
func (db *MemDB) Mem() uint64 {
return db.allocator.capacity + db.vlog.capacity
}
// SetEntrySizeLimit sets the size limit for each entry and total buffer.
func (db *MemDB) SetEntrySizeLimit(entryLimit, bufferLimit uint64) {
db.entrySizeLimit = entryLimit
db.bufferSizeLimit = bufferLimit
}
func (db *MemDB) setSkipMutex(skip bool) {
db.skipMutex = skip
}
// MemHookSet implements the MemBuffer interface.
func (db *MemDB) MemHookSet() bool {
return db.allocator.memChangeHook.Load() != nil
}
var NewMemDB = newRbtDBWithContext
var NewMemDBWithContext = newRbtDBWithContext

View File

@ -0,0 +1,169 @@
// Copyright 2024 TiKV Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//nolint:unused
package unionstore
import (
"context"
"sync"
tikverr "github.com/tikv/client-go/v2/error"
"github.com/tikv/client-go/v2/internal/unionstore/arena"
"github.com/tikv/client-go/v2/internal/unionstore/art"
"github.com/tikv/client-go/v2/kv"
)
// artDBWithContext wraps ART to satisfy the MemBuffer interface.
type artDBWithContext struct {
// This RWMutex only used to ensure rbtSnapGetter.Get will not race with
// concurrent MemBuffer.Set, MemBuffer.SetWithFlags, MemBuffer.Delete and MemBuffer.UpdateFlags.
sync.RWMutex
*art.ART
// when the ART is wrapper by upper RWMutex, we can skip the internal mutex.
skipMutex bool
}
//nolint:unused
func newArtDBWithContext() *artDBWithContext {
return &artDBWithContext{ART: art.New()}
}
func (db *artDBWithContext) setSkipMutex(skip bool) {
db.skipMutex = skip
}
func (db *artDBWithContext) set(key, value []byte, ops []kv.FlagsOp) error {
if !db.skipMutex {
db.Lock()
defer db.Unlock()
}
return db.ART.Set(key, value, ops)
}
func (db *artDBWithContext) Set(key, value []byte) error {
if len(value) == 0 {
return tikverr.ErrCannotSetNilValue
}
return db.set(key, value, nil)
}
// SetWithFlags put key-value into the last active staging buffer with the given KeyFlags.
func (db *artDBWithContext) SetWithFlags(key []byte, value []byte, ops ...kv.FlagsOp) error {
if len(value) == 0 {
return tikverr.ErrCannotSetNilValue
}
return db.set(key, value, ops)
}
func (db *artDBWithContext) UpdateFlags(key []byte, ops ...kv.FlagsOp) {
_ = db.set(key, nil, ops)
}
func (db *artDBWithContext) Delete(key []byte) error {
return db.set(key, arena.Tombstone, nil)
}
func (db *artDBWithContext) DeleteWithFlags(key []byte, ops ...kv.FlagsOp) error {
return db.set(key, arena.Tombstone, ops)
}
func (db *artDBWithContext) Staging() int {
if !db.skipMutex {
db.Lock()
defer db.Unlock()
}
return db.ART.Staging()
}
func (db *artDBWithContext) Cleanup(handle int) {
if !db.skipMutex {
db.Lock()
defer db.Unlock()
}
db.ART.Cleanup(handle)
}
func (db *artDBWithContext) Release(handle int) {
if !db.skipMutex {
db.Lock()
defer db.Unlock()
}
db.ART.Release(handle)
}
func (db *artDBWithContext) Get(_ context.Context, k []byte) ([]byte, error) {
return db.ART.Get(k)
}
func (db *artDBWithContext) GetLocal(_ context.Context, k []byte) ([]byte, error) {
return db.ART.Get(k)
}
func (db *artDBWithContext) Flush(bool) (bool, error) { return false, nil }
func (db *artDBWithContext) FlushWait() error { return nil }
// GetMemDB implements the MemBuffer interface.
func (db *artDBWithContext) GetMemDB() *MemDB {
panic("unimplemented")
}
// BatchGet returns the values for given keys from the MemBuffer.
func (db *artDBWithContext) BatchGet(ctx context.Context, keys [][]byte) (map[string][]byte, error) {
if db.Len() == 0 {
return map[string][]byte{}, nil
}
m := make(map[string][]byte, len(keys))
for _, k := range keys {
v, err := db.Get(ctx, k)
if err != nil {
if tikverr.IsErrNotFound(err) {
continue
}
return nil, err
}
m[string(k)] = v
}
return m, nil
}
// GetMetrics implements the MemBuffer interface.
func (db *artDBWithContext) GetMetrics() Metrics { return Metrics{} }
// Iter implements the Retriever interface.
func (db *artDBWithContext) Iter(lower, upper []byte) (Iterator, error) {
return db.ART.Iter(lower, upper)
}
// IterReverse implements the Retriever interface.
func (db *artDBWithContext) IterReverse(upper, lower []byte) (Iterator, error) {
return db.ART.IterReverse(upper, lower)
}
// SnapshotIter returns an Iterator for a snapshot of MemBuffer.
func (db *artDBWithContext) SnapshotIter(lower, upper []byte) Iterator {
return db.ART.SnapshotIter(lower, upper)
}
// SnapshotIterReverse returns a reversed Iterator for a snapshot of MemBuffer.
func (db *artDBWithContext) SnapshotIterReverse(upper, lower []byte) Iterator {
return db.ART.SnapshotIter(upper, lower)
}
// SnapshotGetter returns a Getter for a snapshot of MemBuffer.
func (db *artDBWithContext) SnapshotGetter() Getter {
return db.ART.SnapshotGetter()
}

View File

@ -35,6 +35,7 @@
package unionstore
import (
"context"
"encoding/binary"
"math/rand"
"testing"
@ -50,7 +51,7 @@ func BenchmarkLargeIndex(b *testing.B) {
for i := range buf {
binary.LittleEndian.PutUint32(buf[i][:], uint32(i))
}
db := newMemDB()
db := NewMemDB()
b.ResetTimer()
for i := range buf {
@ -64,7 +65,7 @@ func BenchmarkPut(b *testing.B) {
binary.BigEndian.PutUint32(buf[i][:], uint32(i))
}
p := newMemDB()
p := NewMemDB()
b.ResetTimer()
for i := range buf {
@ -78,7 +79,7 @@ func BenchmarkPutRandom(b *testing.B) {
binary.LittleEndian.PutUint32(buf[i][:], uint32(rand.Int()))
}
p := newMemDB()
p := NewMemDB()
b.ResetTimer()
for i := range buf {
@ -92,14 +93,15 @@ func BenchmarkGet(b *testing.B) {
binary.BigEndian.PutUint32(buf[i][:], uint32(i))
}
p := newMemDB()
p := NewMemDB()
for i := range buf {
p.Set(buf[i][:keySize], buf[i][:])
}
ctx := context.Background()
b.ResetTimer()
for i := range buf {
p.Get(buf[i][:keySize])
p.Get(ctx, buf[i][:keySize])
}
}
@ -109,14 +111,15 @@ func BenchmarkGetRandom(b *testing.B) {
binary.LittleEndian.PutUint32(buf[i][:], uint32(rand.Int()))
}
p := newMemDB()
p := NewMemDB()
for i := range buf {
p.Set(buf[i][:keySize], buf[i][:])
}
ctx := context.Background()
b.ResetTimer()
for i := 0; i < b.N; i++ {
p.Get(buf[i][:keySize])
p.Get(ctx, buf[i][:keySize])
}
}
@ -127,7 +130,7 @@ func BenchmarkMemDbBufferSequential(b *testing.B) {
for i := 0; i < opCnt; i++ {
data[i] = encodeInt(i)
}
buffer := newMemDB()
buffer := NewMemDB()
benchmarkSetGet(b, buffer, data)
b.ReportAllocs()
}
@ -138,20 +141,20 @@ func BenchmarkMemDbBufferRandom(b *testing.B) {
data[i] = encodeInt(i)
}
shuffle(data)
buffer := newMemDB()
buffer := NewMemDB()
benchmarkSetGet(b, buffer, data)
b.ReportAllocs()
}
func BenchmarkMemDbIter(b *testing.B) {
buffer := newMemDB()
buffer := NewMemDB()
benchIterator(b, buffer)
b.ReportAllocs()
}
func BenchmarkMemDbCreation(b *testing.B) {
for i := 0; i < b.N; i++ {
newMemDB()
NewMemDB()
}
b.ReportAllocs()
}
@ -165,13 +168,14 @@ func shuffle(slc [][]byte) {
}
}
func benchmarkSetGet(b *testing.B, buffer *MemDB, data [][]byte) {
ctx := context.Background()
b.ResetTimer()
for i := 0; i < b.N; i++ {
for _, k := range data {
buffer.Set(k, k)
}
for _, k := range data {
buffer.Get(k)
buffer.Get(ctx, k)
}
}
}

View File

@ -59,7 +59,7 @@ func TestRandom(t *testing.T) {
rand2.Read(keys[i])
}
p1 := newMemDB()
p1 := NewMemDB()
p2 := leveldb.New(comparer.DefaultComparer, 4*1024)
for _, k := range keys {
p1.Set(k, k)
@ -88,7 +88,7 @@ func TestRandom(t *testing.T) {
// The test takes too long under the race detector.
func TestRandomDerive(t *testing.T) {
db := newMemDB()
db := NewMemDB()
golden := leveldb.New(comparer.DefaultComparer, 4*1024)
testRandomDeriveRecur(t, db, golden, 0)
}

View File

@ -0,0 +1,176 @@
// Copyright 2024 TiKV Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package unionstore
import (
"context"
"sync"
tikverr "github.com/tikv/client-go/v2/error"
"github.com/tikv/client-go/v2/internal/unionstore/arena"
"github.com/tikv/client-go/v2/internal/unionstore/rbt"
"github.com/tikv/client-go/v2/kv"
)
// rbtDBWithContext wraps RBT to satisfy the MemBuffer interface.
type rbtDBWithContext struct {
// This RWMutex only used to ensure rbtSnapGetter.Get will not race with
// concurrent MemBuffer.Set, MemBuffer.SetWithFlags, MemBuffer.Delete and MemBuffer.UpdateFlags.
sync.RWMutex
*rbt.RBT
// when the RBT is wrapper by upper RWMutex, we can skip the internal mutex.
skipMutex bool
}
func newRbtDBWithContext() *rbtDBWithContext {
return &rbtDBWithContext{
skipMutex: false,
RBT: rbt.New(),
}
}
func (db *rbtDBWithContext) setSkipMutex(skip bool) {
db.skipMutex = skip
}
func (db *rbtDBWithContext) set(key, value []byte, ops ...kv.FlagsOp) error {
if !db.skipMutex {
db.Lock()
defer db.Unlock()
}
return db.RBT.Set(key, value, ops...)
}
// UpdateFlags update the flags associated with key.
func (db *rbtDBWithContext) UpdateFlags(key []byte, ops ...kv.FlagsOp) {
err := db.set(key, nil, ops...)
_ = err // set without value will never fail
}
// Set sets the value for key k as v into kv store.
// v must NOT be nil or empty, otherwise it returns ErrCannotSetNilValue.
func (db *rbtDBWithContext) Set(key []byte, value []byte) error {
if len(value) == 0 {
return tikverr.ErrCannotSetNilValue
}
return db.set(key, value)
}
// SetWithFlags put key-value into the last active staging buffer with the given KeyFlags.
func (db *rbtDBWithContext) SetWithFlags(key []byte, value []byte, ops ...kv.FlagsOp) error {
if len(value) == 0 {
return tikverr.ErrCannotSetNilValue
}
return db.set(key, value, ops...)
}
// Delete removes the entry for key k from kv store.
func (db *rbtDBWithContext) Delete(key []byte) error {
return db.set(key, arena.Tombstone)
}
// DeleteWithFlags delete key with the given KeyFlags
func (db *rbtDBWithContext) DeleteWithFlags(key []byte, ops ...kv.FlagsOp) error {
return db.set(key, arena.Tombstone, ops...)
}
func (db *rbtDBWithContext) Staging() int {
if !db.skipMutex {
db.Lock()
defer db.Unlock()
}
return db.RBT.Staging()
}
func (db *rbtDBWithContext) Cleanup(handle int) {
if !db.skipMutex {
db.Lock()
defer db.Unlock()
}
db.RBT.Cleanup(handle)
}
func (db *rbtDBWithContext) Release(handle int) {
if !db.skipMutex {
db.Lock()
defer db.Unlock()
}
db.RBT.Release(handle)
}
func (db *rbtDBWithContext) Get(_ context.Context, k []byte) ([]byte, error) {
return db.RBT.Get(k)
}
func (db *rbtDBWithContext) GetLocal(_ context.Context, k []byte) ([]byte, error) {
return db.RBT.Get(k)
}
func (db *rbtDBWithContext) Flush(bool) (bool, error) { return false, nil }
func (db *rbtDBWithContext) FlushWait() error { return nil }
// GetMemDB implements the MemBuffer interface.
func (db *rbtDBWithContext) GetMemDB() *MemDB {
return db
}
// BatchGet returns the values for given keys from the MemBuffer.
func (db *rbtDBWithContext) BatchGet(ctx context.Context, keys [][]byte) (map[string][]byte, error) {
if db.Len() == 0 {
return map[string][]byte{}, nil
}
m := make(map[string][]byte, len(keys))
for _, k := range keys {
v, err := db.Get(ctx, k)
if err != nil {
if tikverr.IsErrNotFound(err) {
continue
}
return nil, err
}
m[string(k)] = v
}
return m, nil
}
// GetMetrics implements the MemBuffer interface.
func (db *rbtDBWithContext) GetMetrics() Metrics { return Metrics{} }
// Iter implements the Retriever interface.
func (db *rbtDBWithContext) Iter(lower, upper []byte) (Iterator, error) {
return db.RBT.Iter(lower, upper)
}
// IterReverse implements the Retriever interface.
func (db *rbtDBWithContext) IterReverse(upper, lower []byte) (Iterator, error) {
return db.RBT.IterReverse(upper, lower)
}
// SnapshotIter returns an Iterator for a snapshot of MemBuffer.
func (db *rbtDBWithContext) SnapshotIter(lower, upper []byte) Iterator {
return db.RBT.SnapshotIter(lower, upper)
}
// SnapshotIterReverse returns a reversed Iterator for a snapshot of MemBuffer.
func (db *rbtDBWithContext) SnapshotIterReverse(upper, lower []byte) Iterator {
return db.RBT.SnapshotIter(upper, lower)
}
// SnapshotGetter returns a Getter for a snapshot of MemBuffer.
func (db *rbtDBWithContext) SnapshotGetter() Getter {
return db.RBT.SnapshotGetter()
}

View File

@ -50,43 +50,33 @@ import (
type KeyFlags = kv.KeyFlags
func init() {
testMode = true
func TestGetSet(t *testing.T) {
testGetSet(t, newRbtDBWithContext())
}
func TestGetSet(t *testing.T) {
func testGetSet(t *testing.T, db MemBuffer) {
require := require.New(t)
const cnt = 10000
p := fillDB(cnt)
fillDB(db, cnt)
var buf [4]byte
for i := 0; i < cnt; i++ {
binary.BigEndian.PutUint32(buf[:], uint32(i))
v, err := p.Get(buf[:])
v, err := db.Get(context.Background(), buf[:])
require.Nil(err)
require.Equal(v, buf[:])
}
}
func TestBigKV(t *testing.T) {
assert := assert.New(t)
db := newMemDB()
db.Set([]byte{1}, make([]byte, 80<<20))
assert.Equal(db.vlog.blockSize, maxBlockSize)
assert.Equal(len(db.vlog.blocks), 1)
h := db.Staging()
db.Set([]byte{2}, make([]byte, 127<<20))
db.Release(h)
assert.Equal(db.vlog.blockSize, maxBlockSize)
assert.Equal(len(db.vlog.blocks), 2)
assert.PanicsWithValue("alloc size is larger than max block size", func() { db.Set([]byte{3}, make([]byte, maxBlockSize+1)) })
func TestIterator(t *testing.T) {
testIterator(t, newRbtDBWithContext())
}
func TestIterator(t *testing.T) {
func testIterator(t *testing.T, db MemBuffer) {
assert := assert.New(t)
const cnt = 10000
db := fillDB(cnt)
fillDB(db, cnt)
var buf [4]byte
var i int
@ -130,14 +120,17 @@ func TestIterator(t *testing.T) {
}
func TestDiscard(t *testing.T) {
testDiscard(t, newRbtDBWithContext())
}
func testDiscard(t *testing.T, db MemBuffer) {
assert := assert.New(t)
const cnt = 10000
db := newMemDB()
base := deriveAndFill(0, cnt, 0, db)
base := deriveAndFill(db, 0, cnt, 0)
sz := db.Size()
db.Cleanup(deriveAndFill(0, cnt, 1, db))
db.Cleanup(deriveAndFill(db, 0, cnt, 1))
assert.Equal(db.Len(), cnt)
assert.Equal(db.Size(), sz)
@ -145,7 +138,7 @@ func TestDiscard(t *testing.T) {
for i := 0; i < cnt; i++ {
binary.BigEndian.PutUint32(buf[:], uint32(i))
v, err := db.Get(buf[:])
v, err := db.Get(context.Background(), buf[:])
assert.Nil(err)
assert.Equal(v, buf[:])
}
@ -171,28 +164,25 @@ func TestDiscard(t *testing.T) {
db.Cleanup(base)
for i := 0; i < cnt; i++ {
binary.BigEndian.PutUint32(buf[:], uint32(i))
_, err := db.Get(buf[:])
_, err := db.Get(context.Background(), buf[:])
assert.NotNil(err)
}
it1, _ := db.Iter(nil, nil)
it := it1.(*MemdbIterator)
it.seekToFirst()
assert.False(it.Valid())
it.seekToLast()
assert.False(it.Valid())
it.seek([]byte{0xff})
it, _ := db.Iter(nil, nil)
assert.False(it.Valid())
}
func TestFlushOverwrite(t *testing.T) {
testFlushOverwrite(t, newRbtDBWithContext())
}
func testFlushOverwrite(t *testing.T, db MemBuffer) {
assert := assert.New(t)
const cnt = 10000
db := newMemDB()
db.Release(deriveAndFill(0, cnt, 0, db))
db.Release(deriveAndFill(db, 0, cnt, 0))
sz := db.Size()
db.Release(deriveAndFill(0, cnt, 1, db))
db.Release(deriveAndFill(db, 0, cnt, 1))
assert.Equal(db.Len(), cnt)
assert.Equal(db.Size(), sz)
@ -202,7 +192,7 @@ func TestFlushOverwrite(t *testing.T) {
for i := 0; i < cnt; i++ {
binary.BigEndian.PutUint32(kbuf[:], uint32(i))
binary.BigEndian.PutUint32(vbuf[:], uint32(i+1))
v, err := db.Get(kbuf[:])
v, err := db.Get(context.Background(), kbuf[:])
assert.Nil(err)
assert.Equal(v, vbuf[:])
}
@ -229,6 +219,10 @@ func TestFlushOverwrite(t *testing.T) {
}
func TestComplexUpdate(t *testing.T) {
testComplexUpdate(t, newRbtDBWithContext())
}
func testComplexUpdate(t *testing.T, db MemBuffer) {
assert := assert.New(t)
const (
@ -237,10 +231,9 @@ func TestComplexUpdate(t *testing.T) {
insert = 9000
)
db := newMemDB()
db.Release(deriveAndFill(0, overwrite, 0, db))
db.Release(deriveAndFill(db, 0, overwrite, 0))
assert.Equal(db.Len(), overwrite)
db.Release(deriveAndFill(keep, insert, 1, db))
db.Release(deriveAndFill(db, keep, insert, 1))
assert.Equal(db.Len(), insert)
var kbuf, vbuf [4]byte
@ -251,20 +244,23 @@ func TestComplexUpdate(t *testing.T) {
if i >= keep {
binary.BigEndian.PutUint32(vbuf[:], uint32(i+1))
}
v, err := db.Get(kbuf[:])
v, err := db.Get(context.Background(), kbuf[:])
assert.Nil(err)
assert.Equal(v, vbuf[:])
}
}
func TestNestedSandbox(t *testing.T) {
testNestedSandbox(t, newRbtDBWithContext())
}
func testNestedSandbox(t *testing.T, db MemBuffer) {
assert := assert.New(t)
db := newMemDB()
h0 := deriveAndFill(0, 200, 0, db)
h1 := deriveAndFill(0, 100, 1, db)
h2 := deriveAndFill(50, 150, 2, db)
h3 := deriveAndFill(100, 120, 3, db)
h4 := deriveAndFill(0, 150, 4, db)
h0 := deriveAndFill(db, 0, 200, 0)
h1 := deriveAndFill(db, 0, 100, 1)
h2 := deriveAndFill(db, 50, 150, 2)
h3 := deriveAndFill(db, 100, 120, 3)
h4 := deriveAndFill(db, 0, 150, 4)
db.Cleanup(h4) // Discard (0..150 -> 4)
db.Release(h3) // Flush (100..120 -> 3)
db.Cleanup(h2) // Discard (100..120 -> 3) & (50..150 -> 2)
@ -280,7 +276,7 @@ func TestNestedSandbox(t *testing.T) {
if i < 100 {
binary.BigEndian.PutUint32(vbuf[:], uint32(i+1))
}
v, err := db.Get(kbuf[:])
v, err := db.Get(context.Background(), kbuf[:])
assert.Nil(err)
assert.Equal(v, vbuf[:])
}
@ -314,10 +310,14 @@ func TestNestedSandbox(t *testing.T) {
}
func TestOverwrite(t *testing.T) {
testOverwrite(t, newRbtDBWithContext())
}
func testOverwrite(t *testing.T, db MemBuffer) {
assert := assert.New(t)
const cnt = 10000
db := fillDB(cnt)
fillDB(db, cnt)
var buf [4]byte
sz := db.Size()
@ -332,7 +332,7 @@ func TestOverwrite(t *testing.T) {
for i := 0; i < cnt; i++ {
binary.BigEndian.PutUint32(buf[:], uint32(i))
val, _ := db.Get(buf[:])
val, _ := db.Get(context.Background(), buf[:])
v := binary.BigEndian.Uint32(val)
if i%3 == 0 {
assert.Equal(v, uint32(i*10))
@ -371,56 +371,32 @@ func TestOverwrite(t *testing.T) {
assert.Equal(i, -1)
}
func TestKVLargeThanBlock(t *testing.T) {
assert := assert.New(t)
db := newMemDB()
db.Set([]byte{1}, make([]byte, 1))
db.Set([]byte{2}, make([]byte, 4096))
assert.Equal(len(db.vlog.blocks), 2)
db.Set([]byte{3}, make([]byte, 3000))
assert.Equal(len(db.vlog.blocks), 2)
val, err := db.Get([]byte{3})
assert.Nil(err)
assert.Equal(len(val), 3000)
}
func TestEmptyDB(t *testing.T) {
assert := assert.New(t)
db := newMemDB()
_, err := db.Get([]byte{0})
assert.NotNil(err)
it1, _ := db.Iter(nil, nil)
it := it1.(*MemdbIterator)
it.seekToFirst()
assert.False(it.Valid())
it.seekToLast()
assert.False(it.Valid())
it.seek([]byte{0xff})
assert.False(it.Valid())
}
func TestReset(t *testing.T) {
testReset(t, newRbtDBWithContext())
}
func testReset(t *testing.T, db interface {
MemBuffer
Reset()
}) {
assert := assert.New(t)
db := fillDB(1000)
fillDB(db, 1000)
db.Reset()
_, err := db.Get([]byte{0, 0, 0, 0})
_, err := db.Get(context.Background(), []byte{0, 0, 0, 0})
assert.NotNil(err)
it1, _ := db.Iter(nil, nil)
it := it1.(*MemdbIterator)
it.seekToFirst()
assert.False(it.Valid())
it.seekToLast()
assert.False(it.Valid())
it.seek([]byte{0xff})
it, _ := db.Iter(nil, nil)
assert.False(it.Valid())
}
func TestInspectStage(t *testing.T) {
testInspectStage(t, newRbtDBWithContext())
}
func testInspectStage(t *testing.T, db MemBuffer) {
assert := assert.New(t)
db := newMemDB()
h1 := deriveAndFill(0, 1000, 0, db)
h2 := deriveAndFill(500, 1000, 1, db)
h1 := deriveAndFill(db, 0, 1000, 0)
h2 := deriveAndFill(db, 500, 1000, 1)
for i := 500; i < 1500; i++ {
var kbuf [4]byte
// don't update in place
@ -429,7 +405,7 @@ func TestInspectStage(t *testing.T) {
binary.BigEndian.PutUint32(vbuf[:], uint32(i+2))
db.Set(kbuf[:], vbuf[:])
}
h3 := deriveAndFill(1000, 2000, 3, db)
h3 := deriveAndFill(db, 1000, 2000, 3)
db.InspectStage(h3, func(key []byte, _ KeyFlags, val []byte) {
k := int(binary.BigEndian.Uint32(key))
@ -470,13 +446,17 @@ func TestInspectStage(t *testing.T) {
}
func TestDirty(t *testing.T) {
testDirty(t, func() MemBuffer { return newRbtDBWithContext() })
}
func testDirty(t *testing.T, createDb func() MemBuffer) {
assert := assert.New(t)
db := newMemDB()
db := createDb()
db.Set([]byte{1}, []byte{1})
assert.True(db.Dirty())
db = newMemDB()
db = createDb()
h := db.Staging()
db.Set([]byte{1}, []byte{1})
db.Cleanup(h)
@ -488,14 +468,14 @@ func TestDirty(t *testing.T) {
assert.True(db.Dirty())
// persistent flags will make memdb dirty.
db = newMemDB()
db = createDb()
h = db.Staging()
db.SetWithFlags([]byte{1}, []byte{1}, kv.SetKeyLocked)
db.Cleanup(h)
assert.True(db.Dirty())
// non-persistent flags will not make memdb dirty.
db = newMemDB()
db = createDb()
h = db.Staging()
db.SetWithFlags([]byte{1}, []byte{1}, kv.SetPresumeKeyNotExists)
db.Cleanup(h)
@ -503,10 +483,13 @@ func TestDirty(t *testing.T) {
}
func TestFlags(t *testing.T) {
testFlags(t, newRbtDBWithContext(), func(db MemBuffer) Iterator { return db.(*rbtDBWithContext).IterWithFlags(nil, nil) })
}
func testFlags(t *testing.T, db MemBuffer, iterWithFlags func(db MemBuffer) Iterator) {
assert := assert.New(t)
const cnt = 10000
db := newMemDB()
h := db.Staging()
for i := uint32(0); i < cnt; i++ {
var buf [4]byte
@ -522,7 +505,7 @@ func TestFlags(t *testing.T) {
for i := uint32(0); i < cnt; i++ {
var buf [4]byte
binary.BigEndian.PutUint32(buf[:], i)
_, err := db.Get(buf[:])
_, err := db.Get(context.Background(), buf[:])
assert.NotNil(err)
flags, err := db.GetFlags(buf[:])
if i%2 == 0 {
@ -537,13 +520,10 @@ func TestFlags(t *testing.T) {
assert.Equal(db.Len(), 5000)
assert.Equal(db.Size(), 20000)
it1, _ := db.Iter(nil, nil)
it := it1.(*MemdbIterator)
it, _ := db.Iter(nil, nil)
assert.False(it.Valid())
it.includeFlags = true
it.init()
it = iterWithFlags(db)
for ; it.Valid(); it.Next() {
k := binary.BigEndian.Uint32(it.Key())
assert.True(k%2 == 0)
@ -557,7 +537,7 @@ func TestFlags(t *testing.T) {
for i := uint32(0); i < cnt; i++ {
var buf [4]byte
binary.BigEndian.PutUint32(buf[:], i)
_, err := db.Get(buf[:])
_, err := db.Get(context.Background(), buf[:])
assert.NotNil(err)
// UpdateFlags will create missing node.
@ -578,7 +558,7 @@ func checkConsist(t *testing.T, p1 *MemDB, p2 *leveldb.DB) {
var prevKey, prevVal []byte
for it2.First(); it2.Valid(); it2.Next() {
v, err := p1.Get(it2.Key())
v, err := p1.Get(context.Background(), it2.Key())
assert.Nil(err)
assert.Equal(v, it2.Value())
@ -608,14 +588,12 @@ func checkConsist(t *testing.T, p1 *MemDB, p2 *leveldb.DB) {
}
}
func fillDB(cnt int) *MemDB {
db := newMemDB()
h := deriveAndFill(0, cnt, 0, db)
func fillDB(db MemBuffer, cnt int) {
h := deriveAndFill(db, 0, cnt, 0)
db.Release(h)
return db
}
func deriveAndFill(start, end, valueBase int, db *MemDB) int {
func deriveAndFill(db MemBuffer, start, end, valueBase int) int {
h := db.Staging()
var kbuf, vbuf [4]byte
for i := start; i < end; i++ {
@ -704,21 +682,21 @@ func checkNewIterator(t *testing.T, buffer *MemDB) {
func mustGet(t *testing.T, buffer *MemDB) {
for i := startIndex; i < testCount; i++ {
s := encodeInt(i * indexStep)
val, err := buffer.Get(s)
val, err := buffer.Get(context.Background(), s)
assert.Nil(t, err)
assert.Equal(t, string(val), string(s))
}
}
func TestKVGetSet(t *testing.T) {
buffer := newMemDB()
buffer := NewMemDB()
insertData(t, buffer)
mustGet(t, buffer)
}
func TestNewIterator(t *testing.T) {
assert := assert.New(t)
buffer := newMemDB()
buffer := NewMemDB()
// should be invalid
iter, err := buffer.Iter(nil, nil)
assert.Nil(err)
@ -746,7 +724,7 @@ func NextUntil(it Iterator, fn FnKeyCmp) error {
func TestIterNextUntil(t *testing.T) {
assert := assert.New(t)
buffer := newMemDB()
buffer := NewMemDB()
insertData(t, buffer)
iter, err := buffer.Iter(nil, nil)
@ -761,7 +739,7 @@ func TestIterNextUntil(t *testing.T) {
func TestBasicNewIterator(t *testing.T) {
assert := assert.New(t)
buffer := newMemDB()
buffer := NewMemDB()
it, err := buffer.Iter([]byte("2"), nil)
assert.Nil(err)
assert.False(it.Valid())
@ -780,7 +758,7 @@ func TestNewIteratorMin(t *testing.T) {
{"DATA_test_main_db_tbl_tbl_test_record__00000000000000000002_0002", "2"},
{"DATA_test_main_db_tbl_tbl_test_record__00000000000000000002_0003", "hello"},
}
buffer := newMemDB()
buffer := NewMemDB()
for _, kv := range kvs {
err := buffer.Set([]byte(kv.key), []byte(kv.value))
assert.Nil(err)
@ -803,7 +781,7 @@ func TestNewIteratorMin(t *testing.T) {
func TestMemDBStaging(t *testing.T) {
assert := assert.New(t)
buffer := newMemDB()
buffer := NewMemDB()
err := buffer.Set([]byte("x"), make([]byte, 2))
assert.Nil(err)
@ -815,25 +793,27 @@ func TestMemDBStaging(t *testing.T) {
err = buffer.Set([]byte("yz"), make([]byte, 1))
assert.Nil(err)
v, _ := buffer.Get([]byte("x"))
v, _ := buffer.Get(context.Background(), []byte("x"))
assert.Equal(len(v), 3)
buffer.Release(h2)
v, _ = buffer.Get([]byte("yz"))
v, _ = buffer.Get(context.Background(), []byte("yz"))
assert.Equal(len(v), 1)
buffer.Cleanup(h1)
v, _ = buffer.Get([]byte("x"))
v, _ = buffer.Get(context.Background(), []byte("x"))
assert.Equal(len(v), 2)
}
func TestBufferLimit(t *testing.T) {
testBufferLimit(t, newRbtDBWithContext())
}
func testBufferLimit(t *testing.T, buffer MemBuffer) {
assert := assert.New(t)
buffer := newMemDB()
buffer.bufferSizeLimit = 1000
buffer.entrySizeLimit = 500
buffer.SetEntrySizeLimit(500, 1000)
err := buffer.Set([]byte("x"), make([]byte, 500))
assert.NotNil(err) // entry size limit
@ -852,7 +832,7 @@ func TestBufferLimit(t *testing.T) {
func TestUnsetTemporaryFlag(t *testing.T) {
require := require.New(t)
db := newMemDB()
db := NewMemDB()
key := []byte{1}
value := []byte{2}
db.SetWithFlags(key, value, kv.SetNeedConstraintCheckInPrewrite)
@ -864,7 +844,7 @@ func TestUnsetTemporaryFlag(t *testing.T) {
func TestSnapshotGetIter(t *testing.T) {
assert := assert.New(t)
buffer := newMemDB()
buffer := NewMemDB()
var getters []Getter
var iters []Iterator
for i := 0; i < 100; i++ {

View File

@ -44,18 +44,18 @@ type mockSnapshot struct {
store *MemDB
}
func (s *mockSnapshot) Get(_ context.Context, k []byte) ([]byte, error) {
return s.store.Get(k)
func (s *mockSnapshot) Get(ctx context.Context, k []byte) ([]byte, error) {
return s.store.Get(ctx, k)
}
func (s *mockSnapshot) SetPriority(priority int) {
}
func (s *mockSnapshot) BatchGet(_ context.Context, keys [][]byte) (map[string][]byte, error) {
func (s *mockSnapshot) BatchGet(ctx context.Context, keys [][]byte) (map[string][]byte, error) {
m := make(map[string][]byte, len(keys))
for _, k := range keys {
v, err := s.store.Get(k)
v, err := s.store.Get(ctx, k)
if tikverr.IsErrNotFound(err) {
continue
}

View File

@ -24,6 +24,7 @@ import (
"github.com/pingcap/errors"
tikverr "github.com/tikv/client-go/v2/error"
"github.com/tikv/client-go/v2/internal/logutil"
"github.com/tikv/client-go/v2/internal/unionstore/arena"
"github.com/tikv/client-go/v2/kv"
"github.com/tikv/client-go/v2/metrics"
"github.com/tikv/client-go/v2/util"
@ -35,7 +36,7 @@ import (
// - an immutable onflushing buffer for read
// - like MemDB, PipelinedMemDB also CANNOT be used concurrently
type PipelinedMemDB struct {
// Like MemDB, this RWMutex only used to ensure memdbSnapGetter.Get will not race with
// Like MemDB, this RWMutex only used to ensure rbtSnapGetter.Get will not race with
// concurrent memdb.Set, memdb.SetWithFlags, memdb.Delete and memdb.UpdateFlags.
sync.RWMutex
onFlushing atomic.Bool
@ -102,8 +103,9 @@ type FlushFunc func(uint64, *MemDB) error
type BufferBatchGetter func(ctx context.Context, keys [][]byte) (map[string][]byte, error)
func NewPipelinedMemDB(bufferBatchGetter BufferBatchGetter, flushFunc FlushFunc) *PipelinedMemDB {
memdb := newMemDB()
memdb := NewMemDB()
memdb.setSkipMutex(true)
entryLimit, _ := memdb.GetEntrySizeLimit()
flushOpt := newFlushOption()
return &PipelinedMemDB{
memDB: memdb,
@ -112,7 +114,7 @@ func NewPipelinedMemDB(bufferBatchGetter BufferBatchGetter, flushFunc FlushFunc)
bufferBatchGetter: bufferBatchGetter,
generation: 0,
// keep entryLimit and bufferLimit same with the memdb's default values.
entryLimit: memdb.entrySizeLimit,
entryLimit: entryLimit,
flushOption: flushOpt,
startTime: time.Now(),
}
@ -129,7 +131,7 @@ func (p *PipelinedMemDB) GetMemDB() *MemDB {
}
func (p *PipelinedMemDB) get(ctx context.Context, k []byte, skipRemoteBuffer bool) ([]byte, error) {
v, err := p.memDB.Get(k)
v, err := p.memDB.Get(ctx, k)
if err == nil {
return v, nil
}
@ -137,7 +139,7 @@ func (p *PipelinedMemDB) get(ctx context.Context, k []byte, skipRemoteBuffer boo
return nil, err
}
if p.flushingMemDB != nil {
v, err = p.flushingMemDB.Get(k)
v, err = p.flushingMemDB.Get(ctx, k)
if err == nil {
return v, nil
}
@ -288,7 +290,7 @@ func (p *PipelinedMemDB) Flush(force bool) (bool, error) {
// invalidate the batch get cache whether the flush is really triggered.
p.batchGetCache = nil
if len(p.memDB.stages) > 0 {
if p.memDB.IsStaging() {
return false, errors.New("there are stages unreleased when Flush is called")
}
@ -309,9 +311,9 @@ func (p *PipelinedMemDB) Flush(force bool) (bool, error) {
p.flushingMemDB = p.memDB
p.len += p.flushingMemDB.Len()
p.size += p.flushingMemDB.Size()
p.missCount += p.memDB.missCount.Load()
p.hitCount += p.memDB.hitCount.Load()
p.memDB = newMemDB()
p.missCount += p.memDB.GetCacheMissCount()
p.hitCount += p.memDB.GetCacheHitCount()
p.memDB = NewMemDB()
// buffer size is limited by ForceFlushMemSizeThreshold. Do not set bufferLimit
p.memDB.SetEntrySizeLimit(p.entryLimit, unlimitedSize)
p.memDB.setSkipMutex(true)
@ -384,7 +386,7 @@ func (p *PipelinedMemDB) FlushWait() error {
func (p *PipelinedMemDB) handleAlreadyExistErr(err error) error {
var existErr *tikverr.ErrKeyExist
if stderrors.As(err, &existErr) {
v, err2 := p.flushingMemDB.Get(existErr.GetKey())
v, err2 := p.flushingMemDB.Get(context.Background(), existErr.GetKey())
if err2 != nil {
// TODO: log more info like start_ts, also for other logs
logutil.BgLogger().Warn(
@ -518,12 +520,12 @@ func (p *PipelinedMemDB) Release(h int) {
}
// Checkpoint implements MemBuffer interface.
func (p *PipelinedMemDB) Checkpoint() *MemDBCheckpoint {
func (p *PipelinedMemDB) Checkpoint() *arena.MemDBCheckpoint {
panic("Checkpoint is not supported for PipelinedMemDB")
}
// RevertToCheckpoint implements MemBuffer interface.
func (p *PipelinedMemDB) RevertToCheckpoint(*MemDBCheckpoint) {
func (p *PipelinedMemDB) RevertToCheckpoint(*arena.MemDBCheckpoint) {
panic("RevertToCheckpoint is not supported for PipelinedMemDB")
}
@ -533,8 +535,8 @@ func (p *PipelinedMemDB) GetMetrics() Metrics {
hitCount := p.hitCount
missCount := p.missCount
if p.memDB != nil {
hitCount += p.memDB.hitCount.Load()
missCount += p.memDB.missCount.Load()
hitCount += p.memDB.GetCacheHitCount()
missCount += p.memDB.GetCacheMissCount()
}
return Metrics{
WaitDuration: p.flushWaitDuration,

View File

@ -208,7 +208,7 @@ func TestPipelinedFlushGet(t *testing.T) {
require.True(t, memdb.OnFlushing())
// The key is in flushingMemDB memdb instead of current mutable memdb.
_, err = memdb.memDB.Get([]byte("key"))
_, err = memdb.memDB.Get(context.Background(), []byte("key"))
require.True(t, tikverr.IsErrNotFound(err))
// But we still can get the value by PipelinedMemDB.Get.
value, err = memdb.Get(context.Background(), []byte("key"))

View File

@ -0,0 +1,926 @@
// Copyright 2021 TiKV Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// NOTE: The code in this file is based on code from the
// TiDB project, licensed under the Apache License v 2.0
//
// https://github.com/pingcap/tidb/tree/cc5e161ac06827589c4966674597c137cc9e809c/store/tikv/unionstore/memdb.go
//
// Copyright 2020 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package rbt
import (
"bytes"
"fmt"
"math"
"sync/atomic"
"unsafe"
tikverr "github.com/tikv/client-go/v2/error"
"github.com/tikv/client-go/v2/internal/unionstore/arena"
"github.com/tikv/client-go/v2/kv"
)
const unlimitedSize = math.MaxUint64
var testMode = false
// RBT is rollbackable Red-Black Tree optimized for TiDB's transaction states buffer use scenario.
// You can think RBT is a combination of two separate tree map, one for key => value and another for key => keyFlags.
//
// The value map is rollbackable, that means you can use the `Staging`, `Release` and `Cleanup` API to safely modify KVs.
//
// The flags map is not rollbackable. There are two types of flag, persistent and non-persistent.
// When discarding a newly added KV in `Cleanup`, the non-persistent flags will be cleared.
// If there are persistent flags associated with key, we will keep this key in node without value.
type RBT struct {
root arena.MemdbArenaAddr
allocator nodeAllocator
vlog arena.MemdbVlog[*memdbNode, *RBT]
entrySizeLimit uint64
bufferSizeLimit uint64
count int
size int
vlogInvalid bool
dirty bool
stages []arena.MemDBCheckpoint
// The lastTraversedNode must exist
lastTraversedNode atomic.Pointer[MemdbNodeAddr]
hitCount atomic.Uint64
missCount atomic.Uint64
}
func New() *RBT {
db := new(RBT)
db.allocator.init()
db.root = arena.NullAddr
db.stages = make([]arena.MemDBCheckpoint, 0, 2)
db.entrySizeLimit = unlimitedSize
db.bufferSizeLimit = unlimitedSize
db.lastTraversedNode.Store(&nullNodeAddr)
return db
}
// updateLastTraversed updates the last traversed node atomically
func (db *RBT) updateLastTraversed(node MemdbNodeAddr) {
db.lastTraversedNode.Store(&node)
}
// checkKeyInCache retrieves the last traversed node if the key matches
func (db *RBT) checkKeyInCache(key []byte) (MemdbNodeAddr, bool) {
nodePtr := db.lastTraversedNode.Load()
if nodePtr == nil || nodePtr.isNull() {
return nullNodeAddr, false
}
if bytes.Equal(key, nodePtr.memdbNode.getKey()) {
return *nodePtr, true
}
return nullNodeAddr, false
}
func (db *RBT) RevertNode(hdr *arena.MemdbVlogHdr) {
node := db.getNode(hdr.NodeAddr)
node.vptr = hdr.OldValue
db.size -= int(hdr.ValueLen)
// oldValue.isNull() == true means this is a newly added value.
if hdr.OldValue.IsNull() {
// If there are no flags associated with this key, we need to delete this node.
keptFlags := node.getKeyFlags().AndPersistent()
if keptFlags == 0 {
db.deleteNode(node)
} else {
node.setKeyFlags(keptFlags)
db.dirty = true
}
} else {
db.size += len(db.vlog.GetValue(hdr.OldValue))
}
}
func (db *RBT) InspectNode(addr arena.MemdbArenaAddr) (*memdbNode, arena.MemdbArenaAddr) {
node := db.allocator.getNode(addr)
return node, node.vptr
}
// IsStaging returns whether the MemBuffer is in staging status.
func (db *RBT) IsStaging() bool {
return len(db.stages) > 0
}
// Staging create a new staging buffer inside the MemBuffer.
// Subsequent writes will be temporarily stored in this new staging buffer.
// When you think all modifications looks good, you can call `Release` to public all of them to the upper level buffer.
func (db *RBT) Staging() int {
db.stages = append(db.stages, db.vlog.Checkpoint())
return len(db.stages)
}
// Release publish all modifications in the latest staging buffer to upper level.
func (db *RBT) Release(h int) {
if h != len(db.stages) {
// This should never happens in production environment.
// Use panic to make debug easier.
panic("cannot release staging buffer")
}
if h == 1 {
tail := db.vlog.Checkpoint()
if !db.stages[0].IsSamePosition(&tail) {
db.dirty = true
}
}
db.stages = db.stages[:h-1]
}
// Cleanup cleanup the resources referenced by the StagingHandle.
// If the changes are not published by `Release`, they will be discarded.
func (db *RBT) Cleanup(h int) {
if h > len(db.stages) {
return
}
if h < len(db.stages) {
// This should never happens in production environment.
// Use panic to make debug easier.
panic(fmt.Sprintf("cannot cleanup staging buffer, h=%v, len(db.stages)=%v", h, len(db.stages)))
}
cp := &db.stages[h-1]
if !db.vlogInvalid {
curr := db.vlog.Checkpoint()
if !curr.IsSamePosition(cp) {
db.vlog.RevertToCheckpoint(db, cp)
db.vlog.Truncate(cp)
}
}
db.stages = db.stages[:h-1]
db.vlog.OnMemChange()
}
// Checkpoint returns a checkpoint of RBT.
func (db *RBT) Checkpoint() *arena.MemDBCheckpoint {
cp := db.vlog.Checkpoint()
return &cp
}
// RevertToCheckpoint reverts the RBT to the checkpoint.
func (db *RBT) RevertToCheckpoint(cp *arena.MemDBCheckpoint) {
db.vlog.RevertToCheckpoint(db, cp)
db.vlog.Truncate(cp)
db.vlog.OnMemChange()
}
// Reset resets the MemBuffer to initial states.
func (db *RBT) Reset() {
db.root = arena.NullAddr
db.stages = db.stages[:0]
db.dirty = false
db.vlogInvalid = false
db.size = 0
db.count = 0
db.vlog.Reset()
db.allocator.reset()
}
// DiscardValues releases the memory used by all values.
// NOTE: any operation need value will panic after this function.
func (db *RBT) DiscardValues() {
db.vlogInvalid = true
db.vlog.Reset()
}
// InspectStage used to inspect the value updates in the given stage.
func (db *RBT) InspectStage(handle int, f func([]byte, kv.KeyFlags, []byte)) {
idx := handle - 1
tail := db.vlog.Checkpoint()
head := db.stages[idx]
db.vlog.InspectKVInLog(db, &head, &tail, f)
}
// Get gets the value for key k from kv store.
// If corresponding kv pair does not exist, it returns nil and ErrNotExist.
func (db *RBT) Get(key []byte) ([]byte, error) {
if db.vlogInvalid {
// panic for easier debugging.
panic("vlog is resetted")
}
x := db.traverse(key, false)
if x.isNull() {
return nil, tikverr.ErrNotExist
}
if x.vptr.IsNull() {
// A flag only key, act as value not exists
return nil, tikverr.ErrNotExist
}
return db.vlog.GetValue(x.vptr), nil
}
// SelectValueHistory select the latest value which makes `predicate` returns true from the modification history.
func (db *RBT) SelectValueHistory(key []byte, predicate func(value []byte) bool) ([]byte, error) {
x := db.traverse(key, false)
if x.isNull() {
return nil, tikverr.ErrNotExist
}
if x.vptr.IsNull() {
// A flag only key, act as value not exists
return nil, tikverr.ErrNotExist
}
result := db.vlog.SelectValueHistory(x.vptr, func(addr arena.MemdbArenaAddr) bool {
return predicate(db.vlog.GetValue(addr))
})
if result.IsNull() {
return nil, nil
}
return db.vlog.GetValue(result), nil
}
// GetFlags returns the latest flags associated with key.
func (db *RBT) GetFlags(key []byte) (kv.KeyFlags, error) {
x := db.traverse(key, false)
if x.isNull() {
return 0, tikverr.ErrNotExist
}
return x.getKeyFlags(), nil
}
// GetKeyByHandle returns key by handle.
func (db *RBT) GetKeyByHandle(handle arena.MemKeyHandle) []byte {
x := db.getNode(handle.ToAddr())
return x.getKey()
}
// GetValueByHandle returns value by handle.
func (db *RBT) GetValueByHandle(handle arena.MemKeyHandle) ([]byte, bool) {
if db.vlogInvalid {
return nil, false
}
x := db.getNode(handle.ToAddr())
if x.vptr.IsNull() {
return nil, false
}
return db.vlog.GetValue(x.vptr), true
}
// Len returns the number of entries in the DB.
func (db *RBT) Len() int {
return db.count
}
// Size returns sum of keys and values length.
func (db *RBT) Size() int {
return db.size
}
// Dirty returns whether the root staging buffer is updated.
func (db *RBT) Dirty() bool {
return db.dirty
}
func (db *RBT) Set(key []byte, value []byte, ops ...kv.FlagsOp) error {
if db.vlogInvalid {
// panic for easier debugging.
panic("vlog is reset")
}
if value != nil {
if size := uint64(len(key) + len(value)); size > db.entrySizeLimit {
return &tikverr.ErrEntryTooLarge{
Limit: db.entrySizeLimit,
Size: size,
}
}
}
if len(db.stages) == 0 {
db.dirty = true
}
x := db.traverse(key, true)
// the NeedConstraintCheckInPrewrite flag is temporary,
// every write to the node removes the flag unless it's explicitly set.
// This set must be in the latest stage so no special processing is needed.
var flags kv.KeyFlags
if value != nil {
flags = kv.ApplyFlagsOps(x.getKeyFlags(), append([]kv.FlagsOp{kv.DelNeedConstraintCheckInPrewrite}, ops...)...)
} else {
// an UpdateFlag operation, do not delete the NeedConstraintCheckInPrewrite flag.
flags = kv.ApplyFlagsOps(x.getKeyFlags(), ops...)
}
if flags.AndPersistent() != 0 {
db.dirty = true
}
x.setKeyFlags(flags)
if value == nil {
return nil
}
db.setValue(x, value)
if uint64(db.Size()) > db.bufferSizeLimit {
return &tikverr.ErrTxnTooLarge{Size: db.Size()}
}
return nil
}
func (db *RBT) setValue(x MemdbNodeAddr, value []byte) {
var activeCp *arena.MemDBCheckpoint
if len(db.stages) > 0 {
activeCp = &db.stages[len(db.stages)-1]
}
var oldVal []byte
if !x.vptr.IsNull() {
oldVal = db.vlog.GetValue(x.vptr)
}
if len(oldVal) > 0 && db.vlog.CanModify(activeCp, x.vptr) {
// For easier to implement, we only consider this case.
// It is the most common usage in TiDB's transaction buffers.
if len(oldVal) == len(value) {
copy(oldVal, value)
return
}
}
x.vptr = db.vlog.AppendValue(x.addr, x.vptr, value)
db.size = db.size - len(oldVal) + len(value)
}
// traverse search for and if not found and insert is true, will add a new node in.
// Returns a pointer to the new node, or the node found.
func (db *RBT) traverse(key []byte, insert bool) MemdbNodeAddr {
if node, found := db.checkKeyInCache(key); found {
db.hitCount.Add(1)
return node
}
db.missCount.Add(1)
x := db.getRoot()
y := MemdbNodeAddr{nil, arena.NullAddr}
found := false
// walk x down the tree
for !x.isNull() && !found {
cmp := bytes.Compare(key, x.getKey())
if cmp < 0 {
if insert && x.left.IsNull() {
y = x
}
x = x.getLeft(db)
} else if cmp > 0 {
if insert && x.right.IsNull() {
y = x
}
x = x.getRight(db)
} else {
found = true
}
}
if found {
db.updateLastTraversed(x)
}
if found || !insert {
return x
}
z := db.allocNode(key)
z.up = y.addr
if y.isNull() {
db.root = z.addr
} else {
cmp := bytes.Compare(z.getKey(), y.getKey())
if cmp < 0 {
y.left = z.addr
} else {
y.right = z.addr
}
}
z.left = arena.NullAddr
z.right = arena.NullAddr
// colour this new node red
z.setRed()
// Having added a red node, we must now walk back up the tree balancing it,
// by a series of rotations and changing of colours
x = z
// While we are not at the top and our parent node is red
// NOTE: Since the root node is guaranteed black, then we
// are also going to stop if we are the child of the root
for x.addr != db.root {
xUp := x.getUp(db)
if xUp.isBlack() {
break
}
xUpUp := xUp.getUp(db)
// if our parent is on the left side of our grandparent
if x.up == xUpUp.left {
// get the right side of our grandparent (uncle?)
y = xUpUp.getRight(db)
if y.isRed() {
// make our parent black
xUp.setBlack()
// make our uncle black
y.setBlack()
// make our grandparent red
xUpUp.setRed()
// now consider our grandparent
x = xUp.getUp(db)
} else {
// if we are on the right side of our parent
if x.addr == xUp.right {
// Move up to our parent
x = x.getUp(db)
db.leftRotate(x)
xUp = x.getUp(db)
xUpUp = xUp.getUp(db)
}
xUp.setBlack()
xUpUp.setRed()
db.rightRotate(xUpUp)
}
} else {
// everything here is the same as above, but exchanging left for right
y = xUpUp.getLeft(db)
if y.isRed() {
xUp.setBlack()
y.setBlack()
xUpUp.setRed()
x = xUp.getUp(db)
} else {
if x.addr == xUp.left {
x = x.getUp(db)
db.rightRotate(x)
xUp = x.getUp(db)
xUpUp = xUp.getUp(db)
}
xUp.setBlack()
xUpUp.setRed()
db.leftRotate(xUpUp)
}
}
}
// Set the root node black
db.getRoot().setBlack()
db.updateLastTraversed(z)
return z
}
//
// Rotate our tree thus:-
//
// X leftRotate(X)---> Y
// / \ / \
// A Y <---rightRotate(Y) X C
// / \ / \
// B C A B
//
// NOTE: This does not change the ordering.
//
// We assume that neither X nor Y is NULL
//
func (db *RBT) leftRotate(x MemdbNodeAddr) {
y := x.getRight(db)
// Turn Y's left subtree into X's right subtree (move B)
x.right = y.left
// If B is not null, set it's parent to be X
if !y.left.IsNull() {
left := y.getLeft(db)
left.up = x.addr
}
// Set Y's parent to be what X's parent was
y.up = x.up
// if X was the root
if x.up.IsNull() {
db.root = y.addr
} else {
xUp := x.getUp(db)
// Set X's parent's left or right pointer to be Y
if x.addr == xUp.left {
xUp.left = y.addr
} else {
xUp.right = y.addr
}
}
// Put X on Y's left
y.left = x.addr
// Set X's parent to be Y
x.up = y.addr
}
func (db *RBT) rightRotate(y MemdbNodeAddr) {
x := y.getLeft(db)
// Turn X's right subtree into Y's left subtree (move B)
y.left = x.right
// If B is not null, set it's parent to be Y
if !x.right.IsNull() {
right := x.getRight(db)
right.up = y.addr
}
// Set X's parent to be what Y's parent was
x.up = y.up
// if Y was the root
if y.up.IsNull() {
db.root = x.addr
} else {
yUp := y.getUp(db)
// Set Y's parent's left or right pointer to be X
if y.addr == yUp.left {
yUp.left = x.addr
} else {
yUp.right = x.addr
}
}
// Put Y on X's right
x.right = y.addr
// Set Y's parent to be X
y.up = x.addr
}
func (db *RBT) deleteNode(z MemdbNodeAddr) {
var x, y MemdbNodeAddr
if db.lastTraversedNode.Load().addr == z.addr {
db.lastTraversedNode.Store(&nullNodeAddr)
}
db.count--
db.size -= int(z.klen)
if z.left.IsNull() || z.right.IsNull() {
y = z
} else {
y = db.successor(z)
}
if !y.left.IsNull() {
x = y.getLeft(db)
} else {
x = y.getRight(db)
}
x.up = y.up
if y.up.IsNull() {
db.root = x.addr
} else {
yUp := y.getUp(db)
if y.addr == yUp.left {
yUp.left = x.addr
} else {
yUp.right = x.addr
}
}
needFix := y.isBlack()
// NOTE: traditional red-black tree will copy key from Y to Z and free Y.
// We cannot do the same thing here, due to Y's pointer is stored in vlog and the space in Z may not suitable for Y.
// So we need to copy states from Z to Y, and relink all nodes formerly connected to Z.
if y != z {
db.replaceNode(z, y)
}
if needFix {
db.deleteNodeFix(x)
}
db.allocator.freeNode(z.addr)
}
func (db *RBT) replaceNode(old MemdbNodeAddr, new MemdbNodeAddr) {
if !old.up.IsNull() {
oldUp := old.getUp(db)
if old.addr == oldUp.left {
oldUp.left = new.addr
} else {
oldUp.right = new.addr
}
} else {
db.root = new.addr
}
new.up = old.up
left := old.getLeft(db)
left.up = new.addr
new.left = old.left
right := old.getRight(db)
right.up = new.addr
new.right = old.right
if old.isBlack() {
new.setBlack()
} else {
new.setRed()
}
}
func (db *RBT) deleteNodeFix(x MemdbNodeAddr) {
for x.addr != db.root && x.isBlack() {
xUp := x.getUp(db)
if x.addr == xUp.left {
w := xUp.getRight(db)
if w.isRed() {
w.setBlack()
xUp.setRed()
db.leftRotate(xUp)
w = x.getUp(db).getRight(db)
}
if w.getLeft(db).isBlack() && w.getRight(db).isBlack() {
w.setRed()
x = x.getUp(db)
} else {
if w.getRight(db).isBlack() {
w.getLeft(db).setBlack()
w.setRed()
db.rightRotate(w)
w = x.getUp(db).getRight(db)
}
xUp := x.getUp(db)
if xUp.isBlack() {
w.setBlack()
} else {
w.setRed()
}
xUp.setBlack()
w.getRight(db).setBlack()
db.leftRotate(xUp)
x = db.getRoot()
}
} else {
w := xUp.getLeft(db)
if w.isRed() {
w.setBlack()
xUp.setRed()
db.rightRotate(xUp)
w = x.getUp(db).getLeft(db)
}
if w.getRight(db).isBlack() && w.getLeft(db).isBlack() {
w.setRed()
x = x.getUp(db)
} else {
if w.getLeft(db).isBlack() {
w.getRight(db).setBlack()
w.setRed()
db.leftRotate(w)
w = x.getUp(db).getLeft(db)
}
xUp := x.getUp(db)
if xUp.isBlack() {
w.setBlack()
} else {
w.setRed()
}
xUp.setBlack()
w.getLeft(db).setBlack()
db.rightRotate(xUp)
x = db.getRoot()
}
}
}
x.setBlack()
}
func (db *RBT) successor(x MemdbNodeAddr) (y MemdbNodeAddr) {
if !x.right.IsNull() {
// If right is not NULL then go right one and
// then keep going left until we find a node with
// no left pointer.
y = x.getRight(db)
for !y.left.IsNull() {
y = y.getLeft(db)
}
return
}
// Go up the tree until we get to a node that is on the
// left of its parent (or the root) and then return the
// parent.
y = x.getUp(db)
for !y.isNull() && x.addr == y.right {
x = y
y = y.getUp(db)
}
return y
}
func (db *RBT) predecessor(x MemdbNodeAddr) (y MemdbNodeAddr) {
if !x.left.IsNull() {
// If left is not NULL then go left one and
// then keep going right until we find a node with
// no right pointer.
y = x.getLeft(db)
for !y.right.IsNull() {
y = y.getRight(db)
}
return
}
// Go up the tree until we get to a node that is on the
// right of its parent (or the root) and then return the
// parent.
y = x.getUp(db)
for !y.isNull() && x.addr == y.left {
x = y
y = y.getUp(db)
}
return y
}
func (db *RBT) getNode(x arena.MemdbArenaAddr) MemdbNodeAddr {
return MemdbNodeAddr{db.allocator.getNode(x), x}
}
func (db *RBT) getRoot() MemdbNodeAddr {
return db.getNode(db.root)
}
func (db *RBT) allocNode(key []byte) MemdbNodeAddr {
db.size += len(key)
db.count++
x, xn := db.allocator.allocNode(key)
return MemdbNodeAddr{xn, x}
}
var nullNodeAddr = MemdbNodeAddr{nil, arena.NullAddr}
type MemdbNodeAddr struct {
*memdbNode
addr arena.MemdbArenaAddr
}
func (a *MemdbNodeAddr) isNull() bool {
return a.addr.IsNull()
}
func (a MemdbNodeAddr) getUp(db *RBT) MemdbNodeAddr {
return db.getNode(a.up)
}
func (a MemdbNodeAddr) getLeft(db *RBT) MemdbNodeAddr {
return db.getNode(a.left)
}
func (a MemdbNodeAddr) getRight(db *RBT) MemdbNodeAddr {
return db.getNode(a.right)
}
type memdbNode struct {
up arena.MemdbArenaAddr
left arena.MemdbArenaAddr
right arena.MemdbArenaAddr
vptr arena.MemdbArenaAddr
klen uint16
flags uint16
}
func (n *memdbNode) isRed() bool {
return n.flags&nodeColorBit != 0
}
func (n *memdbNode) isBlack() bool {
return !n.isRed()
}
func (n *memdbNode) setRed() {
n.flags |= nodeColorBit
}
func (n *memdbNode) setBlack() {
n.flags &= ^nodeColorBit
}
func (n *memdbNode) GetKey() []byte {
return n.getKey()
}
func (n *memdbNode) getKey() []byte {
base := unsafe.Add(unsafe.Pointer(&n.flags), kv.FlagBytes)
return unsafe.Slice((*byte)(base), int(n.klen))
}
const (
// bit 1 => red, bit 0 => black
nodeColorBit uint16 = 0x8000
nodeFlagsMask = ^nodeColorBit
)
func (n *memdbNode) GetKeyFlags() kv.KeyFlags {
return n.getKeyFlags()
}
func (n *memdbNode) getKeyFlags() kv.KeyFlags {
return kv.KeyFlags(n.flags & nodeFlagsMask)
}
func (n *memdbNode) setKeyFlags(f kv.KeyFlags) {
n.flags = (^nodeFlagsMask & n.flags) | uint16(f)
}
// RemoveFromBuffer removes a record from the mem buffer. It should be only used for test.
func (db *RBT) RemoveFromBuffer(key []byte) {
x := db.traverse(key, false)
if x.isNull() {
return
}
db.size -= len(db.vlog.GetValue(x.vptr))
db.deleteNode(x)
}
// SetMemoryFootprintChangeHook sets the hook function that is triggered when memdb grows.
func (db *RBT) SetMemoryFootprintChangeHook(hook func(uint64)) {
innerHook := func() {
hook(db.allocator.Capacity() + db.vlog.Capacity())
}
db.allocator.SetMemChangeHook(innerHook)
db.vlog.SetMemChangeHook(innerHook)
}
// Mem returns the current memory footprint
func (db *RBT) Mem() uint64 {
return db.allocator.Capacity() + db.vlog.Capacity()
}
// GetEntrySizeLimit gets the size limit for each entry and total buffer.
func (db *RBT) GetEntrySizeLimit() (uint64, uint64) {
return db.entrySizeLimit, db.bufferSizeLimit
}
// SetEntrySizeLimit sets the size limit for each entry and total buffer.
func (db *RBT) SetEntrySizeLimit(entryLimit, bufferLimit uint64) {
db.entrySizeLimit = entryLimit
db.bufferSizeLimit = bufferLimit
}
// MemHookSet implements the MemBuffer interface.
func (db *RBT) MemHookSet() bool {
return db.allocator.MemHookSet()
}
func (db *RBT) GetCacheHitCount() uint64 {
return db.hitCount.Load()
}
func (db *RBT) GetCacheMissCount() uint64 {
return db.missCount.Load()
}

View File

@ -0,0 +1,101 @@
// Copyright 2021 TiKV Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// NOTE: The code in this file is based on code from the
// TiDB project, licensed under the Apache License v 2.0
//
// https://github.com/pingcap/tidb/tree/cc5e161ac06827589c4966674597c137cc9e809c/store/tikv/unionstore/memdb_arena.go
//
// Copyright 2020 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package rbt
import (
"unsafe"
"github.com/tikv/client-go/v2/internal/unionstore/arena"
)
type nodeAllocator struct {
arena.MemdbArena
// Dummy node, so that we can make X.left.up = X.
// We then use this instead of NULL to mean the top or bottom
// end of the rb tree. It is a black node.
nullNode memdbNode
}
func (a *nodeAllocator) init() {
a.nullNode = memdbNode{
up: arena.NullAddr,
left: arena.NullAddr,
right: arena.NullAddr,
vptr: arena.NullAddr,
}
}
func (a *nodeAllocator) getNode(addr arena.MemdbArenaAddr) *memdbNode {
if addr.IsNull() {
return &a.nullNode
}
data := a.GetData(addr)
return (*memdbNode)(unsafe.Pointer(&data[0]))
}
const memdbNodeSize = int(unsafe.Sizeof(memdbNode{}))
func (a *nodeAllocator) allocNode(key []byte) (arena.MemdbArenaAddr, *memdbNode) {
nodeSize := memdbNodeSize + len(key)
prevBlocks := a.Blocks()
addr, mem := a.Alloc(nodeSize, true)
n := (*memdbNode)(unsafe.Pointer(&mem[0]))
n.vptr = arena.NullAddr
n.klen = uint16(len(key))
copy(n.getKey(), key)
if prevBlocks != a.Blocks() {
a.OnMemChange()
}
return addr, n
}
func (a *nodeAllocator) freeNode(addr arena.MemdbArenaAddr) {
if testMode {
// Make it easier for debug.
n := a.getNode(addr)
n.left = arena.BadAddr
n.right = arena.BadAddr
n.up = arena.BadAddr
n.vptr = arena.BadAddr
return
}
// TODO: reuse freed nodes. Need to fix lastTraversedNode when implementing this.
}
func (a *nodeAllocator) reset() {
a.MemdbArena.Reset()
a.init()
}

View File

@ -32,18 +32,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package unionstore
package rbt
import (
"bytes"
"github.com/tikv/client-go/v2/internal/unionstore/arena"
"github.com/tikv/client-go/v2/kv"
)
// MemdbIterator is an Iterator with KeyFlags related functions.
type MemdbIterator struct {
db *MemDB
curr memdbNodeAddr
// RBTIterator is an Iterator with KeyFlags related functions.
type RBTIterator struct {
db *RBT
curr MemdbNodeAddr
start []byte
end []byte
reverse bool
@ -54,8 +55,8 @@ type MemdbIterator struct {
// If such entry is not found, it returns an invalid Iterator with no error.
// It yields only keys that < upperBound. If upperBound is nil, it means the upperBound is unbounded.
// The Iterator must be Closed after use.
func (db *MemDB) Iter(k []byte, upperBound []byte) (Iterator, error) {
i := &MemdbIterator{
func (db *RBT) Iter(k []byte, upperBound []byte) (*RBTIterator, error) {
i := &RBTIterator{
db: db,
start: k,
end: upperBound,
@ -68,8 +69,8 @@ func (db *MemDB) Iter(k []byte, upperBound []byte) (Iterator, error) {
// The returned iterator will iterate from greater key to smaller key.
// If k is nil, the returned iterator will be positioned at the last key.
// It yields only keys that >= lowerBound. If lowerBound is nil, it means the lowerBound is unbounded.
func (db *MemDB) IterReverse(k []byte, lowerBound []byte) (Iterator, error) {
i := &MemdbIterator{
func (db *RBT) IterReverse(k []byte, lowerBound []byte) (*RBTIterator, error) {
i := &RBTIterator{
db: db,
start: lowerBound,
end: k,
@ -79,9 +80,9 @@ func (db *MemDB) IterReverse(k []byte, lowerBound []byte) (Iterator, error) {
return i, nil
}
// IterWithFlags returns a MemdbIterator.
func (db *MemDB) IterWithFlags(k []byte, upperBound []byte) *MemdbIterator {
i := &MemdbIterator{
// IterWithFlags returns a RBTIterator.
func (db *RBT) IterWithFlags(k []byte, upperBound []byte) *RBTIterator {
i := &RBTIterator{
db: db,
start: k,
end: upperBound,
@ -91,9 +92,9 @@ func (db *MemDB) IterWithFlags(k []byte, upperBound []byte) *MemdbIterator {
return i
}
// IterReverseWithFlags returns a reversed MemdbIterator.
func (db *MemDB) IterReverseWithFlags(k []byte) *MemdbIterator {
i := &MemdbIterator{
// IterReverseWithFlags returns a reversed RBTIterator.
func (db *RBT) IterReverseWithFlags(k []byte) *RBTIterator {
i := &RBTIterator{
db: db,
end: k,
reverse: true,
@ -103,7 +104,7 @@ func (db *MemDB) IterReverseWithFlags(k []byte) *MemdbIterator {
return i
}
func (i *MemdbIterator) init() {
func (i *RBTIterator) init() {
if i.reverse {
if len(i.end) == 0 {
i.seekToLast()
@ -125,7 +126,7 @@ func (i *MemdbIterator) init() {
}
// Valid returns true if the current iterator is valid.
func (i *MemdbIterator) Valid() bool {
func (i *RBTIterator) Valid() bool {
if !i.reverse {
return !i.curr.isNull() && (i.end == nil || bytes.Compare(i.Key(), i.end) < 0)
}
@ -133,42 +134,39 @@ func (i *MemdbIterator) Valid() bool {
}
// Flags returns flags belong to current iterator.
func (i *MemdbIterator) Flags() kv.KeyFlags {
func (i *RBTIterator) Flags() kv.KeyFlags {
return i.curr.getKeyFlags()
}
// UpdateFlags updates and apply with flagsOp.
func (i *MemdbIterator) UpdateFlags(ops ...kv.FlagsOp) {
func (i *RBTIterator) UpdateFlags(ops ...kv.FlagsOp) {
origin := i.curr.getKeyFlags()
n := kv.ApplyFlagsOps(origin, ops...)
i.curr.setKeyFlags(n)
}
// HasValue returns false if it is flags only.
func (i *MemdbIterator) HasValue() bool {
func (i *RBTIterator) HasValue() bool {
return !i.isFlagsOnly()
}
// Key returns current key.
func (i *MemdbIterator) Key() []byte {
func (i *RBTIterator) Key() []byte {
return i.curr.getKey()
}
// Handle returns MemKeyHandle with the current position.
func (i *MemdbIterator) Handle() MemKeyHandle {
return MemKeyHandle{
idx: uint16(i.curr.addr.idx),
off: i.curr.addr.off,
}
func (i *RBTIterator) Handle() arena.MemKeyHandle {
return i.curr.addr.ToHandle()
}
// Value returns the value.
func (i *MemdbIterator) Value() []byte {
return i.db.vlog.getValue(i.curr.vptr)
func (i *RBTIterator) Value() []byte {
return i.db.vlog.GetValue(i.curr.vptr)
}
// Next goes the next position.
func (i *MemdbIterator) Next() error {
func (i *RBTIterator) Next() error {
for {
if i.reverse {
i.curr = i.db.predecessor(i.curr)
@ -185,10 +183,10 @@ func (i *MemdbIterator) Next() error {
}
// Close closes the current iterator.
func (i *MemdbIterator) Close() {}
func (i *RBTIterator) Close() {}
func (i *MemdbIterator) seekToFirst() {
y := memdbNodeAddr{nil, nullAddr}
func (i *RBTIterator) seekToFirst() {
y := MemdbNodeAddr{nil, arena.NullAddr}
x := i.db.getNode(i.db.root)
for !x.isNull() {
@ -199,8 +197,8 @@ func (i *MemdbIterator) seekToFirst() {
i.curr = y
}
func (i *MemdbIterator) seekToLast() {
y := memdbNodeAddr{nil, nullAddr}
func (i *RBTIterator) seekToLast() {
y := MemdbNodeAddr{nil, arena.NullAddr}
x := i.db.getNode(i.db.root)
for !x.isNull() {
@ -211,8 +209,8 @@ func (i *MemdbIterator) seekToLast() {
i.curr = y
}
func (i *MemdbIterator) seek(key []byte) {
y := memdbNodeAddr{nil, nullAddr}
func (i *RBTIterator) seek(key []byte) {
y := MemdbNodeAddr{nil, arena.NullAddr}
x := i.db.getNode(i.db.root)
var cmp int
@ -246,6 +244,6 @@ func (i *MemdbIterator) seek(key []byte) {
i.curr = y
}
func (i *MemdbIterator) isFlagsOnly() bool {
return !i.curr.isNull() && i.curr.vptr.isNull()
func (i *RBTIterator) isFlagsOnly() bool {
return !i.curr.isNull() && i.curr.vptr.IsNull()
}

View File

@ -32,26 +32,27 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package unionstore
package rbt
import (
"context"
tikverr "github.com/tikv/client-go/v2/error"
"github.com/tikv/client-go/v2/internal/unionstore/arena"
)
// SnapshotGetter returns a Getter for a snapshot of MemBuffer.
func (db *MemDB) SnapshotGetter() Getter {
return &memdbSnapGetter{
func (db *RBT) SnapshotGetter() *rbtSnapGetter {
return &rbtSnapGetter{
db: db,
cp: db.getSnapshot(),
}
}
// SnapshotIter returns a Iterator for a snapshot of MemBuffer.
func (db *MemDB) SnapshotIter(start, end []byte) Iterator {
it := &memdbSnapIter{
MemdbIterator: &MemdbIterator{
// SnapshotIter returns an Iterator for a snapshot of MemBuffer.
func (db *RBT) SnapshotIter(start, end []byte) *rbtSnapIter {
it := &rbtSnapIter{
RBTIterator: &RBTIterator{
db: db,
start: start,
end: end,
@ -63,9 +64,9 @@ func (db *MemDB) SnapshotIter(start, end []byte) Iterator {
}
// SnapshotIterReverse returns a reverse Iterator for a snapshot of MemBuffer.
func (db *MemDB) SnapshotIterReverse(k, lowerBound []byte) Iterator {
it := &memdbSnapIter{
MemdbIterator: &MemdbIterator{
func (db *RBT) SnapshotIterReverse(k, lowerBound []byte) *rbtSnapIter {
it := &rbtSnapIter{
RBTIterator: &RBTIterator{
db: db,
start: lowerBound,
end: k,
@ -77,48 +78,48 @@ func (db *MemDB) SnapshotIterReverse(k, lowerBound []byte) Iterator {
return it
}
func (db *MemDB) getSnapshot() MemDBCheckpoint {
func (db *RBT) getSnapshot() arena.MemDBCheckpoint {
if len(db.stages) > 0 {
return db.stages[0]
}
return db.vlog.checkpoint()
return db.vlog.Checkpoint()
}
type memdbSnapGetter struct {
db *MemDB
cp MemDBCheckpoint
type rbtSnapGetter struct {
db *RBT
cp arena.MemDBCheckpoint
}
func (snap *memdbSnapGetter) Get(ctx context.Context, key []byte) ([]byte, error) {
func (snap *rbtSnapGetter) Get(ctx context.Context, key []byte) ([]byte, error) {
x := snap.db.traverse(key, false)
if x.isNull() {
return nil, tikverr.ErrNotExist
}
if x.vptr.isNull() {
if x.vptr.IsNull() {
// A flag only key, act as value not exists
return nil, tikverr.ErrNotExist
}
v, ok := snap.db.vlog.getSnapshotValue(x.vptr, &snap.cp)
v, ok := snap.db.vlog.GetSnapshotValue(x.vptr, &snap.cp)
if !ok {
return nil, tikverr.ErrNotExist
}
return v, nil
}
type memdbSnapIter struct {
*MemdbIterator
type rbtSnapIter struct {
*RBTIterator
value []byte
cp MemDBCheckpoint
cp arena.MemDBCheckpoint
}
func (i *memdbSnapIter) Value() []byte {
func (i *rbtSnapIter) Value() []byte {
return i.value
}
func (i *memdbSnapIter) Next() error {
func (i *rbtSnapIter) Next() error {
i.value = nil
for i.Valid() {
if err := i.MemdbIterator.Next(); err != nil {
if err := i.RBTIterator.Next(); err != nil {
return err
}
if i.setValue() {
@ -128,18 +129,18 @@ func (i *memdbSnapIter) Next() error {
return nil
}
func (i *memdbSnapIter) setValue() bool {
func (i *rbtSnapIter) setValue() bool {
if !i.Valid() {
return false
}
if v, ok := i.db.vlog.getSnapshotValue(i.curr.vptr, &i.cp); ok {
if v, ok := i.db.vlog.GetSnapshotValue(i.curr.vptr, &i.cp); ok {
i.value = v
return true
}
return false
}
func (i *memdbSnapIter) init() {
func (i *rbtSnapIter) init() {
if i.reverse {
if len(i.end) == 0 {
i.seekToLast()

View File

@ -0,0 +1,170 @@
// Copyright 2024 TiKV Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package rbt
import (
"encoding/binary"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tikv/client-go/v2/kv"
)
func init() {
testMode = true
}
func deriveAndFill(start, end, valueBase int, db *RBT) int {
h := db.Staging()
var kbuf, vbuf [4]byte
for i := start; i < end; i++ {
binary.BigEndian.PutUint32(kbuf[:], uint32(i))
binary.BigEndian.PutUint32(vbuf[:], uint32(i+valueBase))
db.Set(kbuf[:], vbuf[:])
}
return h
}
func TestDiscard(t *testing.T) {
assert := assert.New(t)
const cnt = 10000
db := New()
base := deriveAndFill(0, cnt, 0, db)
sz := db.Size()
db.Cleanup(deriveAndFill(0, cnt, 1, db))
assert.Equal(db.Len(), cnt)
assert.Equal(db.Size(), sz)
var buf [4]byte
for i := 0; i < cnt; i++ {
binary.BigEndian.PutUint32(buf[:], uint32(i))
v, err := db.Get(buf[:])
assert.Nil(err)
assert.Equal(v, buf[:])
}
var i int
for it, _ := db.Iter(nil, nil); it.Valid(); it.Next() {
binary.BigEndian.PutUint32(buf[:], uint32(i))
assert.Equal(it.Key(), buf[:])
assert.Equal(it.Value(), buf[:])
i++
}
assert.Equal(i, cnt)
i--
for it, _ := db.IterReverse(nil, nil); it.Valid(); it.Next() {
binary.BigEndian.PutUint32(buf[:], uint32(i))
assert.Equal(it.Key(), buf[:])
assert.Equal(it.Value(), buf[:])
i--
}
assert.Equal(i, -1)
db.Cleanup(base)
for i := 0; i < cnt; i++ {
binary.BigEndian.PutUint32(buf[:], uint32(i))
_, err := db.Get(buf[:])
assert.NotNil(err)
}
it, _ := db.Iter(nil, nil)
it.seekToFirst()
assert.False(it.Valid())
it.seekToLast()
assert.False(it.Valid())
it.seek([]byte{0xff})
assert.False(it.Valid())
}
func TestEmptyDB(t *testing.T) {
assert := assert.New(t)
db := New()
_, err := db.Get([]byte{0})
assert.NotNil(err)
it, _ := db.Iter(nil, nil)
it.seekToFirst()
assert.False(it.Valid())
it.seekToLast()
assert.False(it.Valid())
it.seek([]byte{0xff})
assert.False(it.Valid())
}
func TestFlags(t *testing.T) {
assert := assert.New(t)
const cnt = 10000
db := New()
h := db.Staging()
for i := uint32(0); i < cnt; i++ {
var buf [4]byte
binary.BigEndian.PutUint32(buf[:], i)
if i%2 == 0 {
db.Set(buf[:], buf[:], kv.SetPresumeKeyNotExists, kv.SetKeyLocked)
} else {
db.Set(buf[:], buf[:], kv.SetPresumeKeyNotExists)
}
}
db.Cleanup(h)
for i := uint32(0); i < cnt; i++ {
var buf [4]byte
binary.BigEndian.PutUint32(buf[:], i)
_, err := db.Get(buf[:])
assert.NotNil(err)
flags, err := db.GetFlags(buf[:])
if i%2 == 0 {
assert.Nil(err)
assert.True(flags.HasLocked())
assert.False(flags.HasPresumeKeyNotExists())
} else {
assert.NotNil(err)
}
}
assert.Equal(db.Len(), 5000)
assert.Equal(db.Size(), 20000)
it, _ := db.Iter(nil, nil)
assert.False(it.Valid())
it.includeFlags = true
it.init()
for ; it.Valid(); it.Next() {
k := binary.BigEndian.Uint32(it.Key())
assert.True(k%2 == 0)
}
for i := uint32(0); i < cnt; i++ {
var buf [4]byte
binary.BigEndian.PutUint32(buf[:], i)
db.Set(buf[:], nil, kv.DelKeyLocked)
}
for i := uint32(0); i < cnt; i++ {
var buf [4]byte
binary.BigEndian.PutUint32(buf[:], i)
_, err := db.Get(buf[:])
assert.NotNil(err)
// UpdateFlags will create missing node.
flags, err := db.GetFlags(buf[:])
assert.Nil(err)
assert.False(flags.HasLocked())
}
}

View File

@ -250,54 +250,7 @@ type Metrics struct {
}
var (
_ MemBuffer = &MemDBWithContext{}
_ MemBuffer = &PipelinedMemDB{}
_ MemBuffer = &rbtDBWithContext{}
_ MemBuffer = &artDBWithContext{}
)
// MemDBWithContext wraps MemDB to satisfy the MemBuffer interface.
type MemDBWithContext struct {
*MemDB
}
func NewMemDBWithContext() *MemDBWithContext {
return &MemDBWithContext{MemDB: newMemDB()}
}
func (db *MemDBWithContext) Get(_ context.Context, k []byte) ([]byte, error) {
return db.MemDB.Get(k)
}
func (db *MemDBWithContext) GetLocal(_ context.Context, k []byte) ([]byte, error) {
return db.MemDB.Get(k)
}
func (db *MemDBWithContext) Flush(bool) (bool, error) { return false, nil }
func (db *MemDBWithContext) FlushWait() error { return nil }
// GetMemDB returns the inner MemDB
func (db *MemDBWithContext) GetMemDB() *MemDB {
return db.MemDB
}
// BatchGet returns the values for given keys from the MemBuffer.
func (db *MemDBWithContext) BatchGet(ctx context.Context, keys [][]byte) (map[string][]byte, error) {
if db.Len() == 0 {
return map[string][]byte{}, nil
}
m := make(map[string][]byte, len(keys))
for _, k := range keys {
v, err := db.Get(ctx, k)
if err != nil {
if tikverr.IsErrNotFound(err) {
continue
}
return nil, err
}
m[string(k)] = v
}
return m, nil
}
// GetFlushMetrisc implements the MemBuffer interface.
func (db *MemDBWithContext) GetMetrics() Metrics { return Metrics{} }

View File

@ -44,7 +44,7 @@ import (
func TestUnionStoreGetSet(t *testing.T) {
assert := assert.New(t)
store := newMemDB()
store := NewMemDB()
us := NewUnionStore(NewMemDBWithContext(), &mockSnapshot{store})
err := store.Set([]byte("1"), []byte("1"))
@ -63,7 +63,7 @@ func TestUnionStoreGetSet(t *testing.T) {
func TestUnionStoreDelete(t *testing.T) {
assert := assert.New(t)
store := newMemDB()
store := NewMemDB()
us := NewUnionStore(NewMemDBWithContext(), &mockSnapshot{store})
err := store.Set([]byte("1"), []byte("1"))
@ -82,7 +82,7 @@ func TestUnionStoreDelete(t *testing.T) {
func TestUnionStoreSeek(t *testing.T) {
assert := assert.New(t)
store := newMemDB()
store := NewMemDB()
us := NewUnionStore(NewMemDBWithContext(), &mockSnapshot{store})
err := store.Set([]byte("1"), []byte("1"))
@ -115,7 +115,7 @@ func TestUnionStoreSeek(t *testing.T) {
func TestUnionStoreIterReverse(t *testing.T) {
assert := assert.New(t)
store := newMemDB()
store := NewMemDB()
us := NewUnionStore(NewMemDBWithContext(), &mockSnapshot{store})
err := store.Set([]byte("1"), []byte("1"))

View File

@ -186,7 +186,7 @@ func NewTiKVTxn(store kvstore, snapshot *txnsnapshot.KVSnapshot, startTS uint64,
RequestSource: snapshot.RequestSource,
}
if !options.PipelinedMemDB {
newTiKVTxn.us = unionstore.NewUnionStore(unionstore.NewMemDBWithContext(), snapshot)
newTiKVTxn.us = unionstore.NewUnionStore(unionstore.NewMemDB(), snapshot)
return newTiKVTxn, nil
}
if err := newTiKVTxn.InitPipelinedMemDB(); err != nil {