tso: merge lastTS and lastArrivalTS into an atomic pointer (#1054)

* fix the issue that stale timestamp may be a future one

Signed-off-by: you06 <you1474600@gmail.com>

* add regression test

Signed-off-by: you06 <you1474600@gmail.com>

* lazy init lastTSO

Signed-off-by: you06 <you1474600@gmail.com>

* fix panic

Signed-off-by: you06 <you1474600@gmail.com>

* address comment

Signed-off-by: you06 <you1474600@gmail.com>

---------

Signed-off-by: you06 <you1474600@gmail.com>
This commit is contained in:
you06 2023-11-15 17:34:14 +09:00 committed by GitHub
parent 6659170644
commit 7c96dfd783
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 72 additions and 64 deletions

View File

@ -63,20 +63,8 @@ func NewEmptyPDOracle() oracle.Oracle {
func SetEmptyPDOracleLastTs(oc oracle.Oracle, ts uint64) {
switch o := oc.(type) {
case *pdOracle:
lastTSInterface, _ := o.lastTSMap.LoadOrStore(oracle.GlobalTxnScope, new(uint64))
lastTSPointer := lastTSInterface.(*uint64)
atomic.StoreUint64(lastTSPointer, ts)
lasTSArrivalInterface, _ := o.lastArrivalTSMap.LoadOrStore(oracle.GlobalTxnScope, new(uint64))
lasTSArrivalPointer := lasTSArrivalInterface.(*uint64)
atomic.StoreUint64(lasTSArrivalPointer, uint64(time.Now().Unix()*1000))
}
setEmptyPDOracleLastArrivalTs(oc, ts)
}
// setEmptyPDOracleLastArrivalTs exports PD oracle's global last ts to test.
func setEmptyPDOracleLastArrivalTs(oc oracle.Oracle, ts uint64) {
switch o := oc.(type) {
case *pdOracle:
o.setLastArrivalTS(ts, oracle.GlobalTxnScope)
lastTSInterface, _ := o.lastTSMap.LoadOrStore(oracle.GlobalTxnScope, &atomic.Pointer[lastTSO]{})
lastTSPointer := lastTSInterface.(*atomic.Pointer[lastTSO])
lastTSPointer.Store(&lastTSO{tso: ts, arrival: ts})
}
}

View File

@ -56,11 +56,15 @@ const slowDist = 30 * time.Millisecond
// pdOracle is an Oracle that uses a placement driver client as source.
type pdOracle struct {
c pd.Client
// txn_scope (string) -> lastTSPointer (*uint64)
// txn_scope (string) -> lastTSPointer (*atomic.Pointer[lastTSO])
lastTSMap sync.Map
// txn_scope (string) -> lastArrivalTSPointer (*uint64)
lastArrivalTSMap sync.Map
quit chan struct{}
quit chan struct{}
}
// lastTSO stores the last timestamp oracle gets from PD server and the local time when the TSO is fetched.
type lastTSO struct {
tso uint64
arrival uint64
}
// NewPdOracle create an Oracle that uses a pd client source.
@ -163,63 +167,51 @@ func (o *pdOracle) setLastTS(ts uint64, txnScope string) {
if txnScope == "" {
txnScope = oracle.GlobalTxnScope
}
current := &lastTSO{
tso: ts,
arrival: o.getArrivalTimestamp(),
}
lastTSInterface, ok := o.lastTSMap.Load(txnScope)
if !ok {
lastTSInterface, _ = o.lastTSMap.LoadOrStore(txnScope, new(uint64))
pointer := &atomic.Pointer[lastTSO]{}
pointer.Store(current)
// do not handle the stored case, because it only runs once.
lastTSInterface, _ = o.lastTSMap.LoadOrStore(txnScope, pointer)
}
lastTSPointer := lastTSInterface.(*uint64)
lastTSPointer := lastTSInterface.(*atomic.Pointer[lastTSO])
for {
lastTS := atomic.LoadUint64(lastTSPointer)
if ts <= lastTS {
last := lastTSPointer.Load()
if current.tso <= last.tso || current.arrival <= last.arrival {
return
}
if atomic.CompareAndSwapUint64(lastTSPointer, lastTS, ts) {
break
}
}
o.setLastArrivalTS(o.getArrivalTimestamp(), txnScope)
}
func (o *pdOracle) setLastArrivalTS(ts uint64, txnScope string) {
if txnScope == "" {
txnScope = oracle.GlobalTxnScope
}
lastTSInterface, ok := o.lastArrivalTSMap.Load(txnScope)
if !ok {
lastTSInterface, _ = o.lastArrivalTSMap.LoadOrStore(txnScope, new(uint64))
}
lastTSPointer := lastTSInterface.(*uint64)
for {
lastTS := atomic.LoadUint64(lastTSPointer)
if ts <= lastTS {
return
}
if atomic.CompareAndSwapUint64(lastTSPointer, lastTS, ts) {
if lastTSPointer.CompareAndSwap(last, current) {
return
}
}
}
func (o *pdOracle) getLastTS(txnScope string) (uint64, bool) {
last, exist := o.getLastTSWithArrivalTS(txnScope)
if !exist {
return 0, false
}
return last.tso, true
}
func (o *pdOracle) getLastTSWithArrivalTS(txnScope string) (*lastTSO, bool) {
if txnScope == "" {
txnScope = oracle.GlobalTxnScope
}
lastTSInterface, ok := o.lastTSMap.Load(txnScope)
if !ok {
return 0, false
return nil, false
}
return atomic.LoadUint64(lastTSInterface.(*uint64)), true
}
func (o *pdOracle) getLastArrivalTS(txnScope string) (uint64, bool) {
if txnScope == "" {
txnScope = oracle.GlobalTxnScope
lastTSPointer := lastTSInterface.(*atomic.Pointer[lastTSO])
last := lastTSPointer.Load()
if last == nil {
return nil, false
}
lastArrivalTSInterface, ok := o.lastArrivalTSMap.Load(txnScope)
if !ok {
return 0, false
}
return atomic.LoadUint64(lastArrivalTSInterface.(*uint64)), true
return last, true
}
func (o *pdOracle) updateTS(ctx context.Context, interval time.Duration) {
@ -293,22 +285,18 @@ func (o *pdOracle) GetLowResolutionTimestampAsync(ctx context.Context, opt *orac
}
func (o *pdOracle) getStaleTimestamp(txnScope string, prevSecond uint64) (uint64, error) {
ts, ok := o.getLastTS(txnScope)
last, ok := o.getLastTSWithArrivalTS(txnScope)
if !ok {
return 0, errors.Errorf("get stale timestamp fail, txnScope: %s", txnScope)
}
arrivalTS, ok := o.getLastArrivalTS(txnScope)
if !ok {
return 0, errors.Errorf("get stale arrival timestamp fail, txnScope: %s", txnScope)
}
ts, arrivalTS := last.tso, last.arrival
arrivalTime := oracle.GetTimeFromTS(arrivalTS)
physicalTime := oracle.GetTimeFromTS(ts)
if uint64(physicalTime.Unix()) <= prevSecond {
return 0, errors.Errorf("invalid prevSecond %v", prevSecond)
}
staleTime := physicalTime.Add(-arrivalTime.Sub(time.Now().Add(-time.Duration(prevSecond) * time.Second)))
staleTime := physicalTime.Add(time.Now().Add(-time.Duration(prevSecond) * time.Second).Sub(arrivalTime))
return oracle.GoTimeToTS(staleTime), nil
}

View File

@ -72,3 +72,35 @@ func TestPdOracle_GetStaleTimestamp(t *testing.T) {
assert.NotNil(t, err)
assert.Regexp(t, ".*invalid prevSecond.*", err.Error())
}
func TestNonFutureStaleTSO(t *testing.T) {
o := oracles.NewEmptyPDOracle()
oracles.SetEmptyPDOracleLastTs(o, oracle.GoTimeToTS(time.Now()))
for i := 0; i < 100; i++ {
time.Sleep(10 * time.Millisecond)
now := time.Now()
upperBound := now.Add(5 * time.Millisecond) // allow 5ms time drift
closeCh := make(chan struct{})
go func() {
time.Sleep(100 * time.Microsecond)
oracles.SetEmptyPDOracleLastTs(o, oracle.GoTimeToTS(now))
close(closeCh)
}()
CHECK:
for {
select {
case <-closeCh:
break CHECK
default:
ts, err := o.GetStaleTimestamp(context.Background(), oracle.GlobalTxnScope, 0)
assert.Nil(t, err)
staleTime := oracle.GetTimeFromTS(ts)
if staleTime.After(upperBound) && time.Since(now) < time.Millisecond /* only check staleTime within 1ms */ {
assert.Less(t, staleTime, upperBound, i)
t.FailNow()
}
}
}
}
}