diff --git a/pkg/endpoints/handlers/responsewriters/writers_test.go b/pkg/endpoints/handlers/responsewriters/writers_test.go index 874dc1980..6ec17847e 100644 --- a/pkg/endpoints/handlers/responsewriters/writers_test.go +++ b/pkg/endpoints/handlers/responsewriters/writers_test.go @@ -19,6 +19,7 @@ package responsewriters import ( "bytes" "compress/gzip" + "context" "encoding/hex" "encoding/json" "errors" @@ -32,6 +33,7 @@ import ( "os" "reflect" "strconv" + "strings" "testing" "time" @@ -371,6 +373,124 @@ func TestSerializeObject(t *testing.T) { } } +func TestDeferredResponseWriter_Write(t *testing.T) { + smallChunk := bytes.Repeat([]byte("b"), defaultGzipThresholdBytes-1) + largeChunk := bytes.Repeat([]byte("b"), defaultGzipThresholdBytes+1) + + tests := []struct { + name string + chunks [][]byte + expectGzip bool + }{ + { + name: "one small chunk write", + chunks: [][]byte{smallChunk}, + expectGzip: false, + }, + { + name: "two small chunk writes", + chunks: [][]byte{smallChunk, smallChunk}, + expectGzip: false, + }, + { + name: "one large chunk writes", + chunks: [][]byte{largeChunk}, + expectGzip: true, + }, + { + name: "two large chunk writes", + chunks: [][]byte{largeChunk, largeChunk}, + expectGzip: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockResponseWriter := httptest.NewRecorder() + + drw := &deferredResponseWriter{ + mediaType: "text/plain", + statusCode: 200, + contentEncoding: "gzip", + hw: mockResponseWriter, + ctx: context.Background(), + } + + fullPayload := []byte{} + + for _, chunk := range tt.chunks { + n, err := drw.Write(chunk) + + if err != nil { + t.Fatalf("unexpected error while writing chunk: %v", err) + } + if n != len(chunk) { + t.Errorf("write is not complete, expected: %d bytes, written: %d bytes", len(chunk), n) + } + + fullPayload = append(fullPayload, chunk...) + } + + err := drw.Close() + if err != nil { + t.Fatalf("unexpected error when closing deferredResponseWriter: %v", err) + } + + res := mockResponseWriter.Result() + + if res.StatusCode != http.StatusOK { + t.Fatalf("status code is not writtend properly, expected: 200, got: %d", res.StatusCode) + } + contentEncoding := res.Header.Get("Content-Encoding") + varyHeader := res.Header.Get("Vary") + + resBytes, err := io.ReadAll(res.Body) + if err != nil { + t.Fatalf("unexpected error occurred while reading response body: %v", err) + } + + if tt.expectGzip { + if contentEncoding != "gzip" { + t.Fatalf("content-encoding is not set properly, expected: gzip, got: %s", contentEncoding) + } + + if !strings.Contains(varyHeader, "Accept-Encoding") { + t.Errorf("vary header doesn't have Accept-Encoding") + } + + gr, err := gzip.NewReader(bytes.NewReader(resBytes)) + if err != nil { + t.Fatalf("failed to create gzip reader: %v", err) + } + + decompressed, err := io.ReadAll(gr) + if err != nil { + t.Fatalf("failed to decompress: %v", err) + } + + if !bytes.Equal(fullPayload, decompressed) { + t.Errorf("payload mismatch, expected: %s, got: %s", fullPayload, decompressed) + } + + } else { + if contentEncoding != "" { + t.Errorf("content-encoding is set unexpectedly") + } + + if strings.Contains(varyHeader, "Accept-Encoding") { + t.Errorf("accept encoding is set unexpectedly") + } + + if !bytes.Equal(fullPayload, resBytes) { + t.Errorf("payload mismatch, expected: %s, got: %s", fullPayload, resBytes) + } + + } + + }) + } +} + func randTime(t *time.Time, r *rand.Rand) { *t = time.Unix(r.Int63n(1000*365*24*60*60), r.Int63()) }