Merge pull request #131704 from karlkfi/karl-watch-subtests

test: Use sub-tests in watch tests

Kubernetes-commit: 4e80b05087cf26188208f1c80d133566be4eae18
This commit is contained in:
Kubernetes Publisher 2025-05-19 12:45:14 -07:00
commit df39bcd7dd
2 changed files with 644 additions and 604 deletions

File diff suppressed because it is too large Load Diff

View File

@ -18,6 +18,7 @@ package endpoints
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@ -268,7 +269,6 @@ func TestWatchClientClose(t *testing.T) {
} }
func TestWatchRead(t *testing.T) { func TestWatchRead(t *testing.T) {
ctx := t.Context()
simpleStorage := &SimpleRESTStorage{} simpleStorage := &SimpleRESTStorage{}
_ = rest.Watcher(simpleStorage) // Give compile error if this doesn't work. _ = rest.Watcher(simpleStorage) // Give compile error if this doesn't work.
handler := handle(map[string]rest.Storage{"simples": simpleStorage}) handler := handle(map[string]rest.Storage{"simples": simpleStorage})
@ -279,7 +279,7 @@ func TestWatchRead(t *testing.T) {
dest.Path = "/" + prefix + "/" + testGroupVersion.Group + "/" + testGroupVersion.Version + "/simples" dest.Path = "/" + prefix + "/" + testGroupVersion.Group + "/" + testGroupVersion.Version + "/simples"
dest.RawQuery = "watch=1" dest.RawQuery = "watch=1"
connectHTTP := func(accept string) (io.ReadCloser, string) { connectHTTP := func(ctx context.Context, accept string) (io.ReadCloser, string) {
client := http.Client{} client := http.Client{}
request, err := http.NewRequestWithContext(ctx, request.MethodGet, dest.String(), nil) request, err := http.NewRequestWithContext(ctx, request.MethodGet, dest.String(), nil)
if err != nil { if err != nil {
@ -299,7 +299,7 @@ func TestWatchRead(t *testing.T) {
return response.Body, response.Header.Get("Content-Type") return response.Body, response.Header.Get("Content-Type")
} }
connectWebSocket := func(accept string) (io.ReadCloser, string) { connectWebSocket := func(ctx context.Context, accept string) (io.ReadCloser, string) {
dest := *dest dest := *dest
dest.Scheme = "ws" // Required by websocket, though the server never sees it. dest.Scheme = "ws" // Required by websocket, though the server never sees it.
config, err := websocket.NewConfig(dest.String(), "http://localhost") config, err := websocket.NewConfig(dest.String(), "http://localhost")
@ -307,14 +307,14 @@ func TestWatchRead(t *testing.T) {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
config.Header.Add("Accept", accept) config.Header.Add("Accept", accept)
ws, err := websocket.DialConfig(config) ws, err := config.DialContext(ctx)
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
return ws, "__default__" return ws, "__default__"
} }
testCases := []struct { cases := []struct {
Accept string Accept string
ExpectedContentType string ExpectedContentType string
MediaType string MediaType string
@ -351,22 +351,24 @@ func TestWatchRead(t *testing.T) {
protocols := []struct { protocols := []struct {
name string name string
selfFraming bool selfFraming bool
fn func(string) (io.ReadCloser, string) fn func(context.Context, string) (io.ReadCloser, string)
}{ }{
{name: "http", fn: connectHTTP}, {name: "http", fn: connectHTTP},
{name: "websocket", selfFraming: true, fn: connectWebSocket}, {name: "websocket", selfFraming: true, fn: connectWebSocket},
} }
for _, protocol := range protocols { for _, protocol := range protocols {
for _, test := range testCases { for textIndex, test := range cases {
func() { testName := fmt.Sprintf("%s-%d", protocol.name, textIndex)
t.Run(testName, func(t *testing.T) {
ctx := t.Context()
info, ok := runtime.SerializerInfoForMediaType(codecs.SupportedMediaTypes(), test.MediaType) info, ok := runtime.SerializerInfoForMediaType(codecs.SupportedMediaTypes(), test.MediaType)
if !ok || info.StreamSerializer == nil { if !ok || info.StreamSerializer == nil {
t.Fatal(info) t.Fatal(info)
} }
streamSerializer := info.StreamSerializer streamSerializer := info.StreamSerializer
r, contentType := protocol.fn(test.Accept) r, contentType := protocol.fn(ctx, test.Accept)
closeBody := apitesting.Close closeBody := apitesting.Close
defer func() { defer func() {
closeBody(t, r) closeBody(t, r)
@ -424,7 +426,7 @@ func TestWatchRead(t *testing.T) {
if err == nil { if err == nil {
t.Errorf("Unexpected non-error") t.Errorf("Unexpected non-error")
} }
}() })
} }
} }
} }
@ -472,7 +474,7 @@ func TestWatchParamParsing(t *testing.T) {
rootPath := "/" + prefix + "/" + testGroupVersion.Group + "/" + testGroupVersion.Version + "/watch/simples" rootPath := "/" + prefix + "/" + testGroupVersion.Group + "/" + testGroupVersion.Version + "/watch/simples"
namespacedPath := "/" + prefix + "/" + testGroupVersion.Group + "/" + testGroupVersion.Version + "/watch/namespaces/other/simpleroots" namespacedPath := "/" + prefix + "/" + testGroupVersion.Group + "/" + testGroupVersion.Version + "/watch/namespaces/other/simpleroots"
table := []struct { cases := []struct {
path string path string
rawQuery string rawQuery string
resourceVersion string resourceVersion string
@ -540,35 +542,37 @@ func TestWatchParamParsing(t *testing.T) {
}, },
} }
for _, item := range table { for testIndex, test := range cases {
ctx := t.Context() testName := fmt.Sprintf("%d", testIndex)
simpleStorage.requestedLabelSelector = labels.Everything() t.Run(testName, func(t *testing.T) {
simpleStorage.requestedFieldSelector = fields.Everything() ctx := t.Context()
simpleStorage.requestedResourceVersion = "5" // Prove this is set in all cases simpleStorage.requestedLabelSelector = labels.Everything()
simpleStorage.requestedResourceNamespace = "" simpleStorage.requestedFieldSelector = fields.Everything()
dest.Path = item.path simpleStorage.requestedResourceVersion = "5" // Prove this is set in all cases
dest.RawQuery = item.rawQuery simpleStorage.requestedResourceNamespace = ""
dest.Path = test.path
dest.RawQuery = test.rawQuery
req, err := http.NewRequestWithContext(ctx, request.MethodGet, dest.String(), nil) req, err := http.NewRequestWithContext(ctx, request.MethodGet, dest.String(), nil)
require.NoError(t, err) require.NoError(t, err)
resp, err := http.DefaultClient.Do(req) resp, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
t.Errorf("%v: unexpected error: %v", item.rawQuery, err) t.Fatalf("%v: unexpected error: %v", test.rawQuery, err)
continue }
} defer apitesting.Close(t, resp.Body)
defer apitesting.Close(t, resp.Body) if e, a := test.namespace, simpleStorage.requestedResourceNamespace; e != a {
if e, a := item.namespace, simpleStorage.requestedResourceNamespace; e != a { t.Errorf("%v: expected %v, got %v", test.rawQuery, e, a)
t.Errorf("%v: expected %v, got %v", item.rawQuery, e, a) }
} if e, a := test.resourceVersion, simpleStorage.requestedResourceVersion; e != a {
if e, a := item.resourceVersion, simpleStorage.requestedResourceVersion; e != a { t.Errorf("%v: expected %v, got %v", test.rawQuery, e, a)
t.Errorf("%v: expected %v, got %v", item.rawQuery, e, a) }
} if e, a := test.labelSelector, simpleStorage.requestedLabelSelector.String(); e != a {
if e, a := item.labelSelector, simpleStorage.requestedLabelSelector.String(); e != a { t.Errorf("%v: expected %v, got %v", test.rawQuery, e, a)
t.Errorf("%v: expected %v, got %v", item.rawQuery, e, a) }
} if e, a := test.fieldSelector, simpleStorage.requestedFieldSelector.String(); e != a {
if e, a := item.fieldSelector, simpleStorage.requestedFieldSelector.String(); e != a { t.Errorf("%v: expected %v, got %v", test.rawQuery, e, a)
t.Errorf("%v: expected %v, got %v", item.rawQuery, e, a) }
} })
} }
} }
@ -584,7 +588,7 @@ func TestWatchProtocolSelection(t *testing.T) {
dest.Path = "/" + prefix + "/" + testGroupVersion.Group + "/" + testGroupVersion.Version + "/watch/simples" dest.Path = "/" + prefix + "/" + testGroupVersion.Group + "/" + testGroupVersion.Version + "/watch/simples"
dest.RawQuery = "" dest.RawQuery = ""
table := []struct { cases := []struct {
isWebsocket bool isWebsocket bool
connHeader string connHeader string
}{ }{
@ -594,30 +598,33 @@ func TestWatchProtocolSelection(t *testing.T) {
{false, "keep-alive"}, {false, "keep-alive"},
} }
for _, item := range table { for _, test := range cases {
ctx := t.Context() testName := fmt.Sprintf("websocket:%v header:%s", test.isWebsocket, test.connHeader)
request, err := http.NewRequestWithContext(ctx, request.MethodGet, dest.String(), nil) t.Run(testName, func(t *testing.T) {
if err != nil { ctx := t.Context()
t.Errorf("unexpected error: %v", err) request, err := http.NewRequestWithContext(ctx, request.MethodGet, dest.String(), nil)
} if err != nil {
request.Header.Set("Connection", item.connHeader) t.Errorf("unexpected error: %v", err)
request.Header.Set("Upgrade", "websocket") }
request.Header.Set("Connection", test.connHeader)
request.Header.Set("Upgrade", "websocket")
response, err := client.Do(request) response, err := client.Do(request)
if err != nil { if err != nil {
t.Errorf("unexpected error: %v", err) t.Errorf("unexpected error: %v", err)
} }
// The requests recognized as websocket requests based on connection // The requests recognized as websocket requests based on connection
// and upgrade headers will not also have the necessary Sec-Websocket-* // and upgrade headers will not also have the necessary Sec-Websocket-*
// headers so it is expected to throw a 400 // headers so it is expected to throw a 400
if item.isWebsocket && response.StatusCode != http.StatusBadRequest { if test.isWebsocket && response.StatusCode != http.StatusBadRequest {
t.Errorf("Unexpected response %#v", response) t.Errorf("Unexpected response %#v", response)
} }
if !item.isWebsocket && response.StatusCode != http.StatusOK { if !test.isWebsocket && response.StatusCode != http.StatusOK {
t.Errorf("Unexpected response %#v", response) t.Errorf("Unexpected response %#v", response)
} }
})
} }
} }