client: replace pingcap/check with testify (#166)

This commit is contained in:
disksing 2021-06-24 18:08:48 +08:00 committed by GitHub
parent 09ff177ac4
commit 901066f801
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 108 additions and 123 deletions

View File

@ -36,25 +36,24 @@ import (
"context"
"fmt"
"sync/atomic"
"testing"
"time"
. "github.com/pingcap/check"
"github.com/pingcap/failpoint"
"github.com/pingcap/kvproto/pkg/kvrpcpb"
"github.com/pingcap/kvproto/pkg/tikvpb"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tikv/client-go/v2/config"
"github.com/tikv/client-go/v2/tikvrpc"
)
type testClientFailSuite struct {
}
func (s *testClientFailSuite) TestPanicInRecvLoop(c *C) {
c.Assert(failpoint.Enable("tikvclient/panicInFailPendingRequests", `panic`), IsNil)
c.Assert(failpoint.Enable("tikvclient/gotErrorInRecvLoop", `return("0")`), IsNil)
func TestPanicInRecvLoop(t *testing.T) {
require.Nil(t, failpoint.Enable("tikvclient/panicInFailPendingRequests", `panic`))
require.Nil(t, failpoint.Enable("tikvclient/gotErrorInRecvLoop", `return("0")`))
server, port := startMockTikvService()
c.Assert(port > 0, IsTrue)
require.True(t, port > 0)
defer server.Stop()
addr := fmt.Sprintf("%s:%d", "127.0.0.1", port)
@ -64,24 +63,24 @@ func (s *testClientFailSuite) TestPanicInRecvLoop(c *C) {
// Start batchRecvLoop, and it should panic in `failPendingRequests`.
_, err := rpcClient.getConnArray(addr, true, func(cfg *config.TiKVClient) { cfg.GrpcConnectionCount = 1 })
c.Assert(err, IsNil, Commentf("cannot establish local connection due to env problems(e.g. heavy load in test machine), please retry again"))
assert.Nil(t, err, "cannot establish local connection due to env problems(e.g. heavy load in test machine), please retry again")
req := tikvrpc.NewRequest(tikvrpc.CmdEmpty, &tikvpb.BatchCommandsEmptyRequest{})
_, err = rpcClient.SendRequest(context.Background(), addr, req, time.Second/2)
c.Assert(err, NotNil)
assert.NotNil(t, err)
c.Assert(failpoint.Disable("tikvclient/gotErrorInRecvLoop"), IsNil)
c.Assert(failpoint.Disable("tikvclient/panicInFailPendingRequests"), IsNil)
require.Nil(t, failpoint.Disable("tikvclient/gotErrorInRecvLoop"))
require.Nil(t, failpoint.Disable("tikvclient/panicInFailPendingRequests"))
time.Sleep(time.Second * 2)
req = tikvrpc.NewRequest(tikvrpc.CmdEmpty, &tikvpb.BatchCommandsEmptyRequest{})
_, err = rpcClient.SendRequest(context.Background(), addr, req, time.Second*4)
c.Assert(err, IsNil)
assert.Nil(t, err)
}
func (s *testClientFailSuite) TestRecvErrorInMultipleRecvLoops(c *C) {
func TestRecvErrorInMultipleRecvLoops(t *testing.T) {
server, port := startMockTikvService()
c.Assert(port > 0, IsTrue)
require.True(t, port > 0)
defer server.Stop()
addr := fmt.Sprintf("%s:%d", "127.0.0.1", port)
@ -100,20 +99,20 @@ func (s *testClientFailSuite) TestRecvErrorInMultipleRecvLoops(c *C) {
for _, forwardedHost := range forwardedHosts {
prewriteReq.ForwardedHost = forwardedHost
_, err := rpcClient.SendRequest(context.Background(), addr, prewriteReq, 10*time.Second)
c.Assert(err, IsNil)
assert.Nil(t, err)
}
connArray, err := rpcClient.getConnArray(addr, true)
c.Assert(connArray, NotNil)
c.Assert(err, IsNil)
assert.NotNil(t, connArray)
assert.Nil(t, err)
batchConn := connArray.batchConn
c.Assert(batchConn, NotNil)
c.Assert(len(batchConn.batchCommandsClients), Equals, 1)
assert.NotNil(t, batchConn)
assert.Equal(t, len(batchConn.batchCommandsClients), 1)
batchClient := batchConn.batchCommandsClients[0]
c.Assert(batchClient.client, NotNil)
c.Assert(batchClient.client.forwardedHost, Equals, "")
c.Assert(len(batchClient.forwardedClients), Equals, 3)
assert.NotNil(t, batchClient.client)
assert.Equal(t, batchClient.client.forwardedHost, "")
assert.Equal(t, len(batchClient.forwardedClients), 3)
for _, forwardedHosts := range forwardedHosts[1:] {
c.Assert(batchClient.forwardedClients[forwardedHosts].forwardedHost, Equals, forwardedHosts)
assert.Equal(t, batchClient.forwardedClients[forwardedHosts].forwardedHost, forwardedHosts)
}
// Save all streams
@ -127,12 +126,12 @@ func (s *testClientFailSuite) TestRecvErrorInMultipleRecvLoops(c *C) {
fp := "tikvclient/gotErrorInRecvLoop"
// Send a request to each stream to trigger reconnection.
for _, forwardedHost := range forwardedHosts {
c.Assert(failpoint.Enable(fp, `1*return("0")`), IsNil)
require.Nil(t, failpoint.Enable(fp, `1*return("0")`))
prewriteReq.ForwardedHost = forwardedHost
_, err := rpcClient.SendRequest(context.Background(), addr, prewriteReq, 10*time.Second)
c.Assert(err, IsNil)
assert.Nil(t, err)
time.Sleep(100 * time.Millisecond)
c.Assert(failpoint.Disable(fp), IsNil)
assert.Nil(t, failpoint.Disable(fp))
}
// Wait for finishing reconnection.
@ -150,14 +149,14 @@ func (s *testClientFailSuite) TestRecvErrorInMultipleRecvLoops(c *C) {
for _, forwardedHost := range forwardedHosts {
prewriteReq.ForwardedHost = forwardedHost
_, err := rpcClient.SendRequest(context.Background(), addr, prewriteReq, 10*time.Second)
c.Assert(err, IsNil)
assert.Nil(t, err)
}
// Should only reconnect once.
c.Assert(atomic.LoadUint64(&batchClient.epoch), Equals, epoch+1)
assert.Equal(t, atomic.LoadUint64(&batchClient.epoch), epoch+1)
// All streams are refreshed.
c.Assert(batchClient.client.Tikv_BatchCommandsClient, Not(Equals), clientSave)
c.Assert(len(batchClient.forwardedClients), Equals, len(forwardedClientsSave))
assert.NotEqual(t, batchClient.client.Tikv_BatchCommandsClient, clientSave)
assert.Equal(t, len(batchClient.forwardedClients), len(forwardedClientsSave))
for host, clientSave := range forwardedClientsSave {
c.Assert(batchClient.forwardedClients[host].Tikv_BatchCommandsClient, Not(Equals), clientSave)
assert.NotEqual(t, batchClient.forwardedClients[host].Tikv_BatchCommandsClient, clientSave)
}
}

View File

@ -40,33 +40,19 @@ import (
"testing"
"time"
. "github.com/pingcap/check"
"github.com/pingcap/errors"
"github.com/pingcap/kvproto/pkg/coprocessor"
"github.com/pingcap/kvproto/pkg/kvrpcpb"
"github.com/pingcap/kvproto/pkg/metapb"
"github.com/pingcap/kvproto/pkg/tikvpb"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tikv/client-go/v2/config"
"github.com/tikv/client-go/v2/tikvrpc"
"google.golang.org/grpc/metadata"
)
func TestT(t *testing.T) {
CustomVerboseFlag = true
TestingT(t)
}
type testClientSuite struct {
}
type testClientSerialSuite struct {
}
var _ = SerialSuites(&testClientSuite{})
var _ = SerialSuites(&testClientFailSuite{})
var _ = SerialSuites(&testClientSerialSuite{})
func (s *testClientSerialSuite) TestConn(c *C) {
func TestConn(t *testing.T) {
defer config.UpdateGlobal(func(conf *config.Config) {
conf.TiKVClient.MaxBatchSize = 0
})()
@ -75,39 +61,39 @@ func (s *testClientSerialSuite) TestConn(c *C) {
addr := "127.0.0.1:6379"
conn1, err := client.getConnArray(addr, true)
c.Assert(err, IsNil)
assert.Nil(t, err)
conn2, err := client.getConnArray(addr, true)
c.Assert(err, IsNil)
c.Assert(conn2.Get(), Not(Equals), conn1.Get())
assert.Nil(t, err)
assert.NotEqual(t, conn2.Get(), conn1.Get())
client.Close()
conn3, err := client.getConnArray(addr, true)
c.Assert(err, NotNil)
c.Assert(conn3, IsNil)
assert.NotNil(t, err)
assert.Nil(t, conn3)
}
func (s *testClientSuite) TestCancelTimeoutRetErr(c *C) {
func TestCancelTimeoutRetErr(t *testing.T) {
req := new(tikvpb.BatchCommandsRequest_Request)
a := newBatchConn(1, 1, nil)
ctx, cancel := context.WithCancel(context.TODO())
cancel()
_, err := sendBatchRequest(ctx, "", "", a, req, 2*time.Second)
c.Assert(errors.Cause(err), Equals, context.Canceled)
assert.Equal(t, errors.Cause(err), context.Canceled)
_, err = sendBatchRequest(context.Background(), "", "", a, req, 0)
c.Assert(errors.Cause(err), Equals, context.DeadlineExceeded)
assert.Equal(t, errors.Cause(err), context.DeadlineExceeded)
}
func (s *testClientSuite) TestSendWhenReconnect(c *C) {
func TestSendWhenReconnect(t *testing.T) {
server, port := startMockTikvService()
c.Assert(port > 0, IsTrue)
require.True(t, port > 0)
rpcClient := NewRPCClient(config.Security{})
addr := fmt.Sprintf("%s:%d", "127.0.0.1", port)
conn, err := rpcClient.getConnArray(addr, true)
c.Assert(err, IsNil)
assert.Nil(t, err)
// Suppose all connections are re-establishing.
for _, client := range conn.batchConn.batchCommandsClients {
@ -116,7 +102,7 @@ func (s *testClientSuite) TestSendWhenReconnect(c *C) {
req := tikvrpc.NewRequest(tikvrpc.CmdEmpty, &tikvpb.BatchCommandsEmptyRequest{})
_, err = rpcClient.SendRequest(context.Background(), addr, req, 100*time.Second)
c.Assert(err.Error() == "no available connections", IsTrue)
assert.True(t, err.Error() == "no available connections")
conn.Close()
server.Stop()
}
@ -137,7 +123,7 @@ func (c *chanClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.
return nil, nil
}
func (s *testClientSuite) TestCollapseResolveLock(c *C) {
func TestCollapseResolveLock(t *testing.T) {
buildResolveLockReq := func(regionID uint64, startTS uint64, commitTS uint64, keys [][]byte) *tikvrpc.Request {
region := &metapb.Region{Id: regionID}
req := tikvrpc.NewRequest(tikvrpc.CmdResolveLock, &kvrpcpb.ResolveLockRequest{
@ -170,10 +156,10 @@ func (s *testClientSuite) TestCollapseResolveLock(c *C) {
time.Sleep(300 * time.Millisecond)
wg.Done()
req := <-reqCh
c.Assert(*req, DeepEquals, *resolveLockReq)
assert.Equal(t, *req, *resolveLockReq)
select {
case <-reqCh:
c.Fatal("fail to collapse ResolveLock")
assert.Fail(t, "fail to collapse ResolveLock")
default:
}
@ -186,7 +172,7 @@ func (s *testClientSuite) TestCollapseResolveLock(c *C) {
wg.Done()
for i := 0; i < 2; i++ {
req := <-reqCh
c.Assert(*req, DeepEquals, *resolveLockLiteReq)
assert.Equal(t, *req, *resolveLockLiteReq)
}
// Don't collapse BatchResolveLock.
@ -200,7 +186,7 @@ func (s *testClientSuite) TestCollapseResolveLock(c *C) {
wg.Done()
for i := 0; i < 2; i++ {
req := <-reqCh
c.Assert(*req, DeepEquals, *batchResolveLockReq)
assert.Equal(t, *req, *batchResolveLockReq)
}
// Mixed
@ -215,14 +201,14 @@ func (s *testClientSuite) TestCollapseResolveLock(c *C) {
}
select {
case <-reqCh:
c.Fatal("unexpected request")
assert.Fail(t, "unexpected request")
default:
}
}
func (s *testClientSerialSuite) TestForwardMetadataByUnaryCall(c *C) {
func TestForwardMetadataByUnaryCall(t *testing.T) {
server, port := startMockTikvService()
c.Assert(port > 0, IsTrue)
require.True(t, port > 0)
defer server.Stop()
addr := fmt.Sprintf("%s:%d", "127.0.0.1", port)
@ -242,7 +228,7 @@ func (s *testClientSerialSuite) TestForwardMetadataByUnaryCall(c *C) {
md, ok := metadata.FromIncomingContext(ctx)
if ok {
vals := md.Get(forwardMetadataKey)
c.Assert(len(vals), Equals, 0)
assert.Equal(t, len(vals), 0)
}
return nil
})
@ -251,15 +237,15 @@ func (s *testClientSerialSuite) TestForwardMetadataByUnaryCall(c *C) {
prewriteReq := tikvrpc.NewRequest(tikvrpc.CmdPrewrite, &kvrpcpb.PrewriteRequest{})
for i := 0; i < 3; i++ {
_, err := rpcClient.SendRequest(context.Background(), addr, prewriteReq, 10*time.Second)
c.Assert(err, IsNil)
assert.Nil(t, err)
}
c.Assert(atomic.LoadUint64(&checkCnt), Equals, uint64(3))
assert.Equal(t, atomic.LoadUint64(&checkCnt), uint64(3))
// CopStream represents unary-stream call.
copStreamReq := tikvrpc.NewRequest(tikvrpc.CmdCopStream, &coprocessor.Request{})
_, err := rpcClient.SendRequest(context.Background(), addr, copStreamReq, 10*time.Second)
c.Assert(err, IsNil)
c.Assert(atomic.LoadUint64(&checkCnt), Equals, uint64(4))
assert.Nil(t, err)
assert.Equal(t, atomic.LoadUint64(&checkCnt), uint64(4))
checkCnt = 0
forwardedHost := "127.0.0.1:6666"
@ -268,29 +254,29 @@ func (s *testClientSerialSuite) TestForwardMetadataByUnaryCall(c *C) {
atomic.AddUint64(&checkCnt, 1)
// gRPC may set some metadata by default, e.g. "context-type".
md, ok := metadata.FromIncomingContext(ctx)
c.Assert(ok, IsTrue)
assert.True(t, ok)
vals := md.Get(forwardMetadataKey)
c.Assert(vals, DeepEquals, []string{forwardedHost})
assert.Equal(t, vals, []string{forwardedHost})
return nil
})
prewriteReq.ForwardedHost = forwardedHost
for i := 0; i < 3; i++ {
_, err = rpcClient.SendRequest(context.Background(), addr, prewriteReq, 10*time.Second)
c.Assert(err, IsNil)
assert.Nil(t, err)
}
// checkCnt should be 3 because we don't use BatchCommands for redirection for now.
c.Assert(atomic.LoadUint64(&checkCnt), Equals, uint64(3))
assert.Equal(t, atomic.LoadUint64(&checkCnt), uint64(3))
copStreamReq.ForwardedHost = forwardedHost
_, err = rpcClient.SendRequest(context.Background(), addr, copStreamReq, 10*time.Second)
c.Assert(err, IsNil)
c.Assert(atomic.LoadUint64(&checkCnt), Equals, uint64(4))
assert.Nil(t, err)
assert.Equal(t, atomic.LoadUint64(&checkCnt), uint64(4))
}
func (s *testClientSerialSuite) TestForwardMetadataByBatchCommands(c *C) {
func TestForwardMetadataByBatchCommands(t *testing.T) {
server, port := startMockTikvService()
c.Assert(port > 0, IsTrue)
require.True(t, port > 0)
defer server.Stop()
addr := fmt.Sprintf("%s:%d", "127.0.0.1", port)
@ -311,12 +297,12 @@ func (s *testClientSerialSuite) TestForwardMetadataByBatchCommands(c *C) {
if forwardedHost == "" {
if ok {
vals := md.Get(forwardMetadataKey)
c.Assert(len(vals), Equals, 0)
assert.Equal(t, len(vals), 0)
}
} else {
c.Assert(ok, IsTrue)
assert.True(t, ok)
vals := md.Get(forwardMetadataKey)
c.Assert(vals, DeepEquals, []string{forwardedHost})
assert.Equal(t, vals, []string{forwardedHost})
}
return nil
@ -330,10 +316,10 @@ func (s *testClientSerialSuite) TestForwardMetadataByBatchCommands(c *C) {
prewriteReq.ForwardedHost = forwardedHost
for i := 0; i < 3; i++ {
_, err := rpcClient.SendRequest(context.Background(), addr, prewriteReq, 10*time.Second)
c.Assert(err, IsNil)
assert.Nil(t, err)
}
// checkCnt should be i because there is a stream for each forwardedHost.
c.Assert(atomic.LoadUint64(&checkCnt), Equals, 1+uint64(i))
assert.Equal(t, atomic.LoadUint64(&checkCnt), 1+uint64(i))
}
checkCnt = 0
@ -342,18 +328,18 @@ func (s *testClientSerialSuite) TestForwardMetadataByBatchCommands(c *C) {
// Check no corresponding metadata if forwardedHost is empty.
setCheckHandler("")
_, err := rpcClient.SendRequest(context.Background(), addr, copStreamReq, 10*time.Second)
c.Assert(err, IsNil)
c.Assert(atomic.LoadUint64(&checkCnt), Equals, uint64(1))
assert.Nil(t, err)
assert.Equal(t, atomic.LoadUint64(&checkCnt), uint64(1))
copStreamReq.ForwardedHost = "127.0.0.1:6666"
// Check the metadata exists.
setCheckHandler(copStreamReq.ForwardedHost)
_, err = rpcClient.SendRequest(context.Background(), addr, copStreamReq, 10*time.Second)
c.Assert(err, IsNil)
c.Assert(atomic.LoadUint64(&checkCnt), Equals, uint64(2))
assert.Nil(t, err)
assert.Equal(t, atomic.LoadUint64(&checkCnt), uint64(2))
}
func (s *testClientSuite) TestBatchCommandsBuilder(c *C) {
func TestBatchCommandsBuilder(t *testing.T) {
builder := newBatchCommandsBuilder(128)
// Test no forwarding requests.
@ -361,21 +347,21 @@ func (s *testClientSuite) TestBatchCommandsBuilder(c *C) {
req := new(tikvpb.BatchCommandsRequest_Request)
for i := 0; i < 10; i++ {
builder.push(&batchCommandsEntry{req: req})
c.Assert(builder.len(), Equals, i+1)
assert.Equal(t, builder.len(), i+1)
}
entryMap := make(map[uint64]*batchCommandsEntry)
batchedReq, forwardingReqs := builder.build(func(id uint64, e *batchCommandsEntry) {
entryMap[id] = e
})
c.Assert(len(batchedReq.GetRequests()), Equals, 10)
c.Assert(len(batchedReq.GetRequestIds()), Equals, 10)
c.Assert(len(entryMap), Equals, 10)
assert.Equal(t, len(batchedReq.GetRequests()), 10)
assert.Equal(t, len(batchedReq.GetRequestIds()), 10)
assert.Equal(t, len(entryMap), 10)
for i, id := range batchedReq.GetRequestIds() {
c.Assert(id, Equals, uint64(i))
c.Assert(entryMap[id].req, Equals, batchedReq.GetRequests()[i])
assert.Equal(t, id, uint64(i))
assert.Equal(t, entryMap[id].req, batchedReq.GetRequests()[i])
}
c.Assert(len(forwardingReqs), Equals, 0)
c.Assert(builder.idAlloc, Equals, uint64(10))
assert.Equal(t, len(forwardingReqs), 0)
assert.Equal(t, builder.idAlloc, uint64(10))
// Test collecting forwarding requests.
builder.reset()
@ -393,19 +379,19 @@ func (s *testClientSuite) TestBatchCommandsBuilder(c *C) {
batchedReq, forwardingReqs = builder.build(func(id uint64, e *batchCommandsEntry) {
entryMap[id] = e
})
c.Assert(len(batchedReq.GetRequests()), Equals, 1)
c.Assert(len(batchedReq.GetRequestIds()), Equals, 1)
c.Assert(len(forwardingReqs), Equals, 3)
assert.Equal(t, len(batchedReq.GetRequests()), 1)
assert.Equal(t, len(batchedReq.GetRequestIds()), 1)
assert.Equal(t, len(forwardingReqs), 3)
for i, host := range forwardedHosts[1:] {
c.Assert(len(forwardingReqs[host].GetRequests()), Equals, i+2)
c.Assert(len(forwardingReqs[host].GetRequestIds()), Equals, i+2)
assert.Equal(t, len(forwardingReqs[host].GetRequests()), i+2)
assert.Equal(t, len(forwardingReqs[host].GetRequestIds()), i+2)
}
c.Assert(builder.idAlloc, Equals, uint64(10+builder.len()))
c.Assert(len(entryMap), Equals, builder.len())
assert.Equal(t, builder.idAlloc, uint64(10+builder.len()))
assert.Equal(t, len(entryMap), builder.len())
for host, forwardingReq := range forwardingReqs {
for i, id := range forwardingReq.GetRequestIds() {
c.Assert(entryMap[id].req, Equals, forwardingReq.GetRequests()[i])
c.Assert(entryMap[id].forwardedHost, Equals, host)
assert.Equal(t, entryMap[id].req, forwardingReq.GetRequests()[i])
assert.Equal(t, entryMap[id].forwardedHost, host)
}
}
@ -425,13 +411,13 @@ func (s *testClientSuite) TestBatchCommandsBuilder(c *C) {
batchedReq, forwardingReqs = builder.build(func(id uint64, e *batchCommandsEntry) {
entryMap[id] = e
})
c.Assert(len(batchedReq.GetRequests()), Equals, 2)
c.Assert(len(batchedReq.GetRequestIds()), Equals, 2)
c.Assert(len(forwardingReqs), Equals, 0)
c.Assert(len(entryMap), Equals, 2)
assert.Equal(t, len(batchedReq.GetRequests()), 2)
assert.Equal(t, len(batchedReq.GetRequestIds()), 2)
assert.Equal(t, len(forwardingReqs), 0)
assert.Equal(t, len(entryMap), 2)
for i, id := range batchedReq.GetRequestIds() {
c.Assert(entryMap[id].req, Equals, batchedReq.GetRequests()[i])
c.Assert(entryMap[id].isCanceled(), IsFalse)
assert.Equal(t, entryMap[id].req, batchedReq.GetRequests()[i])
assert.False(t, entryMap[id].isCanceled())
}
// Test canceling all requests
@ -446,16 +432,16 @@ func (s *testClientSuite) TestBatchCommandsBuilder(c *C) {
builder.cancel(err)
for _, entry := range entries {
_, ok := <-entry.res
c.Assert(ok, IsFalse)
c.Assert(entry.err, Equals, err)
assert.False(t, ok)
assert.Equal(t, entry.err, err)
}
// Test reset
builder.reset()
c.Assert(builder.len(), Equals, 0)
c.Assert(len(builder.entries), Equals, 0)
c.Assert(len(builder.requests), Equals, 0)
c.Assert(len(builder.requestIDs), Equals, 0)
c.Assert(len(builder.forwardingReqs), Equals, 0)
c.Assert(builder.idAlloc, Not(Equals), 0)
assert.Equal(t, builder.len(), 0)
assert.Equal(t, len(builder.entries), 0)
assert.Equal(t, len(builder.requests), 0)
assert.Equal(t, len(builder.requestIDs), 0)
assert.Equal(t, len(builder.forwardingReqs), 0)
assert.NotEqual(t, builder.idAlloc, 0)
}