Merge pull request #130190 from nkeert/test-validate-deferredResponseWriter-for-multiple-writes
Add a test to validate deferredResponseWriter on multiple writes Kubernetes-commit: 7fc8a86381d874e3bb47b6d343c20f93ace42981
This commit is contained in:
commit
b6fda29776
|
@ -19,6 +19,7 @@ package responsewriters
|
|||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
@ -32,6 +33,7 @@ import (
|
|||
"os"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -371,6 +373,124 @@ func TestSerializeObject(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestDeferredResponseWriter_Write(t *testing.T) {
|
||||
smallChunk := bytes.Repeat([]byte("b"), defaultGzipThresholdBytes-1)
|
||||
largeChunk := bytes.Repeat([]byte("b"), defaultGzipThresholdBytes+1)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
chunks [][]byte
|
||||
expectGzip bool
|
||||
}{
|
||||
{
|
||||
name: "one small chunk write",
|
||||
chunks: [][]byte{smallChunk},
|
||||
expectGzip: false,
|
||||
},
|
||||
{
|
||||
name: "two small chunk writes",
|
||||
chunks: [][]byte{smallChunk, smallChunk},
|
||||
expectGzip: false,
|
||||
},
|
||||
{
|
||||
name: "one large chunk writes",
|
||||
chunks: [][]byte{largeChunk},
|
||||
expectGzip: true,
|
||||
},
|
||||
{
|
||||
name: "two large chunk writes",
|
||||
chunks: [][]byte{largeChunk, largeChunk},
|
||||
expectGzip: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockResponseWriter := httptest.NewRecorder()
|
||||
|
||||
drw := &deferredResponseWriter{
|
||||
mediaType: "text/plain",
|
||||
statusCode: 200,
|
||||
contentEncoding: "gzip",
|
||||
hw: mockResponseWriter,
|
||||
ctx: context.Background(),
|
||||
}
|
||||
|
||||
fullPayload := []byte{}
|
||||
|
||||
for _, chunk := range tt.chunks {
|
||||
n, err := drw.Write(chunk)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error while writing chunk: %v", err)
|
||||
}
|
||||
if n != len(chunk) {
|
||||
t.Errorf("write is not complete, expected: %d bytes, written: %d bytes", len(chunk), n)
|
||||
}
|
||||
|
||||
fullPayload = append(fullPayload, chunk...)
|
||||
}
|
||||
|
||||
err := drw.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error when closing deferredResponseWriter: %v", err)
|
||||
}
|
||||
|
||||
res := mockResponseWriter.Result()
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
t.Fatalf("status code is not writtend properly, expected: 200, got: %d", res.StatusCode)
|
||||
}
|
||||
contentEncoding := res.Header.Get("Content-Encoding")
|
||||
varyHeader := res.Header.Get("Vary")
|
||||
|
||||
resBytes, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error occurred while reading response body: %v", err)
|
||||
}
|
||||
|
||||
if tt.expectGzip {
|
||||
if contentEncoding != "gzip" {
|
||||
t.Fatalf("content-encoding is not set properly, expected: gzip, got: %s", contentEncoding)
|
||||
}
|
||||
|
||||
if !strings.Contains(varyHeader, "Accept-Encoding") {
|
||||
t.Errorf("vary header doesn't have Accept-Encoding")
|
||||
}
|
||||
|
||||
gr, err := gzip.NewReader(bytes.NewReader(resBytes))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create gzip reader: %v", err)
|
||||
}
|
||||
|
||||
decompressed, err := io.ReadAll(gr)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to decompress: %v", err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(fullPayload, decompressed) {
|
||||
t.Errorf("payload mismatch, expected: %s, got: %s", fullPayload, decompressed)
|
||||
}
|
||||
|
||||
} else {
|
||||
if contentEncoding != "" {
|
||||
t.Errorf("content-encoding is set unexpectedly")
|
||||
}
|
||||
|
||||
if strings.Contains(varyHeader, "Accept-Encoding") {
|
||||
t.Errorf("accept encoding is set unexpectedly")
|
||||
}
|
||||
|
||||
if !bytes.Equal(fullPayload, resBytes) {
|
||||
t.Errorf("payload mismatch, expected: %s, got: %s", fullPayload, resBytes)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func randTime(t *time.Time, r *rand.Rand) {
|
||||
*t = time.Unix(r.Int63n(1000*365*24*60*60), r.Int63())
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue