mirror of https://github.com/tikv/client-go.git
tso: merge lastTS and lastArrivalTS into an atomic pointer (#1054)
* fix the issue that stale timestamp may be a future one Signed-off-by: you06 <you1474600@gmail.com> * add regression test Signed-off-by: you06 <you1474600@gmail.com> * lazy init lastTSO Signed-off-by: you06 <you1474600@gmail.com> * fix panic Signed-off-by: you06 <you1474600@gmail.com> * address comment Signed-off-by: you06 <you1474600@gmail.com> --------- Signed-off-by: you06 <you1474600@gmail.com>
This commit is contained in:
parent
6659170644
commit
7c96dfd783
|
|
@ -63,20 +63,8 @@ func NewEmptyPDOracle() oracle.Oracle {
|
|||
func SetEmptyPDOracleLastTs(oc oracle.Oracle, ts uint64) {
|
||||
switch o := oc.(type) {
|
||||
case *pdOracle:
|
||||
lastTSInterface, _ := o.lastTSMap.LoadOrStore(oracle.GlobalTxnScope, new(uint64))
|
||||
lastTSPointer := lastTSInterface.(*uint64)
|
||||
atomic.StoreUint64(lastTSPointer, ts)
|
||||
lasTSArrivalInterface, _ := o.lastArrivalTSMap.LoadOrStore(oracle.GlobalTxnScope, new(uint64))
|
||||
lasTSArrivalPointer := lasTSArrivalInterface.(*uint64)
|
||||
atomic.StoreUint64(lasTSArrivalPointer, uint64(time.Now().Unix()*1000))
|
||||
}
|
||||
setEmptyPDOracleLastArrivalTs(oc, ts)
|
||||
}
|
||||
|
||||
// setEmptyPDOracleLastArrivalTs exports PD oracle's global last ts to test.
|
||||
func setEmptyPDOracleLastArrivalTs(oc oracle.Oracle, ts uint64) {
|
||||
switch o := oc.(type) {
|
||||
case *pdOracle:
|
||||
o.setLastArrivalTS(ts, oracle.GlobalTxnScope)
|
||||
lastTSInterface, _ := o.lastTSMap.LoadOrStore(oracle.GlobalTxnScope, &atomic.Pointer[lastTSO]{})
|
||||
lastTSPointer := lastTSInterface.(*atomic.Pointer[lastTSO])
|
||||
lastTSPointer.Store(&lastTSO{tso: ts, arrival: ts})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -56,11 +56,15 @@ const slowDist = 30 * time.Millisecond
|
|||
// pdOracle is an Oracle that uses a placement driver client as source.
|
||||
type pdOracle struct {
|
||||
c pd.Client
|
||||
// txn_scope (string) -> lastTSPointer (*uint64)
|
||||
// txn_scope (string) -> lastTSPointer (*atomic.Pointer[lastTSO])
|
||||
lastTSMap sync.Map
|
||||
// txn_scope (string) -> lastArrivalTSPointer (*uint64)
|
||||
lastArrivalTSMap sync.Map
|
||||
quit chan struct{}
|
||||
quit chan struct{}
|
||||
}
|
||||
|
||||
// lastTSO stores the last timestamp oracle gets from PD server and the local time when the TSO is fetched.
|
||||
type lastTSO struct {
|
||||
tso uint64
|
||||
arrival uint64
|
||||
}
|
||||
|
||||
// NewPdOracle create an Oracle that uses a pd client source.
|
||||
|
|
@ -163,63 +167,51 @@ func (o *pdOracle) setLastTS(ts uint64, txnScope string) {
|
|||
if txnScope == "" {
|
||||
txnScope = oracle.GlobalTxnScope
|
||||
}
|
||||
current := &lastTSO{
|
||||
tso: ts,
|
||||
arrival: o.getArrivalTimestamp(),
|
||||
}
|
||||
lastTSInterface, ok := o.lastTSMap.Load(txnScope)
|
||||
if !ok {
|
||||
lastTSInterface, _ = o.lastTSMap.LoadOrStore(txnScope, new(uint64))
|
||||
pointer := &atomic.Pointer[lastTSO]{}
|
||||
pointer.Store(current)
|
||||
// do not handle the stored case, because it only runs once.
|
||||
lastTSInterface, _ = o.lastTSMap.LoadOrStore(txnScope, pointer)
|
||||
}
|
||||
lastTSPointer := lastTSInterface.(*uint64)
|
||||
lastTSPointer := lastTSInterface.(*atomic.Pointer[lastTSO])
|
||||
for {
|
||||
lastTS := atomic.LoadUint64(lastTSPointer)
|
||||
if ts <= lastTS {
|
||||
last := lastTSPointer.Load()
|
||||
if current.tso <= last.tso || current.arrival <= last.arrival {
|
||||
return
|
||||
}
|
||||
if atomic.CompareAndSwapUint64(lastTSPointer, lastTS, ts) {
|
||||
break
|
||||
}
|
||||
}
|
||||
o.setLastArrivalTS(o.getArrivalTimestamp(), txnScope)
|
||||
}
|
||||
|
||||
func (o *pdOracle) setLastArrivalTS(ts uint64, txnScope string) {
|
||||
if txnScope == "" {
|
||||
txnScope = oracle.GlobalTxnScope
|
||||
}
|
||||
lastTSInterface, ok := o.lastArrivalTSMap.Load(txnScope)
|
||||
if !ok {
|
||||
lastTSInterface, _ = o.lastArrivalTSMap.LoadOrStore(txnScope, new(uint64))
|
||||
}
|
||||
lastTSPointer := lastTSInterface.(*uint64)
|
||||
for {
|
||||
lastTS := atomic.LoadUint64(lastTSPointer)
|
||||
if ts <= lastTS {
|
||||
return
|
||||
}
|
||||
if atomic.CompareAndSwapUint64(lastTSPointer, lastTS, ts) {
|
||||
if lastTSPointer.CompareAndSwap(last, current) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (o *pdOracle) getLastTS(txnScope string) (uint64, bool) {
|
||||
last, exist := o.getLastTSWithArrivalTS(txnScope)
|
||||
if !exist {
|
||||
return 0, false
|
||||
}
|
||||
return last.tso, true
|
||||
}
|
||||
|
||||
func (o *pdOracle) getLastTSWithArrivalTS(txnScope string) (*lastTSO, bool) {
|
||||
if txnScope == "" {
|
||||
txnScope = oracle.GlobalTxnScope
|
||||
}
|
||||
lastTSInterface, ok := o.lastTSMap.Load(txnScope)
|
||||
if !ok {
|
||||
return 0, false
|
||||
return nil, false
|
||||
}
|
||||
return atomic.LoadUint64(lastTSInterface.(*uint64)), true
|
||||
}
|
||||
|
||||
func (o *pdOracle) getLastArrivalTS(txnScope string) (uint64, bool) {
|
||||
if txnScope == "" {
|
||||
txnScope = oracle.GlobalTxnScope
|
||||
lastTSPointer := lastTSInterface.(*atomic.Pointer[lastTSO])
|
||||
last := lastTSPointer.Load()
|
||||
if last == nil {
|
||||
return nil, false
|
||||
}
|
||||
lastArrivalTSInterface, ok := o.lastArrivalTSMap.Load(txnScope)
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
return atomic.LoadUint64(lastArrivalTSInterface.(*uint64)), true
|
||||
return last, true
|
||||
}
|
||||
|
||||
func (o *pdOracle) updateTS(ctx context.Context, interval time.Duration) {
|
||||
|
|
@ -293,22 +285,18 @@ func (o *pdOracle) GetLowResolutionTimestampAsync(ctx context.Context, opt *orac
|
|||
}
|
||||
|
||||
func (o *pdOracle) getStaleTimestamp(txnScope string, prevSecond uint64) (uint64, error) {
|
||||
ts, ok := o.getLastTS(txnScope)
|
||||
last, ok := o.getLastTSWithArrivalTS(txnScope)
|
||||
if !ok {
|
||||
return 0, errors.Errorf("get stale timestamp fail, txnScope: %s", txnScope)
|
||||
}
|
||||
arrivalTS, ok := o.getLastArrivalTS(txnScope)
|
||||
if !ok {
|
||||
return 0, errors.Errorf("get stale arrival timestamp fail, txnScope: %s", txnScope)
|
||||
}
|
||||
ts, arrivalTS := last.tso, last.arrival
|
||||
arrivalTime := oracle.GetTimeFromTS(arrivalTS)
|
||||
physicalTime := oracle.GetTimeFromTS(ts)
|
||||
if uint64(physicalTime.Unix()) <= prevSecond {
|
||||
return 0, errors.Errorf("invalid prevSecond %v", prevSecond)
|
||||
}
|
||||
|
||||
staleTime := physicalTime.Add(-arrivalTime.Sub(time.Now().Add(-time.Duration(prevSecond) * time.Second)))
|
||||
|
||||
staleTime := physicalTime.Add(time.Now().Add(-time.Duration(prevSecond) * time.Second).Sub(arrivalTime))
|
||||
return oracle.GoTimeToTS(staleTime), nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -72,3 +72,35 @@ func TestPdOracle_GetStaleTimestamp(t *testing.T) {
|
|||
assert.NotNil(t, err)
|
||||
assert.Regexp(t, ".*invalid prevSecond.*", err.Error())
|
||||
}
|
||||
|
||||
func TestNonFutureStaleTSO(t *testing.T) {
|
||||
o := oracles.NewEmptyPDOracle()
|
||||
oracles.SetEmptyPDOracleLastTs(o, oracle.GoTimeToTS(time.Now()))
|
||||
for i := 0; i < 100; i++ {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
now := time.Now()
|
||||
upperBound := now.Add(5 * time.Millisecond) // allow 5ms time drift
|
||||
|
||||
closeCh := make(chan struct{})
|
||||
go func() {
|
||||
time.Sleep(100 * time.Microsecond)
|
||||
oracles.SetEmptyPDOracleLastTs(o, oracle.GoTimeToTS(now))
|
||||
close(closeCh)
|
||||
}()
|
||||
CHECK:
|
||||
for {
|
||||
select {
|
||||
case <-closeCh:
|
||||
break CHECK
|
||||
default:
|
||||
ts, err := o.GetStaleTimestamp(context.Background(), oracle.GlobalTxnScope, 0)
|
||||
assert.Nil(t, err)
|
||||
staleTime := oracle.GetTimeFromTS(ts)
|
||||
if staleTime.After(upperBound) && time.Since(now) < time.Millisecond /* only check staleTime within 1ms */ {
|
||||
assert.Less(t, staleTime, upperBound, i)
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue