linkerd2/controller/api/public/http_server_test.go

178 lines
5.0 KiB
Go

package public
import (
"context"
"errors"
"net"
"net/http"
"testing"
"github.com/golang/protobuf/proto"
destinationPb "github.com/linkerd/linkerd2-proxy-api/go/destination"
publicPb "github.com/linkerd/linkerd2/controller/gen/public"
)
type mockServer struct {
LastRequestReceived proto.Message
ResponseToReturn proto.Message
DestinationStreamsToReturn []*destinationPb.Update
ErrorToReturn error
}
type mockGrpcServer struct {
mockServer
DestinationStreamsToReturn []*destinationPb.Update
}
func (m *mockGrpcServer) Version(ctx context.Context, req *publicPb.Empty) (*publicPb.VersionInfo, error) {
m.LastRequestReceived = req
return m.ResponseToReturn.(*publicPb.VersionInfo), m.ErrorToReturn
}
func (m *mockGrpcServer) Get(req *destinationPb.GetDestination, destinationServer destinationPb.Destination_GetServer) error {
m.LastRequestReceived = req
if m.ErrorToReturn == nil {
for _, msg := range m.DestinationStreamsToReturn {
destinationServer.Send(msg)
}
}
return m.ErrorToReturn
}
func (m *mockGrpcServer) GetProfile(_ *destinationPb.GetDestination, _ destinationPb.Destination_GetProfileServer) error {
// Not implemented in the Public API. Instead, the proxies should reach the Destination gRPC server directly.
return errors.New("Not implemented")
}
type grpcCallTestCase struct {
expectedRequest proto.Message
expectedResponse proto.Message
functionCall func() (proto.Message, error)
}
func TestServer(t *testing.T) {
t.Run("Delegates all non-streaming RPC messages to the underlying grpc server", func(t *testing.T) {
mockGrpcServer, clientPublic := getServerPublicClient(t)
versionReq := &publicPb.Empty{}
testVersion := grpcCallTestCase{
expectedRequest: versionReq,
expectedResponse: &publicPb.VersionInfo{
BuildDate: "02/21/1983",
},
functionCall: func() (proto.Message, error) { return clientPublic.Version(context.TODO(), versionReq) },
}
assertCallWasForwarded(t, &mockGrpcServer.mockServer, testVersion.expectedRequest, testVersion.expectedResponse, testVersion.functionCall)
})
t.Run("Delegates all streaming Destination RPC messages to the underlying grpc server", func(t *testing.T) {
mockGrpcServer, client := getServerPublicClient(t)
expectedDestinationGetResponses := []*destinationPb.Update{
{
Update: &destinationPb.Update_Add{
Add: BuildAddrSet(
AuthorityEndpoints{
Namespace: "emojivoto",
ServiceID: "emoji-svc",
Pods: []PodDetails{
{
Name: "emoji-6bf9f47bd5-jjcrl",
IP: 16909060,
Port: 8080,
},
},
},
),
},
},
{
Update: &destinationPb.Update_Add{
Add: BuildAddrSet(
AuthorityEndpoints{
Namespace: "emojivoto",
ServiceID: "voting-svc",
Pods: []PodDetails{
{
Name: "voting-7bf9f47bd5-jjdrl",
IP: 84281096,
Port: 8080,
},
},
},
),
},
},
}
mockGrpcServer.DestinationStreamsToReturn = expectedDestinationGetResponses
mockGrpcServer.ErrorToReturn = nil
destinationGetClient, err := client.Get(context.TODO(), &destinationPb.GetDestination{})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
for _, expectedDestinationGetEvent := range expectedDestinationGetResponses {
actualDestinationGetEvent, err := destinationGetClient.Recv()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if !proto.Equal(actualDestinationGetEvent, expectedDestinationGetEvent) {
t.Fatalf("Expecting destination.get event to be [%v], but was [%v]", expectedDestinationGetEvent, actualDestinationGetEvent)
}
}
})
}
func getServerPublicClient(t *testing.T) (*mockGrpcServer, Client) {
mockGrpcServer := &mockGrpcServer{}
listener, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Could not start listener: %v", err)
}
go func() {
handler := &handler{
grpcServer: mockGrpcServer,
}
err := http.Serve(listener, handler)
if err != nil {
t.Fatalf("Could not start server: %v", err)
}
}()
client, err := NewInternalClient("linkerd", listener.Addr().String())
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
return mockGrpcServer, client
}
func assertCallWasForwarded(t *testing.T, mockServer *mockServer, expectedRequest proto.Message, expectedResponse proto.Message, functionCall func() (proto.Message, error)) {
mockServer.ErrorToReturn = nil
mockServer.ResponseToReturn = expectedResponse
actualResponse, err := functionCall()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
actualRequest := mockServer.LastRequestReceived
if !proto.Equal(actualRequest, expectedRequest) {
t.Fatalf("Expecting server call to return [%v], but got [%v]", expectedRequest, actualRequest)
}
if !proto.Equal(actualResponse, expectedResponse) {
t.Fatalf("Expecting server call to return [%v], but got [%v]", expectedResponse, actualResponse)
}
mockServer.ErrorToReturn = errors.New("expected")
_, err = functionCall()
if err == nil {
t.Fatalf("Expecting error, got nothing")
}
}