mirror of https://github.com/grpc/grpc-go.git
				
				
				
			
		
			
				
	
	
		
			205 lines
		
	
	
		
			4.7 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			205 lines
		
	
	
		
			4.7 KiB
		
	
	
	
		
			Go
		
	
	
	
| /*
 | |
|  *
 | |
|  * Copyright 2023 gRPC authors.
 | |
|  *
 | |
|  * Licensed under the Apache License, Version 2.0 (the "License");
 | |
|  * you may not use this file except in compliance with the License.
 | |
|  * You may obtain a copy of the License at
 | |
|  *
 | |
|  *     http://www.apache.org/licenses/LICENSE-2.0
 | |
|  *
 | |
|  * Unless required by applicable law or agreed to in writing, software
 | |
|  * distributed under the License is distributed on an "AS IS" BASIS,
 | |
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
|  * See the License for the specific language governing permissions and
 | |
|  * limitations under the License.
 | |
|  *
 | |
|  */
 | |
| 
 | |
| package experimental_test
 | |
| 
 | |
| import (
 | |
| 	"bytes"
 | |
| 	"context"
 | |
| 	"io"
 | |
| 	"testing"
 | |
| 	"time"
 | |
| 
 | |
| 	"google.golang.org/grpc"
 | |
| 	"google.golang.org/grpc/encoding/gzip"
 | |
| 	"google.golang.org/grpc/experimental"
 | |
| 	"google.golang.org/grpc/internal/grpctest"
 | |
| 	"google.golang.org/grpc/internal/stubserver"
 | |
| 
 | |
| 	testgrpc "google.golang.org/grpc/interop/grpc_testing"
 | |
| )
 | |
| 
 | |
| type s struct {
 | |
| 	grpctest.Tester
 | |
| }
 | |
| 
 | |
| func Test(t *testing.T) {
 | |
| 	grpctest.RunSubTests(t, s{})
 | |
| }
 | |
| 
 | |
| const defaultTestTimeout = 10 * time.Second
 | |
| 
 | |
| func (s) TestRecvBufferPoolStream(t *testing.T) {
 | |
| 	tcs := []struct {
 | |
| 		name     string
 | |
| 		callOpts []grpc.CallOption
 | |
| 	}{
 | |
| 		{
 | |
| 			name: "default",
 | |
| 		},
 | |
| 		{
 | |
| 			name: "useCompressor",
 | |
| 			callOpts: []grpc.CallOption{
 | |
| 				grpc.UseCompressor(gzip.Name),
 | |
| 			},
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	for _, tc := range tcs {
 | |
| 		t.Run(tc.name, func(t *testing.T) {
 | |
| 			const reqCount = 10
 | |
| 
 | |
| 			ss := &stubserver.StubServer{
 | |
| 				FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error {
 | |
| 					for i := 0; i < reqCount; i++ {
 | |
| 						preparedMsg := &grpc.PreparedMsg{}
 | |
| 						if err := preparedMsg.Encode(stream, &testgrpc.StreamingOutputCallResponse{
 | |
| 							Payload: &testgrpc.Payload{
 | |
| 								Body: []byte{'0' + uint8(i)},
 | |
| 							},
 | |
| 						}); err != nil {
 | |
| 							return err
 | |
| 						}
 | |
| 						stream.SendMsg(preparedMsg)
 | |
| 					}
 | |
| 					return nil
 | |
| 				},
 | |
| 			}
 | |
| 
 | |
| 			pool := &checkBufferPool{}
 | |
| 			sopts := []grpc.ServerOption{experimental.RecvBufferPool(pool)}
 | |
| 			dopts := []grpc.DialOption{experimental.WithRecvBufferPool(pool)}
 | |
| 			if err := ss.Start(sopts, dopts...); err != nil {
 | |
| 				t.Fatalf("Error starting endpoint server: %v", err)
 | |
| 			}
 | |
| 			defer ss.Stop()
 | |
| 
 | |
| 			ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
 | |
| 			defer cancel()
 | |
| 
 | |
| 			stream, err := ss.Client.FullDuplexCall(ctx, tc.callOpts...)
 | |
| 			if err != nil {
 | |
| 				t.Fatalf("ss.Client.FullDuplexCall failed: %v", err)
 | |
| 			}
 | |
| 
 | |
| 			var ngot int
 | |
| 			var buf bytes.Buffer
 | |
| 			for {
 | |
| 				reply, err := stream.Recv()
 | |
| 				if err == io.EOF {
 | |
| 					break
 | |
| 				}
 | |
| 				if err != nil {
 | |
| 					t.Fatal(err)
 | |
| 				}
 | |
| 				ngot++
 | |
| 				if buf.Len() > 0 {
 | |
| 					buf.WriteByte(',')
 | |
| 				}
 | |
| 				buf.Write(reply.GetPayload().GetBody())
 | |
| 			}
 | |
| 			if want := 10; ngot != want {
 | |
| 				t.Fatalf("Got %d replies, want %d", ngot, want)
 | |
| 			}
 | |
| 			if got, want := buf.String(), "0,1,2,3,4,5,6,7,8,9"; got != want {
 | |
| 				t.Fatalf("Got replies %q; want %q", got, want)
 | |
| 			}
 | |
| 
 | |
| 			if len(pool.puts) != reqCount {
 | |
| 				t.Fatalf("Expected 10 buffers to be returned to the pool, got %d", len(pool.puts))
 | |
| 			}
 | |
| 		})
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (s) TestRecvBufferPoolUnary(t *testing.T) {
 | |
| 	tcs := []struct {
 | |
| 		name     string
 | |
| 		callOpts []grpc.CallOption
 | |
| 	}{
 | |
| 		{
 | |
| 			name: "default",
 | |
| 		},
 | |
| 		{
 | |
| 			name: "useCompressor",
 | |
| 			callOpts: []grpc.CallOption{
 | |
| 				grpc.UseCompressor(gzip.Name),
 | |
| 			},
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	for _, tc := range tcs {
 | |
| 		t.Run(tc.name, func(t *testing.T) {
 | |
| 			const largeSize = 1024
 | |
| 
 | |
| 			ss := &stubserver.StubServer{
 | |
| 				UnaryCallF: func(ctx context.Context, in *testgrpc.SimpleRequest) (*testgrpc.SimpleResponse, error) {
 | |
| 					return &testgrpc.SimpleResponse{
 | |
| 						Payload: &testgrpc.Payload{
 | |
| 							Body: make([]byte, largeSize),
 | |
| 						},
 | |
| 					}, nil
 | |
| 				},
 | |
| 			}
 | |
| 
 | |
| 			pool := &checkBufferPool{}
 | |
| 			sopts := []grpc.ServerOption{experimental.RecvBufferPool(pool)}
 | |
| 			dopts := []grpc.DialOption{experimental.WithRecvBufferPool(pool)}
 | |
| 			if err := ss.Start(sopts, dopts...); err != nil {
 | |
| 				t.Fatalf("Error starting endpoint server: %v", err)
 | |
| 			}
 | |
| 			defer ss.Stop()
 | |
| 
 | |
| 			ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
 | |
| 			defer cancel()
 | |
| 
 | |
| 			const reqCount = 10
 | |
| 			for i := 0; i < reqCount; i++ {
 | |
| 				if _, err := ss.Client.UnaryCall(
 | |
| 					ctx,
 | |
| 					&testgrpc.SimpleRequest{
 | |
| 						Payload: &testgrpc.Payload{
 | |
| 							Body: make([]byte, largeSize),
 | |
| 						},
 | |
| 					},
 | |
| 					tc.callOpts...,
 | |
| 				); err != nil {
 | |
| 					t.Fatalf("ss.Client.UnaryCall failed: %v", err)
 | |
| 				}
 | |
| 			}
 | |
| 
 | |
| 			const bufferCount = reqCount * 2 // req + resp
 | |
| 			if len(pool.puts) != bufferCount {
 | |
| 				t.Fatalf("Expected %d buffers to be returned to the pool, got %d", bufferCount, len(pool.puts))
 | |
| 			}
 | |
| 		})
 | |
| 	}
 | |
| }
 | |
| 
 | |
| type checkBufferPool struct {
 | |
| 	puts [][]byte
 | |
| }
 | |
| 
 | |
| func (p *checkBufferPool) Get(size int) []byte {
 | |
| 	return make([]byte, size)
 | |
| }
 | |
| 
 | |
| func (p *checkBufferPool) Put(bs *[]byte) {
 | |
| 	p.puts = append(p.puts, *bs)
 | |
| }
 |