client: replace pingcap/check with testify (#166)

This commit is contained in:
disksing 2021-06-24 18:08:48 +08:00 committed by GitHub
parent 09ff177ac4
commit 901066f801
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 108 additions and 123 deletions

View File

@ -36,25 +36,24 @@ import (
"context" "context"
"fmt" "fmt"
"sync/atomic" "sync/atomic"
"testing"
"time" "time"
. "github.com/pingcap/check"
"github.com/pingcap/failpoint" "github.com/pingcap/failpoint"
"github.com/pingcap/kvproto/pkg/kvrpcpb" "github.com/pingcap/kvproto/pkg/kvrpcpb"
"github.com/pingcap/kvproto/pkg/tikvpb" "github.com/pingcap/kvproto/pkg/tikvpb"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tikv/client-go/v2/config" "github.com/tikv/client-go/v2/config"
"github.com/tikv/client-go/v2/tikvrpc" "github.com/tikv/client-go/v2/tikvrpc"
) )
type testClientFailSuite struct { func TestPanicInRecvLoop(t *testing.T) {
} require.Nil(t, failpoint.Enable("tikvclient/panicInFailPendingRequests", `panic`))
require.Nil(t, failpoint.Enable("tikvclient/gotErrorInRecvLoop", `return("0")`))
func (s *testClientFailSuite) TestPanicInRecvLoop(c *C) {
c.Assert(failpoint.Enable("tikvclient/panicInFailPendingRequests", `panic`), IsNil)
c.Assert(failpoint.Enable("tikvclient/gotErrorInRecvLoop", `return("0")`), IsNil)
server, port := startMockTikvService() server, port := startMockTikvService()
c.Assert(port > 0, IsTrue) require.True(t, port > 0)
defer server.Stop() defer server.Stop()
addr := fmt.Sprintf("%s:%d", "127.0.0.1", port) addr := fmt.Sprintf("%s:%d", "127.0.0.1", port)
@ -64,24 +63,24 @@ func (s *testClientFailSuite) TestPanicInRecvLoop(c *C) {
// Start batchRecvLoop, and it should panic in `failPendingRequests`. // Start batchRecvLoop, and it should panic in `failPendingRequests`.
_, err := rpcClient.getConnArray(addr, true, func(cfg *config.TiKVClient) { cfg.GrpcConnectionCount = 1 }) _, err := rpcClient.getConnArray(addr, true, func(cfg *config.TiKVClient) { cfg.GrpcConnectionCount = 1 })
c.Assert(err, IsNil, Commentf("cannot establish local connection due to env problems(e.g. heavy load in test machine), please retry again")) assert.Nil(t, err, "cannot establish local connection due to env problems(e.g. heavy load in test machine), please retry again")
req := tikvrpc.NewRequest(tikvrpc.CmdEmpty, &tikvpb.BatchCommandsEmptyRequest{}) req := tikvrpc.NewRequest(tikvrpc.CmdEmpty, &tikvpb.BatchCommandsEmptyRequest{})
_, err = rpcClient.SendRequest(context.Background(), addr, req, time.Second/2) _, err = rpcClient.SendRequest(context.Background(), addr, req, time.Second/2)
c.Assert(err, NotNil) assert.NotNil(t, err)
c.Assert(failpoint.Disable("tikvclient/gotErrorInRecvLoop"), IsNil) require.Nil(t, failpoint.Disable("tikvclient/gotErrorInRecvLoop"))
c.Assert(failpoint.Disable("tikvclient/panicInFailPendingRequests"), IsNil) require.Nil(t, failpoint.Disable("tikvclient/panicInFailPendingRequests"))
time.Sleep(time.Second * 2) time.Sleep(time.Second * 2)
req = tikvrpc.NewRequest(tikvrpc.CmdEmpty, &tikvpb.BatchCommandsEmptyRequest{}) req = tikvrpc.NewRequest(tikvrpc.CmdEmpty, &tikvpb.BatchCommandsEmptyRequest{})
_, err = rpcClient.SendRequest(context.Background(), addr, req, time.Second*4) _, err = rpcClient.SendRequest(context.Background(), addr, req, time.Second*4)
c.Assert(err, IsNil) assert.Nil(t, err)
} }
func (s *testClientFailSuite) TestRecvErrorInMultipleRecvLoops(c *C) { func TestRecvErrorInMultipleRecvLoops(t *testing.T) {
server, port := startMockTikvService() server, port := startMockTikvService()
c.Assert(port > 0, IsTrue) require.True(t, port > 0)
defer server.Stop() defer server.Stop()
addr := fmt.Sprintf("%s:%d", "127.0.0.1", port) addr := fmt.Sprintf("%s:%d", "127.0.0.1", port)
@ -100,20 +99,20 @@ func (s *testClientFailSuite) TestRecvErrorInMultipleRecvLoops(c *C) {
for _, forwardedHost := range forwardedHosts { for _, forwardedHost := range forwardedHosts {
prewriteReq.ForwardedHost = forwardedHost prewriteReq.ForwardedHost = forwardedHost
_, err := rpcClient.SendRequest(context.Background(), addr, prewriteReq, 10*time.Second) _, err := rpcClient.SendRequest(context.Background(), addr, prewriteReq, 10*time.Second)
c.Assert(err, IsNil) assert.Nil(t, err)
} }
connArray, err := rpcClient.getConnArray(addr, true) connArray, err := rpcClient.getConnArray(addr, true)
c.Assert(connArray, NotNil) assert.NotNil(t, connArray)
c.Assert(err, IsNil) assert.Nil(t, err)
batchConn := connArray.batchConn batchConn := connArray.batchConn
c.Assert(batchConn, NotNil) assert.NotNil(t, batchConn)
c.Assert(len(batchConn.batchCommandsClients), Equals, 1) assert.Equal(t, len(batchConn.batchCommandsClients), 1)
batchClient := batchConn.batchCommandsClients[0] batchClient := batchConn.batchCommandsClients[0]
c.Assert(batchClient.client, NotNil) assert.NotNil(t, batchClient.client)
c.Assert(batchClient.client.forwardedHost, Equals, "") assert.Equal(t, batchClient.client.forwardedHost, "")
c.Assert(len(batchClient.forwardedClients), Equals, 3) assert.Equal(t, len(batchClient.forwardedClients), 3)
for _, forwardedHosts := range forwardedHosts[1:] { for _, forwardedHosts := range forwardedHosts[1:] {
c.Assert(batchClient.forwardedClients[forwardedHosts].forwardedHost, Equals, forwardedHosts) assert.Equal(t, batchClient.forwardedClients[forwardedHosts].forwardedHost, forwardedHosts)
} }
// Save all streams // Save all streams
@ -127,12 +126,12 @@ func (s *testClientFailSuite) TestRecvErrorInMultipleRecvLoops(c *C) {
fp := "tikvclient/gotErrorInRecvLoop" fp := "tikvclient/gotErrorInRecvLoop"
// Send a request to each stream to trigger reconnection. // Send a request to each stream to trigger reconnection.
for _, forwardedHost := range forwardedHosts { for _, forwardedHost := range forwardedHosts {
c.Assert(failpoint.Enable(fp, `1*return("0")`), IsNil) require.Nil(t, failpoint.Enable(fp, `1*return("0")`))
prewriteReq.ForwardedHost = forwardedHost prewriteReq.ForwardedHost = forwardedHost
_, err := rpcClient.SendRequest(context.Background(), addr, prewriteReq, 10*time.Second) _, err := rpcClient.SendRequest(context.Background(), addr, prewriteReq, 10*time.Second)
c.Assert(err, IsNil) assert.Nil(t, err)
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
c.Assert(failpoint.Disable(fp), IsNil) assert.Nil(t, failpoint.Disable(fp))
} }
// Wait for finishing reconnection. // Wait for finishing reconnection.
@ -150,14 +149,14 @@ func (s *testClientFailSuite) TestRecvErrorInMultipleRecvLoops(c *C) {
for _, forwardedHost := range forwardedHosts { for _, forwardedHost := range forwardedHosts {
prewriteReq.ForwardedHost = forwardedHost prewriteReq.ForwardedHost = forwardedHost
_, err := rpcClient.SendRequest(context.Background(), addr, prewriteReq, 10*time.Second) _, err := rpcClient.SendRequest(context.Background(), addr, prewriteReq, 10*time.Second)
c.Assert(err, IsNil) assert.Nil(t, err)
} }
// Should only reconnect once. // Should only reconnect once.
c.Assert(atomic.LoadUint64(&batchClient.epoch), Equals, epoch+1) assert.Equal(t, atomic.LoadUint64(&batchClient.epoch), epoch+1)
// All streams are refreshed. // All streams are refreshed.
c.Assert(batchClient.client.Tikv_BatchCommandsClient, Not(Equals), clientSave) assert.NotEqual(t, batchClient.client.Tikv_BatchCommandsClient, clientSave)
c.Assert(len(batchClient.forwardedClients), Equals, len(forwardedClientsSave)) assert.Equal(t, len(batchClient.forwardedClients), len(forwardedClientsSave))
for host, clientSave := range forwardedClientsSave { for host, clientSave := range forwardedClientsSave {
c.Assert(batchClient.forwardedClients[host].Tikv_BatchCommandsClient, Not(Equals), clientSave) assert.NotEqual(t, batchClient.forwardedClients[host].Tikv_BatchCommandsClient, clientSave)
} }
} }

View File

@ -40,33 +40,19 @@ import (
"testing" "testing"
"time" "time"
. "github.com/pingcap/check"
"github.com/pingcap/errors" "github.com/pingcap/errors"
"github.com/pingcap/kvproto/pkg/coprocessor" "github.com/pingcap/kvproto/pkg/coprocessor"
"github.com/pingcap/kvproto/pkg/kvrpcpb" "github.com/pingcap/kvproto/pkg/kvrpcpb"
"github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/metapb"
"github.com/pingcap/kvproto/pkg/tikvpb" "github.com/pingcap/kvproto/pkg/tikvpb"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tikv/client-go/v2/config" "github.com/tikv/client-go/v2/config"
"github.com/tikv/client-go/v2/tikvrpc" "github.com/tikv/client-go/v2/tikvrpc"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
) )
func TestT(t *testing.T) { func TestConn(t *testing.T) {
CustomVerboseFlag = true
TestingT(t)
}
type testClientSuite struct {
}
type testClientSerialSuite struct {
}
var _ = SerialSuites(&testClientSuite{})
var _ = SerialSuites(&testClientFailSuite{})
var _ = SerialSuites(&testClientSerialSuite{})
func (s *testClientSerialSuite) TestConn(c *C) {
defer config.UpdateGlobal(func(conf *config.Config) { defer config.UpdateGlobal(func(conf *config.Config) {
conf.TiKVClient.MaxBatchSize = 0 conf.TiKVClient.MaxBatchSize = 0
})() })()
@ -75,39 +61,39 @@ func (s *testClientSerialSuite) TestConn(c *C) {
addr := "127.0.0.1:6379" addr := "127.0.0.1:6379"
conn1, err := client.getConnArray(addr, true) conn1, err := client.getConnArray(addr, true)
c.Assert(err, IsNil) assert.Nil(t, err)
conn2, err := client.getConnArray(addr, true) conn2, err := client.getConnArray(addr, true)
c.Assert(err, IsNil) assert.Nil(t, err)
c.Assert(conn2.Get(), Not(Equals), conn1.Get()) assert.NotEqual(t, conn2.Get(), conn1.Get())
client.Close() client.Close()
conn3, err := client.getConnArray(addr, true) conn3, err := client.getConnArray(addr, true)
c.Assert(err, NotNil) assert.NotNil(t, err)
c.Assert(conn3, IsNil) assert.Nil(t, conn3)
} }
func (s *testClientSuite) TestCancelTimeoutRetErr(c *C) { func TestCancelTimeoutRetErr(t *testing.T) {
req := new(tikvpb.BatchCommandsRequest_Request) req := new(tikvpb.BatchCommandsRequest_Request)
a := newBatchConn(1, 1, nil) a := newBatchConn(1, 1, nil)
ctx, cancel := context.WithCancel(context.TODO()) ctx, cancel := context.WithCancel(context.TODO())
cancel() cancel()
_, err := sendBatchRequest(ctx, "", "", a, req, 2*time.Second) _, err := sendBatchRequest(ctx, "", "", a, req, 2*time.Second)
c.Assert(errors.Cause(err), Equals, context.Canceled) assert.Equal(t, errors.Cause(err), context.Canceled)
_, err = sendBatchRequest(context.Background(), "", "", a, req, 0) _, err = sendBatchRequest(context.Background(), "", "", a, req, 0)
c.Assert(errors.Cause(err), Equals, context.DeadlineExceeded) assert.Equal(t, errors.Cause(err), context.DeadlineExceeded)
} }
func (s *testClientSuite) TestSendWhenReconnect(c *C) { func TestSendWhenReconnect(t *testing.T) {
server, port := startMockTikvService() server, port := startMockTikvService()
c.Assert(port > 0, IsTrue) require.True(t, port > 0)
rpcClient := NewRPCClient(config.Security{}) rpcClient := NewRPCClient(config.Security{})
addr := fmt.Sprintf("%s:%d", "127.0.0.1", port) addr := fmt.Sprintf("%s:%d", "127.0.0.1", port)
conn, err := rpcClient.getConnArray(addr, true) conn, err := rpcClient.getConnArray(addr, true)
c.Assert(err, IsNil) assert.Nil(t, err)
// Suppose all connections are re-establishing. // Suppose all connections are re-establishing.
for _, client := range conn.batchConn.batchCommandsClients { for _, client := range conn.batchConn.batchCommandsClients {
@ -116,7 +102,7 @@ func (s *testClientSuite) TestSendWhenReconnect(c *C) {
req := tikvrpc.NewRequest(tikvrpc.CmdEmpty, &tikvpb.BatchCommandsEmptyRequest{}) req := tikvrpc.NewRequest(tikvrpc.CmdEmpty, &tikvpb.BatchCommandsEmptyRequest{})
_, err = rpcClient.SendRequest(context.Background(), addr, req, 100*time.Second) _, err = rpcClient.SendRequest(context.Background(), addr, req, 100*time.Second)
c.Assert(err.Error() == "no available connections", IsTrue) assert.True(t, err.Error() == "no available connections")
conn.Close() conn.Close()
server.Stop() server.Stop()
} }
@ -137,7 +123,7 @@ func (c *chanClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.
return nil, nil return nil, nil
} }
func (s *testClientSuite) TestCollapseResolveLock(c *C) { func TestCollapseResolveLock(t *testing.T) {
buildResolveLockReq := func(regionID uint64, startTS uint64, commitTS uint64, keys [][]byte) *tikvrpc.Request { buildResolveLockReq := func(regionID uint64, startTS uint64, commitTS uint64, keys [][]byte) *tikvrpc.Request {
region := &metapb.Region{Id: regionID} region := &metapb.Region{Id: regionID}
req := tikvrpc.NewRequest(tikvrpc.CmdResolveLock, &kvrpcpb.ResolveLockRequest{ req := tikvrpc.NewRequest(tikvrpc.CmdResolveLock, &kvrpcpb.ResolveLockRequest{
@ -170,10 +156,10 @@ func (s *testClientSuite) TestCollapseResolveLock(c *C) {
time.Sleep(300 * time.Millisecond) time.Sleep(300 * time.Millisecond)
wg.Done() wg.Done()
req := <-reqCh req := <-reqCh
c.Assert(*req, DeepEquals, *resolveLockReq) assert.Equal(t, *req, *resolveLockReq)
select { select {
case <-reqCh: case <-reqCh:
c.Fatal("fail to collapse ResolveLock") assert.Fail(t, "fail to collapse ResolveLock")
default: default:
} }
@ -186,7 +172,7 @@ func (s *testClientSuite) TestCollapseResolveLock(c *C) {
wg.Done() wg.Done()
for i := 0; i < 2; i++ { for i := 0; i < 2; i++ {
req := <-reqCh req := <-reqCh
c.Assert(*req, DeepEquals, *resolveLockLiteReq) assert.Equal(t, *req, *resolveLockLiteReq)
} }
// Don't collapse BatchResolveLock. // Don't collapse BatchResolveLock.
@ -200,7 +186,7 @@ func (s *testClientSuite) TestCollapseResolveLock(c *C) {
wg.Done() wg.Done()
for i := 0; i < 2; i++ { for i := 0; i < 2; i++ {
req := <-reqCh req := <-reqCh
c.Assert(*req, DeepEquals, *batchResolveLockReq) assert.Equal(t, *req, *batchResolveLockReq)
} }
// Mixed // Mixed
@ -215,14 +201,14 @@ func (s *testClientSuite) TestCollapseResolveLock(c *C) {
} }
select { select {
case <-reqCh: case <-reqCh:
c.Fatal("unexpected request") assert.Fail(t, "unexpected request")
default: default:
} }
} }
func (s *testClientSerialSuite) TestForwardMetadataByUnaryCall(c *C) { func TestForwardMetadataByUnaryCall(t *testing.T) {
server, port := startMockTikvService() server, port := startMockTikvService()
c.Assert(port > 0, IsTrue) require.True(t, port > 0)
defer server.Stop() defer server.Stop()
addr := fmt.Sprintf("%s:%d", "127.0.0.1", port) addr := fmt.Sprintf("%s:%d", "127.0.0.1", port)
@ -242,7 +228,7 @@ func (s *testClientSerialSuite) TestForwardMetadataByUnaryCall(c *C) {
md, ok := metadata.FromIncomingContext(ctx) md, ok := metadata.FromIncomingContext(ctx)
if ok { if ok {
vals := md.Get(forwardMetadataKey) vals := md.Get(forwardMetadataKey)
c.Assert(len(vals), Equals, 0) assert.Equal(t, len(vals), 0)
} }
return nil return nil
}) })
@ -251,15 +237,15 @@ func (s *testClientSerialSuite) TestForwardMetadataByUnaryCall(c *C) {
prewriteReq := tikvrpc.NewRequest(tikvrpc.CmdPrewrite, &kvrpcpb.PrewriteRequest{}) prewriteReq := tikvrpc.NewRequest(tikvrpc.CmdPrewrite, &kvrpcpb.PrewriteRequest{})
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
_, err := rpcClient.SendRequest(context.Background(), addr, prewriteReq, 10*time.Second) _, err := rpcClient.SendRequest(context.Background(), addr, prewriteReq, 10*time.Second)
c.Assert(err, IsNil) assert.Nil(t, err)
} }
c.Assert(atomic.LoadUint64(&checkCnt), Equals, uint64(3)) assert.Equal(t, atomic.LoadUint64(&checkCnt), uint64(3))
// CopStream represents unary-stream call. // CopStream represents unary-stream call.
copStreamReq := tikvrpc.NewRequest(tikvrpc.CmdCopStream, &coprocessor.Request{}) copStreamReq := tikvrpc.NewRequest(tikvrpc.CmdCopStream, &coprocessor.Request{})
_, err := rpcClient.SendRequest(context.Background(), addr, copStreamReq, 10*time.Second) _, err := rpcClient.SendRequest(context.Background(), addr, copStreamReq, 10*time.Second)
c.Assert(err, IsNil) assert.Nil(t, err)
c.Assert(atomic.LoadUint64(&checkCnt), Equals, uint64(4)) assert.Equal(t, atomic.LoadUint64(&checkCnt), uint64(4))
checkCnt = 0 checkCnt = 0
forwardedHost := "127.0.0.1:6666" forwardedHost := "127.0.0.1:6666"
@ -268,29 +254,29 @@ func (s *testClientSerialSuite) TestForwardMetadataByUnaryCall(c *C) {
atomic.AddUint64(&checkCnt, 1) atomic.AddUint64(&checkCnt, 1)
// gRPC may set some metadata by default, e.g. "context-type". // gRPC may set some metadata by default, e.g. "context-type".
md, ok := metadata.FromIncomingContext(ctx) md, ok := metadata.FromIncomingContext(ctx)
c.Assert(ok, IsTrue) assert.True(t, ok)
vals := md.Get(forwardMetadataKey) vals := md.Get(forwardMetadataKey)
c.Assert(vals, DeepEquals, []string{forwardedHost}) assert.Equal(t, vals, []string{forwardedHost})
return nil return nil
}) })
prewriteReq.ForwardedHost = forwardedHost prewriteReq.ForwardedHost = forwardedHost
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
_, err = rpcClient.SendRequest(context.Background(), addr, prewriteReq, 10*time.Second) _, err = rpcClient.SendRequest(context.Background(), addr, prewriteReq, 10*time.Second)
c.Assert(err, IsNil) assert.Nil(t, err)
} }
// checkCnt should be 3 because we don't use BatchCommands for redirection for now. // checkCnt should be 3 because we don't use BatchCommands for redirection for now.
c.Assert(atomic.LoadUint64(&checkCnt), Equals, uint64(3)) assert.Equal(t, atomic.LoadUint64(&checkCnt), uint64(3))
copStreamReq.ForwardedHost = forwardedHost copStreamReq.ForwardedHost = forwardedHost
_, err = rpcClient.SendRequest(context.Background(), addr, copStreamReq, 10*time.Second) _, err = rpcClient.SendRequest(context.Background(), addr, copStreamReq, 10*time.Second)
c.Assert(err, IsNil) assert.Nil(t, err)
c.Assert(atomic.LoadUint64(&checkCnt), Equals, uint64(4)) assert.Equal(t, atomic.LoadUint64(&checkCnt), uint64(4))
} }
func (s *testClientSerialSuite) TestForwardMetadataByBatchCommands(c *C) { func TestForwardMetadataByBatchCommands(t *testing.T) {
server, port := startMockTikvService() server, port := startMockTikvService()
c.Assert(port > 0, IsTrue) require.True(t, port > 0)
defer server.Stop() defer server.Stop()
addr := fmt.Sprintf("%s:%d", "127.0.0.1", port) addr := fmt.Sprintf("%s:%d", "127.0.0.1", port)
@ -311,12 +297,12 @@ func (s *testClientSerialSuite) TestForwardMetadataByBatchCommands(c *C) {
if forwardedHost == "" { if forwardedHost == "" {
if ok { if ok {
vals := md.Get(forwardMetadataKey) vals := md.Get(forwardMetadataKey)
c.Assert(len(vals), Equals, 0) assert.Equal(t, len(vals), 0)
} }
} else { } else {
c.Assert(ok, IsTrue) assert.True(t, ok)
vals := md.Get(forwardMetadataKey) vals := md.Get(forwardMetadataKey)
c.Assert(vals, DeepEquals, []string{forwardedHost}) assert.Equal(t, vals, []string{forwardedHost})
} }
return nil return nil
@ -330,10 +316,10 @@ func (s *testClientSerialSuite) TestForwardMetadataByBatchCommands(c *C) {
prewriteReq.ForwardedHost = forwardedHost prewriteReq.ForwardedHost = forwardedHost
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
_, err := rpcClient.SendRequest(context.Background(), addr, prewriteReq, 10*time.Second) _, err := rpcClient.SendRequest(context.Background(), addr, prewriteReq, 10*time.Second)
c.Assert(err, IsNil) assert.Nil(t, err)
} }
// checkCnt should be i because there is a stream for each forwardedHost. // checkCnt should be i because there is a stream for each forwardedHost.
c.Assert(atomic.LoadUint64(&checkCnt), Equals, 1+uint64(i)) assert.Equal(t, atomic.LoadUint64(&checkCnt), 1+uint64(i))
} }
checkCnt = 0 checkCnt = 0
@ -342,18 +328,18 @@ func (s *testClientSerialSuite) TestForwardMetadataByBatchCommands(c *C) {
// Check no corresponding metadata if forwardedHost is empty. // Check no corresponding metadata if forwardedHost is empty.
setCheckHandler("") setCheckHandler("")
_, err := rpcClient.SendRequest(context.Background(), addr, copStreamReq, 10*time.Second) _, err := rpcClient.SendRequest(context.Background(), addr, copStreamReq, 10*time.Second)
c.Assert(err, IsNil) assert.Nil(t, err)
c.Assert(atomic.LoadUint64(&checkCnt), Equals, uint64(1)) assert.Equal(t, atomic.LoadUint64(&checkCnt), uint64(1))
copStreamReq.ForwardedHost = "127.0.0.1:6666" copStreamReq.ForwardedHost = "127.0.0.1:6666"
// Check the metadata exists. // Check the metadata exists.
setCheckHandler(copStreamReq.ForwardedHost) setCheckHandler(copStreamReq.ForwardedHost)
_, err = rpcClient.SendRequest(context.Background(), addr, copStreamReq, 10*time.Second) _, err = rpcClient.SendRequest(context.Background(), addr, copStreamReq, 10*time.Second)
c.Assert(err, IsNil) assert.Nil(t, err)
c.Assert(atomic.LoadUint64(&checkCnt), Equals, uint64(2)) assert.Equal(t, atomic.LoadUint64(&checkCnt), uint64(2))
} }
func (s *testClientSuite) TestBatchCommandsBuilder(c *C) { func TestBatchCommandsBuilder(t *testing.T) {
builder := newBatchCommandsBuilder(128) builder := newBatchCommandsBuilder(128)
// Test no forwarding requests. // Test no forwarding requests.
@ -361,21 +347,21 @@ func (s *testClientSuite) TestBatchCommandsBuilder(c *C) {
req := new(tikvpb.BatchCommandsRequest_Request) req := new(tikvpb.BatchCommandsRequest_Request)
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
builder.push(&batchCommandsEntry{req: req}) builder.push(&batchCommandsEntry{req: req})
c.Assert(builder.len(), Equals, i+1) assert.Equal(t, builder.len(), i+1)
} }
entryMap := make(map[uint64]*batchCommandsEntry) entryMap := make(map[uint64]*batchCommandsEntry)
batchedReq, forwardingReqs := builder.build(func(id uint64, e *batchCommandsEntry) { batchedReq, forwardingReqs := builder.build(func(id uint64, e *batchCommandsEntry) {
entryMap[id] = e entryMap[id] = e
}) })
c.Assert(len(batchedReq.GetRequests()), Equals, 10) assert.Equal(t, len(batchedReq.GetRequests()), 10)
c.Assert(len(batchedReq.GetRequestIds()), Equals, 10) assert.Equal(t, len(batchedReq.GetRequestIds()), 10)
c.Assert(len(entryMap), Equals, 10) assert.Equal(t, len(entryMap), 10)
for i, id := range batchedReq.GetRequestIds() { for i, id := range batchedReq.GetRequestIds() {
c.Assert(id, Equals, uint64(i)) assert.Equal(t, id, uint64(i))
c.Assert(entryMap[id].req, Equals, batchedReq.GetRequests()[i]) assert.Equal(t, entryMap[id].req, batchedReq.GetRequests()[i])
} }
c.Assert(len(forwardingReqs), Equals, 0) assert.Equal(t, len(forwardingReqs), 0)
c.Assert(builder.idAlloc, Equals, uint64(10)) assert.Equal(t, builder.idAlloc, uint64(10))
// Test collecting forwarding requests. // Test collecting forwarding requests.
builder.reset() builder.reset()
@ -393,19 +379,19 @@ func (s *testClientSuite) TestBatchCommandsBuilder(c *C) {
batchedReq, forwardingReqs = builder.build(func(id uint64, e *batchCommandsEntry) { batchedReq, forwardingReqs = builder.build(func(id uint64, e *batchCommandsEntry) {
entryMap[id] = e entryMap[id] = e
}) })
c.Assert(len(batchedReq.GetRequests()), Equals, 1) assert.Equal(t, len(batchedReq.GetRequests()), 1)
c.Assert(len(batchedReq.GetRequestIds()), Equals, 1) assert.Equal(t, len(batchedReq.GetRequestIds()), 1)
c.Assert(len(forwardingReqs), Equals, 3) assert.Equal(t, len(forwardingReqs), 3)
for i, host := range forwardedHosts[1:] { for i, host := range forwardedHosts[1:] {
c.Assert(len(forwardingReqs[host].GetRequests()), Equals, i+2) assert.Equal(t, len(forwardingReqs[host].GetRequests()), i+2)
c.Assert(len(forwardingReqs[host].GetRequestIds()), Equals, i+2) assert.Equal(t, len(forwardingReqs[host].GetRequestIds()), i+2)
} }
c.Assert(builder.idAlloc, Equals, uint64(10+builder.len())) assert.Equal(t, builder.idAlloc, uint64(10+builder.len()))
c.Assert(len(entryMap), Equals, builder.len()) assert.Equal(t, len(entryMap), builder.len())
for host, forwardingReq := range forwardingReqs { for host, forwardingReq := range forwardingReqs {
for i, id := range forwardingReq.GetRequestIds() { for i, id := range forwardingReq.GetRequestIds() {
c.Assert(entryMap[id].req, Equals, forwardingReq.GetRequests()[i]) assert.Equal(t, entryMap[id].req, forwardingReq.GetRequests()[i])
c.Assert(entryMap[id].forwardedHost, Equals, host) assert.Equal(t, entryMap[id].forwardedHost, host)
} }
} }
@ -425,13 +411,13 @@ func (s *testClientSuite) TestBatchCommandsBuilder(c *C) {
batchedReq, forwardingReqs = builder.build(func(id uint64, e *batchCommandsEntry) { batchedReq, forwardingReqs = builder.build(func(id uint64, e *batchCommandsEntry) {
entryMap[id] = e entryMap[id] = e
}) })
c.Assert(len(batchedReq.GetRequests()), Equals, 2) assert.Equal(t, len(batchedReq.GetRequests()), 2)
c.Assert(len(batchedReq.GetRequestIds()), Equals, 2) assert.Equal(t, len(batchedReq.GetRequestIds()), 2)
c.Assert(len(forwardingReqs), Equals, 0) assert.Equal(t, len(forwardingReqs), 0)
c.Assert(len(entryMap), Equals, 2) assert.Equal(t, len(entryMap), 2)
for i, id := range batchedReq.GetRequestIds() { for i, id := range batchedReq.GetRequestIds() {
c.Assert(entryMap[id].req, Equals, batchedReq.GetRequests()[i]) assert.Equal(t, entryMap[id].req, batchedReq.GetRequests()[i])
c.Assert(entryMap[id].isCanceled(), IsFalse) assert.False(t, entryMap[id].isCanceled())
} }
// Test canceling all requests // Test canceling all requests
@ -446,16 +432,16 @@ func (s *testClientSuite) TestBatchCommandsBuilder(c *C) {
builder.cancel(err) builder.cancel(err)
for _, entry := range entries { for _, entry := range entries {
_, ok := <-entry.res _, ok := <-entry.res
c.Assert(ok, IsFalse) assert.False(t, ok)
c.Assert(entry.err, Equals, err) assert.Equal(t, entry.err, err)
} }
// Test reset // Test reset
builder.reset() builder.reset()
c.Assert(builder.len(), Equals, 0) assert.Equal(t, builder.len(), 0)
c.Assert(len(builder.entries), Equals, 0) assert.Equal(t, len(builder.entries), 0)
c.Assert(len(builder.requests), Equals, 0) assert.Equal(t, len(builder.requests), 0)
c.Assert(len(builder.requestIDs), Equals, 0) assert.Equal(t, len(builder.requestIDs), 0)
c.Assert(len(builder.forwardingReqs), Equals, 0) assert.Equal(t, len(builder.forwardingReqs), 0)
c.Assert(builder.idAlloc, Not(Equals), 0) assert.NotEqual(t, builder.idAlloc, 0)
} }