mirror of https://github.com/tikv/client-go.git
client: replace pingcap/check with testify (#166)
This commit is contained in:
parent
09ff177ac4
commit
901066f801
|
|
@ -36,25 +36,24 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
. "github.com/pingcap/check"
|
|
||||||
"github.com/pingcap/failpoint"
|
"github.com/pingcap/failpoint"
|
||||||
"github.com/pingcap/kvproto/pkg/kvrpcpb"
|
"github.com/pingcap/kvproto/pkg/kvrpcpb"
|
||||||
"github.com/pingcap/kvproto/pkg/tikvpb"
|
"github.com/pingcap/kvproto/pkg/tikvpb"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tikv/client-go/v2/config"
|
"github.com/tikv/client-go/v2/config"
|
||||||
"github.com/tikv/client-go/v2/tikvrpc"
|
"github.com/tikv/client-go/v2/tikvrpc"
|
||||||
)
|
)
|
||||||
|
|
||||||
type testClientFailSuite struct {
|
func TestPanicInRecvLoop(t *testing.T) {
|
||||||
}
|
require.Nil(t, failpoint.Enable("tikvclient/panicInFailPendingRequests", `panic`))
|
||||||
|
require.Nil(t, failpoint.Enable("tikvclient/gotErrorInRecvLoop", `return("0")`))
|
||||||
func (s *testClientFailSuite) TestPanicInRecvLoop(c *C) {
|
|
||||||
c.Assert(failpoint.Enable("tikvclient/panicInFailPendingRequests", `panic`), IsNil)
|
|
||||||
c.Assert(failpoint.Enable("tikvclient/gotErrorInRecvLoop", `return("0")`), IsNil)
|
|
||||||
|
|
||||||
server, port := startMockTikvService()
|
server, port := startMockTikvService()
|
||||||
c.Assert(port > 0, IsTrue)
|
require.True(t, port > 0)
|
||||||
defer server.Stop()
|
defer server.Stop()
|
||||||
|
|
||||||
addr := fmt.Sprintf("%s:%d", "127.0.0.1", port)
|
addr := fmt.Sprintf("%s:%d", "127.0.0.1", port)
|
||||||
|
|
@ -64,24 +63,24 @@ func (s *testClientFailSuite) TestPanicInRecvLoop(c *C) {
|
||||||
|
|
||||||
// Start batchRecvLoop, and it should panic in `failPendingRequests`.
|
// Start batchRecvLoop, and it should panic in `failPendingRequests`.
|
||||||
_, err := rpcClient.getConnArray(addr, true, func(cfg *config.TiKVClient) { cfg.GrpcConnectionCount = 1 })
|
_, err := rpcClient.getConnArray(addr, true, func(cfg *config.TiKVClient) { cfg.GrpcConnectionCount = 1 })
|
||||||
c.Assert(err, IsNil, Commentf("cannot establish local connection due to env problems(e.g. heavy load in test machine), please retry again"))
|
assert.Nil(t, err, "cannot establish local connection due to env problems(e.g. heavy load in test machine), please retry again")
|
||||||
|
|
||||||
req := tikvrpc.NewRequest(tikvrpc.CmdEmpty, &tikvpb.BatchCommandsEmptyRequest{})
|
req := tikvrpc.NewRequest(tikvrpc.CmdEmpty, &tikvpb.BatchCommandsEmptyRequest{})
|
||||||
_, err = rpcClient.SendRequest(context.Background(), addr, req, time.Second/2)
|
_, err = rpcClient.SendRequest(context.Background(), addr, req, time.Second/2)
|
||||||
c.Assert(err, NotNil)
|
assert.NotNil(t, err)
|
||||||
|
|
||||||
c.Assert(failpoint.Disable("tikvclient/gotErrorInRecvLoop"), IsNil)
|
require.Nil(t, failpoint.Disable("tikvclient/gotErrorInRecvLoop"))
|
||||||
c.Assert(failpoint.Disable("tikvclient/panicInFailPendingRequests"), IsNil)
|
require.Nil(t, failpoint.Disable("tikvclient/panicInFailPendingRequests"))
|
||||||
time.Sleep(time.Second * 2)
|
time.Sleep(time.Second * 2)
|
||||||
|
|
||||||
req = tikvrpc.NewRequest(tikvrpc.CmdEmpty, &tikvpb.BatchCommandsEmptyRequest{})
|
req = tikvrpc.NewRequest(tikvrpc.CmdEmpty, &tikvpb.BatchCommandsEmptyRequest{})
|
||||||
_, err = rpcClient.SendRequest(context.Background(), addr, req, time.Second*4)
|
_, err = rpcClient.SendRequest(context.Background(), addr, req, time.Second*4)
|
||||||
c.Assert(err, IsNil)
|
assert.Nil(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *testClientFailSuite) TestRecvErrorInMultipleRecvLoops(c *C) {
|
func TestRecvErrorInMultipleRecvLoops(t *testing.T) {
|
||||||
server, port := startMockTikvService()
|
server, port := startMockTikvService()
|
||||||
c.Assert(port > 0, IsTrue)
|
require.True(t, port > 0)
|
||||||
defer server.Stop()
|
defer server.Stop()
|
||||||
addr := fmt.Sprintf("%s:%d", "127.0.0.1", port)
|
addr := fmt.Sprintf("%s:%d", "127.0.0.1", port)
|
||||||
|
|
||||||
|
|
@ -100,20 +99,20 @@ func (s *testClientFailSuite) TestRecvErrorInMultipleRecvLoops(c *C) {
|
||||||
for _, forwardedHost := range forwardedHosts {
|
for _, forwardedHost := range forwardedHosts {
|
||||||
prewriteReq.ForwardedHost = forwardedHost
|
prewriteReq.ForwardedHost = forwardedHost
|
||||||
_, err := rpcClient.SendRequest(context.Background(), addr, prewriteReq, 10*time.Second)
|
_, err := rpcClient.SendRequest(context.Background(), addr, prewriteReq, 10*time.Second)
|
||||||
c.Assert(err, IsNil)
|
assert.Nil(t, err)
|
||||||
}
|
}
|
||||||
connArray, err := rpcClient.getConnArray(addr, true)
|
connArray, err := rpcClient.getConnArray(addr, true)
|
||||||
c.Assert(connArray, NotNil)
|
assert.NotNil(t, connArray)
|
||||||
c.Assert(err, IsNil)
|
assert.Nil(t, err)
|
||||||
batchConn := connArray.batchConn
|
batchConn := connArray.batchConn
|
||||||
c.Assert(batchConn, NotNil)
|
assert.NotNil(t, batchConn)
|
||||||
c.Assert(len(batchConn.batchCommandsClients), Equals, 1)
|
assert.Equal(t, len(batchConn.batchCommandsClients), 1)
|
||||||
batchClient := batchConn.batchCommandsClients[0]
|
batchClient := batchConn.batchCommandsClients[0]
|
||||||
c.Assert(batchClient.client, NotNil)
|
assert.NotNil(t, batchClient.client)
|
||||||
c.Assert(batchClient.client.forwardedHost, Equals, "")
|
assert.Equal(t, batchClient.client.forwardedHost, "")
|
||||||
c.Assert(len(batchClient.forwardedClients), Equals, 3)
|
assert.Equal(t, len(batchClient.forwardedClients), 3)
|
||||||
for _, forwardedHosts := range forwardedHosts[1:] {
|
for _, forwardedHosts := range forwardedHosts[1:] {
|
||||||
c.Assert(batchClient.forwardedClients[forwardedHosts].forwardedHost, Equals, forwardedHosts)
|
assert.Equal(t, batchClient.forwardedClients[forwardedHosts].forwardedHost, forwardedHosts)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save all streams
|
// Save all streams
|
||||||
|
|
@ -127,12 +126,12 @@ func (s *testClientFailSuite) TestRecvErrorInMultipleRecvLoops(c *C) {
|
||||||
fp := "tikvclient/gotErrorInRecvLoop"
|
fp := "tikvclient/gotErrorInRecvLoop"
|
||||||
// Send a request to each stream to trigger reconnection.
|
// Send a request to each stream to trigger reconnection.
|
||||||
for _, forwardedHost := range forwardedHosts {
|
for _, forwardedHost := range forwardedHosts {
|
||||||
c.Assert(failpoint.Enable(fp, `1*return("0")`), IsNil)
|
require.Nil(t, failpoint.Enable(fp, `1*return("0")`))
|
||||||
prewriteReq.ForwardedHost = forwardedHost
|
prewriteReq.ForwardedHost = forwardedHost
|
||||||
_, err := rpcClient.SendRequest(context.Background(), addr, prewriteReq, 10*time.Second)
|
_, err := rpcClient.SendRequest(context.Background(), addr, prewriteReq, 10*time.Second)
|
||||||
c.Assert(err, IsNil)
|
assert.Nil(t, err)
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
c.Assert(failpoint.Disable(fp), IsNil)
|
assert.Nil(t, failpoint.Disable(fp))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait for finishing reconnection.
|
// Wait for finishing reconnection.
|
||||||
|
|
@ -150,14 +149,14 @@ func (s *testClientFailSuite) TestRecvErrorInMultipleRecvLoops(c *C) {
|
||||||
for _, forwardedHost := range forwardedHosts {
|
for _, forwardedHost := range forwardedHosts {
|
||||||
prewriteReq.ForwardedHost = forwardedHost
|
prewriteReq.ForwardedHost = forwardedHost
|
||||||
_, err := rpcClient.SendRequest(context.Background(), addr, prewriteReq, 10*time.Second)
|
_, err := rpcClient.SendRequest(context.Background(), addr, prewriteReq, 10*time.Second)
|
||||||
c.Assert(err, IsNil)
|
assert.Nil(t, err)
|
||||||
}
|
}
|
||||||
// Should only reconnect once.
|
// Should only reconnect once.
|
||||||
c.Assert(atomic.LoadUint64(&batchClient.epoch), Equals, epoch+1)
|
assert.Equal(t, atomic.LoadUint64(&batchClient.epoch), epoch+1)
|
||||||
// All streams are refreshed.
|
// All streams are refreshed.
|
||||||
c.Assert(batchClient.client.Tikv_BatchCommandsClient, Not(Equals), clientSave)
|
assert.NotEqual(t, batchClient.client.Tikv_BatchCommandsClient, clientSave)
|
||||||
c.Assert(len(batchClient.forwardedClients), Equals, len(forwardedClientsSave))
|
assert.Equal(t, len(batchClient.forwardedClients), len(forwardedClientsSave))
|
||||||
for host, clientSave := range forwardedClientsSave {
|
for host, clientSave := range forwardedClientsSave {
|
||||||
c.Assert(batchClient.forwardedClients[host].Tikv_BatchCommandsClient, Not(Equals), clientSave)
|
assert.NotEqual(t, batchClient.forwardedClients[host].Tikv_BatchCommandsClient, clientSave)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -40,33 +40,19 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
. "github.com/pingcap/check"
|
|
||||||
"github.com/pingcap/errors"
|
"github.com/pingcap/errors"
|
||||||
"github.com/pingcap/kvproto/pkg/coprocessor"
|
"github.com/pingcap/kvproto/pkg/coprocessor"
|
||||||
"github.com/pingcap/kvproto/pkg/kvrpcpb"
|
"github.com/pingcap/kvproto/pkg/kvrpcpb"
|
||||||
"github.com/pingcap/kvproto/pkg/metapb"
|
"github.com/pingcap/kvproto/pkg/metapb"
|
||||||
"github.com/pingcap/kvproto/pkg/tikvpb"
|
"github.com/pingcap/kvproto/pkg/tikvpb"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tikv/client-go/v2/config"
|
"github.com/tikv/client-go/v2/config"
|
||||||
"github.com/tikv/client-go/v2/tikvrpc"
|
"github.com/tikv/client-go/v2/tikvrpc"
|
||||||
"google.golang.org/grpc/metadata"
|
"google.golang.org/grpc/metadata"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestT(t *testing.T) {
|
func TestConn(t *testing.T) {
|
||||||
CustomVerboseFlag = true
|
|
||||||
TestingT(t)
|
|
||||||
}
|
|
||||||
|
|
||||||
type testClientSuite struct {
|
|
||||||
}
|
|
||||||
|
|
||||||
type testClientSerialSuite struct {
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ = SerialSuites(&testClientSuite{})
|
|
||||||
var _ = SerialSuites(&testClientFailSuite{})
|
|
||||||
var _ = SerialSuites(&testClientSerialSuite{})
|
|
||||||
|
|
||||||
func (s *testClientSerialSuite) TestConn(c *C) {
|
|
||||||
defer config.UpdateGlobal(func(conf *config.Config) {
|
defer config.UpdateGlobal(func(conf *config.Config) {
|
||||||
conf.TiKVClient.MaxBatchSize = 0
|
conf.TiKVClient.MaxBatchSize = 0
|
||||||
})()
|
})()
|
||||||
|
|
@ -75,39 +61,39 @@ func (s *testClientSerialSuite) TestConn(c *C) {
|
||||||
|
|
||||||
addr := "127.0.0.1:6379"
|
addr := "127.0.0.1:6379"
|
||||||
conn1, err := client.getConnArray(addr, true)
|
conn1, err := client.getConnArray(addr, true)
|
||||||
c.Assert(err, IsNil)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
conn2, err := client.getConnArray(addr, true)
|
conn2, err := client.getConnArray(addr, true)
|
||||||
c.Assert(err, IsNil)
|
assert.Nil(t, err)
|
||||||
c.Assert(conn2.Get(), Not(Equals), conn1.Get())
|
assert.NotEqual(t, conn2.Get(), conn1.Get())
|
||||||
|
|
||||||
client.Close()
|
client.Close()
|
||||||
conn3, err := client.getConnArray(addr, true)
|
conn3, err := client.getConnArray(addr, true)
|
||||||
c.Assert(err, NotNil)
|
assert.NotNil(t, err)
|
||||||
c.Assert(conn3, IsNil)
|
assert.Nil(t, conn3)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *testClientSuite) TestCancelTimeoutRetErr(c *C) {
|
func TestCancelTimeoutRetErr(t *testing.T) {
|
||||||
req := new(tikvpb.BatchCommandsRequest_Request)
|
req := new(tikvpb.BatchCommandsRequest_Request)
|
||||||
a := newBatchConn(1, 1, nil)
|
a := newBatchConn(1, 1, nil)
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.TODO())
|
ctx, cancel := context.WithCancel(context.TODO())
|
||||||
cancel()
|
cancel()
|
||||||
_, err := sendBatchRequest(ctx, "", "", a, req, 2*time.Second)
|
_, err := sendBatchRequest(ctx, "", "", a, req, 2*time.Second)
|
||||||
c.Assert(errors.Cause(err), Equals, context.Canceled)
|
assert.Equal(t, errors.Cause(err), context.Canceled)
|
||||||
|
|
||||||
_, err = sendBatchRequest(context.Background(), "", "", a, req, 0)
|
_, err = sendBatchRequest(context.Background(), "", "", a, req, 0)
|
||||||
c.Assert(errors.Cause(err), Equals, context.DeadlineExceeded)
|
assert.Equal(t, errors.Cause(err), context.DeadlineExceeded)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *testClientSuite) TestSendWhenReconnect(c *C) {
|
func TestSendWhenReconnect(t *testing.T) {
|
||||||
server, port := startMockTikvService()
|
server, port := startMockTikvService()
|
||||||
c.Assert(port > 0, IsTrue)
|
require.True(t, port > 0)
|
||||||
|
|
||||||
rpcClient := NewRPCClient(config.Security{})
|
rpcClient := NewRPCClient(config.Security{})
|
||||||
addr := fmt.Sprintf("%s:%d", "127.0.0.1", port)
|
addr := fmt.Sprintf("%s:%d", "127.0.0.1", port)
|
||||||
conn, err := rpcClient.getConnArray(addr, true)
|
conn, err := rpcClient.getConnArray(addr, true)
|
||||||
c.Assert(err, IsNil)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
// Suppose all connections are re-establishing.
|
// Suppose all connections are re-establishing.
|
||||||
for _, client := range conn.batchConn.batchCommandsClients {
|
for _, client := range conn.batchConn.batchCommandsClients {
|
||||||
|
|
@ -116,7 +102,7 @@ func (s *testClientSuite) TestSendWhenReconnect(c *C) {
|
||||||
|
|
||||||
req := tikvrpc.NewRequest(tikvrpc.CmdEmpty, &tikvpb.BatchCommandsEmptyRequest{})
|
req := tikvrpc.NewRequest(tikvrpc.CmdEmpty, &tikvpb.BatchCommandsEmptyRequest{})
|
||||||
_, err = rpcClient.SendRequest(context.Background(), addr, req, 100*time.Second)
|
_, err = rpcClient.SendRequest(context.Background(), addr, req, 100*time.Second)
|
||||||
c.Assert(err.Error() == "no available connections", IsTrue)
|
assert.True(t, err.Error() == "no available connections")
|
||||||
conn.Close()
|
conn.Close()
|
||||||
server.Stop()
|
server.Stop()
|
||||||
}
|
}
|
||||||
|
|
@ -137,7 +123,7 @@ func (c *chanClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *testClientSuite) TestCollapseResolveLock(c *C) {
|
func TestCollapseResolveLock(t *testing.T) {
|
||||||
buildResolveLockReq := func(regionID uint64, startTS uint64, commitTS uint64, keys [][]byte) *tikvrpc.Request {
|
buildResolveLockReq := func(regionID uint64, startTS uint64, commitTS uint64, keys [][]byte) *tikvrpc.Request {
|
||||||
region := &metapb.Region{Id: regionID}
|
region := &metapb.Region{Id: regionID}
|
||||||
req := tikvrpc.NewRequest(tikvrpc.CmdResolveLock, &kvrpcpb.ResolveLockRequest{
|
req := tikvrpc.NewRequest(tikvrpc.CmdResolveLock, &kvrpcpb.ResolveLockRequest{
|
||||||
|
|
@ -170,10 +156,10 @@ func (s *testClientSuite) TestCollapseResolveLock(c *C) {
|
||||||
time.Sleep(300 * time.Millisecond)
|
time.Sleep(300 * time.Millisecond)
|
||||||
wg.Done()
|
wg.Done()
|
||||||
req := <-reqCh
|
req := <-reqCh
|
||||||
c.Assert(*req, DeepEquals, *resolveLockReq)
|
assert.Equal(t, *req, *resolveLockReq)
|
||||||
select {
|
select {
|
||||||
case <-reqCh:
|
case <-reqCh:
|
||||||
c.Fatal("fail to collapse ResolveLock")
|
assert.Fail(t, "fail to collapse ResolveLock")
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -186,7 +172,7 @@ func (s *testClientSuite) TestCollapseResolveLock(c *C) {
|
||||||
wg.Done()
|
wg.Done()
|
||||||
for i := 0; i < 2; i++ {
|
for i := 0; i < 2; i++ {
|
||||||
req := <-reqCh
|
req := <-reqCh
|
||||||
c.Assert(*req, DeepEquals, *resolveLockLiteReq)
|
assert.Equal(t, *req, *resolveLockLiteReq)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Don't collapse BatchResolveLock.
|
// Don't collapse BatchResolveLock.
|
||||||
|
|
@ -200,7 +186,7 @@ func (s *testClientSuite) TestCollapseResolveLock(c *C) {
|
||||||
wg.Done()
|
wg.Done()
|
||||||
for i := 0; i < 2; i++ {
|
for i := 0; i < 2; i++ {
|
||||||
req := <-reqCh
|
req := <-reqCh
|
||||||
c.Assert(*req, DeepEquals, *batchResolveLockReq)
|
assert.Equal(t, *req, *batchResolveLockReq)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Mixed
|
// Mixed
|
||||||
|
|
@ -215,14 +201,14 @@ func (s *testClientSuite) TestCollapseResolveLock(c *C) {
|
||||||
}
|
}
|
||||||
select {
|
select {
|
||||||
case <-reqCh:
|
case <-reqCh:
|
||||||
c.Fatal("unexpected request")
|
assert.Fail(t, "unexpected request")
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *testClientSerialSuite) TestForwardMetadataByUnaryCall(c *C) {
|
func TestForwardMetadataByUnaryCall(t *testing.T) {
|
||||||
server, port := startMockTikvService()
|
server, port := startMockTikvService()
|
||||||
c.Assert(port > 0, IsTrue)
|
require.True(t, port > 0)
|
||||||
defer server.Stop()
|
defer server.Stop()
|
||||||
addr := fmt.Sprintf("%s:%d", "127.0.0.1", port)
|
addr := fmt.Sprintf("%s:%d", "127.0.0.1", port)
|
||||||
|
|
||||||
|
|
@ -242,7 +228,7 @@ func (s *testClientSerialSuite) TestForwardMetadataByUnaryCall(c *C) {
|
||||||
md, ok := metadata.FromIncomingContext(ctx)
|
md, ok := metadata.FromIncomingContext(ctx)
|
||||||
if ok {
|
if ok {
|
||||||
vals := md.Get(forwardMetadataKey)
|
vals := md.Get(forwardMetadataKey)
|
||||||
c.Assert(len(vals), Equals, 0)
|
assert.Equal(t, len(vals), 0)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
@ -251,15 +237,15 @@ func (s *testClientSerialSuite) TestForwardMetadataByUnaryCall(c *C) {
|
||||||
prewriteReq := tikvrpc.NewRequest(tikvrpc.CmdPrewrite, &kvrpcpb.PrewriteRequest{})
|
prewriteReq := tikvrpc.NewRequest(tikvrpc.CmdPrewrite, &kvrpcpb.PrewriteRequest{})
|
||||||
for i := 0; i < 3; i++ {
|
for i := 0; i < 3; i++ {
|
||||||
_, err := rpcClient.SendRequest(context.Background(), addr, prewriteReq, 10*time.Second)
|
_, err := rpcClient.SendRequest(context.Background(), addr, prewriteReq, 10*time.Second)
|
||||||
c.Assert(err, IsNil)
|
assert.Nil(t, err)
|
||||||
}
|
}
|
||||||
c.Assert(atomic.LoadUint64(&checkCnt), Equals, uint64(3))
|
assert.Equal(t, atomic.LoadUint64(&checkCnt), uint64(3))
|
||||||
|
|
||||||
// CopStream represents unary-stream call.
|
// CopStream represents unary-stream call.
|
||||||
copStreamReq := tikvrpc.NewRequest(tikvrpc.CmdCopStream, &coprocessor.Request{})
|
copStreamReq := tikvrpc.NewRequest(tikvrpc.CmdCopStream, &coprocessor.Request{})
|
||||||
_, err := rpcClient.SendRequest(context.Background(), addr, copStreamReq, 10*time.Second)
|
_, err := rpcClient.SendRequest(context.Background(), addr, copStreamReq, 10*time.Second)
|
||||||
c.Assert(err, IsNil)
|
assert.Nil(t, err)
|
||||||
c.Assert(atomic.LoadUint64(&checkCnt), Equals, uint64(4))
|
assert.Equal(t, atomic.LoadUint64(&checkCnt), uint64(4))
|
||||||
|
|
||||||
checkCnt = 0
|
checkCnt = 0
|
||||||
forwardedHost := "127.0.0.1:6666"
|
forwardedHost := "127.0.0.1:6666"
|
||||||
|
|
@ -268,29 +254,29 @@ func (s *testClientSerialSuite) TestForwardMetadataByUnaryCall(c *C) {
|
||||||
atomic.AddUint64(&checkCnt, 1)
|
atomic.AddUint64(&checkCnt, 1)
|
||||||
// gRPC may set some metadata by default, e.g. "context-type".
|
// gRPC may set some metadata by default, e.g. "context-type".
|
||||||
md, ok := metadata.FromIncomingContext(ctx)
|
md, ok := metadata.FromIncomingContext(ctx)
|
||||||
c.Assert(ok, IsTrue)
|
assert.True(t, ok)
|
||||||
vals := md.Get(forwardMetadataKey)
|
vals := md.Get(forwardMetadataKey)
|
||||||
c.Assert(vals, DeepEquals, []string{forwardedHost})
|
assert.Equal(t, vals, []string{forwardedHost})
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
prewriteReq.ForwardedHost = forwardedHost
|
prewriteReq.ForwardedHost = forwardedHost
|
||||||
for i := 0; i < 3; i++ {
|
for i := 0; i < 3; i++ {
|
||||||
_, err = rpcClient.SendRequest(context.Background(), addr, prewriteReq, 10*time.Second)
|
_, err = rpcClient.SendRequest(context.Background(), addr, prewriteReq, 10*time.Second)
|
||||||
c.Assert(err, IsNil)
|
assert.Nil(t, err)
|
||||||
}
|
}
|
||||||
// checkCnt should be 3 because we don't use BatchCommands for redirection for now.
|
// checkCnt should be 3 because we don't use BatchCommands for redirection for now.
|
||||||
c.Assert(atomic.LoadUint64(&checkCnt), Equals, uint64(3))
|
assert.Equal(t, atomic.LoadUint64(&checkCnt), uint64(3))
|
||||||
|
|
||||||
copStreamReq.ForwardedHost = forwardedHost
|
copStreamReq.ForwardedHost = forwardedHost
|
||||||
_, err = rpcClient.SendRequest(context.Background(), addr, copStreamReq, 10*time.Second)
|
_, err = rpcClient.SendRequest(context.Background(), addr, copStreamReq, 10*time.Second)
|
||||||
c.Assert(err, IsNil)
|
assert.Nil(t, err)
|
||||||
c.Assert(atomic.LoadUint64(&checkCnt), Equals, uint64(4))
|
assert.Equal(t, atomic.LoadUint64(&checkCnt), uint64(4))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *testClientSerialSuite) TestForwardMetadataByBatchCommands(c *C) {
|
func TestForwardMetadataByBatchCommands(t *testing.T) {
|
||||||
server, port := startMockTikvService()
|
server, port := startMockTikvService()
|
||||||
c.Assert(port > 0, IsTrue)
|
require.True(t, port > 0)
|
||||||
defer server.Stop()
|
defer server.Stop()
|
||||||
addr := fmt.Sprintf("%s:%d", "127.0.0.1", port)
|
addr := fmt.Sprintf("%s:%d", "127.0.0.1", port)
|
||||||
|
|
||||||
|
|
@ -311,12 +297,12 @@ func (s *testClientSerialSuite) TestForwardMetadataByBatchCommands(c *C) {
|
||||||
if forwardedHost == "" {
|
if forwardedHost == "" {
|
||||||
if ok {
|
if ok {
|
||||||
vals := md.Get(forwardMetadataKey)
|
vals := md.Get(forwardMetadataKey)
|
||||||
c.Assert(len(vals), Equals, 0)
|
assert.Equal(t, len(vals), 0)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
c.Assert(ok, IsTrue)
|
assert.True(t, ok)
|
||||||
vals := md.Get(forwardMetadataKey)
|
vals := md.Get(forwardMetadataKey)
|
||||||
c.Assert(vals, DeepEquals, []string{forwardedHost})
|
assert.Equal(t, vals, []string{forwardedHost})
|
||||||
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|
@ -330,10 +316,10 @@ func (s *testClientSerialSuite) TestForwardMetadataByBatchCommands(c *C) {
|
||||||
prewriteReq.ForwardedHost = forwardedHost
|
prewriteReq.ForwardedHost = forwardedHost
|
||||||
for i := 0; i < 3; i++ {
|
for i := 0; i < 3; i++ {
|
||||||
_, err := rpcClient.SendRequest(context.Background(), addr, prewriteReq, 10*time.Second)
|
_, err := rpcClient.SendRequest(context.Background(), addr, prewriteReq, 10*time.Second)
|
||||||
c.Assert(err, IsNil)
|
assert.Nil(t, err)
|
||||||
}
|
}
|
||||||
// checkCnt should be i because there is a stream for each forwardedHost.
|
// checkCnt should be i because there is a stream for each forwardedHost.
|
||||||
c.Assert(atomic.LoadUint64(&checkCnt), Equals, 1+uint64(i))
|
assert.Equal(t, atomic.LoadUint64(&checkCnt), 1+uint64(i))
|
||||||
}
|
}
|
||||||
|
|
||||||
checkCnt = 0
|
checkCnt = 0
|
||||||
|
|
@ -342,18 +328,18 @@ func (s *testClientSerialSuite) TestForwardMetadataByBatchCommands(c *C) {
|
||||||
// Check no corresponding metadata if forwardedHost is empty.
|
// Check no corresponding metadata if forwardedHost is empty.
|
||||||
setCheckHandler("")
|
setCheckHandler("")
|
||||||
_, err := rpcClient.SendRequest(context.Background(), addr, copStreamReq, 10*time.Second)
|
_, err := rpcClient.SendRequest(context.Background(), addr, copStreamReq, 10*time.Second)
|
||||||
c.Assert(err, IsNil)
|
assert.Nil(t, err)
|
||||||
c.Assert(atomic.LoadUint64(&checkCnt), Equals, uint64(1))
|
assert.Equal(t, atomic.LoadUint64(&checkCnt), uint64(1))
|
||||||
|
|
||||||
copStreamReq.ForwardedHost = "127.0.0.1:6666"
|
copStreamReq.ForwardedHost = "127.0.0.1:6666"
|
||||||
// Check the metadata exists.
|
// Check the metadata exists.
|
||||||
setCheckHandler(copStreamReq.ForwardedHost)
|
setCheckHandler(copStreamReq.ForwardedHost)
|
||||||
_, err = rpcClient.SendRequest(context.Background(), addr, copStreamReq, 10*time.Second)
|
_, err = rpcClient.SendRequest(context.Background(), addr, copStreamReq, 10*time.Second)
|
||||||
c.Assert(err, IsNil)
|
assert.Nil(t, err)
|
||||||
c.Assert(atomic.LoadUint64(&checkCnt), Equals, uint64(2))
|
assert.Equal(t, atomic.LoadUint64(&checkCnt), uint64(2))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *testClientSuite) TestBatchCommandsBuilder(c *C) {
|
func TestBatchCommandsBuilder(t *testing.T) {
|
||||||
builder := newBatchCommandsBuilder(128)
|
builder := newBatchCommandsBuilder(128)
|
||||||
|
|
||||||
// Test no forwarding requests.
|
// Test no forwarding requests.
|
||||||
|
|
@ -361,21 +347,21 @@ func (s *testClientSuite) TestBatchCommandsBuilder(c *C) {
|
||||||
req := new(tikvpb.BatchCommandsRequest_Request)
|
req := new(tikvpb.BatchCommandsRequest_Request)
|
||||||
for i := 0; i < 10; i++ {
|
for i := 0; i < 10; i++ {
|
||||||
builder.push(&batchCommandsEntry{req: req})
|
builder.push(&batchCommandsEntry{req: req})
|
||||||
c.Assert(builder.len(), Equals, i+1)
|
assert.Equal(t, builder.len(), i+1)
|
||||||
}
|
}
|
||||||
entryMap := make(map[uint64]*batchCommandsEntry)
|
entryMap := make(map[uint64]*batchCommandsEntry)
|
||||||
batchedReq, forwardingReqs := builder.build(func(id uint64, e *batchCommandsEntry) {
|
batchedReq, forwardingReqs := builder.build(func(id uint64, e *batchCommandsEntry) {
|
||||||
entryMap[id] = e
|
entryMap[id] = e
|
||||||
})
|
})
|
||||||
c.Assert(len(batchedReq.GetRequests()), Equals, 10)
|
assert.Equal(t, len(batchedReq.GetRequests()), 10)
|
||||||
c.Assert(len(batchedReq.GetRequestIds()), Equals, 10)
|
assert.Equal(t, len(batchedReq.GetRequestIds()), 10)
|
||||||
c.Assert(len(entryMap), Equals, 10)
|
assert.Equal(t, len(entryMap), 10)
|
||||||
for i, id := range batchedReq.GetRequestIds() {
|
for i, id := range batchedReq.GetRequestIds() {
|
||||||
c.Assert(id, Equals, uint64(i))
|
assert.Equal(t, id, uint64(i))
|
||||||
c.Assert(entryMap[id].req, Equals, batchedReq.GetRequests()[i])
|
assert.Equal(t, entryMap[id].req, batchedReq.GetRequests()[i])
|
||||||
}
|
}
|
||||||
c.Assert(len(forwardingReqs), Equals, 0)
|
assert.Equal(t, len(forwardingReqs), 0)
|
||||||
c.Assert(builder.idAlloc, Equals, uint64(10))
|
assert.Equal(t, builder.idAlloc, uint64(10))
|
||||||
|
|
||||||
// Test collecting forwarding requests.
|
// Test collecting forwarding requests.
|
||||||
builder.reset()
|
builder.reset()
|
||||||
|
|
@ -393,19 +379,19 @@ func (s *testClientSuite) TestBatchCommandsBuilder(c *C) {
|
||||||
batchedReq, forwardingReqs = builder.build(func(id uint64, e *batchCommandsEntry) {
|
batchedReq, forwardingReqs = builder.build(func(id uint64, e *batchCommandsEntry) {
|
||||||
entryMap[id] = e
|
entryMap[id] = e
|
||||||
})
|
})
|
||||||
c.Assert(len(batchedReq.GetRequests()), Equals, 1)
|
assert.Equal(t, len(batchedReq.GetRequests()), 1)
|
||||||
c.Assert(len(batchedReq.GetRequestIds()), Equals, 1)
|
assert.Equal(t, len(batchedReq.GetRequestIds()), 1)
|
||||||
c.Assert(len(forwardingReqs), Equals, 3)
|
assert.Equal(t, len(forwardingReqs), 3)
|
||||||
for i, host := range forwardedHosts[1:] {
|
for i, host := range forwardedHosts[1:] {
|
||||||
c.Assert(len(forwardingReqs[host].GetRequests()), Equals, i+2)
|
assert.Equal(t, len(forwardingReqs[host].GetRequests()), i+2)
|
||||||
c.Assert(len(forwardingReqs[host].GetRequestIds()), Equals, i+2)
|
assert.Equal(t, len(forwardingReqs[host].GetRequestIds()), i+2)
|
||||||
}
|
}
|
||||||
c.Assert(builder.idAlloc, Equals, uint64(10+builder.len()))
|
assert.Equal(t, builder.idAlloc, uint64(10+builder.len()))
|
||||||
c.Assert(len(entryMap), Equals, builder.len())
|
assert.Equal(t, len(entryMap), builder.len())
|
||||||
for host, forwardingReq := range forwardingReqs {
|
for host, forwardingReq := range forwardingReqs {
|
||||||
for i, id := range forwardingReq.GetRequestIds() {
|
for i, id := range forwardingReq.GetRequestIds() {
|
||||||
c.Assert(entryMap[id].req, Equals, forwardingReq.GetRequests()[i])
|
assert.Equal(t, entryMap[id].req, forwardingReq.GetRequests()[i])
|
||||||
c.Assert(entryMap[id].forwardedHost, Equals, host)
|
assert.Equal(t, entryMap[id].forwardedHost, host)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -425,13 +411,13 @@ func (s *testClientSuite) TestBatchCommandsBuilder(c *C) {
|
||||||
batchedReq, forwardingReqs = builder.build(func(id uint64, e *batchCommandsEntry) {
|
batchedReq, forwardingReqs = builder.build(func(id uint64, e *batchCommandsEntry) {
|
||||||
entryMap[id] = e
|
entryMap[id] = e
|
||||||
})
|
})
|
||||||
c.Assert(len(batchedReq.GetRequests()), Equals, 2)
|
assert.Equal(t, len(batchedReq.GetRequests()), 2)
|
||||||
c.Assert(len(batchedReq.GetRequestIds()), Equals, 2)
|
assert.Equal(t, len(batchedReq.GetRequestIds()), 2)
|
||||||
c.Assert(len(forwardingReqs), Equals, 0)
|
assert.Equal(t, len(forwardingReqs), 0)
|
||||||
c.Assert(len(entryMap), Equals, 2)
|
assert.Equal(t, len(entryMap), 2)
|
||||||
for i, id := range batchedReq.GetRequestIds() {
|
for i, id := range batchedReq.GetRequestIds() {
|
||||||
c.Assert(entryMap[id].req, Equals, batchedReq.GetRequests()[i])
|
assert.Equal(t, entryMap[id].req, batchedReq.GetRequests()[i])
|
||||||
c.Assert(entryMap[id].isCanceled(), IsFalse)
|
assert.False(t, entryMap[id].isCanceled())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test canceling all requests
|
// Test canceling all requests
|
||||||
|
|
@ -446,16 +432,16 @@ func (s *testClientSuite) TestBatchCommandsBuilder(c *C) {
|
||||||
builder.cancel(err)
|
builder.cancel(err)
|
||||||
for _, entry := range entries {
|
for _, entry := range entries {
|
||||||
_, ok := <-entry.res
|
_, ok := <-entry.res
|
||||||
c.Assert(ok, IsFalse)
|
assert.False(t, ok)
|
||||||
c.Assert(entry.err, Equals, err)
|
assert.Equal(t, entry.err, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test reset
|
// Test reset
|
||||||
builder.reset()
|
builder.reset()
|
||||||
c.Assert(builder.len(), Equals, 0)
|
assert.Equal(t, builder.len(), 0)
|
||||||
c.Assert(len(builder.entries), Equals, 0)
|
assert.Equal(t, len(builder.entries), 0)
|
||||||
c.Assert(len(builder.requests), Equals, 0)
|
assert.Equal(t, len(builder.requests), 0)
|
||||||
c.Assert(len(builder.requestIDs), Equals, 0)
|
assert.Equal(t, len(builder.requestIDs), 0)
|
||||||
c.Assert(len(builder.forwardingReqs), Equals, 0)
|
assert.Equal(t, len(builder.forwardingReqs), 0)
|
||||||
c.Assert(builder.idAlloc, Not(Equals), 0)
|
assert.NotEqual(t, builder.idAlloc, 0)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue