Use context everywhere (#23)

* rawkv

Signed-off-by: disksing <i@disksing.com>

* txnkv wip

Signed-off-by: disksing <i@disksing.com>

* txnkv wip

Signed-off-by: disksing <i@disksing.com>

* txnkv update get & batchGet

Signed-off-by: disksing <i@disksing.com>

* txnkv iterators

Signed-off-by: disksing <i@disksing.com>
This commit is contained in:
disksing 2019-06-20 01:18:29 +08:00 committed by GitHub
parent 44b82dcc9f
commit 77a15fcd87
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 394 additions and 333 deletions

View File

@ -14,6 +14,7 @@
package main
import (
"context"
"fmt"
"github.com/tikv/client-go/config"
@ -21,7 +22,7 @@ import (
)
func main() {
cli, err := rawkv.NewClient([]string{"127.0.0.1:2379"}, config.Default())
cli, err := rawkv.NewClient(context.TODO(), []string{"127.0.0.1:2379"}, config.Default())
if err != nil {
panic(err)
}
@ -33,28 +34,28 @@ func main() {
val := []byte("PingCAP")
// put key into tikv
err = cli.Put(key, val)
err = cli.Put(context.TODO(), key, val)
if err != nil {
panic(err)
}
fmt.Printf("Successfully put %s:%s to tikv\n", key, val)
// get key from tikv
val, err = cli.Get(key)
val, err = cli.Get(context.TODO(), key)
if err != nil {
panic(err)
}
fmt.Printf("found val: %s for key: %s\n", val, key)
// delete key from tikv
err = cli.Delete(key)
err = cli.Delete(context.TODO(), key)
if err != nil {
panic(err)
}
fmt.Printf("key: %s deleted\n", key)
// get key again from tikv
val, err = cli.Get(key)
val, err = cli.Get(context.TODO(), key)
if err != nil {
panic(err)
}

View File

@ -41,7 +41,7 @@ var (
// Init initializes information.
func initStore() {
var err error
client, err = txnkv.NewClient([]string{*pdAddr}, config.Default())
client, err = txnkv.NewClient(context.TODO(), []string{*pdAddr}, config.Default())
if err != nil {
panic(err)
}
@ -49,7 +49,7 @@ func initStore() {
// key1 val1 key2 val2 ...
func puts(args ...[]byte) error {
tx, err := client.Begin()
tx, err := client.Begin(context.TODO())
if err != nil {
return err
}
@ -65,11 +65,11 @@ func puts(args ...[]byte) error {
}
func get(k []byte) (KV, error) {
tx, err := client.Begin()
tx, err := client.Begin(context.TODO())
if err != nil {
return KV{}, err
}
v, err := tx.Get(k)
v, err := tx.Get(context.TODO(), k)
if err != nil {
return KV{}, err
}
@ -77,7 +77,7 @@ func get(k []byte) (KV, error) {
}
func dels(keys ...[]byte) error {
tx, err := client.Begin()
tx, err := client.Begin(context.TODO())
if err != nil {
return err
}
@ -91,11 +91,11 @@ func dels(keys ...[]byte) error {
}
func scan(keyPrefix []byte, limit int) ([]KV, error) {
tx, err := client.Begin()
tx, err := client.Begin(context.TODO())
if err != nil {
return nil, err
}
it, err := tx.Iter(key.Key(keyPrefix), nil)
it, err := tx.Iter(context.TODO(), key.Key(keyPrefix), nil)
if err != nil {
return nil, err
}
@ -104,7 +104,7 @@ func scan(keyPrefix []byte, limit int) ([]KV, error) {
for it.Valid() && limit > 0 {
ret = append(ret, KV{K: it.Key()[:], V: it.Value()[:]})
limit--
it.Next()
it.Next(context.TODO())
}
return ret, nil
}

View File

@ -14,7 +14,9 @@
package httpproxy
import (
"context"
"net/http"
"time"
"github.com/gorilla/mux"
"github.com/tikv/client-go/proxy"
@ -69,3 +71,19 @@ func NewHTTPProxyHandler() http.Handler {
return router
}
var defaultTimeout = 20 * time.Second
func reqContext(vars map[string]string) (context.Context, context.CancelFunc) {
ctx := context.Background()
if id := vars["id"]; id != "" {
ctx = context.WithValue(ctx, proxy.UUIDKey, proxy.UUID(id))
}
d, err := time.ParseDuration(vars["timeout"])
if err != nil {
d = defaultTimeout
}
return context.WithTimeout(ctx, d)
}

View File

@ -14,6 +14,7 @@
package httpproxy
import (
"context"
"encoding/json"
"io/ioutil"
"net/http"
@ -47,94 +48,94 @@ type RawResponse struct {
Values [][]byte `json:"values,omitempty"` // for batchGet
}
func (h rawkvHandler) New(vars map[string]string, r *RawRequest) (*RawResponse, int, error) {
id, err := h.p.New(r.PDAddrs, config.Default())
func (h rawkvHandler) New(ctx context.Context, r *RawRequest) (*RawResponse, int, error) {
id, err := h.p.New(ctx, r.PDAddrs, config.Default())
if err != nil {
return nil, http.StatusInternalServerError, err
}
return &RawResponse{ID: string(id)}, http.StatusCreated, nil
}
func (h rawkvHandler) Close(vars map[string]string, r *RawRequest) (*RawResponse, int, error) {
if err := h.p.Close(proxy.UUID(vars["id"])); err != nil {
func (h rawkvHandler) Close(ctx context.Context, r *RawRequest) (*RawResponse, int, error) {
if err := h.p.Close(ctx); err != nil {
return nil, http.StatusInternalServerError, err
}
return &RawResponse{}, http.StatusOK, nil
}
func (h rawkvHandler) Get(vars map[string]string, r *RawRequest) (*RawResponse, int, error) {
val, err := h.p.Get(proxy.UUID(vars["id"]), r.Key)
func (h rawkvHandler) Get(ctx context.Context, r *RawRequest) (*RawResponse, int, error) {
val, err := h.p.Get(ctx, r.Key)
if err != nil {
return nil, http.StatusInternalServerError, err
}
return &RawResponse{Value: val}, http.StatusOK, nil
}
func (h rawkvHandler) BatchGet(vars map[string]string, r *RawRequest) (*RawResponse, int, error) {
vals, err := h.p.BatchGet(proxy.UUID(vars["id"]), r.Keys)
func (h rawkvHandler) BatchGet(ctx context.Context, r *RawRequest) (*RawResponse, int, error) {
vals, err := h.p.BatchGet(ctx, r.Keys)
if err != nil {
return nil, http.StatusInternalServerError, err
}
return &RawResponse{Values: vals}, http.StatusOK, nil
}
func (h rawkvHandler) Put(vars map[string]string, r *RawRequest) (*RawResponse, int, error) {
err := h.p.Put(proxy.UUID(vars["id"]), r.Key, r.Value)
func (h rawkvHandler) Put(ctx context.Context, r *RawRequest) (*RawResponse, int, error) {
err := h.p.Put(ctx, r.Key, r.Value)
if err != nil {
return nil, http.StatusInternalServerError, err
}
return &RawResponse{}, http.StatusOK, nil
}
func (h rawkvHandler) BatchPut(vars map[string]string, r *RawRequest) (*RawResponse, int, error) {
err := h.p.BatchPut(proxy.UUID(vars["id"]), r.Keys, r.Values)
func (h rawkvHandler) BatchPut(ctx context.Context, r *RawRequest) (*RawResponse, int, error) {
err := h.p.BatchPut(ctx, r.Keys, r.Values)
if err != nil {
return nil, http.StatusInternalServerError, err
}
return &RawResponse{}, http.StatusOK, nil
}
func (h rawkvHandler) Delete(vars map[string]string, r *RawRequest) (*RawResponse, int, error) {
err := h.p.Delete(proxy.UUID(vars["id"]), r.Key)
func (h rawkvHandler) Delete(ctx context.Context, r *RawRequest) (*RawResponse, int, error) {
err := h.p.Delete(ctx, r.Key)
if err != nil {
return nil, http.StatusInternalServerError, err
}
return &RawResponse{}, http.StatusOK, nil
}
func (h rawkvHandler) BatchDelete(vars map[string]string, r *RawRequest) (*RawResponse, int, error) {
err := h.p.BatchDelete(proxy.UUID(vars["id"]), r.Keys)
func (h rawkvHandler) BatchDelete(ctx context.Context, r *RawRequest) (*RawResponse, int, error) {
err := h.p.BatchDelete(ctx, r.Keys)
if err != nil {
return nil, http.StatusInternalServerError, err
}
return &RawResponse{}, http.StatusOK, nil
}
func (h rawkvHandler) DeleteRange(vars map[string]string, r *RawRequest) (*RawResponse, int, error) {
err := h.p.DeleteRange(proxy.UUID(vars["id"]), r.StartKey, r.EndKey)
func (h rawkvHandler) DeleteRange(ctx context.Context, r *RawRequest) (*RawResponse, int, error) {
err := h.p.DeleteRange(ctx, r.StartKey, r.EndKey)
if err != nil {
return nil, http.StatusInternalServerError, err
}
return &RawResponse{}, http.StatusOK, nil
}
func (h rawkvHandler) Scan(vars map[string]string, r *RawRequest) (*RawResponse, int, error) {
keys, values, err := h.p.Scan(proxy.UUID(vars["id"]), r.StartKey, r.EndKey, r.Limit)
func (h rawkvHandler) Scan(ctx context.Context, r *RawRequest) (*RawResponse, int, error) {
keys, values, err := h.p.Scan(ctx, r.StartKey, r.EndKey, r.Limit)
if err != nil {
return nil, http.StatusInternalServerError, err
}
return &RawResponse{Keys: keys, Values: values}, http.StatusOK, nil
}
func (h rawkvHandler) ReverseScan(vars map[string]string, r *RawRequest) (*RawResponse, int, error) {
keys, values, err := h.p.ReverseScan(proxy.UUID(vars["id"]), r.StartKey, r.EndKey, r.Limit)
func (h rawkvHandler) ReverseScan(ctx context.Context, r *RawRequest) (*RawResponse, int, error) {
keys, values, err := h.p.ReverseScan(ctx, r.StartKey, r.EndKey, r.Limit)
if err != nil {
return nil, http.StatusInternalServerError, err
}
return &RawResponse{Keys: keys, Values: values}, http.StatusOK, nil
}
func (h rawkvHandler) handlerFunc(f func(map[string]string, *RawRequest) (*RawResponse, int, error)) func(http.ResponseWriter, *http.Request) {
func (h rawkvHandler) handlerFunc(f func(context.Context, *RawRequest) (*RawResponse, int, error)) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
data, err := ioutil.ReadAll(r.Body)
if err != nil {
@ -146,7 +147,9 @@ func (h rawkvHandler) handlerFunc(f func(map[string]string, *RawRequest) (*RawRe
sendError(w, err, http.StatusBadRequest)
return
}
res, status, err := f(mux.Vars(r), &req)
ctx, cancel := reqContext(mux.Vars(r))
res, status, err := f(ctx, &req)
cancel()
if err != nil {
sendError(w, err, status)
return

View File

@ -14,6 +14,7 @@
package httpproxy
import (
"context"
"encoding/json"
"io/ioutil"
"net/http"
@ -51,56 +52,56 @@ type TxnResponse struct {
Length int `json:"length,omitempty"` // for length
}
func (h txnkvHandler) New(vars map[string]string, r *TxnRequest) (*TxnResponse, int, error) {
id, err := h.p.New(r.PDAddrs, config.Default())
func (h txnkvHandler) New(ctx context.Context, r *TxnRequest) (*TxnResponse, int, error) {
id, err := h.p.New(ctx, r.PDAddrs, config.Default())
if err != nil {
return nil, http.StatusInternalServerError, err
}
return &TxnResponse{ID: string(id)}, http.StatusOK, nil
}
func (h txnkvHandler) Close(vars map[string]string, r *TxnRequest) (*TxnResponse, int, error) {
err := h.p.Close(proxy.UUID(vars["id"]))
func (h txnkvHandler) Close(ctx context.Context, r *TxnRequest) (*TxnResponse, int, error) {
err := h.p.Close(ctx)
if err != nil {
return nil, http.StatusInternalServerError, err
}
return &TxnResponse{}, http.StatusOK, nil
}
func (h txnkvHandler) Begin(vars map[string]string, r *TxnRequest) (*TxnResponse, int, error) {
txnID, err := h.p.Begin(proxy.UUID(vars["id"]))
func (h txnkvHandler) Begin(ctx context.Context, r *TxnRequest) (*TxnResponse, int, error) {
txnID, err := h.p.Begin(ctx)
if err != nil {
return nil, http.StatusInternalServerError, err
}
return &TxnResponse{ID: string(txnID)}, http.StatusCreated, nil
}
func (h txnkvHandler) BeginWithTS(vars map[string]string, r *TxnRequest) (*TxnResponse, int, error) {
txnID, err := h.p.BeginWithTS(proxy.UUID(vars["id"]), r.TS)
func (h txnkvHandler) BeginWithTS(ctx context.Context, r *TxnRequest) (*TxnResponse, int, error) {
txnID, err := h.p.BeginWithTS(ctx, r.TS)
if err != nil {
return nil, http.StatusInternalServerError, err
}
return &TxnResponse{ID: string(txnID)}, http.StatusOK, nil
}
func (h txnkvHandler) GetTS(vars map[string]string, r *TxnRequest) (*TxnResponse, int, error) {
ts, err := h.p.GetTS(proxy.UUID(vars["id"]))
func (h txnkvHandler) GetTS(ctx context.Context, r *TxnRequest) (*TxnResponse, int, error) {
ts, err := h.p.GetTS(ctx)
if err != nil {
return nil, http.StatusInternalServerError, err
}
return &TxnResponse{TS: ts}, http.StatusOK, nil
}
func (h txnkvHandler) TxnGet(vars map[string]string, r *TxnRequest) (*TxnResponse, int, error) {
val, err := h.p.TxnGet(proxy.UUID(vars["id"]), r.Key)
func (h txnkvHandler) TxnGet(ctx context.Context, r *TxnRequest) (*TxnResponse, int, error) {
val, err := h.p.TxnGet(ctx, r.Key)
if err != nil {
return nil, http.StatusInternalServerError, err
}
return &TxnResponse{Value: val}, http.StatusOK, nil
}
func (h txnkvHandler) TxnBatchGet(vars map[string]string, r *TxnRequest) (*TxnResponse, int, error) {
m, err := h.p.TxnBatchGet(proxy.UUID(vars["id"]), r.Keys)
func (h txnkvHandler) TxnBatchGet(ctx context.Context, r *TxnRequest) (*TxnResponse, int, error) {
m, err := h.p.TxnBatchGet(ctx, r.Keys)
if err != nil {
return nil, http.StatusInternalServerError, err
}
@ -112,135 +113,135 @@ func (h txnkvHandler) TxnBatchGet(vars map[string]string, r *TxnRequest) (*TxnRe
return &TxnResponse{Keys: keys, Values: values}, http.StatusOK, nil
}
func (h txnkvHandler) TxnSet(vars map[string]string, r *TxnRequest) (*TxnResponse, int, error) {
err := h.p.TxnSet(proxy.UUID(vars["id"]), r.Key, r.Value)
func (h txnkvHandler) TxnSet(ctx context.Context, r *TxnRequest) (*TxnResponse, int, error) {
err := h.p.TxnSet(ctx, r.Key, r.Value)
if err != nil {
return nil, http.StatusInternalServerError, err
}
return &TxnResponse{}, http.StatusOK, nil
}
func (h txnkvHandler) TxnIter(vars map[string]string, r *TxnRequest) (*TxnResponse, int, error) {
iterID, err := h.p.TxnIter(proxy.UUID(vars["id"]), r.Key, r.UpperBound)
func (h txnkvHandler) TxnIter(ctx context.Context, r *TxnRequest) (*TxnResponse, int, error) {
iterID, err := h.p.TxnIter(ctx, r.Key, r.UpperBound)
if err != nil {
return nil, http.StatusInternalServerError, err
}
return &TxnResponse{ID: string(iterID)}, http.StatusCreated, nil
}
func (h txnkvHandler) TxnIterReverse(vars map[string]string, r *TxnRequest) (*TxnResponse, int, error) {
iterID, err := h.p.TxnIterReverse(proxy.UUID(vars["id"]), r.Key)
func (h txnkvHandler) TxnIterReverse(ctx context.Context, r *TxnRequest) (*TxnResponse, int, error) {
iterID, err := h.p.TxnIterReverse(ctx, r.Key)
if err != nil {
return nil, http.StatusInternalServerError, err
}
return &TxnResponse{ID: string(iterID)}, http.StatusCreated, nil
}
func (h txnkvHandler) TxnIsReadOnly(vars map[string]string, r *TxnRequest) (*TxnResponse, int, error) {
readonly, err := h.p.TxnIsReadOnly(proxy.UUID(vars["id"]))
func (h txnkvHandler) TxnIsReadOnly(ctx context.Context, r *TxnRequest) (*TxnResponse, int, error) {
readonly, err := h.p.TxnIsReadOnly(ctx)
if err != nil {
return nil, http.StatusInternalServerError, err
}
return &TxnResponse{IsReadOnly: readonly}, http.StatusOK, nil
}
func (h txnkvHandler) TxnDelete(vars map[string]string, r *TxnRequest) (*TxnResponse, int, error) {
err := h.p.TxnDelete(proxy.UUID(vars["id"]), r.Key)
func (h txnkvHandler) TxnDelete(ctx context.Context, r *TxnRequest) (*TxnResponse, int, error) {
err := h.p.TxnDelete(ctx, r.Key)
if err != nil {
return nil, http.StatusInternalServerError, err
}
return &TxnResponse{}, http.StatusOK, nil
}
func (h txnkvHandler) TxnCommit(vars map[string]string, r *TxnRequest) (*TxnResponse, int, error) {
err := h.p.TxnCommit(proxy.UUID(vars["id"]))
func (h txnkvHandler) TxnCommit(ctx context.Context, r *TxnRequest) (*TxnResponse, int, error) {
err := h.p.TxnCommit(ctx)
if err != nil {
return nil, http.StatusInternalServerError, err
}
return &TxnResponse{}, http.StatusOK, nil
}
func (h txnkvHandler) TxnRollback(vars map[string]string, r *TxnRequest) (*TxnResponse, int, error) {
err := h.p.TxnRollback(proxy.UUID(vars["id"]))
func (h txnkvHandler) TxnRollback(ctx context.Context, r *TxnRequest) (*TxnResponse, int, error) {
err := h.p.TxnRollback(ctx)
if err != nil {
return nil, http.StatusInternalServerError, err
}
return &TxnResponse{}, http.StatusOK, nil
}
func (h txnkvHandler) TxnLockKeys(vars map[string]string, r *TxnRequest) (*TxnResponse, int, error) {
err := h.p.TxnLockKeys(proxy.UUID(vars["id"]), r.Keys)
func (h txnkvHandler) TxnLockKeys(ctx context.Context, r *TxnRequest) (*TxnResponse, int, error) {
err := h.p.TxnLockKeys(ctx, r.Keys)
if err != nil {
return nil, http.StatusInternalServerError, err
}
return &TxnResponse{}, http.StatusOK, nil
}
func (h txnkvHandler) TxnValid(vars map[string]string, r *TxnRequest) (*TxnResponse, int, error) {
valid, err := h.p.TxnValid(proxy.UUID(vars["id"]))
func (h txnkvHandler) TxnValid(ctx context.Context, r *TxnRequest) (*TxnResponse, int, error) {
valid, err := h.p.TxnValid(ctx)
if err != nil {
return nil, http.StatusInternalServerError, err
}
return &TxnResponse{IsValid: valid}, http.StatusOK, nil
}
func (h txnkvHandler) TxnLen(vars map[string]string, r *TxnRequest) (*TxnResponse, int, error) {
length, err := h.p.TxnLen(proxy.UUID(vars["id"]))
func (h txnkvHandler) TxnLen(ctx context.Context, r *TxnRequest) (*TxnResponse, int, error) {
length, err := h.p.TxnLen(ctx)
if err != nil {
return nil, http.StatusInternalServerError, err
}
return &TxnResponse{Length: length}, http.StatusOK, nil
}
func (h txnkvHandler) TxnSize(vars map[string]string, r *TxnRequest) (*TxnResponse, int, error) {
size, err := h.p.TxnSize(proxy.UUID(vars["id"]))
func (h txnkvHandler) TxnSize(ctx context.Context, r *TxnRequest) (*TxnResponse, int, error) {
size, err := h.p.TxnSize(ctx)
if err != nil {
return nil, http.StatusInternalServerError, err
}
return &TxnResponse{Size: size}, http.StatusOK, nil
}
func (h txnkvHandler) IterValid(vars map[string]string, r *TxnRequest) (*TxnResponse, int, error) {
valid, err := h.p.IterValid(proxy.UUID(vars["id"]))
func (h txnkvHandler) IterValid(ctx context.Context, r *TxnRequest) (*TxnResponse, int, error) {
valid, err := h.p.IterValid(ctx)
if err != nil {
return nil, http.StatusInternalServerError, err
}
return &TxnResponse{IsValid: valid}, http.StatusOK, nil
}
func (h txnkvHandler) IterKey(vars map[string]string, r *TxnRequest) (*TxnResponse, int, error) {
key, err := h.p.IterKey(proxy.UUID(vars["id"]))
func (h txnkvHandler) IterKey(ctx context.Context, r *TxnRequest) (*TxnResponse, int, error) {
key, err := h.p.IterKey(ctx)
if err != nil {
return nil, http.StatusInternalServerError, err
}
return &TxnResponse{Key: key}, http.StatusOK, nil
}
func (h txnkvHandler) IterValue(vars map[string]string, r *TxnRequest) (*TxnResponse, int, error) {
val, err := h.p.IterValue(proxy.UUID(vars["id"]))
func (h txnkvHandler) IterValue(ctx context.Context, r *TxnRequest) (*TxnResponse, int, error) {
val, err := h.p.IterValue(ctx)
if err != nil {
return nil, http.StatusInternalServerError, err
}
return &TxnResponse{Value: val}, http.StatusOK, nil
}
func (h txnkvHandler) IterNext(vars map[string]string, r *TxnRequest) (*TxnResponse, int, error) {
err := h.p.IterNext(proxy.UUID(vars["id"]))
func (h txnkvHandler) IterNext(ctx context.Context, r *TxnRequest) (*TxnResponse, int, error) {
err := h.p.IterNext(ctx)
if err != nil {
return nil, http.StatusInternalServerError, err
}
return &TxnResponse{}, http.StatusOK, nil
}
func (h txnkvHandler) IterClose(vars map[string]string, r *TxnRequest) (*TxnResponse, int, error) {
err := h.p.IterClose(proxy.UUID(vars["id"]))
func (h txnkvHandler) IterClose(ctx context.Context, r *TxnRequest) (*TxnResponse, int, error) {
err := h.p.IterClose(ctx)
if err != nil {
return nil, http.StatusInternalServerError, err
}
return &TxnResponse{}, http.StatusOK, nil
}
func (h txnkvHandler) handlerFunc(f func(map[string]string, *TxnRequest) (*TxnResponse, int, error)) func(http.ResponseWriter, *http.Request) {
func (h txnkvHandler) handlerFunc(f func(context.Context, *TxnRequest) (*TxnResponse, int, error)) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
data, err := ioutil.ReadAll(r.Body)
if err != nil {
@ -252,7 +253,9 @@ func (h txnkvHandler) handlerFunc(f func(map[string]string, *TxnRequest) (*TxnRe
sendError(w, err, http.StatusBadRequest)
return
}
res, status, err := f(mux.Vars(r), &req)
ctx, cancel := reqContext(mux.Vars(r))
res, status, err := f(ctx, &req)
cancel()
if err != nil {
sendError(w, err, status)
return

View File

@ -14,6 +14,7 @@
package proxy
import (
"context"
"sync"
"github.com/pkg/errors"
@ -35,8 +36,8 @@ func NewRaw() RawKVProxy {
}
// New creates a new client and returns the client's UUID.
func (p RawKVProxy) New(pdAddrs []string, conf config.Config) (UUID, error) {
client, err := rawkv.NewClient(pdAddrs, conf)
func (p RawKVProxy) New(ctx context.Context, pdAddrs []string, conf config.Config) (UUID, error) {
client, err := rawkv.NewClient(ctx, pdAddrs, conf)
if err != nil {
return "", err
}
@ -44,7 +45,8 @@ func (p RawKVProxy) New(pdAddrs []string, conf config.Config) (UUID, error) {
}
// Close releases a rawkv client.
func (p RawKVProxy) Close(id UUID) error {
func (p RawKVProxy) Close(ctx context.Context) error {
id := uuidFromContext(ctx)
client, ok := p.clients.Load(id)
if !ok {
return errors.WithStack(ErrClientNotFound)
@ -57,83 +59,83 @@ func (p RawKVProxy) Close(id UUID) error {
}
// Get queries value with the key.
func (p RawKVProxy) Get(id UUID, key []byte) ([]byte, error) {
client, ok := p.clients.Load(id)
func (p RawKVProxy) Get(ctx context.Context, key []byte) ([]byte, error) {
client, ok := p.clients.Load(uuidFromContext(ctx))
if !ok {
return nil, errors.WithStack(ErrClientNotFound)
}
return client.(*rawkv.Client).Get(key)
return client.(*rawkv.Client).Get(ctx, key)
}
// BatchGet queries values with the keys.
func (p RawKVProxy) BatchGet(id UUID, keys [][]byte) ([][]byte, error) {
client, ok := p.clients.Load(id)
func (p RawKVProxy) BatchGet(ctx context.Context, keys [][]byte) ([][]byte, error) {
client, ok := p.clients.Load(uuidFromContext(ctx))
if !ok {
return nil, errors.WithStack(ErrClientNotFound)
}
return client.(*rawkv.Client).BatchGet(keys)
return client.(*rawkv.Client).BatchGet(ctx, keys)
}
// Put stores a key-value pair to TiKV.
func (p RawKVProxy) Put(id UUID, key, value []byte) error {
client, ok := p.clients.Load(id)
func (p RawKVProxy) Put(ctx context.Context, key, value []byte) error {
client, ok := p.clients.Load(uuidFromContext(ctx))
if !ok {
return errors.WithStack(ErrClientNotFound)
}
return client.(*rawkv.Client).Put(key, value)
return client.(*rawkv.Client).Put(ctx, key, value)
}
// BatchPut stores key-value pairs to TiKV.
func (p RawKVProxy) BatchPut(id UUID, keys, values [][]byte) error {
client, ok := p.clients.Load(id)
func (p RawKVProxy) BatchPut(ctx context.Context, keys, values [][]byte) error {
client, ok := p.clients.Load(uuidFromContext(ctx))
if !ok {
return errors.WithStack(ErrClientNotFound)
}
return client.(*rawkv.Client).BatchPut(keys, values)
return client.(*rawkv.Client).BatchPut(ctx, keys, values)
}
// Delete deletes a key-value pair from TiKV.
func (p RawKVProxy) Delete(id UUID, key []byte) error {
client, ok := p.clients.Load(id)
func (p RawKVProxy) Delete(ctx context.Context, key []byte) error {
client, ok := p.clients.Load(uuidFromContext(ctx))
if !ok {
return errors.WithStack(ErrClientNotFound)
}
return client.(*rawkv.Client).Delete(key)
return client.(*rawkv.Client).Delete(ctx, key)
}
// BatchDelete deletes key-value pairs from TiKV.
func (p RawKVProxy) BatchDelete(id UUID, keys [][]byte) error {
client, ok := p.clients.Load(id)
func (p RawKVProxy) BatchDelete(ctx context.Context, keys [][]byte) error {
client, ok := p.clients.Load(uuidFromContext(ctx))
if !ok {
return errors.WithStack(ErrClientNotFound)
}
return client.(*rawkv.Client).BatchDelete(keys)
return client.(*rawkv.Client).BatchDelete(ctx, keys)
}
// DeleteRange deletes all key-value pairs in a range from TiKV.
func (p RawKVProxy) DeleteRange(id UUID, startKey []byte, endKey []byte) error {
client, ok := p.clients.Load(id)
func (p RawKVProxy) DeleteRange(ctx context.Context, startKey []byte, endKey []byte) error {
client, ok := p.clients.Load(uuidFromContext(ctx))
if !ok {
return errors.WithStack(ErrClientNotFound)
}
return client.(*rawkv.Client).DeleteRange(startKey, endKey)
return client.(*rawkv.Client).DeleteRange(ctx, startKey, endKey)
}
// Scan queries continuous kv pairs in range [startKey, endKey), up to limit pairs.
func (p RawKVProxy) Scan(id UUID, startKey, endKey []byte, limit int) ([][]byte, [][]byte, error) {
client, ok := p.clients.Load(id)
func (p RawKVProxy) Scan(ctx context.Context, startKey, endKey []byte, limit int) ([][]byte, [][]byte, error) {
client, ok := p.clients.Load(uuidFromContext(ctx))
if !ok {
return nil, nil, errors.WithStack(ErrClientNotFound)
}
return client.(*rawkv.Client).Scan(startKey, endKey, limit)
return client.(*rawkv.Client).Scan(ctx, startKey, endKey, limit)
}
// ReverseScan queries continuous kv pairs in range [endKey, startKey), up to limit pairs.
// Direction is different from Scan, upper to lower.
func (p RawKVProxy) ReverseScan(id UUID, startKey, endKey []byte, limit int) ([][]byte, [][]byte, error) {
client, ok := p.clients.Load(id)
func (p RawKVProxy) ReverseScan(ctx context.Context, startKey, endKey []byte, limit int) ([][]byte, [][]byte, error) {
client, ok := p.clients.Load(uuidFromContext(ctx))
if !ok {
return nil, nil, errors.WithStack(ErrClientNotFound)
}
return client.(*rawkv.Client).ReverseScan(startKey, endKey, limit)
return client.(*rawkv.Client).ReverseScan(ctx, startKey, endKey, limit)
}

View File

@ -43,8 +43,8 @@ func NewTxn() TxnKVProxy {
}
// New creates a new client and returns the client's UUID.
func (p TxnKVProxy) New(pdAddrs []string, conf config.Config) (UUID, error) {
client, err := txnkv.NewClient(pdAddrs, conf)
func (p TxnKVProxy) New(ctx context.Context, pdAddrs []string, conf config.Config) (UUID, error) {
client, err := txnkv.NewClient(ctx, pdAddrs, conf)
if err != nil {
return "", err
}
@ -52,7 +52,8 @@ func (p TxnKVProxy) New(pdAddrs []string, conf config.Config) (UUID, error) {
}
// Close releases a txnkv client.
func (p TxnKVProxy) Close(id UUID) error {
func (p TxnKVProxy) Close(ctx context.Context) error {
id := uuidFromContext(ctx)
client, ok := p.clients.Load(id)
if !ok {
return errors.WithStack(ErrClientNotFound)
@ -65,12 +66,12 @@ func (p TxnKVProxy) Close(id UUID) error {
}
// Begin starts a new transaction and returns its UUID.
func (p TxnKVProxy) Begin(id UUID) (UUID, error) {
client, ok := p.clients.Load(id)
func (p TxnKVProxy) Begin(ctx context.Context) (UUID, error) {
client, ok := p.clients.Load(uuidFromContext(ctx))
if !ok {
return "", errors.WithStack(ErrClientNotFound)
}
txn, err := client.(*txnkv.Client).Begin()
txn, err := client.(*txnkv.Client).Begin(ctx)
if err != nil {
return "", err
}
@ -78,45 +79,45 @@ func (p TxnKVProxy) Begin(id UUID) (UUID, error) {
}
// BeginWithTS starts a new transaction with given ts and returns its UUID.
func (p TxnKVProxy) BeginWithTS(id UUID, ts uint64) (UUID, error) {
client, ok := p.clients.Load(id)
func (p TxnKVProxy) BeginWithTS(ctx context.Context, ts uint64) (UUID, error) {
client, ok := p.clients.Load(uuidFromContext(ctx))
if !ok {
return "", errors.WithStack(ErrClientNotFound)
}
return insertWithRetry(p.txns, client.(*txnkv.Client).BeginWithTS(ts)), nil
return insertWithRetry(p.txns, client.(*txnkv.Client).BeginWithTS(ctx, ts)), nil
}
// GetTS returns a latest timestamp.
func (p TxnKVProxy) GetTS(id UUID) (uint64, error) {
client, ok := p.clients.Load(id)
func (p TxnKVProxy) GetTS(ctx context.Context) (uint64, error) {
client, ok := p.clients.Load(uuidFromContext(ctx))
if !ok {
return 0, errors.WithStack(ErrClientNotFound)
}
return client.(*txnkv.Client).GetTS()
return client.(*txnkv.Client).GetTS(ctx)
}
// TxnGet queries value for the given key from TiKV server.
func (p TxnKVProxy) TxnGet(id UUID, key []byte) ([]byte, error) {
txn, ok := p.txns.Load(id)
func (p TxnKVProxy) TxnGet(ctx context.Context, key []byte) ([]byte, error) {
txn, ok := p.txns.Load(uuidFromContext(ctx))
if !ok {
return nil, errors.WithStack(ErrTxnNotFound)
}
return txn.(*txnkv.Transaction).Get(key)
return txn.(*txnkv.Transaction).Get(ctx, key)
}
// TxnBatchGet gets a batch of values from TiKV server.
func (p TxnKVProxy) TxnBatchGet(id UUID, keys [][]byte) (map[string][]byte, error) {
txn, ok := p.txns.Load(id)
func (p TxnKVProxy) TxnBatchGet(ctx context.Context, keys [][]byte) (map[string][]byte, error) {
txn, ok := p.txns.Load(uuidFromContext(ctx))
if !ok {
return nil, errors.WithStack(ErrTxnNotFound)
}
ks := *(*[]key.Key)(unsafe.Pointer(&keys))
return txn.(*txnkv.Transaction).BatchGet(ks)
return txn.(*txnkv.Transaction).BatchGet(ctx, ks)
}
// TxnSet sets the value for key k as v into TiKV server.
func (p TxnKVProxy) TxnSet(id UUID, k []byte, v []byte) error {
txn, ok := p.txns.Load(id)
func (p TxnKVProxy) TxnSet(ctx context.Context, k []byte, v []byte) error {
txn, ok := p.txns.Load(uuidFromContext(ctx))
if !ok {
return errors.WithStack(ErrTxnNotFound)
}
@ -125,12 +126,12 @@ func (p TxnKVProxy) TxnSet(id UUID, k []byte, v []byte) error {
// TxnIter creates an Iterator positioned on the first entry that key <= entry's
// key and returns the Iterator's UUID.
func (p TxnKVProxy) TxnIter(id UUID, key []byte, upperBound []byte) (UUID, error) {
txn, ok := p.txns.Load(id)
func (p TxnKVProxy) TxnIter(ctx context.Context, key []byte, upperBound []byte) (UUID, error) {
txn, ok := p.txns.Load(uuidFromContext(ctx))
if !ok {
return "", errors.WithStack(ErrTxnNotFound)
}
iter, err := txn.(*txnkv.Transaction).Iter(key, upperBound)
iter, err := txn.(*txnkv.Transaction).Iter(ctx, key, upperBound)
if err != nil {
return "", err
}
@ -139,12 +140,12 @@ func (p TxnKVProxy) TxnIter(id UUID, key []byte, upperBound []byte) (UUID, error
// TxnIterReverse creates a reversed Iterator positioned on the first entry
// which key is less than key and returns the Iterator's UUID.
func (p TxnKVProxy) TxnIterReverse(id UUID, key []byte) (UUID, error) {
txn, ok := p.txns.Load(id)
func (p TxnKVProxy) TxnIterReverse(ctx context.Context, key []byte) (UUID, error) {
txn, ok := p.txns.Load(uuidFromContext(ctx))
if !ok {
return "", errors.WithStack(ErrTxnNotFound)
}
iter, err := txn.(*txnkv.Transaction).IterReverse(key)
iter, err := txn.(*txnkv.Transaction).IterReverse(ctx, key)
if err != nil {
return "", err
}
@ -152,8 +153,8 @@ func (p TxnKVProxy) TxnIterReverse(id UUID, key []byte) (UUID, error) {
}
// TxnIsReadOnly returns if there are pending key-value to commit in the transaction.
func (p TxnKVProxy) TxnIsReadOnly(id UUID) (bool, error) {
txn, ok := p.txns.Load(id)
func (p TxnKVProxy) TxnIsReadOnly(ctx context.Context) (bool, error) {
txn, ok := p.txns.Load(uuidFromContext(ctx))
if !ok {
return false, errors.WithStack(ErrTxnNotFound)
}
@ -161,8 +162,8 @@ func (p TxnKVProxy) TxnIsReadOnly(id UUID) (bool, error) {
}
// TxnDelete removes the entry for key from TiKV server.
func (p TxnKVProxy) TxnDelete(id UUID, key []byte) error {
txn, ok := p.txns.Load(id)
func (p TxnKVProxy) TxnDelete(ctx context.Context, key []byte) error {
txn, ok := p.txns.Load(uuidFromContext(ctx))
if !ok {
return errors.WithStack(ErrTxnNotFound)
}
@ -170,7 +171,8 @@ func (p TxnKVProxy) TxnDelete(id UUID, key []byte) error {
}
// TxnCommit commits the transaction operations to TiKV server.
func (p TxnKVProxy) TxnCommit(id UUID) error {
func (p TxnKVProxy) TxnCommit(ctx context.Context) error {
id := uuidFromContext(ctx)
txn, ok := p.txns.Load(id)
if !ok {
return errors.WithStack(ErrTxnNotFound)
@ -180,7 +182,8 @@ func (p TxnKVProxy) TxnCommit(id UUID) error {
}
// TxnRollback undoes the transaction operations to TiKV server.
func (p TxnKVProxy) TxnRollback(id UUID) error {
func (p TxnKVProxy) TxnRollback(ctx context.Context) error {
id := uuidFromContext(ctx)
txn, ok := p.txns.Load(id)
if !ok {
return errors.WithStack(ErrTxnNotFound)
@ -190,8 +193,8 @@ func (p TxnKVProxy) TxnRollback(id UUID) error {
}
// TxnLockKeys tries to lock the entries with the keys in TiKV server.
func (p TxnKVProxy) TxnLockKeys(id UUID, keys [][]byte) error {
txn, ok := p.txns.Load(id)
func (p TxnKVProxy) TxnLockKeys(ctx context.Context, keys [][]byte) error {
txn, ok := p.txns.Load(uuidFromContext(ctx))
if !ok {
return errors.WithStack(ErrTxnNotFound)
}
@ -200,8 +203,8 @@ func (p TxnKVProxy) TxnLockKeys(id UUID, keys [][]byte) error {
}
// TxnValid returns if the transaction is valid.
func (p TxnKVProxy) TxnValid(id UUID) (bool, error) {
txn, ok := p.txns.Load(id)
func (p TxnKVProxy) TxnValid(ctx context.Context) (bool, error) {
txn, ok := p.txns.Load(uuidFromContext(ctx))
if !ok {
return false, errors.WithStack(ErrTxnNotFound)
}
@ -209,8 +212,8 @@ func (p TxnKVProxy) TxnValid(id UUID) (bool, error) {
}
// TxnLen returns the count of key-value pairs in the transaction's memory buffer.
func (p TxnKVProxy) TxnLen(id UUID) (int, error) {
txn, ok := p.txns.Load(id)
func (p TxnKVProxy) TxnLen(ctx context.Context) (int, error) {
txn, ok := p.txns.Load(uuidFromContext(ctx))
if !ok {
return 0, errors.WithStack(ErrTxnNotFound)
}
@ -218,8 +221,8 @@ func (p TxnKVProxy) TxnLen(id UUID) (int, error) {
}
// TxnSize returns the length (in bytes) of the transaction's memory buffer.
func (p TxnKVProxy) TxnSize(id UUID) (int, error) {
txn, ok := p.txns.Load(id)
func (p TxnKVProxy) TxnSize(ctx context.Context) (int, error) {
txn, ok := p.txns.Load(uuidFromContext(ctx))
if !ok {
return 0, errors.WithStack(ErrTxnNotFound)
}
@ -227,8 +230,8 @@ func (p TxnKVProxy) TxnSize(id UUID) (int, error) {
}
// IterValid returns if the iterator is valid to use.
func (p TxnKVProxy) IterValid(id UUID) (bool, error) {
iter, ok := p.iterators.Load(id)
func (p TxnKVProxy) IterValid(ctx context.Context) (bool, error) {
iter, ok := p.iterators.Load(uuidFromContext(ctx))
if !ok {
return false, errors.WithStack(ErrIterNotFound)
}
@ -236,8 +239,8 @@ func (p TxnKVProxy) IterValid(id UUID) (bool, error) {
}
// IterKey returns the key which the iterator points to.
func (p TxnKVProxy) IterKey(id UUID) ([]byte, error) {
iter, ok := p.iterators.Load(id)
func (p TxnKVProxy) IterKey(ctx context.Context) ([]byte, error) {
iter, ok := p.iterators.Load(uuidFromContext(ctx))
if !ok {
return nil, errors.WithStack(ErrIterNotFound)
}
@ -245,8 +248,8 @@ func (p TxnKVProxy) IterKey(id UUID) ([]byte, error) {
}
// IterValue returns the value which the iterator points to.
func (p TxnKVProxy) IterValue(id UUID) ([]byte, error) {
iter, ok := p.iterators.Load(id)
func (p TxnKVProxy) IterValue(ctx context.Context) ([]byte, error) {
iter, ok := p.iterators.Load(uuidFromContext(ctx))
if !ok {
return nil, errors.WithStack(ErrIterNotFound)
}
@ -254,16 +257,17 @@ func (p TxnKVProxy) IterValue(id UUID) ([]byte, error) {
}
// IterNext moves the iterator to next entry.
func (p TxnKVProxy) IterNext(id UUID) error {
iter, ok := p.iterators.Load(id)
func (p TxnKVProxy) IterNext(ctx context.Context) error {
iter, ok := p.iterators.Load(uuidFromContext(ctx))
if !ok {
return errors.WithStack(ErrIterNotFound)
}
return iter.(kv.Iterator).Next()
return iter.(kv.Iterator).Next(ctx)
}
// IterClose releases an iterator.
func (p TxnKVProxy) IterClose(id UUID) error {
func (p TxnKVProxy) IterClose(ctx context.Context) error {
id := uuidFromContext(ctx)
iter, ok := p.iterators.Load(id)
if !ok {
return errors.WithStack(ErrIterNotFound)

View File

@ -14,6 +14,7 @@
package proxy
import (
"context"
"errors"
"sync"
@ -27,6 +28,10 @@ var (
ErrIterNotFound = errors.New("iterator not found")
)
type ContextKey int
var UUIDKey ContextKey = 1
// UUID is a global unique ID to identify clients, transactions, or iterators.
type UUID string
@ -38,3 +43,10 @@ func insertWithRetry(m *sync.Map, d interface{}) UUID {
}
}
}
func uuidFromContext(ctx context.Context) UUID {
if id := ctx.Value(UUIDKey); id != nil {
return id.(UUID)
}
return ""
}

View File

@ -44,7 +44,7 @@ type Client struct {
}
// NewClient creates a client with PD cluster addrs.
func NewClient(pdAddrs []string, conf config.Config) (*Client, error) {
func NewClient(ctx context.Context, pdAddrs []string, conf config.Config) (*Client, error) {
pdCli, err := pd.NewClient(pdAddrs, pd.SecurityOption{
CAPath: conf.RPC.Security.SSLCA,
CertPath: conf.RPC.Security.SSLCert,
@ -54,7 +54,7 @@ func NewClient(pdAddrs []string, conf config.Config) (*Client, error) {
return nil, err
}
return &Client{
clusterID: pdCli.GetClusterID(context.TODO()),
clusterID: pdCli.GetClusterID(ctx),
conf: &conf,
regionCache: locate.NewRegionCache(pdCli, &conf.RegionCache),
pdClient: pdCli,
@ -74,7 +74,7 @@ func (c *Client) ClusterID() uint64 {
}
// Get queries value with the key. When the key does not exist, it returns `nil, nil`.
func (c *Client) Get(key []byte) ([]byte, error) {
func (c *Client) Get(ctx context.Context, key []byte) ([]byte, error) {
start := time.Now()
defer func() { metrics.RawkvCmdHistogram.WithLabelValues("get").Observe(time.Since(start).Seconds()) }()
@ -84,7 +84,7 @@ func (c *Client) Get(key []byte) ([]byte, error) {
Key: key,
},
}
resp, _, err := c.sendReq(key, req)
resp, _, err := c.sendReq(ctx, key, req)
if err != nil {
return nil, err
}
@ -102,13 +102,13 @@ func (c *Client) Get(key []byte) ([]byte, error) {
}
// BatchGet queries values with the keys.
func (c *Client) BatchGet(keys [][]byte) ([][]byte, error) {
func (c *Client) BatchGet(ctx context.Context, keys [][]byte) ([][]byte, error) {
start := time.Now()
defer func() {
metrics.RawkvCmdHistogram.WithLabelValues("batch_get").Observe(time.Since(start).Seconds())
}()
bo := retry.NewBackoffer(context.Background(), retry.RawkvMaxBackoff)
bo := retry.NewBackoffer(ctx, retry.RawkvMaxBackoff)
resp, err := c.sendBatchReq(bo, keys, rpc.CmdRawBatchGet)
if err != nil {
return nil, err
@ -132,7 +132,7 @@ func (c *Client) BatchGet(keys [][]byte) ([][]byte, error) {
}
// Put stores a key-value pair to TiKV.
func (c *Client) Put(key, value []byte) error {
func (c *Client) Put(ctx context.Context, key, value []byte) error {
start := time.Now()
defer func() { metrics.RawkvCmdHistogram.WithLabelValues("put").Observe(time.Since(start).Seconds()) }()
metrics.RawkvSizeHistogram.WithLabelValues("key").Observe(float64(len(key)))
@ -149,7 +149,7 @@ func (c *Client) Put(key, value []byte) error {
Value: value,
},
}
resp, _, err := c.sendReq(key, req)
resp, _, err := c.sendReq(ctx, key, req)
if err != nil {
return err
}
@ -164,7 +164,7 @@ func (c *Client) Put(key, value []byte) error {
}
// BatchPut stores key-value pairs to TiKV.
func (c *Client) BatchPut(keys, values [][]byte) error {
func (c *Client) BatchPut(ctx context.Context, keys, values [][]byte) error {
start := time.Now()
defer func() {
metrics.RawkvCmdHistogram.WithLabelValues("batch_put").Observe(time.Since(start).Seconds())
@ -178,12 +178,12 @@ func (c *Client) BatchPut(keys, values [][]byte) error {
return errors.New("empty value is not supported")
}
}
bo := retry.NewBackoffer(context.Background(), retry.RawkvMaxBackoff)
bo := retry.NewBackoffer(ctx, retry.RawkvMaxBackoff)
return c.sendBatchPut(bo, keys, values)
}
// Delete deletes a key-value pair from TiKV.
func (c *Client) Delete(key []byte) error {
func (c *Client) Delete(ctx context.Context, key []byte) error {
start := time.Now()
defer func() { metrics.RawkvCmdHistogram.WithLabelValues("delete").Observe(time.Since(start).Seconds()) }()
@ -193,7 +193,7 @@ func (c *Client) Delete(key []byte) error {
Key: key,
},
}
resp, _, err := c.sendReq(key, req)
resp, _, err := c.sendReq(ctx, key, req)
if err != nil {
return err
}
@ -208,13 +208,13 @@ func (c *Client) Delete(key []byte) error {
}
// BatchDelete deletes key-value pairs from TiKV.
func (c *Client) BatchDelete(keys [][]byte) error {
func (c *Client) BatchDelete(ctx context.Context, keys [][]byte) error {
start := time.Now()
defer func() {
metrics.RawkvCmdHistogram.WithLabelValues("batch_delete").Observe(time.Since(start).Seconds())
}()
bo := retry.NewBackoffer(context.Background(), retry.RawkvMaxBackoff)
bo := retry.NewBackoffer(ctx, retry.RawkvMaxBackoff)
resp, err := c.sendBatchReq(bo, keys, rpc.CmdRawBatchDelete)
if err != nil {
return err
@ -230,7 +230,7 @@ func (c *Client) BatchDelete(keys [][]byte) error {
}
// DeleteRange deletes all key-value pairs in a range from TiKV
func (c *Client) DeleteRange(startKey []byte, endKey []byte) error {
func (c *Client) DeleteRange(ctx context.Context, startKey []byte, endKey []byte) error {
start := time.Now()
var err error
defer func() {
@ -245,7 +245,7 @@ func (c *Client) DeleteRange(startKey []byte, endKey []byte) error {
for !bytes.Equal(startKey, endKey) {
var resp *rpc.Response
var actualEndKey []byte
resp, actualEndKey, err = c.sendDeleteRangeReq(startKey, endKey)
resp, actualEndKey, err = c.sendDeleteRangeReq(ctx, startKey, endKey)
if err != nil {
return err
}
@ -267,7 +267,7 @@ func (c *Client) DeleteRange(startKey []byte, endKey []byte) error {
// If you want to exclude the startKey or include the endKey, append a '\0' to the key. For example, to scan
// (startKey, endKey], you can write:
// `Scan(append(startKey, '\0'), append(endKey, '\0'), limit)`.
func (c *Client) Scan(startKey, endKey []byte, limit int) (keys [][]byte, values [][]byte, err error) {
func (c *Client) Scan(ctx context.Context, startKey, endKey []byte, limit int) (keys [][]byte, values [][]byte, err error) {
start := time.Now()
defer func() { metrics.RawkvCmdHistogram.WithLabelValues("raw_scan").Observe(time.Since(start).Seconds()) }()
@ -284,7 +284,7 @@ func (c *Client) Scan(startKey, endKey []byte, limit int) (keys [][]byte, values
Limit: uint32(limit - len(keys)),
},
}
resp, loc, err := c.sendReq(startKey, req)
resp, loc, err := c.sendReq(ctx, startKey, req)
if err != nil {
return nil, nil, err
}
@ -311,7 +311,7 @@ func (c *Client) Scan(startKey, endKey []byte, limit int) (keys [][]byte, values
// (endKey, startKey], you can write:
// `ReverseScan(append(startKey, '\0'), append(endKey, '\0'), limit)`.
// It doesn't support Scanning from "", because locating the last Region is not yet implemented.
func (c *Client) ReverseScan(startKey, endKey []byte, limit int) (keys [][]byte, values [][]byte, err error) {
func (c *Client) ReverseScan(ctx context.Context, startKey, endKey []byte, limit int) (keys [][]byte, values [][]byte, err error) {
start := time.Now()
defer func() {
metrics.RawkvCmdHistogram.WithLabelValues("raw_reverse_scan").Observe(time.Since(start).Seconds())
@ -331,7 +331,7 @@ func (c *Client) ReverseScan(startKey, endKey []byte, limit int) (keys [][]byte,
Reverse: true,
},
}
resp, loc, err := c.sendReq(startKey, req)
resp, loc, err := c.sendReq(ctx, startKey, req)
if err != nil {
return nil, nil, err
}
@ -351,8 +351,8 @@ func (c *Client) ReverseScan(startKey, endKey []byte, limit int) (keys [][]byte,
return
}
func (c *Client) sendReq(key []byte, req *rpc.Request) (*rpc.Response, *locate.KeyLocation, error) {
bo := retry.NewBackoffer(context.Background(), retry.RawkvMaxBackoff)
func (c *Client) sendReq(ctx context.Context, key []byte, req *rpc.Request) (*rpc.Response, *locate.KeyLocation, error) {
bo := retry.NewBackoffer(ctx, retry.RawkvMaxBackoff)
sender := rpc.NewRegionRequestSender(c.regionCache, c.rpcClient)
for {
loc, err := c.regionCache.LocateKey(bo, key)
@ -491,8 +491,8 @@ func (c *Client) doBatchReq(bo *retry.Backoffer, batch batch, cmdType rpc.CmdTyp
// If the given range spans over more than one regions, the actual endKey is the end of the first region.
// We can't use sendReq directly, because we need to know the end of the region before we send the request
// TODO: Is there any better way to avoid duplicating code with func `sendReq` ?
func (c *Client) sendDeleteRangeReq(startKey []byte, endKey []byte) (*rpc.Response, []byte, error) {
bo := retry.NewBackoffer(context.Background(), retry.RawkvMaxBackoff)
func (c *Client) sendDeleteRangeReq(ctx context.Context, startKey []byte, endKey []byte) (*rpc.Response, []byte, error) {
bo := retry.NewBackoffer(ctx, retry.RawkvMaxBackoff)
sender := rpc.NewRegionRequestSender(c.regionCache, c.rpcClient)
for {
loc, err := c.regionCache.LocateKey(bo, startKey)

View File

@ -59,13 +59,13 @@ func (s *testRawKVSuite) TearDownTest(c *C) {
}
func (s *testRawKVSuite) mustNotExist(c *C, key []byte) {
v, err := s.client.Get(key)
v, err := s.client.Get(context.TODO(), key)
c.Assert(err, IsNil)
c.Assert(v, IsNil)
}
func (s *testRawKVSuite) mustBatchNotExist(c *C, keys [][]byte) {
values, err := s.client.BatchGet(keys)
values, err := s.client.BatchGet(context.TODO(), keys)
c.Assert(err, IsNil)
c.Assert(values, NotNil)
c.Assert(len(keys), Equals, len(values))
@ -75,14 +75,14 @@ func (s *testRawKVSuite) mustBatchNotExist(c *C, keys [][]byte) {
}
func (s *testRawKVSuite) mustGet(c *C, key, value []byte) {
v, err := s.client.Get(key)
v, err := s.client.Get(context.TODO(), key)
c.Assert(err, IsNil)
c.Assert(v, NotNil)
c.Assert(v, BytesEquals, value)
}
func (s *testRawKVSuite) mustBatchGet(c *C, keys, values [][]byte) {
checkValues, err := s.client.BatchGet(keys)
checkValues, err := s.client.BatchGet(context.TODO(), keys)
c.Assert(err, IsNil)
c.Assert(checkValues, NotNil)
c.Assert(len(keys), Equals, len(checkValues))
@ -92,27 +92,27 @@ func (s *testRawKVSuite) mustBatchGet(c *C, keys, values [][]byte) {
}
func (s *testRawKVSuite) mustPut(c *C, key, value []byte) {
err := s.client.Put(key, value)
err := s.client.Put(context.TODO(), key, value)
c.Assert(err, IsNil)
}
func (s *testRawKVSuite) mustBatchPut(c *C, keys, values [][]byte) {
err := s.client.BatchPut(keys, values)
err := s.client.BatchPut(context.TODO(), keys, values)
c.Assert(err, IsNil)
}
func (s *testRawKVSuite) mustDelete(c *C, key []byte) {
err := s.client.Delete(key)
err := s.client.Delete(context.TODO(), key)
c.Assert(err, IsNil)
}
func (s *testRawKVSuite) mustBatchDelete(c *C, keys [][]byte) {
err := s.client.BatchDelete(keys)
err := s.client.BatchDelete(context.TODO(), keys)
c.Assert(err, IsNil)
}
func (s *testRawKVSuite) mustScan(c *C, startKey string, limit int, expect ...string) {
keys, values, err := s.client.Scan([]byte(startKey), nil, limit)
keys, values, err := s.client.Scan(context.TODO(), []byte(startKey), nil, limit)
c.Assert(err, IsNil)
c.Assert(len(keys)*2, Equals, len(expect))
for i := range keys {
@ -122,7 +122,7 @@ func (s *testRawKVSuite) mustScan(c *C, startKey string, limit int, expect ...st
}
func (s *testRawKVSuite) mustScanRange(c *C, startKey string, endKey string, limit int, expect ...string) {
keys, values, err := s.client.Scan([]byte(startKey), []byte(endKey), limit)
keys, values, err := s.client.Scan(context.TODO(), []byte(startKey), []byte(endKey), limit)
c.Assert(err, IsNil)
c.Assert(len(keys)*2, Equals, len(expect))
for i := range keys {
@ -132,7 +132,7 @@ func (s *testRawKVSuite) mustScanRange(c *C, startKey string, endKey string, lim
}
func (s *testRawKVSuite) mustReverseScan(c *C, startKey []byte, limit int, expect ...string) {
keys, values, err := s.client.ReverseScan(startKey, nil, limit)
keys, values, err := s.client.ReverseScan(context.TODO(), startKey, nil, limit)
c.Assert(err, IsNil)
c.Assert(len(keys)*2, Equals, len(expect))
for i := range keys {
@ -142,7 +142,7 @@ func (s *testRawKVSuite) mustReverseScan(c *C, startKey []byte, limit int, expec
}
func (s *testRawKVSuite) mustReverseScanRange(c *C, startKey, endKey []byte, limit int, expect ...string) {
keys, values, err := s.client.ReverseScan(startKey, endKey, limit)
keys, values, err := s.client.ReverseScan(context.TODO(), startKey, endKey, limit)
c.Assert(err, IsNil)
c.Assert(len(keys)*2, Equals, len(expect))
for i := range keys {
@ -152,7 +152,7 @@ func (s *testRawKVSuite) mustReverseScanRange(c *C, startKey, endKey []byte, lim
}
func (s *testRawKVSuite) mustDeleteRange(c *C, startKey, endKey []byte, expected map[string]string) {
err := s.client.DeleteRange(startKey, endKey)
err := s.client.DeleteRange(context.TODO(), startKey, endKey)
c.Assert(err, IsNil)
for keyStr := range expected {
@ -166,7 +166,7 @@ func (s *testRawKVSuite) mustDeleteRange(c *C, startKey, endKey []byte, expected
}
func (s *testRawKVSuite) checkData(c *C, expected map[string]string) {
keys, values, err := s.client.Scan([]byte(""), nil, len(expected)+1)
keys, values, err := s.client.Scan(context.TODO(), []byte(""), nil, len(expected)+1)
c.Assert(err, IsNil)
c.Assert(len(expected), Equals, len(keys))
@ -192,11 +192,12 @@ func (s *testRawKVSuite) TestSimple(c *C) {
s.mustGet(c, []byte("key"), []byte("value"))
s.mustDelete(c, []byte("key"))
s.mustNotExist(c, []byte("key"))
err := s.client.Put([]byte("key"), []byte(""))
err := s.client.Put(context.TODO(), []byte("key"), []byte(""))
c.Assert(err, NotNil)
}
func (s *testRawKVSuite) TestRawBatch(c *C) {
testNum := 0
size := 0
var testKeys [][]byte

View File

@ -27,8 +27,8 @@ type Client struct {
}
// NewClient creates a client with PD addresses.
func NewClient(pdAddrs []string, config config.Config) (*Client, error) {
tikvStore, err := store.NewStore(pdAddrs, config)
func NewClient(ctx context.Context, pdAddrs []string, config config.Config) (*Client, error) {
tikvStore, err := store.NewStore(ctx, pdAddrs, config)
if err != nil {
return nil, err
}
@ -43,20 +43,20 @@ func (c *Client) Close() error {
}
// Begin creates a transaction for read/write.
func (c *Client) Begin() (*Transaction, error) {
ts, err := c.GetTS()
func (c *Client) Begin(ctx context.Context) (*Transaction, error) {
ts, err := c.GetTS(ctx)
if err != nil {
return nil, err
}
return c.BeginWithTS(ts), nil
return c.BeginWithTS(ctx, ts), nil
}
// BeginWithTS creates a transaction which is normally readonly.
func (c *Client) BeginWithTS(ts uint64) *Transaction {
func (c *Client) BeginWithTS(ctx context.Context, ts uint64) *Transaction {
return newTransaction(c.tikvStore, ts)
}
// GetTS returns a latest timestamp.
func (c *Client) GetTS() (uint64, error) {
return c.tikvStore.GetTimestampWithRetry(retry.NewBackoffer(context.TODO(), retry.TsoMaxBackoff))
func (c *Client) GetTS(ctx context.Context) (uint64, error) {
return c.tikvStore.GetTimestampWithRetry(retry.NewBackoffer(ctx, retry.TsoMaxBackoff))
}

View File

@ -14,6 +14,8 @@
package kv
import (
"context"
"github.com/tikv/client-go/config"
"github.com/tikv/client-go/key"
)
@ -49,10 +51,10 @@ func (s *BufferStore) SetCap(cap int) {
}
// Get implements the Retriever interface.
func (s *BufferStore) Get(k key.Key) ([]byte, error) {
val, err := s.MemBuffer.Get(k)
func (s *BufferStore) Get(ctx context.Context, k key.Key) ([]byte, error) {
val, err := s.MemBuffer.Get(ctx, k)
if IsErrNotFound(err) {
val, err = s.r.Get(k)
val, err = s.r.Get(ctx, k)
}
if err != nil {
return nil, err
@ -64,29 +66,29 @@ func (s *BufferStore) Get(k key.Key) ([]byte, error) {
}
// Iter implements the Retriever interface.
func (s *BufferStore) Iter(k key.Key, upperBound key.Key) (Iterator, error) {
bufferIt, err := s.MemBuffer.Iter(k, upperBound)
func (s *BufferStore) Iter(ctx context.Context, k key.Key, upperBound key.Key) (Iterator, error) {
bufferIt, err := s.MemBuffer.Iter(ctx, k, upperBound)
if err != nil {
return nil, err
}
retrieverIt, err := s.r.Iter(k, upperBound)
retrieverIt, err := s.r.Iter(ctx, k, upperBound)
if err != nil {
return nil, err
}
return NewUnionIter(bufferIt, retrieverIt, false)
return NewUnionIter(ctx, bufferIt, retrieverIt, false)
}
// IterReverse implements the Retriever interface.
func (s *BufferStore) IterReverse(k key.Key) (Iterator, error) {
bufferIt, err := s.MemBuffer.IterReverse(k)
func (s *BufferStore) IterReverse(ctx context.Context, k key.Key) (Iterator, error) {
bufferIt, err := s.MemBuffer.IterReverse(ctx, k)
if err != nil {
return nil, err
}
retrieverIt, err := s.r.IterReverse(k)
retrieverIt, err := s.r.IterReverse(ctx, k)
if err != nil {
return nil, err
}
return NewUnionIter(bufferIt, retrieverIt, true)
return NewUnionIter(ctx, bufferIt, retrieverIt, true)
}
// WalkBuffer iterates all buffered kv pairs.

View File

@ -15,6 +15,7 @@ package kv
import (
"bytes"
"context"
"fmt"
"testing"
@ -35,13 +36,13 @@ func (s testBufferStoreSuite) TestGetSet(c *C) {
conf := config.DefaultTxn()
bs := NewBufferStore(&mockSnapshot{NewMemDbBuffer(&conf, 0)}, &conf)
key := key.Key("key")
_, err := bs.Get(key)
_, err := bs.Get(context.TODO(), key)
c.Check(err, NotNil)
err = bs.Set(key, []byte("value"))
c.Check(err, IsNil)
value, err := bs.Get(key)
value, err := bs.Get(context.TODO(), key)
c.Check(err, IsNil)
c.Check(bytes.Compare(value, []byte("value")), Equals, 0)
}
@ -62,11 +63,11 @@ func (s testBufferStoreSuite) TestSaveTo(c *C) {
err := bs.SaveTo(mutator)
c.Check(err, IsNil)
iter, err := mutator.Iter(nil, nil)
iter, err := mutator.Iter(context.TODO(), nil, nil)
c.Check(err, IsNil)
for iter.Valid() {
cmp := bytes.Compare(iter.Key(), iter.Value())
c.Check(cmp, Equals, 0)
iter.Next()
iter.Next(context.TODO())
}
}

View File

@ -13,7 +13,11 @@
package kv
import "github.com/tikv/client-go/key"
import (
"context"
"github.com/tikv/client-go/key"
)
// Priority value for transaction priority.
const (
@ -36,18 +40,18 @@ const (
type Retriever interface {
// Get gets the value for key k from kv store.
// If corresponding kv pair does not exist, it returns nil and ErrNotExist.
Get(k key.Key) ([]byte, error)
Get(ctx context.Context, k key.Key) ([]byte, error)
// Iter creates an Iterator positioned on the first entry that k <= entry's key.
// If such entry is not found, it returns an invalid Iterator with no error.
// It yields only keys that < upperBound. If upperBound is nil, it means the upperBound is unbounded.
// The Iterator must be closed after use.
Iter(k key.Key, upperBound key.Key) (Iterator, error)
Iter(ctx context.Context, k key.Key, upperBound key.Key) (Iterator, error)
// IterReverse creates a reversed Iterator positioned on the first entry which key is less than k.
// The returned iterator will iterate from greater key to smaller key.
// If k is nil, the returned iterator will be positioned at the last key.
// TODO: Add lower bound limit
IterReverse(k key.Key) (Iterator, error)
IterReverse(ctx context.Context, k key.Key) (Iterator, error)
}
// Mutator is the interface wraps the basic Set and Delete methods.
@ -83,7 +87,7 @@ type MemBuffer interface {
type Snapshot interface {
Retriever
// BatchGet gets a batch of values from snapshot.
BatchGet(keys []key.Key) (map[string][]byte, error)
BatchGet(ctx context.Context, keys []key.Key) (map[string][]byte, error)
// SetPriority snapshot set the priority
SetPriority(priority int)
}
@ -93,7 +97,7 @@ type Iterator interface {
Valid() bool
Key() key.Key
Value() []byte
Next() error
Next(context.Context) error
Close()
}

View File

@ -16,6 +16,7 @@
package kv
import (
"context"
"fmt"
"math/rand"
"testing"
@ -73,7 +74,7 @@ func valToStr(c *C, iter Iterator) string {
func checkNewIterator(c *C, buffer MemBuffer) {
for i := startIndex; i < testCount; i++ {
val := encodeInt(i * indexStep)
iter, err := buffer.Iter(val, nil)
iter, err := buffer.Iter(context.TODO(), val, nil)
c.Assert(err, IsNil)
c.Assert([]byte(iter.Key()), BytesEquals, val)
c.Assert(decodeInt([]byte(valToStr(c, iter))), Equals, i*indexStep)
@ -83,12 +84,12 @@ func checkNewIterator(c *C, buffer MemBuffer) {
// Test iterator Next()
for i := startIndex; i < testCount-1; i++ {
val := encodeInt(i * indexStep)
iter, err := buffer.Iter(val, nil)
iter, err := buffer.Iter(context.TODO(), val, nil)
c.Assert(err, IsNil)
c.Assert([]byte(iter.Key()), BytesEquals, val)
c.Assert(valToStr(c, iter), Equals, string(val))
err = iter.Next()
err = iter.Next(context.TODO())
c.Assert(err, IsNil)
c.Assert(iter.Valid(), IsTrue)
@ -99,7 +100,7 @@ func checkNewIterator(c *C, buffer MemBuffer) {
}
// Non exist and beyond maximum seek test
iter, err := buffer.Iter(encodeInt(testCount*indexStep), nil)
iter, err := buffer.Iter(context.TODO(), encodeInt(testCount*indexStep), nil)
c.Assert(err, IsNil)
c.Assert(iter.Valid(), IsFalse)
@ -107,7 +108,7 @@ func checkNewIterator(c *C, buffer MemBuffer) {
// it returns the smallest key that larger than the one we are seeking
inBetween := encodeInt((testCount-1)*indexStep - 1)
last := encodeInt((testCount - 1) * indexStep)
iter, err = buffer.Iter(inBetween, nil)
iter, err = buffer.Iter(context.TODO(), inBetween, nil)
c.Assert(err, IsNil)
c.Assert(iter.Valid(), IsTrue)
c.Assert([]byte(iter.Key()), Not(BytesEquals), inBetween)
@ -118,7 +119,7 @@ func checkNewIterator(c *C, buffer MemBuffer) {
func mustGet(c *C, buffer MemBuffer) {
for i := startIndex; i < testCount; i++ {
s := encodeInt(i * indexStep)
val, err := buffer.Get(s)
val, err := buffer.Get(context.TODO(), s)
c.Assert(err, IsNil)
c.Assert(string(val), Equals, string(s))
}
@ -135,7 +136,7 @@ func (s *testKVSuite) TestGetSet(c *C) {
func (s *testKVSuite) TestNewIterator(c *C) {
for _, buffer := range s.bs {
// should be invalid
iter, err := buffer.Iter(nil, nil)
iter, err := buffer.Iter(context.TODO(), nil, nil)
c.Assert(err, IsNil)
c.Assert(iter.Valid(), IsFalse)
@ -147,7 +148,7 @@ func (s *testKVSuite) TestNewIterator(c *C) {
func (s *testKVSuite) TestBasicNewIterator(c *C) {
for _, buffer := range s.bs {
it, err := buffer.Iter([]byte("2"), nil)
it, err := buffer.Iter(context.TODO(), []byte("2"), nil)
c.Assert(err, IsNil)
c.Assert(it.Valid(), IsFalse)
}
@ -171,15 +172,15 @@ func (s *testKVSuite) TestNewIteratorMin(c *C) {
}
cnt := 0
it, err := buffer.Iter(nil, nil)
it, err := buffer.Iter(context.TODO(), nil, nil)
c.Assert(err, IsNil)
for it.Valid() {
cnt++
it.Next()
it.Next(context.TODO())
}
c.Assert(cnt, Equals, 6)
it, err = buffer.Iter([]byte("DATA_test_main_db_tbl_tbl_test_record__00000000000000000000"), nil)
it, err = buffer.Iter(context.TODO(), []byte("DATA_test_main_db_tbl_tbl_test_record__00000000000000000000"), nil)
c.Assert(err, IsNil)
c.Assert(string(it.Key()), Equals, "DATA_test_main_db_tbl_tbl_test_record__00000000000000000001")
}
@ -265,7 +266,7 @@ func benchmarkSetGet(b *testing.B, buffer MemBuffer, data [][]byte) {
buffer.Set(k, k)
}
for _, k := range data {
buffer.Get(k)
buffer.Get(context.TODO(), k)
}
}
}
@ -276,12 +277,12 @@ func benchIterator(b *testing.B, buffer MemBuffer) {
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
iter, err := buffer.Iter(nil, nil)
iter, err := buffer.Iter(context.TODO(), nil, nil)
if err != nil {
b.Error(err)
}
for iter.Valid() {
iter.Next()
iter.Next(context.TODO())
}
iter.Close()
}

View File

@ -16,6 +16,7 @@
package kv
import (
"context"
"fmt"
"github.com/pingcap/goleveldb/leveldb"
@ -55,10 +56,10 @@ func NewMemDbBuffer(conf *config.Txn, cap int) MemBuffer {
}
// Iter creates an Iterator.
func (m *memDbBuffer) Iter(k key.Key, upperBound key.Key) (Iterator, error) {
func (m *memDbBuffer) Iter(ctx context.Context, k key.Key, upperBound key.Key) (Iterator, error) {
i := &memDbIter{iter: m.db.NewIterator(&util.Range{Start: []byte(k), Limit: []byte(upperBound)}), reverse: false}
err := i.Next()
err := i.Next(ctx)
if err != nil {
return nil, errors.WithStack(err)
}
@ -69,7 +70,7 @@ func (m *memDbBuffer) SetCap(cap int) {
}
func (m *memDbBuffer) IterReverse(k key.Key) (Iterator, error) {
func (m *memDbBuffer) IterReverse(ctx context.Context, k key.Key) (Iterator, error) {
var i *memDbIter
if k == nil {
i = &memDbIter{iter: m.db.NewIterator(&util.Range{}), reverse: true}
@ -81,7 +82,7 @@ func (m *memDbBuffer) IterReverse(k key.Key) (Iterator, error) {
}
// Get returns the value associated with key.
func (m *memDbBuffer) Get(k key.Key) ([]byte, error) {
func (m *memDbBuffer) Get(ctx context.Context, k key.Key) ([]byte, error) {
v, err := m.db.Get(k)
if err == leveldb.ErrNotFound {
return nil, ErrNotExist
@ -130,7 +131,7 @@ func (m *memDbBuffer) Reset() {
}
// Next implements the Iterator Next.
func (i *memDbIter) Next() error {
func (i *memDbIter) Next(context.Context) error {
if i.reverse {
i.iter.Prev()
} else {
@ -161,7 +162,7 @@ func (i *memDbIter) Close() {
// WalkMemBuffer iterates all buffered kv pairs in memBuf
func WalkMemBuffer(memBuf MemBuffer, f func(k key.Key, v []byte) error) error {
iter, err := memBuf.Iter(nil, nil)
iter, err := memBuf.Iter(context.Background(), nil, nil)
if err != nil {
return errors.WithStack(err)
}
@ -171,7 +172,7 @@ func WalkMemBuffer(memBuf MemBuffer, f func(k key.Key, v []byte) error) error {
if err = f(iter.Key(), iter.Value()); err != nil {
return errors.WithStack(err)
}
err = iter.Next()
err = iter.Next(context.Background())
if err != nil {
return errors.WithStack(err)
}

View File

@ -14,6 +14,8 @@
package kv
import (
"context"
"github.com/tikv/client-go/key"
)
@ -21,18 +23,18 @@ type mockSnapshot struct {
store MemBuffer
}
func (s *mockSnapshot) Get(k key.Key) ([]byte, error) {
return s.store.Get(k)
func (s *mockSnapshot) Get(ctx context.Context, k key.Key) ([]byte, error) {
return s.store.Get(ctx, k)
}
func (s *mockSnapshot) SetPriority(priority int) {
}
func (s *mockSnapshot) BatchGet(keys []key.Key) (map[string][]byte, error) {
func (s *mockSnapshot) BatchGet(ctx context.Context, keys []key.Key) (map[string][]byte, error) {
m := make(map[string][]byte)
for _, k := range keys {
v, err := s.store.Get(k)
v, err := s.store.Get(ctx, k)
if IsErrNotFound(err) {
continue
}
@ -44,10 +46,10 @@ func (s *mockSnapshot) BatchGet(keys []key.Key) (map[string][]byte, error) {
return m, nil
}
func (s *mockSnapshot) Iter(k key.Key, upperBound key.Key) (Iterator, error) {
return s.store.Iter(k, upperBound)
func (s *mockSnapshot) Iter(ctx context.Context, k key.Key, upperBound key.Key) (Iterator, error) {
return s.store.Iter(ctx, k, upperBound)
}
func (s *mockSnapshot) IterReverse(k key.Key) (Iterator, error) {
return s.store.IterReverse(k)
func (s *mockSnapshot) IterReverse(ctx context.Context, k key.Key) (Iterator, error) {
return s.store.IterReverse(ctx, k)
}

View File

@ -14,6 +14,8 @@
package kv
import (
"context"
log "github.com/sirupsen/logrus"
"github.com/tikv/client-go/key"
)
@ -32,7 +34,7 @@ type UnionIter struct {
}
// NewUnionIter returns a union iterator for BufferStore.
func NewUnionIter(dirtyIt Iterator, snapshotIt Iterator, reverse bool) (*UnionIter, error) {
func NewUnionIter(ctx context.Context, dirtyIt Iterator, snapshotIt Iterator, reverse bool) (*UnionIter, error) {
it := &UnionIter{
dirtyIt: dirtyIt,
snapshotIt: snapshotIt,
@ -40,7 +42,7 @@ func NewUnionIter(dirtyIt Iterator, snapshotIt Iterator, reverse bool) (*UnionIt
snapshotValid: snapshotIt.Valid(),
reverse: reverse,
}
err := it.updateCur()
err := it.updateCur(ctx)
if err != nil {
return nil, err
}
@ -48,20 +50,20 @@ func NewUnionIter(dirtyIt Iterator, snapshotIt Iterator, reverse bool) (*UnionIt
}
// dirtyNext makes iter.dirtyIt go and update valid status.
func (iter *UnionIter) dirtyNext() error {
err := iter.dirtyIt.Next()
func (iter *UnionIter) dirtyNext(ctx context.Context) error {
err := iter.dirtyIt.Next(ctx)
iter.dirtyValid = iter.dirtyIt.Valid()
return err
}
// snapshotNext makes iter.snapshotIt go and update valid status.
func (iter *UnionIter) snapshotNext() error {
err := iter.snapshotIt.Next()
func (iter *UnionIter) snapshotNext(ctx context.Context) error {
err := iter.snapshotIt.Next(ctx)
iter.snapshotValid = iter.snapshotIt.Valid()
return err
}
func (iter *UnionIter) updateCur() error {
func (iter *UnionIter) updateCur(ctx context.Context) error {
iter.isValid = true
for {
if !iter.dirtyValid && !iter.snapshotValid {
@ -78,7 +80,7 @@ func (iter *UnionIter) updateCur() error {
iter.curIsDirty = true
// if delete it
if len(iter.dirtyIt.Value()) == 0 {
if err := iter.dirtyNext(); err != nil {
if err := iter.dirtyNext(ctx); err != nil {
return err
}
continue
@ -99,15 +101,15 @@ func (iter *UnionIter) updateCur() error {
if len(iter.dirtyIt.Value()) == 0 {
// snapshot has a record, but txn says we have deleted it
// just go next
if err := iter.dirtyNext(); err != nil {
if err := iter.dirtyNext(ctx); err != nil {
return err
}
if err := iter.snapshotNext(); err != nil {
if err := iter.snapshotNext(ctx); err != nil {
return err
}
continue
}
if err := iter.snapshotNext(); err != nil {
if err := iter.snapshotNext(ctx); err != nil {
return err
}
iter.curIsDirty = true
@ -121,7 +123,7 @@ func (iter *UnionIter) updateCur() error {
if len(iter.dirtyIt.Value()) == 0 {
log.Warnf("[kv] delete a record not exists? k = %q", iter.dirtyIt.Key())
// jump over this deletion
if err := iter.dirtyNext(); err != nil {
if err := iter.dirtyNext(ctx); err != nil {
return err
}
continue
@ -135,17 +137,17 @@ func (iter *UnionIter) updateCur() error {
}
// Next implements the Iterator Next interface.
func (iter *UnionIter) Next() error {
func (iter *UnionIter) Next(ctx context.Context) error {
var err error
if !iter.curIsDirty {
err = iter.snapshotNext()
err = iter.snapshotNext(ctx)
} else {
err = iter.dirtyNext()
err = iter.dirtyNext(ctx)
}
if err != nil {
return err
}
return iter.updateCur()
return iter.updateCur(ctx)
}
// Value implements the Iterator Value interface.

View File

@ -14,6 +14,8 @@
package kv
import (
"context"
"github.com/tikv/client-go/config"
"github.com/tikv/client-go/key"
)
@ -89,7 +91,7 @@ func (it invalidIterator) Valid() bool {
return false
}
func (it invalidIterator) Next() error {
func (it invalidIterator) Next(context.Context) error {
return nil
}
@ -110,12 +112,12 @@ type lazyMemBuffer struct {
conf *config.Txn
}
func (lmb *lazyMemBuffer) Get(k key.Key) ([]byte, error) {
func (lmb *lazyMemBuffer) Get(ctx context.Context, k key.Key) ([]byte, error) {
if lmb.mb == nil {
return nil, ErrNotExist
}
return lmb.mb.Get(k)
return lmb.mb.Get(ctx, k)
}
func (lmb *lazyMemBuffer) Set(key key.Key, value []byte) error {
@ -134,18 +136,18 @@ func (lmb *lazyMemBuffer) Delete(k key.Key) error {
return lmb.mb.Delete(k)
}
func (lmb *lazyMemBuffer) Iter(k key.Key, upperBound key.Key) (Iterator, error) {
func (lmb *lazyMemBuffer) Iter(ctx context.Context, k key.Key, upperBound key.Key) (Iterator, error) {
if lmb.mb == nil {
return invalidIterator{}, nil
}
return lmb.mb.Iter(k, upperBound)
return lmb.mb.Iter(ctx, k, upperBound)
}
func (lmb *lazyMemBuffer) IterReverse(k key.Key) (Iterator, error) {
func (lmb *lazyMemBuffer) IterReverse(ctx context.Context, k key.Key) (Iterator, error) {
if lmb.mb == nil {
return invalidIterator{}, nil
}
return lmb.mb.IterReverse(k)
return lmb.mb.IterReverse(ctx, k)
}
func (lmb *lazyMemBuffer) Size() int {
@ -173,8 +175,8 @@ func (lmb *lazyMemBuffer) SetCap(cap int) {
}
// Get implements the Retriever interface.
func (us *unionStore) Get(k key.Key) ([]byte, error) {
v, err := us.MemBuffer.Get(k)
func (us *unionStore) Get(ctx context.Context, k key.Key) ([]byte, error) {
v, err := us.MemBuffer.Get(ctx, k)
if IsErrNotFound(err) {
if _, ok := us.opts.Get(PresumeKeyNotExists); ok {
e, ok := us.opts.Get(PresumeKeyNotExistsError)
@ -185,7 +187,7 @@ func (us *unionStore) Get(k key.Key) ([]byte, error) {
}
return nil, ErrNotExist
}
v, err = us.BufferStore.r.Get(k)
v, err = us.BufferStore.r.Get(ctx, k)
}
if err != nil {
return v, err

View File

@ -14,6 +14,8 @@
package kv
import (
"context"
. "github.com/pingcap/check"
"github.com/pkg/errors"
"github.com/tikv/client-go/config"
@ -34,11 +36,11 @@ func (s *testUnionStoreSuite) SetUpTest(c *C) {
func (s *testUnionStoreSuite) TestGetSet(c *C) {
s.store.Set([]byte("1"), []byte("1"))
v, err := s.us.Get([]byte("1"))
v, err := s.us.Get(context.TODO(), []byte("1"))
c.Assert(err, IsNil)
c.Assert(v, BytesEquals, []byte("1"))
s.us.Set([]byte("1"), []byte("2"))
v, err = s.us.Get([]byte("1"))
v, err = s.us.Get(context.TODO(), []byte("1"))
c.Assert(err, IsNil)
c.Assert(v, BytesEquals, []byte("2"))
}
@ -47,11 +49,11 @@ func (s *testUnionStoreSuite) TestDelete(c *C) {
s.store.Set([]byte("1"), []byte("1"))
err := s.us.Delete([]byte("1"))
c.Assert(err, IsNil)
_, err = s.us.Get([]byte("1"))
_, err = s.us.Get(context.TODO(), []byte("1"))
c.Assert(IsErrNotFound(err), IsTrue)
s.us.Set([]byte("1"), []byte("2"))
v, err := s.us.Get([]byte("1"))
v, err := s.us.Get(context.TODO(), []byte("1"))
c.Assert(err, IsNil)
c.Assert(v, BytesEquals, []byte("2"))
}
@ -61,21 +63,21 @@ func (s *testUnionStoreSuite) TestSeek(c *C) {
s.store.Set([]byte("2"), []byte("2"))
s.store.Set([]byte("3"), []byte("3"))
iter, err := s.us.Iter(nil, nil)
iter, err := s.us.Iter(context.TODO(), nil, nil)
c.Assert(err, IsNil)
checkIterator(c, iter, [][]byte{[]byte("1"), []byte("2"), []byte("3")}, [][]byte{[]byte("1"), []byte("2"), []byte("3")})
iter, err = s.us.Iter([]byte("2"), nil)
iter, err = s.us.Iter(context.TODO(), []byte("2"), nil)
c.Assert(err, IsNil)
checkIterator(c, iter, [][]byte{[]byte("2"), []byte("3")}, [][]byte{[]byte("2"), []byte("3")})
s.us.Set([]byte("4"), []byte("4"))
iter, err = s.us.Iter([]byte("2"), nil)
iter, err = s.us.Iter(context.TODO(), []byte("2"), nil)
c.Assert(err, IsNil)
checkIterator(c, iter, [][]byte{[]byte("2"), []byte("3"), []byte("4")}, [][]byte{[]byte("2"), []byte("3"), []byte("4")})
s.us.Delete([]byte("3"))
iter, err = s.us.Iter([]byte("2"), nil)
iter, err = s.us.Iter(context.TODO(), []byte("2"), nil)
c.Assert(err, IsNil)
checkIterator(c, iter, [][]byte{[]byte("2"), []byte("4")}, [][]byte{[]byte("2"), []byte("4")})
}
@ -86,21 +88,21 @@ func (s *testUnionStoreSuite) TestIterReverse(c *C) {
s.store.Set([]byte("2"), []byte("2"))
s.store.Set([]byte("3"), []byte("3"))
iter, err := s.us.IterReverse(nil)
iter, err := s.us.IterReverse(context.TODO(), nil)
c.Assert(err, IsNil)
checkIterator(c, iter, [][]byte{[]byte("3"), []byte("2"), []byte("1")}, [][]byte{[]byte("3"), []byte("2"), []byte("1")})
iter, err = s.us.IterReverse([]byte("3"))
iter, err = s.us.IterReverse(context.TODO(), []byte("3"))
c.Assert(err, IsNil)
checkIterator(c, iter, [][]byte{[]byte("2"), []byte("1")}, [][]byte{[]byte("2"), []byte("1")})
s.us.Set([]byte("0"), []byte("0"))
iter, err = s.us.IterReverse([]byte("3"))
iter, err = s.us.IterReverse(context.TODO(), []byte("3"))
c.Assert(err, IsNil)
checkIterator(c, iter, [][]byte{[]byte("2"), []byte("1"), []byte("0")}, [][]byte{[]byte("2"), []byte("1"), []byte("0")})
s.us.Delete([]byte("1"))
iter, err = s.us.IterReverse([]byte("3"))
iter, err = s.us.IterReverse(context.TODO(), []byte("3"))
c.Assert(err, IsNil)
checkIterator(c, iter, [][]byte{[]byte("2"), []byte("0")}, [][]byte{[]byte("2"), []byte("0")})
}
@ -110,13 +112,13 @@ func (s *testUnionStoreSuite) TestLazyConditionCheck(c *C) {
s.store.Set([]byte("1"), []byte("1"))
s.store.Set([]byte("2"), []byte("2"))
v, err := s.us.Get([]byte("1"))
v, err := s.us.Get(context.TODO(), []byte("1"))
c.Assert(err, IsNil)
c.Assert(v, BytesEquals, []byte("1"))
s.us.SetOption(PresumeKeyNotExists, nil)
s.us.SetOption(PresumeKeyNotExistsError, ErrNotExist)
_, err = s.us.Get([]byte("2"))
_, err = s.us.Get(context.TODO(), []byte("2"))
c.Assert(errors.Cause(err) == ErrNotExist, IsTrue, Commentf("err %v", err))
condionPair1 := s.us.LookupConditionPair([]byte("1"))
@ -138,7 +140,7 @@ func checkIterator(c *C, iter Iterator, keys [][]byte, values [][]byte) {
c.Assert(iter.Valid(), IsTrue)
c.Assert([]byte(iter.Key()), BytesEquals, k)
c.Assert(iter.Value(), BytesEquals, v)
c.Assert(iter.Next(), IsNil)
c.Assert(iter.Next(context.TODO()), IsNil)
}
c.Assert(iter.Valid(), IsFalse)
}

View File

@ -18,7 +18,7 @@ import (
"sync/atomic"
"time"
"github.com/pingcap/pd/client"
pd "github.com/pingcap/pd/client"
log "github.com/sirupsen/logrus"
"github.com/tikv/client-go/config"
"github.com/tikv/client-go/metrics"

View File

@ -57,8 +57,8 @@ var _ = NewLockResolver
// NewLockResolver creates a LockResolver.
// It is exported for other pkg to use. For instance, binlog service needs
// to determine a transaction's commit state.
func NewLockResolver(etcdAddrs []string, conf config.Config) (*LockResolver, error) {
s, err := NewStore(etcdAddrs, conf)
func NewLockResolver(ctx context.Context, etcdAddrs []string, conf config.Config) (*LockResolver, error) {
s, err := NewStore(ctx, etcdAddrs, conf)
if err != nil {
return nil, err
}
@ -259,8 +259,8 @@ func (lr *LockResolver) ResolveLocks(bo *retry.Backoffer, locks []*Lock) (ok boo
// If the primary key is still locked, it will launch a Rollback to abort it.
// To avoid unnecessarily aborting too many txns, it is wiser to wait a few
// seconds before calling it after Prewrite.
func (lr *LockResolver) GetTxnStatus(txnID uint64, primary []byte) (TxnStatus, error) {
bo := retry.NewBackoffer(context.Background(), retry.CleanupMaxBackoff)
func (lr *LockResolver) GetTxnStatus(ctx context.Context, txnID uint64, primary []byte) (TxnStatus, error) {
bo := retry.NewBackoffer(ctx, retry.CleanupMaxBackoff)
return lr.getTxnStatus(bo, txnID, primary)
}

View File

@ -40,7 +40,7 @@ type Scanner struct {
eof bool
}
func newScanner(snapshot *TiKVSnapshot, startKey []byte, endKey []byte, batchSize int) (*Scanner, error) {
func newScanner(ctx context.Context, snapshot *TiKVSnapshot, startKey []byte, endKey []byte, batchSize int) (*Scanner, error) {
// It must be > 1. Otherwise scanner won't skipFirst.
if batchSize <= 1 {
batchSize = snapshot.conf.Txn.ScanBatchSize
@ -53,7 +53,7 @@ func newScanner(snapshot *TiKVSnapshot, startKey []byte, endKey []byte, batchSiz
nextStartKey: startKey,
endKey: endKey,
}
err := scanner.Next()
err := scanner.Next(ctx)
if kv.IsErrNotFound(err) {
return scanner, nil
}
@ -82,8 +82,8 @@ func (s *Scanner) Value() []byte {
}
// Next return next element.
func (s *Scanner) Next() error {
bo := retry.NewBackoffer(context.Background(), retry.ScannerNextMaxBackoff)
func (s *Scanner) Next(ctx context.Context) error {
bo := retry.NewBackoffer(ctx, retry.ScannerNextMaxBackoff)
if !s.valid {
return errors.New("scanner iterator is invalid")
}

View File

@ -55,7 +55,7 @@ func newTiKVSnapshot(store *TiKVStore, ts uint64) *TiKVSnapshot {
// BatchGet gets all the keys' value from kv-server and returns a map contains key/value pairs.
// The map will not contain nonexistent keys.
func (s *TiKVSnapshot) BatchGet(keys []key.Key) (map[string][]byte, error) {
func (s *TiKVSnapshot) BatchGet(ctx context.Context, keys []key.Key) (map[string][]byte, error) {
m := make(map[string][]byte)
if len(keys) == 0 {
return m, nil
@ -66,7 +66,7 @@ func (s *TiKVSnapshot) BatchGet(keys []key.Key) (map[string][]byte, error) {
// We want [][]byte instead of []key.Key, use some magic to save memory.
bytesKeys := *(*[][]byte)(unsafe.Pointer(&keys))
bo := retry.NewBackoffer(context.Background(), retry.BatchGetMaxBackoff)
bo := retry.NewBackoffer(ctx, retry.BatchGetMaxBackoff)
// Create a map to collect key-values from region servers.
var mu sync.Mutex
@ -198,8 +198,8 @@ func (s *TiKVSnapshot) batchGetSingleRegion(bo *retry.Backoffer, batch batchKeys
}
// Get gets the value for key k from snapshot.
func (s *TiKVSnapshot) Get(k key.Key) ([]byte, error) {
val, err := s.get(retry.NewBackoffer(context.Background(), retry.GetMaxBackoff), k)
func (s *TiKVSnapshot) Get(ctx context.Context, k key.Key) ([]byte, error) {
val, err := s.get(retry.NewBackoffer(ctx, retry.GetMaxBackoff), k)
if err != nil {
return nil, err
}
@ -270,13 +270,13 @@ func (s *TiKVSnapshot) get(bo *retry.Backoffer, k key.Key) ([]byte, error) {
}
// Iter returns a list of key-value pair after `k`.
func (s *TiKVSnapshot) Iter(k key.Key, upperBound key.Key) (kv.Iterator, error) {
scanner, err := newScanner(s, k, upperBound, s.conf.Txn.ScanBatchSize)
func (s *TiKVSnapshot) Iter(ctx context.Context, k key.Key, upperBound key.Key) (kv.Iterator, error) {
scanner, err := newScanner(ctx, s, k, upperBound, s.conf.Txn.ScanBatchSize)
return scanner, err
}
// IterReverse creates a reversed Iterator positioned on the first entry which key is less than k.
func (s *TiKVSnapshot) IterReverse(k key.Key) (kv.Iterator, error) {
func (s *TiKVSnapshot) IterReverse(ctx context.Context, k key.Key) (kv.Iterator, error) {
return nil, ErrNotImplemented
}

View File

@ -27,9 +27,9 @@ import (
// SplitRegion splits the region contains splitKey into 2 regions: [start,
// splitKey) and [splitKey, end).
func SplitRegion(store *TiKVStore, splitKey key.Key) error {
func SplitRegion(ctx context.Context, store *TiKVStore, splitKey key.Key) error {
log.Infof("start split_region at %q", splitKey)
bo := retry.NewBackoffer(context.Background(), retry.SplitRegionBackoff)
bo := retry.NewBackoffer(ctx, retry.SplitRegionBackoff)
sender := rpc.NewRegionRequestSender(store.GetRegionCache(), store.GetRPCClient())
req := &rpc.Request{
Type: rpc.CmdSplitRegion,

View File

@ -55,7 +55,7 @@ type TiKVStore struct {
}
// NewStore creates a TiKVStore instance.
func NewStore(pdAddrs []string, conf config.Config) (*TiKVStore, error) {
func NewStore(ctx context.Context, pdAddrs []string, conf config.Config) (*TiKVStore, error) {
pdCli, err := pd.NewClient(pdAddrs, pd.SecurityOption{
CAPath: conf.RPC.Security.SSLCA,
CertPath: conf.RPC.Security.SSLCert,
@ -82,7 +82,7 @@ func NewStore(pdAddrs []string, conf config.Config) (*TiKVStore, error) {
return nil, err
}
clusterID := pdCli.GetClusterID(context.TODO())
clusterID := pdCli.GetClusterID(ctx)
store := &TiKVStore{
conf: &conf,

View File

@ -57,12 +57,12 @@ func newTransaction(tikvStore *store.TiKVStore, ts uint64) *Transaction {
}
// Get implements transaction interface.
func (txn *Transaction) Get(k key.Key) ([]byte, error) {
func (txn *Transaction) Get(ctx context.Context, k key.Key) ([]byte, error) {
metrics.TxnCmdCounter.WithLabelValues("get").Inc()
start := time.Now()
defer func() { metrics.TxnCmdHistogram.WithLabelValues("get").Observe(time.Since(start).Seconds()) }()
ret, err := txn.us.Get(k)
ret, err := txn.us.Get(ctx, k)
if kv.IsErrNotFound(err) {
return nil, err
}
@ -79,14 +79,14 @@ func (txn *Transaction) Get(k key.Key) ([]byte, error) {
}
// BatchGet gets a batch of values from TiKV server.
func (txn *Transaction) BatchGet(keys []key.Key) (map[string][]byte, error) {
func (txn *Transaction) BatchGet(ctx context.Context, keys []key.Key) (map[string][]byte, error) {
if txn.IsReadOnly() {
return txn.snapshot.BatchGet(keys)
return txn.snapshot.BatchGet(ctx, keys)
}
bufferValues := make([][]byte, len(keys))
shrinkKeys := make([]key.Key, 0, len(keys))
for i, key := range keys {
val, err := txn.us.GetMemBuffer().Get(key)
val, err := txn.us.GetMemBuffer().Get(ctx, key)
if kv.IsErrNotFound(err) {
shrinkKeys = append(shrinkKeys, key)
continue
@ -98,7 +98,7 @@ func (txn *Transaction) BatchGet(keys []key.Key) (map[string][]byte, error) {
bufferValues[i] = val
}
}
storageValues, err := txn.snapshot.BatchGet(shrinkKeys)
storageValues, err := txn.snapshot.BatchGet(ctx, shrinkKeys)
if err != nil {
return nil, err
}
@ -125,23 +125,23 @@ func (txn *Transaction) String() string {
// If such entry is not found, it returns an invalid Iterator with no error.
// It yields only keys that < upperBound. If upperBound is nil, it means the upperBound is unbounded.
// The Iterator must be closed after use.
func (txn *Transaction) Iter(k key.Key, upperBound key.Key) (kv.Iterator, error) {
func (txn *Transaction) Iter(ctx context.Context, k key.Key, upperBound key.Key) (kv.Iterator, error) {
metrics.TxnCmdCounter.WithLabelValues("seek").Inc()
start := time.Now()
defer func() { metrics.TxnCmdHistogram.WithLabelValues("seek").Observe(time.Since(start).Seconds()) }()
return txn.us.Iter(k, upperBound)
return txn.us.Iter(ctx, k, upperBound)
}
// IterReverse creates a reversed Iterator positioned on the first entry which key is less than k.
func (txn *Transaction) IterReverse(k key.Key) (kv.Iterator, error) {
func (txn *Transaction) IterReverse(ctx context.Context, k key.Key) (kv.Iterator, error) {
metrics.TxnCmdCounter.WithLabelValues("seek_reverse").Inc()
start := time.Now()
defer func() {
metrics.TxnCmdHistogram.WithLabelValues("seek_reverse").Observe(time.Since(start).Seconds())
}()
return txn.us.IterReverse(k)
return txn.us.IterReverse(ctx, k)
}
// IsReadOnly returns if there are pending key-value to commit in the transaction.