mirror of https://github.com/linkerd/linkerd2.git
501 lines
16 KiB
Go
501 lines
16 KiB
Go
package public
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"errors"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"net/http"
|
|
"reflect"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/golang/protobuf/proto"
|
|
pb "github.com/linkerd/linkerd2/controller/gen/public"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/status"
|
|
)
|
|
|
|
type stubResponseWriter struct {
|
|
body *bytes.Buffer
|
|
headers http.Header
|
|
}
|
|
|
|
func (w *stubResponseWriter) Header() http.Header {
|
|
return w.headers
|
|
}
|
|
|
|
func (w *stubResponseWriter) Write(p []byte) (int, error) {
|
|
n, err := w.body.Write(p)
|
|
return n, err
|
|
}
|
|
|
|
func (w *stubResponseWriter) WriteHeader(int) {}
|
|
|
|
func (w *stubResponseWriter) Flush() {}
|
|
|
|
type nonStreamingResponseWriter struct {
|
|
}
|
|
|
|
func (w *nonStreamingResponseWriter) Header() http.Header { return nil }
|
|
|
|
func (w *nonStreamingResponseWriter) Write(p []byte) (int, error) { return -1, nil }
|
|
|
|
func (w *nonStreamingResponseWriter) WriteHeader(int) {}
|
|
|
|
func newStubResponseWriter() *stubResponseWriter {
|
|
return &stubResponseWriter{
|
|
headers: make(http.Header),
|
|
body: bytes.NewBufferString(""),
|
|
}
|
|
}
|
|
|
|
func TestHttpRequestToProto(t *testing.T) {
|
|
someURL := "https://www.example.org/something"
|
|
someMethod := http.MethodPost
|
|
|
|
t.Run("Given a valid request, serializes its contents into protobuf object", func(t *testing.T) {
|
|
expectedProtoMessage := pb.Pod{
|
|
Name: "some-name",
|
|
PodIP: "some-name",
|
|
Owner: &pb.Pod_Deployment{Deployment: "some-name"},
|
|
Status: "some-name",
|
|
Added: false,
|
|
ControllerNamespace: "some-name",
|
|
ControlPlane: false,
|
|
}
|
|
payload, err := proto.Marshal(&expectedProtoMessage)
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
|
|
req, err := http.NewRequest(someMethod, someURL, bytes.NewReader(payload))
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
|
|
var actualProtoMessage pb.Pod
|
|
err = httpRequestToProto(req, &actualProtoMessage)
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
|
|
if !proto.Equal(&actualProtoMessage, &expectedProtoMessage) {
|
|
t.Fatalf("Expected request to be [%v], but got [%v]", expectedProtoMessage, actualProtoMessage)
|
|
}
|
|
})
|
|
|
|
t.Run("Given a broken request, returns http error", func(t *testing.T) {
|
|
var actualProtoMessage pb.Pod
|
|
|
|
req, err := http.NewRequest(someMethod, someURL, strings.NewReader("not really protobuf"))
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
|
|
err = httpRequestToProto(req, &actualProtoMessage)
|
|
if err == nil {
|
|
t.Fatalf("Expecting error, got nothing")
|
|
}
|
|
|
|
if httpErr, ok := err.(httpError); ok {
|
|
expectedStatusCode := http.StatusBadRequest
|
|
if httpErr.Code != expectedStatusCode || httpErr.WrappedError == nil {
|
|
t.Fatalf("Expected error status to be [%d] and contain wrapper error, got status [%d] and error [%v]", expectedStatusCode, httpErr.Code, httpErr.WrappedError)
|
|
}
|
|
} else {
|
|
t.Fatalf("Expected error to be httpError, got: %v", err)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestWriteErrorToHttpResponse(t *testing.T) {
|
|
t.Run("Writes generic error correctly to response", func(t *testing.T) {
|
|
expectedErrorStatusCode := defaultHTTPErrorStatusCode
|
|
|
|
responseWriter := newStubResponseWriter()
|
|
genericError := errors.New("expected generic error")
|
|
|
|
writeErrorToHTTPResponse(responseWriter, genericError)
|
|
|
|
assertResponseHasProtobufContentType(t, responseWriter)
|
|
|
|
actualErrorStatusCode := responseWriter.headers.Get(errorHeader)
|
|
if actualErrorStatusCode != http.StatusText(expectedErrorStatusCode) {
|
|
t.Fatalf("Expecting response to have status code [%d], got [%s]", expectedErrorStatusCode, actualErrorStatusCode)
|
|
}
|
|
|
|
payloadRead, err := deserializePayloadFromReader(bufio.NewReader(bytes.NewReader(responseWriter.body.Bytes())))
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
|
|
expectedErrorPayload := pb.ApiError{Error: genericError.Error()}
|
|
var actualErrorPayload pb.ApiError
|
|
err = proto.Unmarshal(payloadRead, &actualErrorPayload)
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
|
|
if !proto.Equal(&actualErrorPayload, &expectedErrorPayload) {
|
|
t.Fatalf("Expecting error to be serialized as [%v], but got [%v]", expectedErrorPayload, actualErrorPayload)
|
|
}
|
|
})
|
|
|
|
t.Run("Writes http specific error correctly to response", func(t *testing.T) {
|
|
expectedErrorStatusCode := http.StatusBadGateway
|
|
responseWriter := newStubResponseWriter()
|
|
httpError := httpError{
|
|
WrappedError: errors.New("expected to be wrapped"),
|
|
Code: http.StatusBadGateway,
|
|
}
|
|
|
|
writeErrorToHTTPResponse(responseWriter, httpError)
|
|
|
|
assertResponseHasProtobufContentType(t, responseWriter)
|
|
|
|
actualErrorStatusCode := responseWriter.headers.Get(errorHeader)
|
|
if actualErrorStatusCode != http.StatusText(expectedErrorStatusCode) {
|
|
t.Fatalf("Expecting response to have status code [%d], got [%s]", expectedErrorStatusCode, actualErrorStatusCode)
|
|
}
|
|
|
|
payloadRead, err := deserializePayloadFromReader(bufio.NewReader(bytes.NewReader(responseWriter.body.Bytes())))
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
|
|
expectedErrorPayload := pb.ApiError{Error: httpError.WrappedError.Error()}
|
|
var actualErrorPayload pb.ApiError
|
|
err = proto.Unmarshal(payloadRead, &actualErrorPayload)
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
|
|
if !proto.Equal(&actualErrorPayload, &expectedErrorPayload) {
|
|
t.Fatalf("Expecting error to be serialized as [%v], but got [%v]", expectedErrorPayload, actualErrorPayload)
|
|
}
|
|
})
|
|
|
|
t.Run("Writes gRPC specific error correctly to response", func(t *testing.T) {
|
|
expectedErrorStatusCode := defaultHTTPErrorStatusCode
|
|
|
|
responseWriter := newStubResponseWriter()
|
|
expectedErrorMessage := "error message"
|
|
grpcError := status.Errorf(codes.AlreadyExists, expectedErrorMessage)
|
|
|
|
writeErrorToHTTPResponse(responseWriter, grpcError)
|
|
|
|
assertResponseHasProtobufContentType(t, responseWriter)
|
|
|
|
actualErrorStatusCode := responseWriter.headers.Get(errorHeader)
|
|
if actualErrorStatusCode != http.StatusText(expectedErrorStatusCode) {
|
|
t.Fatalf("Expecting response to have status code [%d], got [%s]", expectedErrorStatusCode, actualErrorStatusCode)
|
|
}
|
|
|
|
payloadRead, err := deserializePayloadFromReader(bufio.NewReader(bytes.NewReader(responseWriter.body.Bytes())))
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
|
|
expectedErrorPayload := pb.ApiError{Error: expectedErrorMessage}
|
|
var actualErrorPayload pb.ApiError
|
|
err = proto.Unmarshal(payloadRead, &actualErrorPayload)
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
|
|
if !reflect.DeepEqual(actualErrorPayload, expectedErrorPayload) {
|
|
t.Fatalf("Expecting error to be serialized as [%v], but got [%v]", expectedErrorPayload, actualErrorPayload)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestWriteProtoToHttpResponse(t *testing.T) {
|
|
t.Run("Writes valid payload", func(t *testing.T) {
|
|
expectedMessage := pb.VersionInfo{
|
|
ReleaseVersion: "0.0.1",
|
|
BuildDate: "02/21/1983",
|
|
GoVersion: "10.2.45",
|
|
}
|
|
|
|
responseWriter := newStubResponseWriter()
|
|
err := writeProtoToHTTPResponse(responseWriter, &expectedMessage)
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
|
|
assertResponseHasProtobufContentType(t, responseWriter)
|
|
|
|
payloadRead, err := deserializePayloadFromReader(bufio.NewReader(bytes.NewReader(responseWriter.body.Bytes())))
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
|
|
var actualMessage pb.VersionInfo
|
|
err = proto.Unmarshal(payloadRead, &actualMessage)
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
|
|
if !proto.Equal(&actualMessage, &expectedMessage) {
|
|
t.Fatalf("Expected response body to contain message [%v], but got [%v]", expectedMessage, actualMessage)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestDeserializePayloadFromReader(t *testing.T) {
|
|
t.Run("Can read message correctly based on payload size correct payload size to message", func(t *testing.T) {
|
|
expectedMessage := "this is the message"
|
|
|
|
messageWithSize := serializeAsPayload([]byte(expectedMessage))
|
|
messageWithSomeNoise := append(messageWithSize, []byte("this is noise and should not be read")...)
|
|
|
|
actualMessage, err := deserializePayloadFromReader(bufio.NewReader(bytes.NewReader(messageWithSomeNoise)))
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
|
|
if string(actualMessage) != expectedMessage {
|
|
t.Fatalf("Expecting payload to contain message [%s], but it had [%s]", expectedMessage, actualMessage)
|
|
}
|
|
})
|
|
|
|
t.Run("Can multiple messages in the same stream", func(t *testing.T) {
|
|
expectedMessage1 := "Hit the road, Jack and don't you come back\n"
|
|
for i := 0; i < 450; i++ {
|
|
expectedMessage1 = expectedMessage1 + fmt.Sprintf("no more (%d), ", i)
|
|
}
|
|
|
|
expectedMessage2 := "back street back, alright\n"
|
|
for i := 0; i < 450; i++ {
|
|
expectedMessage2 = expectedMessage2 + fmt.Sprintf("tum (%d), ", i)
|
|
}
|
|
|
|
messageWithSize1 := serializeAsPayload([]byte(expectedMessage1))
|
|
messageWithSize2 := serializeAsPayload([]byte(expectedMessage2))
|
|
|
|
streamWithManyMessages := append(messageWithSize1, messageWithSize2...)
|
|
reader := bufio.NewReader(bytes.NewReader(streamWithManyMessages))
|
|
|
|
actualMessage1, err := deserializePayloadFromReader(reader)
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
|
|
actualMessage2, err := deserializePayloadFromReader(reader)
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
|
|
if string(actualMessage1) != expectedMessage1 {
|
|
t.Fatalf("Expecting payload to contain message:\n%s\nbut it had\n%s", expectedMessage1, actualMessage1)
|
|
}
|
|
|
|
if string(actualMessage2) != expectedMessage2 {
|
|
t.Fatalf("Expecting payload to contain message:\n%s\nbut it had\n%s", expectedMessage2, actualMessage2)
|
|
}
|
|
})
|
|
|
|
t.Run("Can write and read marshalled protobuf messages", func(t *testing.T) {
|
|
expectedMessage := &pb.VersionInfo{
|
|
GoVersion: "1.9.1",
|
|
BuildDate: "2017.11.17",
|
|
ReleaseVersion: "1.2.3",
|
|
}
|
|
|
|
expectedReadArray, err := proto.Marshal(expectedMessage)
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
|
|
serialized := serializeAsPayload(expectedReadArray)
|
|
|
|
reader := bufio.NewReader(bytes.NewReader(serialized))
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
|
|
actualReadArray, err := deserializePayloadFromReader(reader)
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
|
|
if !reflect.DeepEqual(actualReadArray, expectedReadArray) {
|
|
n := len(actualReadArray)
|
|
xor := make([]byte, n)
|
|
for i := 0; i < n; i++ {
|
|
xor[i] = actualReadArray[i] ^ expectedReadArray[i]
|
|
}
|
|
t.Fatalf("Expecting read byte array to be equal to written byte array, but they were different. xor: [%v]", xor)
|
|
}
|
|
|
|
actualMessage := &pb.VersionInfo{}
|
|
err = proto.Unmarshal(actualReadArray, actualMessage)
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
|
|
if !proto.Equal(actualMessage, expectedMessage) {
|
|
t.Fatalf("Expecting payload to contain message [%s], but it had [%s]", expectedMessage, actualMessage)
|
|
}
|
|
})
|
|
|
|
t.Run("Can read byte streams larger than Go's default buffer chunk size", func(t *testing.T) {
|
|
goDefaultChunkSize := 4000
|
|
expectedMessage := "Hit the road, Jack and don't you come back\n"
|
|
for i := 0; i < 450; i++ {
|
|
expectedMessage = expectedMessage + fmt.Sprintf("no more (%d), ", i)
|
|
}
|
|
|
|
expectedMessageAsBytes := []byte(expectedMessage)
|
|
lengthOfInputData := len(expectedMessageAsBytes)
|
|
|
|
if lengthOfInputData < goDefaultChunkSize {
|
|
t.Fatalf("Test needs data larger than [%d] bytes, currently only [%d] bytes", goDefaultChunkSize, lengthOfInputData)
|
|
}
|
|
|
|
payload := serializeAsPayload(expectedMessageAsBytes)
|
|
actualMessage, err := deserializePayloadFromReader(bufio.NewReader(bytes.NewReader(payload)))
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
|
|
if string(actualMessage) != expectedMessage {
|
|
t.Fatalf("Expecting payload to contain message:\n%s\n, but it had\n%s", expectedMessageAsBytes, actualMessage)
|
|
}
|
|
})
|
|
|
|
t.Run("Returns error when message has fewer bytes than declared message size", func(t *testing.T) {
|
|
expectedMessage := "this is the message"
|
|
|
|
messageWithSize := serializeAsPayload([]byte(expectedMessage))
|
|
messageMissingOneCharacter := messageWithSize[:len(expectedMessage)-1]
|
|
_, err := deserializePayloadFromReader(bufio.NewReader(bytes.NewReader(messageMissingOneCharacter)))
|
|
if err == nil {
|
|
t.Fatalf("Expecting error, got nothing")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestNewStreamingWriter(t *testing.T) {
|
|
t.Run("Returns a streaming writer if the ResponseWriter is compatible with streaming", func(t *testing.T) {
|
|
rawWriter := newStubResponseWriter()
|
|
flushableWriter, err := newStreamingWriter(rawWriter)
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
|
|
if flushableWriter != rawWriter {
|
|
t.Fatalf("Expected to return same instance of writer")
|
|
}
|
|
|
|
header := "Connection"
|
|
expectedValue := "keep-alive"
|
|
actualValue := rawWriter.Header().Get(header)
|
|
if actualValue != expectedValue {
|
|
t.Fatalf("Expected header [%s] to be set to [%s], but was [%s]", header, expectedValue, actualValue)
|
|
}
|
|
|
|
header = "Transfer-Encoding"
|
|
expectedValue = "chunked"
|
|
actualValue = rawWriter.Header().Get(header)
|
|
if actualValue != expectedValue {
|
|
t.Fatalf("Expected header [%s] to be set to [%s], but was [%s]", header, expectedValue, actualValue)
|
|
}
|
|
})
|
|
|
|
t.Run("Returns an error if writer doesnt support streaming", func(t *testing.T) {
|
|
_, err := newStreamingWriter(&nonStreamingResponseWriter{})
|
|
if err == nil {
|
|
t.Fatalf("Expecting error, got nothing")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestCheckIfResponseHasError(t *testing.T) {
|
|
t.Run("returns nil if response doesn't contain linkerd-error header and is 200", func(t *testing.T) {
|
|
response := &http.Response{
|
|
Header: make(http.Header),
|
|
StatusCode: http.StatusOK,
|
|
}
|
|
err := checkIfResponseHasError(response)
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
})
|
|
|
|
t.Run("returns error in body if response contains linkerd-error header", func(t *testing.T) {
|
|
expectedErrorMessage := "expected error message"
|
|
protoInBytes, err := proto.Marshal(&pb.ApiError{Error: expectedErrorMessage})
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
|
|
message := serializeAsPayload(protoInBytes)
|
|
response := &http.Response{
|
|
Header: make(http.Header),
|
|
Body: ioutil.NopCloser(bytes.NewReader(message)),
|
|
StatusCode: http.StatusInternalServerError,
|
|
}
|
|
response.Header.Set(errorHeader, "error")
|
|
|
|
err = checkIfResponseHasError(response)
|
|
if err == nil {
|
|
t.Fatalf("Expecting error, got nothing")
|
|
}
|
|
|
|
actualErrorMessage := err.Error()
|
|
if actualErrorMessage != expectedErrorMessage {
|
|
t.Fatalf("Expected error message to be [%s], but it was [%s]", expectedErrorMessage, actualErrorMessage)
|
|
}
|
|
})
|
|
|
|
t.Run("returns error if response contains linkerd-error header but body isn't error message", func(t *testing.T) {
|
|
protoInBytes, err := proto.Marshal(&pb.VersionInfo{ReleaseVersion: "0.0.1"})
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
|
|
message := serializeAsPayload(protoInBytes)
|
|
|
|
response := &http.Response{
|
|
Header: make(http.Header),
|
|
Body: ioutil.NopCloser(bytes.NewReader(message)),
|
|
StatusCode: http.StatusInternalServerError,
|
|
}
|
|
response.Header.Set(errorHeader, "error")
|
|
|
|
err = checkIfResponseHasError(response)
|
|
if err == nil {
|
|
t.Fatalf("Expecting error, got nothing")
|
|
}
|
|
})
|
|
|
|
t.Run("returns error if response is not a 200", func(t *testing.T) {
|
|
response := &http.Response{
|
|
StatusCode: http.StatusServiceUnavailable,
|
|
Status: "503 Service Unavailable",
|
|
}
|
|
|
|
err := checkIfResponseHasError(response)
|
|
if err == nil {
|
|
t.Fatalf("Expecting error, got nothing")
|
|
}
|
|
|
|
expectedErrorMessage := "Unexpected API response: 503 Service Unavailable"
|
|
actualErrorMessage := err.Error()
|
|
if actualErrorMessage != expectedErrorMessage {
|
|
t.Fatalf("Expected error message to be [%s], but it was [%s]", expectedErrorMessage, actualErrorMessage)
|
|
}
|
|
})
|
|
}
|
|
|
|
func assertResponseHasProtobufContentType(t *testing.T, responseWriter *stubResponseWriter) {
|
|
actualContentType := responseWriter.headers.Get(contentTypeHeader)
|
|
expectedContentType := protobufContentType
|
|
if actualContentType != expectedContentType {
|
|
t.Fatalf("Expected content-type to be [%s], but got [%s]", expectedContentType, actualContentType)
|
|
}
|
|
}
|