From 77a15fcd874abdb6ded1d1d0dc71f15cfce0a887 Mon Sep 17 00:00:00 2001 From: disksing Date: Thu, 20 Jun 2019 01:18:29 +0800 Subject: [PATCH] Use context everywhere (#23) * rawkv Signed-off-by: disksing * txnkv wip Signed-off-by: disksing * txnkv wip Signed-off-by: disksing * txnkv update get & batchGet Signed-off-by: disksing * txnkv iterators Signed-off-by: disksing --- examples/rawkv/rawkv.go | 11 ++-- examples/txnkv/txnkv.go | 16 +++--- proxy/httpproxy/handler.go | 18 ++++++ proxy/httpproxy/rawkv.go | 51 +++++++++-------- proxy/httpproxy/txnkv.go | 99 ++++++++++++++++---------------- proxy/rawkv.go | 62 ++++++++++---------- proxy/txnkv.go | 104 ++++++++++++++++++---------------- proxy/utils.go | 12 ++++ rawkv/rawkv.go | 48 ++++++++-------- rawkv/rawkv_test.go | 31 +++++----- txnkv/client.go | 16 +++--- txnkv/kv/buffer_store.go | 24 ++++---- txnkv/kv/buffer_store_test.go | 9 +-- txnkv/kv/kv.go | 16 ++++-- txnkv/kv/mem_buffer_test.go | 29 +++++----- txnkv/kv/memdb_buffer.go | 15 ++--- txnkv/kv/mock.go | 18 +++--- txnkv/kv/union_iter.go | 34 +++++------ txnkv/kv/union_store.go | 22 +++---- txnkv/kv/union_store_test.go | 32 ++++++----- txnkv/oracle/oracles/pd.go | 2 +- txnkv/store/lock_resolver.go | 8 +-- txnkv/store/scan.go | 8 +-- txnkv/store/snapshot.go | 14 ++--- txnkv/store/split_region.go | 4 +- txnkv/store/store.go | 4 +- txnkv/txn.go | 20 +++---- 27 files changed, 394 insertions(+), 333 deletions(-) diff --git a/examples/rawkv/rawkv.go b/examples/rawkv/rawkv.go index 2ea33c14..3914b52a 100644 --- a/examples/rawkv/rawkv.go +++ b/examples/rawkv/rawkv.go @@ -14,6 +14,7 @@ package main import ( + "context" "fmt" "github.com/tikv/client-go/config" @@ -21,7 +22,7 @@ import ( ) func main() { - cli, err := rawkv.NewClient([]string{"127.0.0.1:2379"}, config.Default()) + cli, err := rawkv.NewClient(context.TODO(), []string{"127.0.0.1:2379"}, config.Default()) if err != nil { panic(err) } @@ -33,28 +34,28 @@ func main() { val := []byte("PingCAP") // put key into tikv - err = cli.Put(key, val) + err = cli.Put(context.TODO(), key, val) if err != nil { panic(err) } fmt.Printf("Successfully put %s:%s to tikv\n", key, val) // get key from tikv - val, err = cli.Get(key) + val, err = cli.Get(context.TODO(), key) if err != nil { panic(err) } fmt.Printf("found val: %s for key: %s\n", val, key) // delete key from tikv - err = cli.Delete(key) + err = cli.Delete(context.TODO(), key) if err != nil { panic(err) } fmt.Printf("key: %s deleted\n", key) // get key again from tikv - val, err = cli.Get(key) + val, err = cli.Get(context.TODO(), key) if err != nil { panic(err) } diff --git a/examples/txnkv/txnkv.go b/examples/txnkv/txnkv.go index d3169fb7..c8e2093b 100644 --- a/examples/txnkv/txnkv.go +++ b/examples/txnkv/txnkv.go @@ -41,7 +41,7 @@ var ( // Init initializes information. func initStore() { var err error - client, err = txnkv.NewClient([]string{*pdAddr}, config.Default()) + client, err = txnkv.NewClient(context.TODO(), []string{*pdAddr}, config.Default()) if err != nil { panic(err) } @@ -49,7 +49,7 @@ func initStore() { // key1 val1 key2 val2 ... func puts(args ...[]byte) error { - tx, err := client.Begin() + tx, err := client.Begin(context.TODO()) if err != nil { return err } @@ -65,11 +65,11 @@ func puts(args ...[]byte) error { } func get(k []byte) (KV, error) { - tx, err := client.Begin() + tx, err := client.Begin(context.TODO()) if err != nil { return KV{}, err } - v, err := tx.Get(k) + v, err := tx.Get(context.TODO(), k) if err != nil { return KV{}, err } @@ -77,7 +77,7 @@ func get(k []byte) (KV, error) { } func dels(keys ...[]byte) error { - tx, err := client.Begin() + tx, err := client.Begin(context.TODO()) if err != nil { return err } @@ -91,11 +91,11 @@ func dels(keys ...[]byte) error { } func scan(keyPrefix []byte, limit int) ([]KV, error) { - tx, err := client.Begin() + tx, err := client.Begin(context.TODO()) if err != nil { return nil, err } - it, err := tx.Iter(key.Key(keyPrefix), nil) + it, err := tx.Iter(context.TODO(), key.Key(keyPrefix), nil) if err != nil { return nil, err } @@ -104,7 +104,7 @@ func scan(keyPrefix []byte, limit int) ([]KV, error) { for it.Valid() && limit > 0 { ret = append(ret, KV{K: it.Key()[:], V: it.Value()[:]}) limit-- - it.Next() + it.Next(context.TODO()) } return ret, nil } diff --git a/proxy/httpproxy/handler.go b/proxy/httpproxy/handler.go index ad3f9396..85876f25 100644 --- a/proxy/httpproxy/handler.go +++ b/proxy/httpproxy/handler.go @@ -14,7 +14,9 @@ package httpproxy import ( + "context" "net/http" + "time" "github.com/gorilla/mux" "github.com/tikv/client-go/proxy" @@ -69,3 +71,19 @@ func NewHTTPProxyHandler() http.Handler { return router } + +var defaultTimeout = 20 * time.Second + +func reqContext(vars map[string]string) (context.Context, context.CancelFunc) { + ctx := context.Background() + if id := vars["id"]; id != "" { + ctx = context.WithValue(ctx, proxy.UUIDKey, proxy.UUID(id)) + } + + d, err := time.ParseDuration(vars["timeout"]) + if err != nil { + d = defaultTimeout + } + + return context.WithTimeout(ctx, d) +} diff --git a/proxy/httpproxy/rawkv.go b/proxy/httpproxy/rawkv.go index 797c6c0c..8de62a03 100644 --- a/proxy/httpproxy/rawkv.go +++ b/proxy/httpproxy/rawkv.go @@ -14,6 +14,7 @@ package httpproxy import ( + "context" "encoding/json" "io/ioutil" "net/http" @@ -47,94 +48,94 @@ type RawResponse struct { Values [][]byte `json:"values,omitempty"` // for batchGet } -func (h rawkvHandler) New(vars map[string]string, r *RawRequest) (*RawResponse, int, error) { - id, err := h.p.New(r.PDAddrs, config.Default()) +func (h rawkvHandler) New(ctx context.Context, r *RawRequest) (*RawResponse, int, error) { + id, err := h.p.New(ctx, r.PDAddrs, config.Default()) if err != nil { return nil, http.StatusInternalServerError, err } return &RawResponse{ID: string(id)}, http.StatusCreated, nil } -func (h rawkvHandler) Close(vars map[string]string, r *RawRequest) (*RawResponse, int, error) { - if err := h.p.Close(proxy.UUID(vars["id"])); err != nil { +func (h rawkvHandler) Close(ctx context.Context, r *RawRequest) (*RawResponse, int, error) { + if err := h.p.Close(ctx); err != nil { return nil, http.StatusInternalServerError, err } return &RawResponse{}, http.StatusOK, nil } -func (h rawkvHandler) Get(vars map[string]string, r *RawRequest) (*RawResponse, int, error) { - val, err := h.p.Get(proxy.UUID(vars["id"]), r.Key) +func (h rawkvHandler) Get(ctx context.Context, r *RawRequest) (*RawResponse, int, error) { + val, err := h.p.Get(ctx, r.Key) if err != nil { return nil, http.StatusInternalServerError, err } return &RawResponse{Value: val}, http.StatusOK, nil } -func (h rawkvHandler) BatchGet(vars map[string]string, r *RawRequest) (*RawResponse, int, error) { - vals, err := h.p.BatchGet(proxy.UUID(vars["id"]), r.Keys) +func (h rawkvHandler) BatchGet(ctx context.Context, r *RawRequest) (*RawResponse, int, error) { + vals, err := h.p.BatchGet(ctx, r.Keys) if err != nil { return nil, http.StatusInternalServerError, err } return &RawResponse{Values: vals}, http.StatusOK, nil } -func (h rawkvHandler) Put(vars map[string]string, r *RawRequest) (*RawResponse, int, error) { - err := h.p.Put(proxy.UUID(vars["id"]), r.Key, r.Value) +func (h rawkvHandler) Put(ctx context.Context, r *RawRequest) (*RawResponse, int, error) { + err := h.p.Put(ctx, r.Key, r.Value) if err != nil { return nil, http.StatusInternalServerError, err } return &RawResponse{}, http.StatusOK, nil } -func (h rawkvHandler) BatchPut(vars map[string]string, r *RawRequest) (*RawResponse, int, error) { - err := h.p.BatchPut(proxy.UUID(vars["id"]), r.Keys, r.Values) +func (h rawkvHandler) BatchPut(ctx context.Context, r *RawRequest) (*RawResponse, int, error) { + err := h.p.BatchPut(ctx, r.Keys, r.Values) if err != nil { return nil, http.StatusInternalServerError, err } return &RawResponse{}, http.StatusOK, nil } -func (h rawkvHandler) Delete(vars map[string]string, r *RawRequest) (*RawResponse, int, error) { - err := h.p.Delete(proxy.UUID(vars["id"]), r.Key) +func (h rawkvHandler) Delete(ctx context.Context, r *RawRequest) (*RawResponse, int, error) { + err := h.p.Delete(ctx, r.Key) if err != nil { return nil, http.StatusInternalServerError, err } return &RawResponse{}, http.StatusOK, nil } -func (h rawkvHandler) BatchDelete(vars map[string]string, r *RawRequest) (*RawResponse, int, error) { - err := h.p.BatchDelete(proxy.UUID(vars["id"]), r.Keys) +func (h rawkvHandler) BatchDelete(ctx context.Context, r *RawRequest) (*RawResponse, int, error) { + err := h.p.BatchDelete(ctx, r.Keys) if err != nil { return nil, http.StatusInternalServerError, err } return &RawResponse{}, http.StatusOK, nil } -func (h rawkvHandler) DeleteRange(vars map[string]string, r *RawRequest) (*RawResponse, int, error) { - err := h.p.DeleteRange(proxy.UUID(vars["id"]), r.StartKey, r.EndKey) +func (h rawkvHandler) DeleteRange(ctx context.Context, r *RawRequest) (*RawResponse, int, error) { + err := h.p.DeleteRange(ctx, r.StartKey, r.EndKey) if err != nil { return nil, http.StatusInternalServerError, err } return &RawResponse{}, http.StatusOK, nil } -func (h rawkvHandler) Scan(vars map[string]string, r *RawRequest) (*RawResponse, int, error) { - keys, values, err := h.p.Scan(proxy.UUID(vars["id"]), r.StartKey, r.EndKey, r.Limit) +func (h rawkvHandler) Scan(ctx context.Context, r *RawRequest) (*RawResponse, int, error) { + keys, values, err := h.p.Scan(ctx, r.StartKey, r.EndKey, r.Limit) if err != nil { return nil, http.StatusInternalServerError, err } return &RawResponse{Keys: keys, Values: values}, http.StatusOK, nil } -func (h rawkvHandler) ReverseScan(vars map[string]string, r *RawRequest) (*RawResponse, int, error) { - keys, values, err := h.p.ReverseScan(proxy.UUID(vars["id"]), r.StartKey, r.EndKey, r.Limit) +func (h rawkvHandler) ReverseScan(ctx context.Context, r *RawRequest) (*RawResponse, int, error) { + keys, values, err := h.p.ReverseScan(ctx, r.StartKey, r.EndKey, r.Limit) if err != nil { return nil, http.StatusInternalServerError, err } return &RawResponse{Keys: keys, Values: values}, http.StatusOK, nil } -func (h rawkvHandler) handlerFunc(f func(map[string]string, *RawRequest) (*RawResponse, int, error)) func(http.ResponseWriter, *http.Request) { +func (h rawkvHandler) handlerFunc(f func(context.Context, *RawRequest) (*RawResponse, int, error)) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { data, err := ioutil.ReadAll(r.Body) if err != nil { @@ -146,7 +147,9 @@ func (h rawkvHandler) handlerFunc(f func(map[string]string, *RawRequest) (*RawRe sendError(w, err, http.StatusBadRequest) return } - res, status, err := f(mux.Vars(r), &req) + ctx, cancel := reqContext(mux.Vars(r)) + res, status, err := f(ctx, &req) + cancel() if err != nil { sendError(w, err, status) return diff --git a/proxy/httpproxy/txnkv.go b/proxy/httpproxy/txnkv.go index 31fe73f4..940376d0 100644 --- a/proxy/httpproxy/txnkv.go +++ b/proxy/httpproxy/txnkv.go @@ -14,6 +14,7 @@ package httpproxy import ( + "context" "encoding/json" "io/ioutil" "net/http" @@ -51,56 +52,56 @@ type TxnResponse struct { Length int `json:"length,omitempty"` // for length } -func (h txnkvHandler) New(vars map[string]string, r *TxnRequest) (*TxnResponse, int, error) { - id, err := h.p.New(r.PDAddrs, config.Default()) +func (h txnkvHandler) New(ctx context.Context, r *TxnRequest) (*TxnResponse, int, error) { + id, err := h.p.New(ctx, r.PDAddrs, config.Default()) if err != nil { return nil, http.StatusInternalServerError, err } return &TxnResponse{ID: string(id)}, http.StatusOK, nil } -func (h txnkvHandler) Close(vars map[string]string, r *TxnRequest) (*TxnResponse, int, error) { - err := h.p.Close(proxy.UUID(vars["id"])) +func (h txnkvHandler) Close(ctx context.Context, r *TxnRequest) (*TxnResponse, int, error) { + err := h.p.Close(ctx) if err != nil { return nil, http.StatusInternalServerError, err } return &TxnResponse{}, http.StatusOK, nil } -func (h txnkvHandler) Begin(vars map[string]string, r *TxnRequest) (*TxnResponse, int, error) { - txnID, err := h.p.Begin(proxy.UUID(vars["id"])) +func (h txnkvHandler) Begin(ctx context.Context, r *TxnRequest) (*TxnResponse, int, error) { + txnID, err := h.p.Begin(ctx) if err != nil { return nil, http.StatusInternalServerError, err } return &TxnResponse{ID: string(txnID)}, http.StatusCreated, nil } -func (h txnkvHandler) BeginWithTS(vars map[string]string, r *TxnRequest) (*TxnResponse, int, error) { - txnID, err := h.p.BeginWithTS(proxy.UUID(vars["id"]), r.TS) +func (h txnkvHandler) BeginWithTS(ctx context.Context, r *TxnRequest) (*TxnResponse, int, error) { + txnID, err := h.p.BeginWithTS(ctx, r.TS) if err != nil { return nil, http.StatusInternalServerError, err } return &TxnResponse{ID: string(txnID)}, http.StatusOK, nil } -func (h txnkvHandler) GetTS(vars map[string]string, r *TxnRequest) (*TxnResponse, int, error) { - ts, err := h.p.GetTS(proxy.UUID(vars["id"])) +func (h txnkvHandler) GetTS(ctx context.Context, r *TxnRequest) (*TxnResponse, int, error) { + ts, err := h.p.GetTS(ctx) if err != nil { return nil, http.StatusInternalServerError, err } return &TxnResponse{TS: ts}, http.StatusOK, nil } -func (h txnkvHandler) TxnGet(vars map[string]string, r *TxnRequest) (*TxnResponse, int, error) { - val, err := h.p.TxnGet(proxy.UUID(vars["id"]), r.Key) +func (h txnkvHandler) TxnGet(ctx context.Context, r *TxnRequest) (*TxnResponse, int, error) { + val, err := h.p.TxnGet(ctx, r.Key) if err != nil { return nil, http.StatusInternalServerError, err } return &TxnResponse{Value: val}, http.StatusOK, nil } -func (h txnkvHandler) TxnBatchGet(vars map[string]string, r *TxnRequest) (*TxnResponse, int, error) { - m, err := h.p.TxnBatchGet(proxy.UUID(vars["id"]), r.Keys) +func (h txnkvHandler) TxnBatchGet(ctx context.Context, r *TxnRequest) (*TxnResponse, int, error) { + m, err := h.p.TxnBatchGet(ctx, r.Keys) if err != nil { return nil, http.StatusInternalServerError, err } @@ -112,135 +113,135 @@ func (h txnkvHandler) TxnBatchGet(vars map[string]string, r *TxnRequest) (*TxnRe return &TxnResponse{Keys: keys, Values: values}, http.StatusOK, nil } -func (h txnkvHandler) TxnSet(vars map[string]string, r *TxnRequest) (*TxnResponse, int, error) { - err := h.p.TxnSet(proxy.UUID(vars["id"]), r.Key, r.Value) +func (h txnkvHandler) TxnSet(ctx context.Context, r *TxnRequest) (*TxnResponse, int, error) { + err := h.p.TxnSet(ctx, r.Key, r.Value) if err != nil { return nil, http.StatusInternalServerError, err } return &TxnResponse{}, http.StatusOK, nil } -func (h txnkvHandler) TxnIter(vars map[string]string, r *TxnRequest) (*TxnResponse, int, error) { - iterID, err := h.p.TxnIter(proxy.UUID(vars["id"]), r.Key, r.UpperBound) +func (h txnkvHandler) TxnIter(ctx context.Context, r *TxnRequest) (*TxnResponse, int, error) { + iterID, err := h.p.TxnIter(ctx, r.Key, r.UpperBound) if err != nil { return nil, http.StatusInternalServerError, err } return &TxnResponse{ID: string(iterID)}, http.StatusCreated, nil } -func (h txnkvHandler) TxnIterReverse(vars map[string]string, r *TxnRequest) (*TxnResponse, int, error) { - iterID, err := h.p.TxnIterReverse(proxy.UUID(vars["id"]), r.Key) +func (h txnkvHandler) TxnIterReverse(ctx context.Context, r *TxnRequest) (*TxnResponse, int, error) { + iterID, err := h.p.TxnIterReverse(ctx, r.Key) if err != nil { return nil, http.StatusInternalServerError, err } return &TxnResponse{ID: string(iterID)}, http.StatusCreated, nil } -func (h txnkvHandler) TxnIsReadOnly(vars map[string]string, r *TxnRequest) (*TxnResponse, int, error) { - readonly, err := h.p.TxnIsReadOnly(proxy.UUID(vars["id"])) +func (h txnkvHandler) TxnIsReadOnly(ctx context.Context, r *TxnRequest) (*TxnResponse, int, error) { + readonly, err := h.p.TxnIsReadOnly(ctx) if err != nil { return nil, http.StatusInternalServerError, err } return &TxnResponse{IsReadOnly: readonly}, http.StatusOK, nil } -func (h txnkvHandler) TxnDelete(vars map[string]string, r *TxnRequest) (*TxnResponse, int, error) { - err := h.p.TxnDelete(proxy.UUID(vars["id"]), r.Key) +func (h txnkvHandler) TxnDelete(ctx context.Context, r *TxnRequest) (*TxnResponse, int, error) { + err := h.p.TxnDelete(ctx, r.Key) if err != nil { return nil, http.StatusInternalServerError, err } return &TxnResponse{}, http.StatusOK, nil } -func (h txnkvHandler) TxnCommit(vars map[string]string, r *TxnRequest) (*TxnResponse, int, error) { - err := h.p.TxnCommit(proxy.UUID(vars["id"])) +func (h txnkvHandler) TxnCommit(ctx context.Context, r *TxnRequest) (*TxnResponse, int, error) { + err := h.p.TxnCommit(ctx) if err != nil { return nil, http.StatusInternalServerError, err } return &TxnResponse{}, http.StatusOK, nil } -func (h txnkvHandler) TxnRollback(vars map[string]string, r *TxnRequest) (*TxnResponse, int, error) { - err := h.p.TxnRollback(proxy.UUID(vars["id"])) +func (h txnkvHandler) TxnRollback(ctx context.Context, r *TxnRequest) (*TxnResponse, int, error) { + err := h.p.TxnRollback(ctx) if err != nil { return nil, http.StatusInternalServerError, err } return &TxnResponse{}, http.StatusOK, nil } -func (h txnkvHandler) TxnLockKeys(vars map[string]string, r *TxnRequest) (*TxnResponse, int, error) { - err := h.p.TxnLockKeys(proxy.UUID(vars["id"]), r.Keys) +func (h txnkvHandler) TxnLockKeys(ctx context.Context, r *TxnRequest) (*TxnResponse, int, error) { + err := h.p.TxnLockKeys(ctx, r.Keys) if err != nil { return nil, http.StatusInternalServerError, err } return &TxnResponse{}, http.StatusOK, nil } -func (h txnkvHandler) TxnValid(vars map[string]string, r *TxnRequest) (*TxnResponse, int, error) { - valid, err := h.p.TxnValid(proxy.UUID(vars["id"])) +func (h txnkvHandler) TxnValid(ctx context.Context, r *TxnRequest) (*TxnResponse, int, error) { + valid, err := h.p.TxnValid(ctx) if err != nil { return nil, http.StatusInternalServerError, err } return &TxnResponse{IsValid: valid}, http.StatusOK, nil } -func (h txnkvHandler) TxnLen(vars map[string]string, r *TxnRequest) (*TxnResponse, int, error) { - length, err := h.p.TxnLen(proxy.UUID(vars["id"])) +func (h txnkvHandler) TxnLen(ctx context.Context, r *TxnRequest) (*TxnResponse, int, error) { + length, err := h.p.TxnLen(ctx) if err != nil { return nil, http.StatusInternalServerError, err } return &TxnResponse{Length: length}, http.StatusOK, nil } -func (h txnkvHandler) TxnSize(vars map[string]string, r *TxnRequest) (*TxnResponse, int, error) { - size, err := h.p.TxnSize(proxy.UUID(vars["id"])) +func (h txnkvHandler) TxnSize(ctx context.Context, r *TxnRequest) (*TxnResponse, int, error) { + size, err := h.p.TxnSize(ctx) if err != nil { return nil, http.StatusInternalServerError, err } return &TxnResponse{Size: size}, http.StatusOK, nil } -func (h txnkvHandler) IterValid(vars map[string]string, r *TxnRequest) (*TxnResponse, int, error) { - valid, err := h.p.IterValid(proxy.UUID(vars["id"])) +func (h txnkvHandler) IterValid(ctx context.Context, r *TxnRequest) (*TxnResponse, int, error) { + valid, err := h.p.IterValid(ctx) if err != nil { return nil, http.StatusInternalServerError, err } return &TxnResponse{IsValid: valid}, http.StatusOK, nil } -func (h txnkvHandler) IterKey(vars map[string]string, r *TxnRequest) (*TxnResponse, int, error) { - key, err := h.p.IterKey(proxy.UUID(vars["id"])) +func (h txnkvHandler) IterKey(ctx context.Context, r *TxnRequest) (*TxnResponse, int, error) { + key, err := h.p.IterKey(ctx) if err != nil { return nil, http.StatusInternalServerError, err } return &TxnResponse{Key: key}, http.StatusOK, nil } -func (h txnkvHandler) IterValue(vars map[string]string, r *TxnRequest) (*TxnResponse, int, error) { - val, err := h.p.IterValue(proxy.UUID(vars["id"])) +func (h txnkvHandler) IterValue(ctx context.Context, r *TxnRequest) (*TxnResponse, int, error) { + val, err := h.p.IterValue(ctx) if err != nil { return nil, http.StatusInternalServerError, err } return &TxnResponse{Value: val}, http.StatusOK, nil } -func (h txnkvHandler) IterNext(vars map[string]string, r *TxnRequest) (*TxnResponse, int, error) { - err := h.p.IterNext(proxy.UUID(vars["id"])) +func (h txnkvHandler) IterNext(ctx context.Context, r *TxnRequest) (*TxnResponse, int, error) { + err := h.p.IterNext(ctx) if err != nil { return nil, http.StatusInternalServerError, err } return &TxnResponse{}, http.StatusOK, nil } -func (h txnkvHandler) IterClose(vars map[string]string, r *TxnRequest) (*TxnResponse, int, error) { - err := h.p.IterClose(proxy.UUID(vars["id"])) +func (h txnkvHandler) IterClose(ctx context.Context, r *TxnRequest) (*TxnResponse, int, error) { + err := h.p.IterClose(ctx) if err != nil { return nil, http.StatusInternalServerError, err } return &TxnResponse{}, http.StatusOK, nil } -func (h txnkvHandler) handlerFunc(f func(map[string]string, *TxnRequest) (*TxnResponse, int, error)) func(http.ResponseWriter, *http.Request) { +func (h txnkvHandler) handlerFunc(f func(context.Context, *TxnRequest) (*TxnResponse, int, error)) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { data, err := ioutil.ReadAll(r.Body) if err != nil { @@ -252,7 +253,9 @@ func (h txnkvHandler) handlerFunc(f func(map[string]string, *TxnRequest) (*TxnRe sendError(w, err, http.StatusBadRequest) return } - res, status, err := f(mux.Vars(r), &req) + ctx, cancel := reqContext(mux.Vars(r)) + res, status, err := f(ctx, &req) + cancel() if err != nil { sendError(w, err, status) return diff --git a/proxy/rawkv.go b/proxy/rawkv.go index 6ede9551..167209e5 100644 --- a/proxy/rawkv.go +++ b/proxy/rawkv.go @@ -14,6 +14,7 @@ package proxy import ( + "context" "sync" "github.com/pkg/errors" @@ -35,8 +36,8 @@ func NewRaw() RawKVProxy { } // New creates a new client and returns the client's UUID. -func (p RawKVProxy) New(pdAddrs []string, conf config.Config) (UUID, error) { - client, err := rawkv.NewClient(pdAddrs, conf) +func (p RawKVProxy) New(ctx context.Context, pdAddrs []string, conf config.Config) (UUID, error) { + client, err := rawkv.NewClient(ctx, pdAddrs, conf) if err != nil { return "", err } @@ -44,7 +45,8 @@ func (p RawKVProxy) New(pdAddrs []string, conf config.Config) (UUID, error) { } // Close releases a rawkv client. -func (p RawKVProxy) Close(id UUID) error { +func (p RawKVProxy) Close(ctx context.Context) error { + id := uuidFromContext(ctx) client, ok := p.clients.Load(id) if !ok { return errors.WithStack(ErrClientNotFound) @@ -57,83 +59,83 @@ func (p RawKVProxy) Close(id UUID) error { } // Get queries value with the key. -func (p RawKVProxy) Get(id UUID, key []byte) ([]byte, error) { - client, ok := p.clients.Load(id) +func (p RawKVProxy) Get(ctx context.Context, key []byte) ([]byte, error) { + client, ok := p.clients.Load(uuidFromContext(ctx)) if !ok { return nil, errors.WithStack(ErrClientNotFound) } - return client.(*rawkv.Client).Get(key) + return client.(*rawkv.Client).Get(ctx, key) } // BatchGet queries values with the keys. -func (p RawKVProxy) BatchGet(id UUID, keys [][]byte) ([][]byte, error) { - client, ok := p.clients.Load(id) +func (p RawKVProxy) BatchGet(ctx context.Context, keys [][]byte) ([][]byte, error) { + client, ok := p.clients.Load(uuidFromContext(ctx)) if !ok { return nil, errors.WithStack(ErrClientNotFound) } - return client.(*rawkv.Client).BatchGet(keys) + return client.(*rawkv.Client).BatchGet(ctx, keys) } // Put stores a key-value pair to TiKV. -func (p RawKVProxy) Put(id UUID, key, value []byte) error { - client, ok := p.clients.Load(id) +func (p RawKVProxy) Put(ctx context.Context, key, value []byte) error { + client, ok := p.clients.Load(uuidFromContext(ctx)) if !ok { return errors.WithStack(ErrClientNotFound) } - return client.(*rawkv.Client).Put(key, value) + return client.(*rawkv.Client).Put(ctx, key, value) } // BatchPut stores key-value pairs to TiKV. -func (p RawKVProxy) BatchPut(id UUID, keys, values [][]byte) error { - client, ok := p.clients.Load(id) +func (p RawKVProxy) BatchPut(ctx context.Context, keys, values [][]byte) error { + client, ok := p.clients.Load(uuidFromContext(ctx)) if !ok { return errors.WithStack(ErrClientNotFound) } - return client.(*rawkv.Client).BatchPut(keys, values) + return client.(*rawkv.Client).BatchPut(ctx, keys, values) } // Delete deletes a key-value pair from TiKV. -func (p RawKVProxy) Delete(id UUID, key []byte) error { - client, ok := p.clients.Load(id) +func (p RawKVProxy) Delete(ctx context.Context, key []byte) error { + client, ok := p.clients.Load(uuidFromContext(ctx)) if !ok { return errors.WithStack(ErrClientNotFound) } - return client.(*rawkv.Client).Delete(key) + return client.(*rawkv.Client).Delete(ctx, key) } // BatchDelete deletes key-value pairs from TiKV. -func (p RawKVProxy) BatchDelete(id UUID, keys [][]byte) error { - client, ok := p.clients.Load(id) +func (p RawKVProxy) BatchDelete(ctx context.Context, keys [][]byte) error { + client, ok := p.clients.Load(uuidFromContext(ctx)) if !ok { return errors.WithStack(ErrClientNotFound) } - return client.(*rawkv.Client).BatchDelete(keys) + return client.(*rawkv.Client).BatchDelete(ctx, keys) } // DeleteRange deletes all key-value pairs in a range from TiKV. -func (p RawKVProxy) DeleteRange(id UUID, startKey []byte, endKey []byte) error { - client, ok := p.clients.Load(id) +func (p RawKVProxy) DeleteRange(ctx context.Context, startKey []byte, endKey []byte) error { + client, ok := p.clients.Load(uuidFromContext(ctx)) if !ok { return errors.WithStack(ErrClientNotFound) } - return client.(*rawkv.Client).DeleteRange(startKey, endKey) + return client.(*rawkv.Client).DeleteRange(ctx, startKey, endKey) } // Scan queries continuous kv pairs in range [startKey, endKey), up to limit pairs. -func (p RawKVProxy) Scan(id UUID, startKey, endKey []byte, limit int) ([][]byte, [][]byte, error) { - client, ok := p.clients.Load(id) +func (p RawKVProxy) Scan(ctx context.Context, startKey, endKey []byte, limit int) ([][]byte, [][]byte, error) { + client, ok := p.clients.Load(uuidFromContext(ctx)) if !ok { return nil, nil, errors.WithStack(ErrClientNotFound) } - return client.(*rawkv.Client).Scan(startKey, endKey, limit) + return client.(*rawkv.Client).Scan(ctx, startKey, endKey, limit) } // ReverseScan queries continuous kv pairs in range [endKey, startKey), up to limit pairs. // Direction is different from Scan, upper to lower. -func (p RawKVProxy) ReverseScan(id UUID, startKey, endKey []byte, limit int) ([][]byte, [][]byte, error) { - client, ok := p.clients.Load(id) +func (p RawKVProxy) ReverseScan(ctx context.Context, startKey, endKey []byte, limit int) ([][]byte, [][]byte, error) { + client, ok := p.clients.Load(uuidFromContext(ctx)) if !ok { return nil, nil, errors.WithStack(ErrClientNotFound) } - return client.(*rawkv.Client).ReverseScan(startKey, endKey, limit) + return client.(*rawkv.Client).ReverseScan(ctx, startKey, endKey, limit) } diff --git a/proxy/txnkv.go b/proxy/txnkv.go index f6e42930..d7b0aa61 100644 --- a/proxy/txnkv.go +++ b/proxy/txnkv.go @@ -43,8 +43,8 @@ func NewTxn() TxnKVProxy { } // New creates a new client and returns the client's UUID. -func (p TxnKVProxy) New(pdAddrs []string, conf config.Config) (UUID, error) { - client, err := txnkv.NewClient(pdAddrs, conf) +func (p TxnKVProxy) New(ctx context.Context, pdAddrs []string, conf config.Config) (UUID, error) { + client, err := txnkv.NewClient(ctx, pdAddrs, conf) if err != nil { return "", err } @@ -52,7 +52,8 @@ func (p TxnKVProxy) New(pdAddrs []string, conf config.Config) (UUID, error) { } // Close releases a txnkv client. -func (p TxnKVProxy) Close(id UUID) error { +func (p TxnKVProxy) Close(ctx context.Context) error { + id := uuidFromContext(ctx) client, ok := p.clients.Load(id) if !ok { return errors.WithStack(ErrClientNotFound) @@ -65,12 +66,12 @@ func (p TxnKVProxy) Close(id UUID) error { } // Begin starts a new transaction and returns its UUID. -func (p TxnKVProxy) Begin(id UUID) (UUID, error) { - client, ok := p.clients.Load(id) +func (p TxnKVProxy) Begin(ctx context.Context) (UUID, error) { + client, ok := p.clients.Load(uuidFromContext(ctx)) if !ok { return "", errors.WithStack(ErrClientNotFound) } - txn, err := client.(*txnkv.Client).Begin() + txn, err := client.(*txnkv.Client).Begin(ctx) if err != nil { return "", err } @@ -78,45 +79,45 @@ func (p TxnKVProxy) Begin(id UUID) (UUID, error) { } // BeginWithTS starts a new transaction with given ts and returns its UUID. -func (p TxnKVProxy) BeginWithTS(id UUID, ts uint64) (UUID, error) { - client, ok := p.clients.Load(id) +func (p TxnKVProxy) BeginWithTS(ctx context.Context, ts uint64) (UUID, error) { + client, ok := p.clients.Load(uuidFromContext(ctx)) if !ok { return "", errors.WithStack(ErrClientNotFound) } - return insertWithRetry(p.txns, client.(*txnkv.Client).BeginWithTS(ts)), nil + return insertWithRetry(p.txns, client.(*txnkv.Client).BeginWithTS(ctx, ts)), nil } // GetTS returns a latest timestamp. -func (p TxnKVProxy) GetTS(id UUID) (uint64, error) { - client, ok := p.clients.Load(id) +func (p TxnKVProxy) GetTS(ctx context.Context) (uint64, error) { + client, ok := p.clients.Load(uuidFromContext(ctx)) if !ok { return 0, errors.WithStack(ErrClientNotFound) } - return client.(*txnkv.Client).GetTS() + return client.(*txnkv.Client).GetTS(ctx) } // TxnGet queries value for the given key from TiKV server. -func (p TxnKVProxy) TxnGet(id UUID, key []byte) ([]byte, error) { - txn, ok := p.txns.Load(id) +func (p TxnKVProxy) TxnGet(ctx context.Context, key []byte) ([]byte, error) { + txn, ok := p.txns.Load(uuidFromContext(ctx)) if !ok { return nil, errors.WithStack(ErrTxnNotFound) } - return txn.(*txnkv.Transaction).Get(key) + return txn.(*txnkv.Transaction).Get(ctx, key) } // TxnBatchGet gets a batch of values from TiKV server. -func (p TxnKVProxy) TxnBatchGet(id UUID, keys [][]byte) (map[string][]byte, error) { - txn, ok := p.txns.Load(id) +func (p TxnKVProxy) TxnBatchGet(ctx context.Context, keys [][]byte) (map[string][]byte, error) { + txn, ok := p.txns.Load(uuidFromContext(ctx)) if !ok { return nil, errors.WithStack(ErrTxnNotFound) } ks := *(*[]key.Key)(unsafe.Pointer(&keys)) - return txn.(*txnkv.Transaction).BatchGet(ks) + return txn.(*txnkv.Transaction).BatchGet(ctx, ks) } // TxnSet sets the value for key k as v into TiKV server. -func (p TxnKVProxy) TxnSet(id UUID, k []byte, v []byte) error { - txn, ok := p.txns.Load(id) +func (p TxnKVProxy) TxnSet(ctx context.Context, k []byte, v []byte) error { + txn, ok := p.txns.Load(uuidFromContext(ctx)) if !ok { return errors.WithStack(ErrTxnNotFound) } @@ -125,12 +126,12 @@ func (p TxnKVProxy) TxnSet(id UUID, k []byte, v []byte) error { // TxnIter creates an Iterator positioned on the first entry that key <= entry's // key and returns the Iterator's UUID. -func (p TxnKVProxy) TxnIter(id UUID, key []byte, upperBound []byte) (UUID, error) { - txn, ok := p.txns.Load(id) +func (p TxnKVProxy) TxnIter(ctx context.Context, key []byte, upperBound []byte) (UUID, error) { + txn, ok := p.txns.Load(uuidFromContext(ctx)) if !ok { return "", errors.WithStack(ErrTxnNotFound) } - iter, err := txn.(*txnkv.Transaction).Iter(key, upperBound) + iter, err := txn.(*txnkv.Transaction).Iter(ctx, key, upperBound) if err != nil { return "", err } @@ -139,12 +140,12 @@ func (p TxnKVProxy) TxnIter(id UUID, key []byte, upperBound []byte) (UUID, error // TxnIterReverse creates a reversed Iterator positioned on the first entry // which key is less than key and returns the Iterator's UUID. -func (p TxnKVProxy) TxnIterReverse(id UUID, key []byte) (UUID, error) { - txn, ok := p.txns.Load(id) +func (p TxnKVProxy) TxnIterReverse(ctx context.Context, key []byte) (UUID, error) { + txn, ok := p.txns.Load(uuidFromContext(ctx)) if !ok { return "", errors.WithStack(ErrTxnNotFound) } - iter, err := txn.(*txnkv.Transaction).IterReverse(key) + iter, err := txn.(*txnkv.Transaction).IterReverse(ctx, key) if err != nil { return "", err } @@ -152,8 +153,8 @@ func (p TxnKVProxy) TxnIterReverse(id UUID, key []byte) (UUID, error) { } // TxnIsReadOnly returns if there are pending key-value to commit in the transaction. -func (p TxnKVProxy) TxnIsReadOnly(id UUID) (bool, error) { - txn, ok := p.txns.Load(id) +func (p TxnKVProxy) TxnIsReadOnly(ctx context.Context) (bool, error) { + txn, ok := p.txns.Load(uuidFromContext(ctx)) if !ok { return false, errors.WithStack(ErrTxnNotFound) } @@ -161,8 +162,8 @@ func (p TxnKVProxy) TxnIsReadOnly(id UUID) (bool, error) { } // TxnDelete removes the entry for key from TiKV server. -func (p TxnKVProxy) TxnDelete(id UUID, key []byte) error { - txn, ok := p.txns.Load(id) +func (p TxnKVProxy) TxnDelete(ctx context.Context, key []byte) error { + txn, ok := p.txns.Load(uuidFromContext(ctx)) if !ok { return errors.WithStack(ErrTxnNotFound) } @@ -170,7 +171,8 @@ func (p TxnKVProxy) TxnDelete(id UUID, key []byte) error { } // TxnCommit commits the transaction operations to TiKV server. -func (p TxnKVProxy) TxnCommit(id UUID) error { +func (p TxnKVProxy) TxnCommit(ctx context.Context) error { + id := uuidFromContext(ctx) txn, ok := p.txns.Load(id) if !ok { return errors.WithStack(ErrTxnNotFound) @@ -180,7 +182,8 @@ func (p TxnKVProxy) TxnCommit(id UUID) error { } // TxnRollback undoes the transaction operations to TiKV server. -func (p TxnKVProxy) TxnRollback(id UUID) error { +func (p TxnKVProxy) TxnRollback(ctx context.Context) error { + id := uuidFromContext(ctx) txn, ok := p.txns.Load(id) if !ok { return errors.WithStack(ErrTxnNotFound) @@ -190,8 +193,8 @@ func (p TxnKVProxy) TxnRollback(id UUID) error { } // TxnLockKeys tries to lock the entries with the keys in TiKV server. -func (p TxnKVProxy) TxnLockKeys(id UUID, keys [][]byte) error { - txn, ok := p.txns.Load(id) +func (p TxnKVProxy) TxnLockKeys(ctx context.Context, keys [][]byte) error { + txn, ok := p.txns.Load(uuidFromContext(ctx)) if !ok { return errors.WithStack(ErrTxnNotFound) } @@ -200,8 +203,8 @@ func (p TxnKVProxy) TxnLockKeys(id UUID, keys [][]byte) error { } // TxnValid returns if the transaction is valid. -func (p TxnKVProxy) TxnValid(id UUID) (bool, error) { - txn, ok := p.txns.Load(id) +func (p TxnKVProxy) TxnValid(ctx context.Context) (bool, error) { + txn, ok := p.txns.Load(uuidFromContext(ctx)) if !ok { return false, errors.WithStack(ErrTxnNotFound) } @@ -209,8 +212,8 @@ func (p TxnKVProxy) TxnValid(id UUID) (bool, error) { } // TxnLen returns the count of key-value pairs in the transaction's memory buffer. -func (p TxnKVProxy) TxnLen(id UUID) (int, error) { - txn, ok := p.txns.Load(id) +func (p TxnKVProxy) TxnLen(ctx context.Context) (int, error) { + txn, ok := p.txns.Load(uuidFromContext(ctx)) if !ok { return 0, errors.WithStack(ErrTxnNotFound) } @@ -218,8 +221,8 @@ func (p TxnKVProxy) TxnLen(id UUID) (int, error) { } // TxnSize returns the length (in bytes) of the transaction's memory buffer. -func (p TxnKVProxy) TxnSize(id UUID) (int, error) { - txn, ok := p.txns.Load(id) +func (p TxnKVProxy) TxnSize(ctx context.Context) (int, error) { + txn, ok := p.txns.Load(uuidFromContext(ctx)) if !ok { return 0, errors.WithStack(ErrTxnNotFound) } @@ -227,8 +230,8 @@ func (p TxnKVProxy) TxnSize(id UUID) (int, error) { } // IterValid returns if the iterator is valid to use. -func (p TxnKVProxy) IterValid(id UUID) (bool, error) { - iter, ok := p.iterators.Load(id) +func (p TxnKVProxy) IterValid(ctx context.Context) (bool, error) { + iter, ok := p.iterators.Load(uuidFromContext(ctx)) if !ok { return false, errors.WithStack(ErrIterNotFound) } @@ -236,8 +239,8 @@ func (p TxnKVProxy) IterValid(id UUID) (bool, error) { } // IterKey returns the key which the iterator points to. -func (p TxnKVProxy) IterKey(id UUID) ([]byte, error) { - iter, ok := p.iterators.Load(id) +func (p TxnKVProxy) IterKey(ctx context.Context) ([]byte, error) { + iter, ok := p.iterators.Load(uuidFromContext(ctx)) if !ok { return nil, errors.WithStack(ErrIterNotFound) } @@ -245,8 +248,8 @@ func (p TxnKVProxy) IterKey(id UUID) ([]byte, error) { } // IterValue returns the value which the iterator points to. -func (p TxnKVProxy) IterValue(id UUID) ([]byte, error) { - iter, ok := p.iterators.Load(id) +func (p TxnKVProxy) IterValue(ctx context.Context) ([]byte, error) { + iter, ok := p.iterators.Load(uuidFromContext(ctx)) if !ok { return nil, errors.WithStack(ErrIterNotFound) } @@ -254,16 +257,17 @@ func (p TxnKVProxy) IterValue(id UUID) ([]byte, error) { } // IterNext moves the iterator to next entry. -func (p TxnKVProxy) IterNext(id UUID) error { - iter, ok := p.iterators.Load(id) +func (p TxnKVProxy) IterNext(ctx context.Context) error { + iter, ok := p.iterators.Load(uuidFromContext(ctx)) if !ok { return errors.WithStack(ErrIterNotFound) } - return iter.(kv.Iterator).Next() + return iter.(kv.Iterator).Next(ctx) } // IterClose releases an iterator. -func (p TxnKVProxy) IterClose(id UUID) error { +func (p TxnKVProxy) IterClose(ctx context.Context) error { + id := uuidFromContext(ctx) iter, ok := p.iterators.Load(id) if !ok { return errors.WithStack(ErrIterNotFound) diff --git a/proxy/utils.go b/proxy/utils.go index edece822..3fcabe00 100644 --- a/proxy/utils.go +++ b/proxy/utils.go @@ -14,6 +14,7 @@ package proxy import ( + "context" "errors" "sync" @@ -27,6 +28,10 @@ var ( ErrIterNotFound = errors.New("iterator not found") ) +type ContextKey int + +var UUIDKey ContextKey = 1 + // UUID is a global unique ID to identify clients, transactions, or iterators. type UUID string @@ -38,3 +43,10 @@ func insertWithRetry(m *sync.Map, d interface{}) UUID { } } } + +func uuidFromContext(ctx context.Context) UUID { + if id := ctx.Value(UUIDKey); id != nil { + return id.(UUID) + } + return "" +} diff --git a/rawkv/rawkv.go b/rawkv/rawkv.go index 20845e01..08d40775 100644 --- a/rawkv/rawkv.go +++ b/rawkv/rawkv.go @@ -44,7 +44,7 @@ type Client struct { } // NewClient creates a client with PD cluster addrs. -func NewClient(pdAddrs []string, conf config.Config) (*Client, error) { +func NewClient(ctx context.Context, pdAddrs []string, conf config.Config) (*Client, error) { pdCli, err := pd.NewClient(pdAddrs, pd.SecurityOption{ CAPath: conf.RPC.Security.SSLCA, CertPath: conf.RPC.Security.SSLCert, @@ -54,7 +54,7 @@ func NewClient(pdAddrs []string, conf config.Config) (*Client, error) { return nil, err } return &Client{ - clusterID: pdCli.GetClusterID(context.TODO()), + clusterID: pdCli.GetClusterID(ctx), conf: &conf, regionCache: locate.NewRegionCache(pdCli, &conf.RegionCache), pdClient: pdCli, @@ -74,7 +74,7 @@ func (c *Client) ClusterID() uint64 { } // Get queries value with the key. When the key does not exist, it returns `nil, nil`. -func (c *Client) Get(key []byte) ([]byte, error) { +func (c *Client) Get(ctx context.Context, key []byte) ([]byte, error) { start := time.Now() defer func() { metrics.RawkvCmdHistogram.WithLabelValues("get").Observe(time.Since(start).Seconds()) }() @@ -84,7 +84,7 @@ func (c *Client) Get(key []byte) ([]byte, error) { Key: key, }, } - resp, _, err := c.sendReq(key, req) + resp, _, err := c.sendReq(ctx, key, req) if err != nil { return nil, err } @@ -102,13 +102,13 @@ func (c *Client) Get(key []byte) ([]byte, error) { } // BatchGet queries values with the keys. -func (c *Client) BatchGet(keys [][]byte) ([][]byte, error) { +func (c *Client) BatchGet(ctx context.Context, keys [][]byte) ([][]byte, error) { start := time.Now() defer func() { metrics.RawkvCmdHistogram.WithLabelValues("batch_get").Observe(time.Since(start).Seconds()) }() - bo := retry.NewBackoffer(context.Background(), retry.RawkvMaxBackoff) + bo := retry.NewBackoffer(ctx, retry.RawkvMaxBackoff) resp, err := c.sendBatchReq(bo, keys, rpc.CmdRawBatchGet) if err != nil { return nil, err @@ -132,7 +132,7 @@ func (c *Client) BatchGet(keys [][]byte) ([][]byte, error) { } // Put stores a key-value pair to TiKV. -func (c *Client) Put(key, value []byte) error { +func (c *Client) Put(ctx context.Context, key, value []byte) error { start := time.Now() defer func() { metrics.RawkvCmdHistogram.WithLabelValues("put").Observe(time.Since(start).Seconds()) }() metrics.RawkvSizeHistogram.WithLabelValues("key").Observe(float64(len(key))) @@ -149,7 +149,7 @@ func (c *Client) Put(key, value []byte) error { Value: value, }, } - resp, _, err := c.sendReq(key, req) + resp, _, err := c.sendReq(ctx, key, req) if err != nil { return err } @@ -164,7 +164,7 @@ func (c *Client) Put(key, value []byte) error { } // BatchPut stores key-value pairs to TiKV. -func (c *Client) BatchPut(keys, values [][]byte) error { +func (c *Client) BatchPut(ctx context.Context, keys, values [][]byte) error { start := time.Now() defer func() { metrics.RawkvCmdHistogram.WithLabelValues("batch_put").Observe(time.Since(start).Seconds()) @@ -178,12 +178,12 @@ func (c *Client) BatchPut(keys, values [][]byte) error { return errors.New("empty value is not supported") } } - bo := retry.NewBackoffer(context.Background(), retry.RawkvMaxBackoff) + bo := retry.NewBackoffer(ctx, retry.RawkvMaxBackoff) return c.sendBatchPut(bo, keys, values) } // Delete deletes a key-value pair from TiKV. -func (c *Client) Delete(key []byte) error { +func (c *Client) Delete(ctx context.Context, key []byte) error { start := time.Now() defer func() { metrics.RawkvCmdHistogram.WithLabelValues("delete").Observe(time.Since(start).Seconds()) }() @@ -193,7 +193,7 @@ func (c *Client) Delete(key []byte) error { Key: key, }, } - resp, _, err := c.sendReq(key, req) + resp, _, err := c.sendReq(ctx, key, req) if err != nil { return err } @@ -208,13 +208,13 @@ func (c *Client) Delete(key []byte) error { } // BatchDelete deletes key-value pairs from TiKV. -func (c *Client) BatchDelete(keys [][]byte) error { +func (c *Client) BatchDelete(ctx context.Context, keys [][]byte) error { start := time.Now() defer func() { metrics.RawkvCmdHistogram.WithLabelValues("batch_delete").Observe(time.Since(start).Seconds()) }() - bo := retry.NewBackoffer(context.Background(), retry.RawkvMaxBackoff) + bo := retry.NewBackoffer(ctx, retry.RawkvMaxBackoff) resp, err := c.sendBatchReq(bo, keys, rpc.CmdRawBatchDelete) if err != nil { return err @@ -230,7 +230,7 @@ func (c *Client) BatchDelete(keys [][]byte) error { } // DeleteRange deletes all key-value pairs in a range from TiKV -func (c *Client) DeleteRange(startKey []byte, endKey []byte) error { +func (c *Client) DeleteRange(ctx context.Context, startKey []byte, endKey []byte) error { start := time.Now() var err error defer func() { @@ -245,7 +245,7 @@ func (c *Client) DeleteRange(startKey []byte, endKey []byte) error { for !bytes.Equal(startKey, endKey) { var resp *rpc.Response var actualEndKey []byte - resp, actualEndKey, err = c.sendDeleteRangeReq(startKey, endKey) + resp, actualEndKey, err = c.sendDeleteRangeReq(ctx, startKey, endKey) if err != nil { return err } @@ -267,7 +267,7 @@ func (c *Client) DeleteRange(startKey []byte, endKey []byte) error { // If you want to exclude the startKey or include the endKey, append a '\0' to the key. For example, to scan // (startKey, endKey], you can write: // `Scan(append(startKey, '\0'), append(endKey, '\0'), limit)`. -func (c *Client) Scan(startKey, endKey []byte, limit int) (keys [][]byte, values [][]byte, err error) { +func (c *Client) Scan(ctx context.Context, startKey, endKey []byte, limit int) (keys [][]byte, values [][]byte, err error) { start := time.Now() defer func() { metrics.RawkvCmdHistogram.WithLabelValues("raw_scan").Observe(time.Since(start).Seconds()) }() @@ -284,7 +284,7 @@ func (c *Client) Scan(startKey, endKey []byte, limit int) (keys [][]byte, values Limit: uint32(limit - len(keys)), }, } - resp, loc, err := c.sendReq(startKey, req) + resp, loc, err := c.sendReq(ctx, startKey, req) if err != nil { return nil, nil, err } @@ -311,7 +311,7 @@ func (c *Client) Scan(startKey, endKey []byte, limit int) (keys [][]byte, values // (endKey, startKey], you can write: // `ReverseScan(append(startKey, '\0'), append(endKey, '\0'), limit)`. // It doesn't support Scanning from "", because locating the last Region is not yet implemented. -func (c *Client) ReverseScan(startKey, endKey []byte, limit int) (keys [][]byte, values [][]byte, err error) { +func (c *Client) ReverseScan(ctx context.Context, startKey, endKey []byte, limit int) (keys [][]byte, values [][]byte, err error) { start := time.Now() defer func() { metrics.RawkvCmdHistogram.WithLabelValues("raw_reverse_scan").Observe(time.Since(start).Seconds()) @@ -331,7 +331,7 @@ func (c *Client) ReverseScan(startKey, endKey []byte, limit int) (keys [][]byte, Reverse: true, }, } - resp, loc, err := c.sendReq(startKey, req) + resp, loc, err := c.sendReq(ctx, startKey, req) if err != nil { return nil, nil, err } @@ -351,8 +351,8 @@ func (c *Client) ReverseScan(startKey, endKey []byte, limit int) (keys [][]byte, return } -func (c *Client) sendReq(key []byte, req *rpc.Request) (*rpc.Response, *locate.KeyLocation, error) { - bo := retry.NewBackoffer(context.Background(), retry.RawkvMaxBackoff) +func (c *Client) sendReq(ctx context.Context, key []byte, req *rpc.Request) (*rpc.Response, *locate.KeyLocation, error) { + bo := retry.NewBackoffer(ctx, retry.RawkvMaxBackoff) sender := rpc.NewRegionRequestSender(c.regionCache, c.rpcClient) for { loc, err := c.regionCache.LocateKey(bo, key) @@ -491,8 +491,8 @@ func (c *Client) doBatchReq(bo *retry.Backoffer, batch batch, cmdType rpc.CmdTyp // If the given range spans over more than one regions, the actual endKey is the end of the first region. // We can't use sendReq directly, because we need to know the end of the region before we send the request // TODO: Is there any better way to avoid duplicating code with func `sendReq` ? -func (c *Client) sendDeleteRangeReq(startKey []byte, endKey []byte) (*rpc.Response, []byte, error) { - bo := retry.NewBackoffer(context.Background(), retry.RawkvMaxBackoff) +func (c *Client) sendDeleteRangeReq(ctx context.Context, startKey []byte, endKey []byte) (*rpc.Response, []byte, error) { + bo := retry.NewBackoffer(ctx, retry.RawkvMaxBackoff) sender := rpc.NewRegionRequestSender(c.regionCache, c.rpcClient) for { loc, err := c.regionCache.LocateKey(bo, startKey) diff --git a/rawkv/rawkv_test.go b/rawkv/rawkv_test.go index e03f360c..f648ef6b 100644 --- a/rawkv/rawkv_test.go +++ b/rawkv/rawkv_test.go @@ -59,13 +59,13 @@ func (s *testRawKVSuite) TearDownTest(c *C) { } func (s *testRawKVSuite) mustNotExist(c *C, key []byte) { - v, err := s.client.Get(key) + v, err := s.client.Get(context.TODO(), key) c.Assert(err, IsNil) c.Assert(v, IsNil) } func (s *testRawKVSuite) mustBatchNotExist(c *C, keys [][]byte) { - values, err := s.client.BatchGet(keys) + values, err := s.client.BatchGet(context.TODO(), keys) c.Assert(err, IsNil) c.Assert(values, NotNil) c.Assert(len(keys), Equals, len(values)) @@ -75,14 +75,14 @@ func (s *testRawKVSuite) mustBatchNotExist(c *C, keys [][]byte) { } func (s *testRawKVSuite) mustGet(c *C, key, value []byte) { - v, err := s.client.Get(key) + v, err := s.client.Get(context.TODO(), key) c.Assert(err, IsNil) c.Assert(v, NotNil) c.Assert(v, BytesEquals, value) } func (s *testRawKVSuite) mustBatchGet(c *C, keys, values [][]byte) { - checkValues, err := s.client.BatchGet(keys) + checkValues, err := s.client.BatchGet(context.TODO(), keys) c.Assert(err, IsNil) c.Assert(checkValues, NotNil) c.Assert(len(keys), Equals, len(checkValues)) @@ -92,27 +92,27 @@ func (s *testRawKVSuite) mustBatchGet(c *C, keys, values [][]byte) { } func (s *testRawKVSuite) mustPut(c *C, key, value []byte) { - err := s.client.Put(key, value) + err := s.client.Put(context.TODO(), key, value) c.Assert(err, IsNil) } func (s *testRawKVSuite) mustBatchPut(c *C, keys, values [][]byte) { - err := s.client.BatchPut(keys, values) + err := s.client.BatchPut(context.TODO(), keys, values) c.Assert(err, IsNil) } func (s *testRawKVSuite) mustDelete(c *C, key []byte) { - err := s.client.Delete(key) + err := s.client.Delete(context.TODO(), key) c.Assert(err, IsNil) } func (s *testRawKVSuite) mustBatchDelete(c *C, keys [][]byte) { - err := s.client.BatchDelete(keys) + err := s.client.BatchDelete(context.TODO(), keys) c.Assert(err, IsNil) } func (s *testRawKVSuite) mustScan(c *C, startKey string, limit int, expect ...string) { - keys, values, err := s.client.Scan([]byte(startKey), nil, limit) + keys, values, err := s.client.Scan(context.TODO(), []byte(startKey), nil, limit) c.Assert(err, IsNil) c.Assert(len(keys)*2, Equals, len(expect)) for i := range keys { @@ -122,7 +122,7 @@ func (s *testRawKVSuite) mustScan(c *C, startKey string, limit int, expect ...st } func (s *testRawKVSuite) mustScanRange(c *C, startKey string, endKey string, limit int, expect ...string) { - keys, values, err := s.client.Scan([]byte(startKey), []byte(endKey), limit) + keys, values, err := s.client.Scan(context.TODO(), []byte(startKey), []byte(endKey), limit) c.Assert(err, IsNil) c.Assert(len(keys)*2, Equals, len(expect)) for i := range keys { @@ -132,7 +132,7 @@ func (s *testRawKVSuite) mustScanRange(c *C, startKey string, endKey string, lim } func (s *testRawKVSuite) mustReverseScan(c *C, startKey []byte, limit int, expect ...string) { - keys, values, err := s.client.ReverseScan(startKey, nil, limit) + keys, values, err := s.client.ReverseScan(context.TODO(), startKey, nil, limit) c.Assert(err, IsNil) c.Assert(len(keys)*2, Equals, len(expect)) for i := range keys { @@ -142,7 +142,7 @@ func (s *testRawKVSuite) mustReverseScan(c *C, startKey []byte, limit int, expec } func (s *testRawKVSuite) mustReverseScanRange(c *C, startKey, endKey []byte, limit int, expect ...string) { - keys, values, err := s.client.ReverseScan(startKey, endKey, limit) + keys, values, err := s.client.ReverseScan(context.TODO(), startKey, endKey, limit) c.Assert(err, IsNil) c.Assert(len(keys)*2, Equals, len(expect)) for i := range keys { @@ -152,7 +152,7 @@ func (s *testRawKVSuite) mustReverseScanRange(c *C, startKey, endKey []byte, lim } func (s *testRawKVSuite) mustDeleteRange(c *C, startKey, endKey []byte, expected map[string]string) { - err := s.client.DeleteRange(startKey, endKey) + err := s.client.DeleteRange(context.TODO(), startKey, endKey) c.Assert(err, IsNil) for keyStr := range expected { @@ -166,7 +166,7 @@ func (s *testRawKVSuite) mustDeleteRange(c *C, startKey, endKey []byte, expected } func (s *testRawKVSuite) checkData(c *C, expected map[string]string) { - keys, values, err := s.client.Scan([]byte(""), nil, len(expected)+1) + keys, values, err := s.client.Scan(context.TODO(), []byte(""), nil, len(expected)+1) c.Assert(err, IsNil) c.Assert(len(expected), Equals, len(keys)) @@ -192,11 +192,12 @@ func (s *testRawKVSuite) TestSimple(c *C) { s.mustGet(c, []byte("key"), []byte("value")) s.mustDelete(c, []byte("key")) s.mustNotExist(c, []byte("key")) - err := s.client.Put([]byte("key"), []byte("")) + err := s.client.Put(context.TODO(), []byte("key"), []byte("")) c.Assert(err, NotNil) } func (s *testRawKVSuite) TestRawBatch(c *C) { + testNum := 0 size := 0 var testKeys [][]byte diff --git a/txnkv/client.go b/txnkv/client.go index 84e037cf..11d08d55 100644 --- a/txnkv/client.go +++ b/txnkv/client.go @@ -27,8 +27,8 @@ type Client struct { } // NewClient creates a client with PD addresses. -func NewClient(pdAddrs []string, config config.Config) (*Client, error) { - tikvStore, err := store.NewStore(pdAddrs, config) +func NewClient(ctx context.Context, pdAddrs []string, config config.Config) (*Client, error) { + tikvStore, err := store.NewStore(ctx, pdAddrs, config) if err != nil { return nil, err } @@ -43,20 +43,20 @@ func (c *Client) Close() error { } // Begin creates a transaction for read/write. -func (c *Client) Begin() (*Transaction, error) { - ts, err := c.GetTS() +func (c *Client) Begin(ctx context.Context) (*Transaction, error) { + ts, err := c.GetTS(ctx) if err != nil { return nil, err } - return c.BeginWithTS(ts), nil + return c.BeginWithTS(ctx, ts), nil } // BeginWithTS creates a transaction which is normally readonly. -func (c *Client) BeginWithTS(ts uint64) *Transaction { +func (c *Client) BeginWithTS(ctx context.Context, ts uint64) *Transaction { return newTransaction(c.tikvStore, ts) } // GetTS returns a latest timestamp. -func (c *Client) GetTS() (uint64, error) { - return c.tikvStore.GetTimestampWithRetry(retry.NewBackoffer(context.TODO(), retry.TsoMaxBackoff)) +func (c *Client) GetTS(ctx context.Context) (uint64, error) { + return c.tikvStore.GetTimestampWithRetry(retry.NewBackoffer(ctx, retry.TsoMaxBackoff)) } diff --git a/txnkv/kv/buffer_store.go b/txnkv/kv/buffer_store.go index d9043d3f..2da260b1 100644 --- a/txnkv/kv/buffer_store.go +++ b/txnkv/kv/buffer_store.go @@ -14,6 +14,8 @@ package kv import ( + "context" + "github.com/tikv/client-go/config" "github.com/tikv/client-go/key" ) @@ -49,10 +51,10 @@ func (s *BufferStore) SetCap(cap int) { } // Get implements the Retriever interface. -func (s *BufferStore) Get(k key.Key) ([]byte, error) { - val, err := s.MemBuffer.Get(k) +func (s *BufferStore) Get(ctx context.Context, k key.Key) ([]byte, error) { + val, err := s.MemBuffer.Get(ctx, k) if IsErrNotFound(err) { - val, err = s.r.Get(k) + val, err = s.r.Get(ctx, k) } if err != nil { return nil, err @@ -64,29 +66,29 @@ func (s *BufferStore) Get(k key.Key) ([]byte, error) { } // Iter implements the Retriever interface. -func (s *BufferStore) Iter(k key.Key, upperBound key.Key) (Iterator, error) { - bufferIt, err := s.MemBuffer.Iter(k, upperBound) +func (s *BufferStore) Iter(ctx context.Context, k key.Key, upperBound key.Key) (Iterator, error) { + bufferIt, err := s.MemBuffer.Iter(ctx, k, upperBound) if err != nil { return nil, err } - retrieverIt, err := s.r.Iter(k, upperBound) + retrieverIt, err := s.r.Iter(ctx, k, upperBound) if err != nil { return nil, err } - return NewUnionIter(bufferIt, retrieverIt, false) + return NewUnionIter(ctx, bufferIt, retrieverIt, false) } // IterReverse implements the Retriever interface. -func (s *BufferStore) IterReverse(k key.Key) (Iterator, error) { - bufferIt, err := s.MemBuffer.IterReverse(k) +func (s *BufferStore) IterReverse(ctx context.Context, k key.Key) (Iterator, error) { + bufferIt, err := s.MemBuffer.IterReverse(ctx, k) if err != nil { return nil, err } - retrieverIt, err := s.r.IterReverse(k) + retrieverIt, err := s.r.IterReverse(ctx, k) if err != nil { return nil, err } - return NewUnionIter(bufferIt, retrieverIt, true) + return NewUnionIter(ctx, bufferIt, retrieverIt, true) } // WalkBuffer iterates all buffered kv pairs. diff --git a/txnkv/kv/buffer_store_test.go b/txnkv/kv/buffer_store_test.go index 06f2aea7..1e7ec4fa 100644 --- a/txnkv/kv/buffer_store_test.go +++ b/txnkv/kv/buffer_store_test.go @@ -15,6 +15,7 @@ package kv import ( "bytes" + "context" "fmt" "testing" @@ -35,13 +36,13 @@ func (s testBufferStoreSuite) TestGetSet(c *C) { conf := config.DefaultTxn() bs := NewBufferStore(&mockSnapshot{NewMemDbBuffer(&conf, 0)}, &conf) key := key.Key("key") - _, err := bs.Get(key) + _, err := bs.Get(context.TODO(), key) c.Check(err, NotNil) err = bs.Set(key, []byte("value")) c.Check(err, IsNil) - value, err := bs.Get(key) + value, err := bs.Get(context.TODO(), key) c.Check(err, IsNil) c.Check(bytes.Compare(value, []byte("value")), Equals, 0) } @@ -62,11 +63,11 @@ func (s testBufferStoreSuite) TestSaveTo(c *C) { err := bs.SaveTo(mutator) c.Check(err, IsNil) - iter, err := mutator.Iter(nil, nil) + iter, err := mutator.Iter(context.TODO(), nil, nil) c.Check(err, IsNil) for iter.Valid() { cmp := bytes.Compare(iter.Key(), iter.Value()) c.Check(cmp, Equals, 0) - iter.Next() + iter.Next(context.TODO()) } } diff --git a/txnkv/kv/kv.go b/txnkv/kv/kv.go index 7e0aa8c8..cd62bd5a 100644 --- a/txnkv/kv/kv.go +++ b/txnkv/kv/kv.go @@ -13,7 +13,11 @@ package kv -import "github.com/tikv/client-go/key" +import ( + "context" + + "github.com/tikv/client-go/key" +) // Priority value for transaction priority. const ( @@ -36,18 +40,18 @@ const ( type Retriever interface { // Get gets the value for key k from kv store. // If corresponding kv pair does not exist, it returns nil and ErrNotExist. - Get(k key.Key) ([]byte, error) + Get(ctx context.Context, k key.Key) ([]byte, error) // Iter creates an Iterator positioned on the first entry that k <= entry's key. // If such entry is not found, it returns an invalid Iterator with no error. // It yields only keys that < upperBound. If upperBound is nil, it means the upperBound is unbounded. // The Iterator must be closed after use. - Iter(k key.Key, upperBound key.Key) (Iterator, error) + Iter(ctx context.Context, k key.Key, upperBound key.Key) (Iterator, error) // IterReverse creates a reversed Iterator positioned on the first entry which key is less than k. // The returned iterator will iterate from greater key to smaller key. // If k is nil, the returned iterator will be positioned at the last key. // TODO: Add lower bound limit - IterReverse(k key.Key) (Iterator, error) + IterReverse(ctx context.Context, k key.Key) (Iterator, error) } // Mutator is the interface wraps the basic Set and Delete methods. @@ -83,7 +87,7 @@ type MemBuffer interface { type Snapshot interface { Retriever // BatchGet gets a batch of values from snapshot. - BatchGet(keys []key.Key) (map[string][]byte, error) + BatchGet(ctx context.Context, keys []key.Key) (map[string][]byte, error) // SetPriority snapshot set the priority SetPriority(priority int) } @@ -93,7 +97,7 @@ type Iterator interface { Valid() bool Key() key.Key Value() []byte - Next() error + Next(context.Context) error Close() } diff --git a/txnkv/kv/mem_buffer_test.go b/txnkv/kv/mem_buffer_test.go index 4f4b3857..46810565 100644 --- a/txnkv/kv/mem_buffer_test.go +++ b/txnkv/kv/mem_buffer_test.go @@ -16,6 +16,7 @@ package kv import ( + "context" "fmt" "math/rand" "testing" @@ -73,7 +74,7 @@ func valToStr(c *C, iter Iterator) string { func checkNewIterator(c *C, buffer MemBuffer) { for i := startIndex; i < testCount; i++ { val := encodeInt(i * indexStep) - iter, err := buffer.Iter(val, nil) + iter, err := buffer.Iter(context.TODO(), val, nil) c.Assert(err, IsNil) c.Assert([]byte(iter.Key()), BytesEquals, val) c.Assert(decodeInt([]byte(valToStr(c, iter))), Equals, i*indexStep) @@ -83,12 +84,12 @@ func checkNewIterator(c *C, buffer MemBuffer) { // Test iterator Next() for i := startIndex; i < testCount-1; i++ { val := encodeInt(i * indexStep) - iter, err := buffer.Iter(val, nil) + iter, err := buffer.Iter(context.TODO(), val, nil) c.Assert(err, IsNil) c.Assert([]byte(iter.Key()), BytesEquals, val) c.Assert(valToStr(c, iter), Equals, string(val)) - err = iter.Next() + err = iter.Next(context.TODO()) c.Assert(err, IsNil) c.Assert(iter.Valid(), IsTrue) @@ -99,7 +100,7 @@ func checkNewIterator(c *C, buffer MemBuffer) { } // Non exist and beyond maximum seek test - iter, err := buffer.Iter(encodeInt(testCount*indexStep), nil) + iter, err := buffer.Iter(context.TODO(), encodeInt(testCount*indexStep), nil) c.Assert(err, IsNil) c.Assert(iter.Valid(), IsFalse) @@ -107,7 +108,7 @@ func checkNewIterator(c *C, buffer MemBuffer) { // it returns the smallest key that larger than the one we are seeking inBetween := encodeInt((testCount-1)*indexStep - 1) last := encodeInt((testCount - 1) * indexStep) - iter, err = buffer.Iter(inBetween, nil) + iter, err = buffer.Iter(context.TODO(), inBetween, nil) c.Assert(err, IsNil) c.Assert(iter.Valid(), IsTrue) c.Assert([]byte(iter.Key()), Not(BytesEquals), inBetween) @@ -118,7 +119,7 @@ func checkNewIterator(c *C, buffer MemBuffer) { func mustGet(c *C, buffer MemBuffer) { for i := startIndex; i < testCount; i++ { s := encodeInt(i * indexStep) - val, err := buffer.Get(s) + val, err := buffer.Get(context.TODO(), s) c.Assert(err, IsNil) c.Assert(string(val), Equals, string(s)) } @@ -135,7 +136,7 @@ func (s *testKVSuite) TestGetSet(c *C) { func (s *testKVSuite) TestNewIterator(c *C) { for _, buffer := range s.bs { // should be invalid - iter, err := buffer.Iter(nil, nil) + iter, err := buffer.Iter(context.TODO(), nil, nil) c.Assert(err, IsNil) c.Assert(iter.Valid(), IsFalse) @@ -147,7 +148,7 @@ func (s *testKVSuite) TestNewIterator(c *C) { func (s *testKVSuite) TestBasicNewIterator(c *C) { for _, buffer := range s.bs { - it, err := buffer.Iter([]byte("2"), nil) + it, err := buffer.Iter(context.TODO(), []byte("2"), nil) c.Assert(err, IsNil) c.Assert(it.Valid(), IsFalse) } @@ -171,15 +172,15 @@ func (s *testKVSuite) TestNewIteratorMin(c *C) { } cnt := 0 - it, err := buffer.Iter(nil, nil) + it, err := buffer.Iter(context.TODO(), nil, nil) c.Assert(err, IsNil) for it.Valid() { cnt++ - it.Next() + it.Next(context.TODO()) } c.Assert(cnt, Equals, 6) - it, err = buffer.Iter([]byte("DATA_test_main_db_tbl_tbl_test_record__00000000000000000000"), nil) + it, err = buffer.Iter(context.TODO(), []byte("DATA_test_main_db_tbl_tbl_test_record__00000000000000000000"), nil) c.Assert(err, IsNil) c.Assert(string(it.Key()), Equals, "DATA_test_main_db_tbl_tbl_test_record__00000000000000000001") } @@ -265,7 +266,7 @@ func benchmarkSetGet(b *testing.B, buffer MemBuffer, data [][]byte) { buffer.Set(k, k) } for _, k := range data { - buffer.Get(k) + buffer.Get(context.TODO(), k) } } } @@ -276,12 +277,12 @@ func benchIterator(b *testing.B, buffer MemBuffer) { } b.ResetTimer() for i := 0; i < b.N; i++ { - iter, err := buffer.Iter(nil, nil) + iter, err := buffer.Iter(context.TODO(), nil, nil) if err != nil { b.Error(err) } for iter.Valid() { - iter.Next() + iter.Next(context.TODO()) } iter.Close() } diff --git a/txnkv/kv/memdb_buffer.go b/txnkv/kv/memdb_buffer.go index c27cd0cc..15286086 100644 --- a/txnkv/kv/memdb_buffer.go +++ b/txnkv/kv/memdb_buffer.go @@ -16,6 +16,7 @@ package kv import ( + "context" "fmt" "github.com/pingcap/goleveldb/leveldb" @@ -55,10 +56,10 @@ func NewMemDbBuffer(conf *config.Txn, cap int) MemBuffer { } // Iter creates an Iterator. -func (m *memDbBuffer) Iter(k key.Key, upperBound key.Key) (Iterator, error) { +func (m *memDbBuffer) Iter(ctx context.Context, k key.Key, upperBound key.Key) (Iterator, error) { i := &memDbIter{iter: m.db.NewIterator(&util.Range{Start: []byte(k), Limit: []byte(upperBound)}), reverse: false} - err := i.Next() + err := i.Next(ctx) if err != nil { return nil, errors.WithStack(err) } @@ -69,7 +70,7 @@ func (m *memDbBuffer) SetCap(cap int) { } -func (m *memDbBuffer) IterReverse(k key.Key) (Iterator, error) { +func (m *memDbBuffer) IterReverse(ctx context.Context, k key.Key) (Iterator, error) { var i *memDbIter if k == nil { i = &memDbIter{iter: m.db.NewIterator(&util.Range{}), reverse: true} @@ -81,7 +82,7 @@ func (m *memDbBuffer) IterReverse(k key.Key) (Iterator, error) { } // Get returns the value associated with key. -func (m *memDbBuffer) Get(k key.Key) ([]byte, error) { +func (m *memDbBuffer) Get(ctx context.Context, k key.Key) ([]byte, error) { v, err := m.db.Get(k) if err == leveldb.ErrNotFound { return nil, ErrNotExist @@ -130,7 +131,7 @@ func (m *memDbBuffer) Reset() { } // Next implements the Iterator Next. -func (i *memDbIter) Next() error { +func (i *memDbIter) Next(context.Context) error { if i.reverse { i.iter.Prev() } else { @@ -161,7 +162,7 @@ func (i *memDbIter) Close() { // WalkMemBuffer iterates all buffered kv pairs in memBuf func WalkMemBuffer(memBuf MemBuffer, f func(k key.Key, v []byte) error) error { - iter, err := memBuf.Iter(nil, nil) + iter, err := memBuf.Iter(context.Background(), nil, nil) if err != nil { return errors.WithStack(err) } @@ -171,7 +172,7 @@ func WalkMemBuffer(memBuf MemBuffer, f func(k key.Key, v []byte) error) error { if err = f(iter.Key(), iter.Value()); err != nil { return errors.WithStack(err) } - err = iter.Next() + err = iter.Next(context.Background()) if err != nil { return errors.WithStack(err) } diff --git a/txnkv/kv/mock.go b/txnkv/kv/mock.go index ea77a295..26ac2f43 100644 --- a/txnkv/kv/mock.go +++ b/txnkv/kv/mock.go @@ -14,6 +14,8 @@ package kv import ( + "context" + "github.com/tikv/client-go/key" ) @@ -21,18 +23,18 @@ type mockSnapshot struct { store MemBuffer } -func (s *mockSnapshot) Get(k key.Key) ([]byte, error) { - return s.store.Get(k) +func (s *mockSnapshot) Get(ctx context.Context, k key.Key) ([]byte, error) { + return s.store.Get(ctx, k) } func (s *mockSnapshot) SetPriority(priority int) { } -func (s *mockSnapshot) BatchGet(keys []key.Key) (map[string][]byte, error) { +func (s *mockSnapshot) BatchGet(ctx context.Context, keys []key.Key) (map[string][]byte, error) { m := make(map[string][]byte) for _, k := range keys { - v, err := s.store.Get(k) + v, err := s.store.Get(ctx, k) if IsErrNotFound(err) { continue } @@ -44,10 +46,10 @@ func (s *mockSnapshot) BatchGet(keys []key.Key) (map[string][]byte, error) { return m, nil } -func (s *mockSnapshot) Iter(k key.Key, upperBound key.Key) (Iterator, error) { - return s.store.Iter(k, upperBound) +func (s *mockSnapshot) Iter(ctx context.Context, k key.Key, upperBound key.Key) (Iterator, error) { + return s.store.Iter(ctx, k, upperBound) } -func (s *mockSnapshot) IterReverse(k key.Key) (Iterator, error) { - return s.store.IterReverse(k) +func (s *mockSnapshot) IterReverse(ctx context.Context, k key.Key) (Iterator, error) { + return s.store.IterReverse(ctx, k) } diff --git a/txnkv/kv/union_iter.go b/txnkv/kv/union_iter.go index d85f727a..7f16d538 100644 --- a/txnkv/kv/union_iter.go +++ b/txnkv/kv/union_iter.go @@ -14,6 +14,8 @@ package kv import ( + "context" + log "github.com/sirupsen/logrus" "github.com/tikv/client-go/key" ) @@ -32,7 +34,7 @@ type UnionIter struct { } // NewUnionIter returns a union iterator for BufferStore. -func NewUnionIter(dirtyIt Iterator, snapshotIt Iterator, reverse bool) (*UnionIter, error) { +func NewUnionIter(ctx context.Context, dirtyIt Iterator, snapshotIt Iterator, reverse bool) (*UnionIter, error) { it := &UnionIter{ dirtyIt: dirtyIt, snapshotIt: snapshotIt, @@ -40,7 +42,7 @@ func NewUnionIter(dirtyIt Iterator, snapshotIt Iterator, reverse bool) (*UnionIt snapshotValid: snapshotIt.Valid(), reverse: reverse, } - err := it.updateCur() + err := it.updateCur(ctx) if err != nil { return nil, err } @@ -48,20 +50,20 @@ func NewUnionIter(dirtyIt Iterator, snapshotIt Iterator, reverse bool) (*UnionIt } // dirtyNext makes iter.dirtyIt go and update valid status. -func (iter *UnionIter) dirtyNext() error { - err := iter.dirtyIt.Next() +func (iter *UnionIter) dirtyNext(ctx context.Context) error { + err := iter.dirtyIt.Next(ctx) iter.dirtyValid = iter.dirtyIt.Valid() return err } // snapshotNext makes iter.snapshotIt go and update valid status. -func (iter *UnionIter) snapshotNext() error { - err := iter.snapshotIt.Next() +func (iter *UnionIter) snapshotNext(ctx context.Context) error { + err := iter.snapshotIt.Next(ctx) iter.snapshotValid = iter.snapshotIt.Valid() return err } -func (iter *UnionIter) updateCur() error { +func (iter *UnionIter) updateCur(ctx context.Context) error { iter.isValid = true for { if !iter.dirtyValid && !iter.snapshotValid { @@ -78,7 +80,7 @@ func (iter *UnionIter) updateCur() error { iter.curIsDirty = true // if delete it if len(iter.dirtyIt.Value()) == 0 { - if err := iter.dirtyNext(); err != nil { + if err := iter.dirtyNext(ctx); err != nil { return err } continue @@ -99,15 +101,15 @@ func (iter *UnionIter) updateCur() error { if len(iter.dirtyIt.Value()) == 0 { // snapshot has a record, but txn says we have deleted it // just go next - if err := iter.dirtyNext(); err != nil { + if err := iter.dirtyNext(ctx); err != nil { return err } - if err := iter.snapshotNext(); err != nil { + if err := iter.snapshotNext(ctx); err != nil { return err } continue } - if err := iter.snapshotNext(); err != nil { + if err := iter.snapshotNext(ctx); err != nil { return err } iter.curIsDirty = true @@ -121,7 +123,7 @@ func (iter *UnionIter) updateCur() error { if len(iter.dirtyIt.Value()) == 0 { log.Warnf("[kv] delete a record not exists? k = %q", iter.dirtyIt.Key()) // jump over this deletion - if err := iter.dirtyNext(); err != nil { + if err := iter.dirtyNext(ctx); err != nil { return err } continue @@ -135,17 +137,17 @@ func (iter *UnionIter) updateCur() error { } // Next implements the Iterator Next interface. -func (iter *UnionIter) Next() error { +func (iter *UnionIter) Next(ctx context.Context) error { var err error if !iter.curIsDirty { - err = iter.snapshotNext() + err = iter.snapshotNext(ctx) } else { - err = iter.dirtyNext() + err = iter.dirtyNext(ctx) } if err != nil { return err } - return iter.updateCur() + return iter.updateCur(ctx) } // Value implements the Iterator Value interface. diff --git a/txnkv/kv/union_store.go b/txnkv/kv/union_store.go index 4045a422..e08c9bee 100644 --- a/txnkv/kv/union_store.go +++ b/txnkv/kv/union_store.go @@ -14,6 +14,8 @@ package kv import ( + "context" + "github.com/tikv/client-go/config" "github.com/tikv/client-go/key" ) @@ -89,7 +91,7 @@ func (it invalidIterator) Valid() bool { return false } -func (it invalidIterator) Next() error { +func (it invalidIterator) Next(context.Context) error { return nil } @@ -110,12 +112,12 @@ type lazyMemBuffer struct { conf *config.Txn } -func (lmb *lazyMemBuffer) Get(k key.Key) ([]byte, error) { +func (lmb *lazyMemBuffer) Get(ctx context.Context, k key.Key) ([]byte, error) { if lmb.mb == nil { return nil, ErrNotExist } - return lmb.mb.Get(k) + return lmb.mb.Get(ctx, k) } func (lmb *lazyMemBuffer) Set(key key.Key, value []byte) error { @@ -134,18 +136,18 @@ func (lmb *lazyMemBuffer) Delete(k key.Key) error { return lmb.mb.Delete(k) } -func (lmb *lazyMemBuffer) Iter(k key.Key, upperBound key.Key) (Iterator, error) { +func (lmb *lazyMemBuffer) Iter(ctx context.Context, k key.Key, upperBound key.Key) (Iterator, error) { if lmb.mb == nil { return invalidIterator{}, nil } - return lmb.mb.Iter(k, upperBound) + return lmb.mb.Iter(ctx, k, upperBound) } -func (lmb *lazyMemBuffer) IterReverse(k key.Key) (Iterator, error) { +func (lmb *lazyMemBuffer) IterReverse(ctx context.Context, k key.Key) (Iterator, error) { if lmb.mb == nil { return invalidIterator{}, nil } - return lmb.mb.IterReverse(k) + return lmb.mb.IterReverse(ctx, k) } func (lmb *lazyMemBuffer) Size() int { @@ -173,8 +175,8 @@ func (lmb *lazyMemBuffer) SetCap(cap int) { } // Get implements the Retriever interface. -func (us *unionStore) Get(k key.Key) ([]byte, error) { - v, err := us.MemBuffer.Get(k) +func (us *unionStore) Get(ctx context.Context, k key.Key) ([]byte, error) { + v, err := us.MemBuffer.Get(ctx, k) if IsErrNotFound(err) { if _, ok := us.opts.Get(PresumeKeyNotExists); ok { e, ok := us.opts.Get(PresumeKeyNotExistsError) @@ -185,7 +187,7 @@ func (us *unionStore) Get(k key.Key) ([]byte, error) { } return nil, ErrNotExist } - v, err = us.BufferStore.r.Get(k) + v, err = us.BufferStore.r.Get(ctx, k) } if err != nil { return v, err diff --git a/txnkv/kv/union_store_test.go b/txnkv/kv/union_store_test.go index 1161a981..54db8898 100644 --- a/txnkv/kv/union_store_test.go +++ b/txnkv/kv/union_store_test.go @@ -14,6 +14,8 @@ package kv import ( + "context" + . "github.com/pingcap/check" "github.com/pkg/errors" "github.com/tikv/client-go/config" @@ -34,11 +36,11 @@ func (s *testUnionStoreSuite) SetUpTest(c *C) { func (s *testUnionStoreSuite) TestGetSet(c *C) { s.store.Set([]byte("1"), []byte("1")) - v, err := s.us.Get([]byte("1")) + v, err := s.us.Get(context.TODO(), []byte("1")) c.Assert(err, IsNil) c.Assert(v, BytesEquals, []byte("1")) s.us.Set([]byte("1"), []byte("2")) - v, err = s.us.Get([]byte("1")) + v, err = s.us.Get(context.TODO(), []byte("1")) c.Assert(err, IsNil) c.Assert(v, BytesEquals, []byte("2")) } @@ -47,11 +49,11 @@ func (s *testUnionStoreSuite) TestDelete(c *C) { s.store.Set([]byte("1"), []byte("1")) err := s.us.Delete([]byte("1")) c.Assert(err, IsNil) - _, err = s.us.Get([]byte("1")) + _, err = s.us.Get(context.TODO(), []byte("1")) c.Assert(IsErrNotFound(err), IsTrue) s.us.Set([]byte("1"), []byte("2")) - v, err := s.us.Get([]byte("1")) + v, err := s.us.Get(context.TODO(), []byte("1")) c.Assert(err, IsNil) c.Assert(v, BytesEquals, []byte("2")) } @@ -61,21 +63,21 @@ func (s *testUnionStoreSuite) TestSeek(c *C) { s.store.Set([]byte("2"), []byte("2")) s.store.Set([]byte("3"), []byte("3")) - iter, err := s.us.Iter(nil, nil) + iter, err := s.us.Iter(context.TODO(), nil, nil) c.Assert(err, IsNil) checkIterator(c, iter, [][]byte{[]byte("1"), []byte("2"), []byte("3")}, [][]byte{[]byte("1"), []byte("2"), []byte("3")}) - iter, err = s.us.Iter([]byte("2"), nil) + iter, err = s.us.Iter(context.TODO(), []byte("2"), nil) c.Assert(err, IsNil) checkIterator(c, iter, [][]byte{[]byte("2"), []byte("3")}, [][]byte{[]byte("2"), []byte("3")}) s.us.Set([]byte("4"), []byte("4")) - iter, err = s.us.Iter([]byte("2"), nil) + iter, err = s.us.Iter(context.TODO(), []byte("2"), nil) c.Assert(err, IsNil) checkIterator(c, iter, [][]byte{[]byte("2"), []byte("3"), []byte("4")}, [][]byte{[]byte("2"), []byte("3"), []byte("4")}) s.us.Delete([]byte("3")) - iter, err = s.us.Iter([]byte("2"), nil) + iter, err = s.us.Iter(context.TODO(), []byte("2"), nil) c.Assert(err, IsNil) checkIterator(c, iter, [][]byte{[]byte("2"), []byte("4")}, [][]byte{[]byte("2"), []byte("4")}) } @@ -86,21 +88,21 @@ func (s *testUnionStoreSuite) TestIterReverse(c *C) { s.store.Set([]byte("2"), []byte("2")) s.store.Set([]byte("3"), []byte("3")) - iter, err := s.us.IterReverse(nil) + iter, err := s.us.IterReverse(context.TODO(), nil) c.Assert(err, IsNil) checkIterator(c, iter, [][]byte{[]byte("3"), []byte("2"), []byte("1")}, [][]byte{[]byte("3"), []byte("2"), []byte("1")}) - iter, err = s.us.IterReverse([]byte("3")) + iter, err = s.us.IterReverse(context.TODO(), []byte("3")) c.Assert(err, IsNil) checkIterator(c, iter, [][]byte{[]byte("2"), []byte("1")}, [][]byte{[]byte("2"), []byte("1")}) s.us.Set([]byte("0"), []byte("0")) - iter, err = s.us.IterReverse([]byte("3")) + iter, err = s.us.IterReverse(context.TODO(), []byte("3")) c.Assert(err, IsNil) checkIterator(c, iter, [][]byte{[]byte("2"), []byte("1"), []byte("0")}, [][]byte{[]byte("2"), []byte("1"), []byte("0")}) s.us.Delete([]byte("1")) - iter, err = s.us.IterReverse([]byte("3")) + iter, err = s.us.IterReverse(context.TODO(), []byte("3")) c.Assert(err, IsNil) checkIterator(c, iter, [][]byte{[]byte("2"), []byte("0")}, [][]byte{[]byte("2"), []byte("0")}) } @@ -110,13 +112,13 @@ func (s *testUnionStoreSuite) TestLazyConditionCheck(c *C) { s.store.Set([]byte("1"), []byte("1")) s.store.Set([]byte("2"), []byte("2")) - v, err := s.us.Get([]byte("1")) + v, err := s.us.Get(context.TODO(), []byte("1")) c.Assert(err, IsNil) c.Assert(v, BytesEquals, []byte("1")) s.us.SetOption(PresumeKeyNotExists, nil) s.us.SetOption(PresumeKeyNotExistsError, ErrNotExist) - _, err = s.us.Get([]byte("2")) + _, err = s.us.Get(context.TODO(), []byte("2")) c.Assert(errors.Cause(err) == ErrNotExist, IsTrue, Commentf("err %v", err)) condionPair1 := s.us.LookupConditionPair([]byte("1")) @@ -138,7 +140,7 @@ func checkIterator(c *C, iter Iterator, keys [][]byte, values [][]byte) { c.Assert(iter.Valid(), IsTrue) c.Assert([]byte(iter.Key()), BytesEquals, k) c.Assert(iter.Value(), BytesEquals, v) - c.Assert(iter.Next(), IsNil) + c.Assert(iter.Next(context.TODO()), IsNil) } c.Assert(iter.Valid(), IsFalse) } diff --git a/txnkv/oracle/oracles/pd.go b/txnkv/oracle/oracles/pd.go index ad161d8d..eb34d345 100644 --- a/txnkv/oracle/oracles/pd.go +++ b/txnkv/oracle/oracles/pd.go @@ -18,7 +18,7 @@ import ( "sync/atomic" "time" - "github.com/pingcap/pd/client" + pd "github.com/pingcap/pd/client" log "github.com/sirupsen/logrus" "github.com/tikv/client-go/config" "github.com/tikv/client-go/metrics" diff --git a/txnkv/store/lock_resolver.go b/txnkv/store/lock_resolver.go index 9ae8df94..00f09551 100644 --- a/txnkv/store/lock_resolver.go +++ b/txnkv/store/lock_resolver.go @@ -57,8 +57,8 @@ var _ = NewLockResolver // NewLockResolver creates a LockResolver. // It is exported for other pkg to use. For instance, binlog service needs // to determine a transaction's commit state. -func NewLockResolver(etcdAddrs []string, conf config.Config) (*LockResolver, error) { - s, err := NewStore(etcdAddrs, conf) +func NewLockResolver(ctx context.Context, etcdAddrs []string, conf config.Config) (*LockResolver, error) { + s, err := NewStore(ctx, etcdAddrs, conf) if err != nil { return nil, err } @@ -259,8 +259,8 @@ func (lr *LockResolver) ResolveLocks(bo *retry.Backoffer, locks []*Lock) (ok boo // If the primary key is still locked, it will launch a Rollback to abort it. // To avoid unnecessarily aborting too many txns, it is wiser to wait a few // seconds before calling it after Prewrite. -func (lr *LockResolver) GetTxnStatus(txnID uint64, primary []byte) (TxnStatus, error) { - bo := retry.NewBackoffer(context.Background(), retry.CleanupMaxBackoff) +func (lr *LockResolver) GetTxnStatus(ctx context.Context, txnID uint64, primary []byte) (TxnStatus, error) { + bo := retry.NewBackoffer(ctx, retry.CleanupMaxBackoff) return lr.getTxnStatus(bo, txnID, primary) } diff --git a/txnkv/store/scan.go b/txnkv/store/scan.go index 5d2b8660..9ee9e449 100644 --- a/txnkv/store/scan.go +++ b/txnkv/store/scan.go @@ -40,7 +40,7 @@ type Scanner struct { eof bool } -func newScanner(snapshot *TiKVSnapshot, startKey []byte, endKey []byte, batchSize int) (*Scanner, error) { +func newScanner(ctx context.Context, snapshot *TiKVSnapshot, startKey []byte, endKey []byte, batchSize int) (*Scanner, error) { // It must be > 1. Otherwise scanner won't skipFirst. if batchSize <= 1 { batchSize = snapshot.conf.Txn.ScanBatchSize @@ -53,7 +53,7 @@ func newScanner(snapshot *TiKVSnapshot, startKey []byte, endKey []byte, batchSiz nextStartKey: startKey, endKey: endKey, } - err := scanner.Next() + err := scanner.Next(ctx) if kv.IsErrNotFound(err) { return scanner, nil } @@ -82,8 +82,8 @@ func (s *Scanner) Value() []byte { } // Next return next element. -func (s *Scanner) Next() error { - bo := retry.NewBackoffer(context.Background(), retry.ScannerNextMaxBackoff) +func (s *Scanner) Next(ctx context.Context) error { + bo := retry.NewBackoffer(ctx, retry.ScannerNextMaxBackoff) if !s.valid { return errors.New("scanner iterator is invalid") } diff --git a/txnkv/store/snapshot.go b/txnkv/store/snapshot.go index f7e60d60..071acb09 100644 --- a/txnkv/store/snapshot.go +++ b/txnkv/store/snapshot.go @@ -55,7 +55,7 @@ func newTiKVSnapshot(store *TiKVStore, ts uint64) *TiKVSnapshot { // BatchGet gets all the keys' value from kv-server and returns a map contains key/value pairs. // The map will not contain nonexistent keys. -func (s *TiKVSnapshot) BatchGet(keys []key.Key) (map[string][]byte, error) { +func (s *TiKVSnapshot) BatchGet(ctx context.Context, keys []key.Key) (map[string][]byte, error) { m := make(map[string][]byte) if len(keys) == 0 { return m, nil @@ -66,7 +66,7 @@ func (s *TiKVSnapshot) BatchGet(keys []key.Key) (map[string][]byte, error) { // We want [][]byte instead of []key.Key, use some magic to save memory. bytesKeys := *(*[][]byte)(unsafe.Pointer(&keys)) - bo := retry.NewBackoffer(context.Background(), retry.BatchGetMaxBackoff) + bo := retry.NewBackoffer(ctx, retry.BatchGetMaxBackoff) // Create a map to collect key-values from region servers. var mu sync.Mutex @@ -198,8 +198,8 @@ func (s *TiKVSnapshot) batchGetSingleRegion(bo *retry.Backoffer, batch batchKeys } // Get gets the value for key k from snapshot. -func (s *TiKVSnapshot) Get(k key.Key) ([]byte, error) { - val, err := s.get(retry.NewBackoffer(context.Background(), retry.GetMaxBackoff), k) +func (s *TiKVSnapshot) Get(ctx context.Context, k key.Key) ([]byte, error) { + val, err := s.get(retry.NewBackoffer(ctx, retry.GetMaxBackoff), k) if err != nil { return nil, err } @@ -270,13 +270,13 @@ func (s *TiKVSnapshot) get(bo *retry.Backoffer, k key.Key) ([]byte, error) { } // Iter returns a list of key-value pair after `k`. -func (s *TiKVSnapshot) Iter(k key.Key, upperBound key.Key) (kv.Iterator, error) { - scanner, err := newScanner(s, k, upperBound, s.conf.Txn.ScanBatchSize) +func (s *TiKVSnapshot) Iter(ctx context.Context, k key.Key, upperBound key.Key) (kv.Iterator, error) { + scanner, err := newScanner(ctx, s, k, upperBound, s.conf.Txn.ScanBatchSize) return scanner, err } // IterReverse creates a reversed Iterator positioned on the first entry which key is less than k. -func (s *TiKVSnapshot) IterReverse(k key.Key) (kv.Iterator, error) { +func (s *TiKVSnapshot) IterReverse(ctx context.Context, k key.Key) (kv.Iterator, error) { return nil, ErrNotImplemented } diff --git a/txnkv/store/split_region.go b/txnkv/store/split_region.go index 5c8ae8f8..1dc74a22 100644 --- a/txnkv/store/split_region.go +++ b/txnkv/store/split_region.go @@ -27,9 +27,9 @@ import ( // SplitRegion splits the region contains splitKey into 2 regions: [start, // splitKey) and [splitKey, end). -func SplitRegion(store *TiKVStore, splitKey key.Key) error { +func SplitRegion(ctx context.Context, store *TiKVStore, splitKey key.Key) error { log.Infof("start split_region at %q", splitKey) - bo := retry.NewBackoffer(context.Background(), retry.SplitRegionBackoff) + bo := retry.NewBackoffer(ctx, retry.SplitRegionBackoff) sender := rpc.NewRegionRequestSender(store.GetRegionCache(), store.GetRPCClient()) req := &rpc.Request{ Type: rpc.CmdSplitRegion, diff --git a/txnkv/store/store.go b/txnkv/store/store.go index 2f1721f1..ab9019bc 100644 --- a/txnkv/store/store.go +++ b/txnkv/store/store.go @@ -55,7 +55,7 @@ type TiKVStore struct { } // NewStore creates a TiKVStore instance. -func NewStore(pdAddrs []string, conf config.Config) (*TiKVStore, error) { +func NewStore(ctx context.Context, pdAddrs []string, conf config.Config) (*TiKVStore, error) { pdCli, err := pd.NewClient(pdAddrs, pd.SecurityOption{ CAPath: conf.RPC.Security.SSLCA, CertPath: conf.RPC.Security.SSLCert, @@ -82,7 +82,7 @@ func NewStore(pdAddrs []string, conf config.Config) (*TiKVStore, error) { return nil, err } - clusterID := pdCli.GetClusterID(context.TODO()) + clusterID := pdCli.GetClusterID(ctx) store := &TiKVStore{ conf: &conf, diff --git a/txnkv/txn.go b/txnkv/txn.go index bd304a2e..4225c714 100644 --- a/txnkv/txn.go +++ b/txnkv/txn.go @@ -57,12 +57,12 @@ func newTransaction(tikvStore *store.TiKVStore, ts uint64) *Transaction { } // Get implements transaction interface. -func (txn *Transaction) Get(k key.Key) ([]byte, error) { +func (txn *Transaction) Get(ctx context.Context, k key.Key) ([]byte, error) { metrics.TxnCmdCounter.WithLabelValues("get").Inc() start := time.Now() defer func() { metrics.TxnCmdHistogram.WithLabelValues("get").Observe(time.Since(start).Seconds()) }() - ret, err := txn.us.Get(k) + ret, err := txn.us.Get(ctx, k) if kv.IsErrNotFound(err) { return nil, err } @@ -79,14 +79,14 @@ func (txn *Transaction) Get(k key.Key) ([]byte, error) { } // BatchGet gets a batch of values from TiKV server. -func (txn *Transaction) BatchGet(keys []key.Key) (map[string][]byte, error) { +func (txn *Transaction) BatchGet(ctx context.Context, keys []key.Key) (map[string][]byte, error) { if txn.IsReadOnly() { - return txn.snapshot.BatchGet(keys) + return txn.snapshot.BatchGet(ctx, keys) } bufferValues := make([][]byte, len(keys)) shrinkKeys := make([]key.Key, 0, len(keys)) for i, key := range keys { - val, err := txn.us.GetMemBuffer().Get(key) + val, err := txn.us.GetMemBuffer().Get(ctx, key) if kv.IsErrNotFound(err) { shrinkKeys = append(shrinkKeys, key) continue @@ -98,7 +98,7 @@ func (txn *Transaction) BatchGet(keys []key.Key) (map[string][]byte, error) { bufferValues[i] = val } } - storageValues, err := txn.snapshot.BatchGet(shrinkKeys) + storageValues, err := txn.snapshot.BatchGet(ctx, shrinkKeys) if err != nil { return nil, err } @@ -125,23 +125,23 @@ func (txn *Transaction) String() string { // If such entry is not found, it returns an invalid Iterator with no error. // It yields only keys that < upperBound. If upperBound is nil, it means the upperBound is unbounded. // The Iterator must be closed after use. -func (txn *Transaction) Iter(k key.Key, upperBound key.Key) (kv.Iterator, error) { +func (txn *Transaction) Iter(ctx context.Context, k key.Key, upperBound key.Key) (kv.Iterator, error) { metrics.TxnCmdCounter.WithLabelValues("seek").Inc() start := time.Now() defer func() { metrics.TxnCmdHistogram.WithLabelValues("seek").Observe(time.Since(start).Seconds()) }() - return txn.us.Iter(k, upperBound) + return txn.us.Iter(ctx, k, upperBound) } // IterReverse creates a reversed Iterator positioned on the first entry which key is less than k. -func (txn *Transaction) IterReverse(k key.Key) (kv.Iterator, error) { +func (txn *Transaction) IterReverse(ctx context.Context, k key.Key) (kv.Iterator, error) { metrics.TxnCmdCounter.WithLabelValues("seek_reverse").Inc() start := time.Now() defer func() { metrics.TxnCmdHistogram.WithLabelValues("seek_reverse").Observe(time.Since(start).Seconds()) }() - return txn.us.IterReverse(k) + return txn.us.IterReverse(ctx, k) } // IsReadOnly returns if there are pending key-value to commit in the transaction.