diff --git a/internal/unionstore/memdb.go b/internal/unionstore/memdb.go index 76fe3794..46475043 100644 --- a/internal/unionstore/memdb.go +++ b/internal/unionstore/memdb.go @@ -39,6 +39,7 @@ import ( "fmt" "math" "sync" + "sync/atomic" "unsafe" tikverr "github.com/tikv/client-go/v2/error" @@ -88,6 +89,11 @@ type MemDB struct { stages []MemDBCheckpoint // when the MemDB is wrapper by upper RWMutex, we can skip the internal mutex. skipMutex bool + + // The lastTraversedNode must exist + lastTraversedNode atomic.Pointer[memdbNodeAddr] + hitCount atomic.Uint64 + missCount atomic.Uint64 } const unlimitedSize = math.MaxUint64 @@ -101,9 +107,29 @@ func newMemDB() *MemDB { db.bufferSizeLimit = unlimitedSize db.vlog.memdb = db db.skipMutex = false + db.lastTraversedNode.Store(&nullNodeAddr) return db } +// updateLastTraversed updates the last traversed node atomically +func (db *MemDB) updateLastTraversed(node memdbNodeAddr) { + db.lastTraversedNode.Store(&node) +} + +// checkKeyInCache retrieves the last traversed node if the key matches +func (db *MemDB) checkKeyInCache(key []byte) (memdbNodeAddr, bool) { + nodePtr := db.lastTraversedNode.Load() + if nodePtr == nil || nodePtr.isNull() { + return nullNodeAddr, false + } + + if bytes.Equal(key, nodePtr.memdbNode.getKey()) { + return *nodePtr, true + } + + return nullNodeAddr, false +} + // Staging create a new staging buffer inside the MemBuffer. // Subsequent writes will be temporarily stored in this new staging buffer. // When you think all modifications looks good, you can call `Release` to public all of them to the upper level buffer. @@ -329,7 +355,7 @@ func (db *MemDB) set(key []byte, value []byte, ops ...kv.FlagsOp) error { if db.vlogInvalid { // panic for easier debugging. - panic("vlog is resetted") + panic("vlog is reset") } if value != nil { @@ -398,23 +424,38 @@ func (db *MemDB) setValue(x memdbNodeAddr, value []byte) { // traverse search for and if not found and insert is true, will add a new node in. // Returns a pointer to the new node, or the node found. func (db *MemDB) traverse(key []byte, insert bool) memdbNodeAddr { + if node, found := db.checkKeyInCache(key); found { + db.hitCount.Add(1) + return node + } + db.missCount.Add(1) + x := db.getRoot() y := memdbNodeAddr{nil, nullAddr} found := false // walk x down the tree for !x.isNull() && !found { - y = x cmp := bytes.Compare(key, x.getKey()) if cmp < 0 { + if insert && x.left.isNull() { + y = x + } x = x.getLeft(db) } else if cmp > 0 { + if insert && x.right.isNull() { + y = x + } x = x.getRight(db) } else { found = true } } + if found { + db.updateLastTraversed(x) + } + if found || !insert { return x } @@ -508,6 +549,8 @@ func (db *MemDB) traverse(key []byte, insert bool) memdbNodeAddr { // Set the root node black db.getRoot().setBlack() + db.updateLastTraversed(z) + return z } @@ -595,6 +638,9 @@ func (db *MemDB) rightRotate(y memdbNodeAddr) { func (db *MemDB) deleteNode(z memdbNodeAddr) { var x, y memdbNodeAddr + if db.lastTraversedNode.Load().addr == z.addr { + db.lastTraversedNode.Store(&nullNodeAddr) + } db.count-- db.size -= int(z.klen) diff --git a/internal/unionstore/memdb_arena.go b/internal/unionstore/memdb_arena.go index 146d6a0f..e92a98e0 100644 --- a/internal/unionstore/memdb_arena.go +++ b/internal/unionstore/memdb_arena.go @@ -39,10 +39,8 @@ import ( "math" "unsafe" - "github.com/tikv/client-go/v2/internal/logutil" "github.com/tikv/client-go/v2/kv" "go.uber.org/atomic" - "go.uber.org/zap" ) const ( @@ -54,8 +52,9 @@ const ( ) var ( - nullAddr = memdbArenaAddr{math.MaxUint32, math.MaxUint32} - endian = binary.LittleEndian + nullAddr = memdbArenaAddr{math.MaxUint32, math.MaxUint32} + nullNodeAddr = memdbNodeAddr{nil, nullAddr} + endian = binary.LittleEndian ) type memdbArenaAddr struct { @@ -64,17 +63,8 @@ type memdbArenaAddr struct { } func (addr memdbArenaAddr) isNull() bool { - if addr == nullAddr { - return true - } - if addr.idx == math.MaxUint32 || addr.off == math.MaxUint32 { - // defensive programming, the code should never run to here. - // it always means something wrong... (maybe caused by data race?) - // because we never set part of idx/off to math.MaxUint64 - logutil.BgLogger().Warn("Invalid memdbArenaAddr", zap.Uint32("idx", addr.idx), zap.Uint32("off", addr.off)) - return true - } - return false + // Combine all checks into a single condition + return addr == nullAddr || addr.idx == math.MaxUint32 || addr.off == math.MaxUint32 } // store and load is used by vlog, due to pointer in vlog is not aligned. @@ -279,7 +269,7 @@ func (a *nodeAllocator) freeNode(addr memdbArenaAddr) { n.vptr = badAddr return } - // TODO: reuse freed nodes. + // TODO: reuse freed nodes. Need to fix lastTraversedNode when implementing this. } func (a *nodeAllocator) reset() { diff --git a/internal/unionstore/pipelined_memdb.go b/internal/unionstore/pipelined_memdb.go index 77538491..47518aed 100644 --- a/internal/unionstore/pipelined_memdb.go +++ b/internal/unionstore/pipelined_memdb.go @@ -58,6 +58,8 @@ type PipelinedMemDB struct { // metrics flushWaitDuration time.Duration + hitCount uint64 + missCount uint64 startTime time.Time } @@ -307,6 +309,8 @@ func (p *PipelinedMemDB) Flush(force bool) (bool, error) { p.flushingMemDB = p.memDB p.len += p.flushingMemDB.Len() p.size += p.flushingMemDB.Size() + p.missCount += p.memDB.missCount.Load() + p.hitCount += p.memDB.hitCount.Load() p.memDB = newMemDB() // buffer size is limited by ForceFlushMemSizeThreshold. Do not set bufferLimit p.memDB.SetEntrySizeLimit(p.entryLimit, unlimitedSize) @@ -523,11 +527,20 @@ func (p *PipelinedMemDB) RevertToCheckpoint(*MemDBCheckpoint) { panic("RevertToCheckpoint is not supported for PipelinedMemDB") } -// GetFlushMetrics implements MemBuffer interface. -func (p *PipelinedMemDB) GetFlushMetrics() FlushMetrics { - return FlushMetrics{ - WaitDuration: p.flushWaitDuration, - TotalDuration: time.Since(p.startTime), +// GetMetrics implements MemBuffer interface. +// DO NOT call it during execution, otherwise data race may occur +func (p *PipelinedMemDB) GetMetrics() Metrics { + hitCount := p.hitCount + missCount := p.missCount + if p.memDB != nil { + hitCount += p.memDB.hitCount.Load() + missCount += p.memDB.missCount.Load() + } + return Metrics{ + WaitDuration: p.flushWaitDuration, + TotalDuration: time.Since(p.startTime), + MemDBHitCount: hitCount, + MemDBMissCount: missCount, } } diff --git a/internal/unionstore/union_store.go b/internal/unionstore/union_store.go index 537513dd..19c10093 100644 --- a/internal/unionstore/union_store.go +++ b/internal/unionstore/union_store.go @@ -238,13 +238,15 @@ type MemBuffer interface { Flush(force bool) (bool, error) // FlushWait waits for the flushing task done and return error. FlushWait() error - // GetFlushDetails returns the metrics related to flushing - GetFlushMetrics() FlushMetrics + // GetMetrics returns the metrics related to flushing + GetMetrics() Metrics } -type FlushMetrics struct { - WaitDuration time.Duration - TotalDuration time.Duration +type Metrics struct { + WaitDuration time.Duration + TotalDuration time.Duration + MemDBHitCount uint64 + MemDBMissCount uint64 } var ( @@ -298,4 +300,4 @@ func (db *MemDBWithContext) BatchGet(ctx context.Context, keys [][]byte) (map[st } // GetFlushMetrisc implements the MemBuffer interface. -func (db *MemDBWithContext) GetFlushMetrics() FlushMetrics { return FlushMetrics{} } +func (db *MemDBWithContext) GetMetrics() Metrics { return Metrics{} } diff --git a/tikv/unionstore_export.go b/tikv/unionstore_export.go index 7ec5cdb0..80ee88f4 100644 --- a/tikv/unionstore_export.go +++ b/tikv/unionstore_export.go @@ -57,3 +57,6 @@ type MemBuffer = unionstore.MemBuffer // MemDBCheckpoint is the checkpoint of memory DB. type MemDBCheckpoint = unionstore.MemDBCheckpoint + +// Metrics is the metrics of unionstore. +type Metrics = unionstore.Metrics diff --git a/txnkv/transaction/pessimistic.go b/txnkv/transaction/pessimistic.go index ac6050dd..637044c2 100644 --- a/txnkv/transaction/pessimistic.go +++ b/txnkv/transaction/pessimistic.go @@ -174,7 +174,7 @@ func (action actionPessimisticLock) handleSingleBatch( for _, m := range mutations { keys = append(keys, hex.EncodeToString(m.Key)) } - logutil.BgLogger().Info( + logutil.BgLogger().Debug( "[failpoint] injected lock ttl = 1 on pessimistic lock", zap.Uint64("txnStartTS", c.startTS), zap.Strings("keys", keys), ) diff --git a/txnkv/transaction/pipelined_flush.go b/txnkv/transaction/pipelined_flush.go index b481dfc2..3ba39d27 100644 --- a/txnkv/transaction/pipelined_flush.go +++ b/txnkv/transaction/pipelined_flush.go @@ -304,8 +304,10 @@ func (c *twoPhaseCommitter) commitFlushedMutations(bo *retry.Backoffer) error { logutil.Logger(bo.GetCtx()).Info( "[pipelined dml] start to commit transaction", zap.Int("keys", c.txn.GetMemBuffer().Len()), - zap.Duration("flush_wait_duration", c.txn.GetMemBuffer().GetFlushMetrics().WaitDuration), - zap.Duration("total_duration", c.txn.GetMemBuffer().GetFlushMetrics().TotalDuration), + zap.Duration("flush_wait_duration", c.txn.GetMemBuffer().GetMetrics().WaitDuration), + zap.Duration("total_duration", c.txn.GetMemBuffer().GetMetrics().TotalDuration), + zap.Uint64("memdb traversal cache hit", c.txn.GetMemBuffer().GetMetrics().MemDBHitCount), + zap.Uint64("memdb traversal cache miss", c.txn.GetMemBuffer().GetMetrics().MemDBMissCount), zap.String("size", units.HumanSize(float64(c.txn.GetMemBuffer().Size()))), zap.Uint64("startTS", c.startTS), ) diff --git a/txnkv/transaction/txn.go b/txnkv/transaction/txn.go index 02cf5035..2496d3f3 100644 --- a/txnkv/transaction/txn.go +++ b/txnkv/transaction/txn.go @@ -876,7 +876,7 @@ func (txn *KVTxn) onCommitted(err error) { AsyncCommitFallback: txn.committer.hasTriedAsyncCommit && !isAsyncCommit, OnePCFallback: txn.committer.hasTriedOnePC && !isOnePC, Pipelined: txn.IsPipelined(), - FlushWaitMs: txn.GetMemBuffer().GetFlushMetrics().WaitDuration.Milliseconds(), + FlushWaitMs: txn.GetMemBuffer().GetMetrics().WaitDuration.Milliseconds(), } if err != nil { info.ErrMsg = err.Error()