Merge pull request #131704 from karlkfi/karl-watch-subtests
test: Use sub-tests in watch tests Kubernetes-commit: 4e80b05087cf26188208f1c80d133566be4eae18
This commit is contained in:
commit
df39bcd7dd
File diff suppressed because it is too large
Load Diff
|
@ -18,6 +18,7 @@ package endpoints
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
@ -268,7 +269,6 @@ func TestWatchClientClose(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWatchRead(t *testing.T) {
|
func TestWatchRead(t *testing.T) {
|
||||||
ctx := t.Context()
|
|
||||||
simpleStorage := &SimpleRESTStorage{}
|
simpleStorage := &SimpleRESTStorage{}
|
||||||
_ = rest.Watcher(simpleStorage) // Give compile error if this doesn't work.
|
_ = rest.Watcher(simpleStorage) // Give compile error if this doesn't work.
|
||||||
handler := handle(map[string]rest.Storage{"simples": simpleStorage})
|
handler := handle(map[string]rest.Storage{"simples": simpleStorage})
|
||||||
|
@ -279,7 +279,7 @@ func TestWatchRead(t *testing.T) {
|
||||||
dest.Path = "/" + prefix + "/" + testGroupVersion.Group + "/" + testGroupVersion.Version + "/simples"
|
dest.Path = "/" + prefix + "/" + testGroupVersion.Group + "/" + testGroupVersion.Version + "/simples"
|
||||||
dest.RawQuery = "watch=1"
|
dest.RawQuery = "watch=1"
|
||||||
|
|
||||||
connectHTTP := func(accept string) (io.ReadCloser, string) {
|
connectHTTP := func(ctx context.Context, accept string) (io.ReadCloser, string) {
|
||||||
client := http.Client{}
|
client := http.Client{}
|
||||||
request, err := http.NewRequestWithContext(ctx, request.MethodGet, dest.String(), nil)
|
request, err := http.NewRequestWithContext(ctx, request.MethodGet, dest.String(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -299,7 +299,7 @@ func TestWatchRead(t *testing.T) {
|
||||||
return response.Body, response.Header.Get("Content-Type")
|
return response.Body, response.Header.Get("Content-Type")
|
||||||
}
|
}
|
||||||
|
|
||||||
connectWebSocket := func(accept string) (io.ReadCloser, string) {
|
connectWebSocket := func(ctx context.Context, accept string) (io.ReadCloser, string) {
|
||||||
dest := *dest
|
dest := *dest
|
||||||
dest.Scheme = "ws" // Required by websocket, though the server never sees it.
|
dest.Scheme = "ws" // Required by websocket, though the server never sees it.
|
||||||
config, err := websocket.NewConfig(dest.String(), "http://localhost")
|
config, err := websocket.NewConfig(dest.String(), "http://localhost")
|
||||||
|
@ -307,14 +307,14 @@ func TestWatchRead(t *testing.T) {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
config.Header.Add("Accept", accept)
|
config.Header.Add("Accept", accept)
|
||||||
ws, err := websocket.DialConfig(config)
|
ws, err := config.DialContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
return ws, "__default__"
|
return ws, "__default__"
|
||||||
}
|
}
|
||||||
|
|
||||||
testCases := []struct {
|
cases := []struct {
|
||||||
Accept string
|
Accept string
|
||||||
ExpectedContentType string
|
ExpectedContentType string
|
||||||
MediaType string
|
MediaType string
|
||||||
|
@ -351,22 +351,24 @@ func TestWatchRead(t *testing.T) {
|
||||||
protocols := []struct {
|
protocols := []struct {
|
||||||
name string
|
name string
|
||||||
selfFraming bool
|
selfFraming bool
|
||||||
fn func(string) (io.ReadCloser, string)
|
fn func(context.Context, string) (io.ReadCloser, string)
|
||||||
}{
|
}{
|
||||||
{name: "http", fn: connectHTTP},
|
{name: "http", fn: connectHTTP},
|
||||||
{name: "websocket", selfFraming: true, fn: connectWebSocket},
|
{name: "websocket", selfFraming: true, fn: connectWebSocket},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, protocol := range protocols {
|
for _, protocol := range protocols {
|
||||||
for _, test := range testCases {
|
for textIndex, test := range cases {
|
||||||
func() {
|
testName := fmt.Sprintf("%s-%d", protocol.name, textIndex)
|
||||||
|
t.Run(testName, func(t *testing.T) {
|
||||||
|
ctx := t.Context()
|
||||||
info, ok := runtime.SerializerInfoForMediaType(codecs.SupportedMediaTypes(), test.MediaType)
|
info, ok := runtime.SerializerInfoForMediaType(codecs.SupportedMediaTypes(), test.MediaType)
|
||||||
if !ok || info.StreamSerializer == nil {
|
if !ok || info.StreamSerializer == nil {
|
||||||
t.Fatal(info)
|
t.Fatal(info)
|
||||||
}
|
}
|
||||||
streamSerializer := info.StreamSerializer
|
streamSerializer := info.StreamSerializer
|
||||||
|
|
||||||
r, contentType := protocol.fn(test.Accept)
|
r, contentType := protocol.fn(ctx, test.Accept)
|
||||||
closeBody := apitesting.Close
|
closeBody := apitesting.Close
|
||||||
defer func() {
|
defer func() {
|
||||||
closeBody(t, r)
|
closeBody(t, r)
|
||||||
|
@ -424,7 +426,7 @@ func TestWatchRead(t *testing.T) {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("Unexpected non-error")
|
t.Errorf("Unexpected non-error")
|
||||||
}
|
}
|
||||||
}()
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -472,7 +474,7 @@ func TestWatchParamParsing(t *testing.T) {
|
||||||
rootPath := "/" + prefix + "/" + testGroupVersion.Group + "/" + testGroupVersion.Version + "/watch/simples"
|
rootPath := "/" + prefix + "/" + testGroupVersion.Group + "/" + testGroupVersion.Version + "/watch/simples"
|
||||||
namespacedPath := "/" + prefix + "/" + testGroupVersion.Group + "/" + testGroupVersion.Version + "/watch/namespaces/other/simpleroots"
|
namespacedPath := "/" + prefix + "/" + testGroupVersion.Group + "/" + testGroupVersion.Version + "/watch/namespaces/other/simpleroots"
|
||||||
|
|
||||||
table := []struct {
|
cases := []struct {
|
||||||
path string
|
path string
|
||||||
rawQuery string
|
rawQuery string
|
||||||
resourceVersion string
|
resourceVersion string
|
||||||
|
@ -540,35 +542,37 @@ func TestWatchParamParsing(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, item := range table {
|
for testIndex, test := range cases {
|
||||||
ctx := t.Context()
|
testName := fmt.Sprintf("%d", testIndex)
|
||||||
simpleStorage.requestedLabelSelector = labels.Everything()
|
t.Run(testName, func(t *testing.T) {
|
||||||
simpleStorage.requestedFieldSelector = fields.Everything()
|
ctx := t.Context()
|
||||||
simpleStorage.requestedResourceVersion = "5" // Prove this is set in all cases
|
simpleStorage.requestedLabelSelector = labels.Everything()
|
||||||
simpleStorage.requestedResourceNamespace = ""
|
simpleStorage.requestedFieldSelector = fields.Everything()
|
||||||
dest.Path = item.path
|
simpleStorage.requestedResourceVersion = "5" // Prove this is set in all cases
|
||||||
dest.RawQuery = item.rawQuery
|
simpleStorage.requestedResourceNamespace = ""
|
||||||
|
dest.Path = test.path
|
||||||
|
dest.RawQuery = test.rawQuery
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, request.MethodGet, dest.String(), nil)
|
req, err := http.NewRequestWithContext(ctx, request.MethodGet, dest.String(), nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
resp, err := http.DefaultClient.Do(req)
|
resp, err := http.DefaultClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("%v: unexpected error: %v", item.rawQuery, err)
|
t.Fatalf("%v: unexpected error: %v", test.rawQuery, err)
|
||||||
continue
|
}
|
||||||
}
|
defer apitesting.Close(t, resp.Body)
|
||||||
defer apitesting.Close(t, resp.Body)
|
if e, a := test.namespace, simpleStorage.requestedResourceNamespace; e != a {
|
||||||
if e, a := item.namespace, simpleStorage.requestedResourceNamespace; e != a {
|
t.Errorf("%v: expected %v, got %v", test.rawQuery, e, a)
|
||||||
t.Errorf("%v: expected %v, got %v", item.rawQuery, e, a)
|
}
|
||||||
}
|
if e, a := test.resourceVersion, simpleStorage.requestedResourceVersion; e != a {
|
||||||
if e, a := item.resourceVersion, simpleStorage.requestedResourceVersion; e != a {
|
t.Errorf("%v: expected %v, got %v", test.rawQuery, e, a)
|
||||||
t.Errorf("%v: expected %v, got %v", item.rawQuery, e, a)
|
}
|
||||||
}
|
if e, a := test.labelSelector, simpleStorage.requestedLabelSelector.String(); e != a {
|
||||||
if e, a := item.labelSelector, simpleStorage.requestedLabelSelector.String(); e != a {
|
t.Errorf("%v: expected %v, got %v", test.rawQuery, e, a)
|
||||||
t.Errorf("%v: expected %v, got %v", item.rawQuery, e, a)
|
}
|
||||||
}
|
if e, a := test.fieldSelector, simpleStorage.requestedFieldSelector.String(); e != a {
|
||||||
if e, a := item.fieldSelector, simpleStorage.requestedFieldSelector.String(); e != a {
|
t.Errorf("%v: expected %v, got %v", test.rawQuery, e, a)
|
||||||
t.Errorf("%v: expected %v, got %v", item.rawQuery, e, a)
|
}
|
||||||
}
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -584,7 +588,7 @@ func TestWatchProtocolSelection(t *testing.T) {
|
||||||
dest.Path = "/" + prefix + "/" + testGroupVersion.Group + "/" + testGroupVersion.Version + "/watch/simples"
|
dest.Path = "/" + prefix + "/" + testGroupVersion.Group + "/" + testGroupVersion.Version + "/watch/simples"
|
||||||
dest.RawQuery = ""
|
dest.RawQuery = ""
|
||||||
|
|
||||||
table := []struct {
|
cases := []struct {
|
||||||
isWebsocket bool
|
isWebsocket bool
|
||||||
connHeader string
|
connHeader string
|
||||||
}{
|
}{
|
||||||
|
@ -594,30 +598,33 @@ func TestWatchProtocolSelection(t *testing.T) {
|
||||||
{false, "keep-alive"},
|
{false, "keep-alive"},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, item := range table {
|
for _, test := range cases {
|
||||||
ctx := t.Context()
|
testName := fmt.Sprintf("websocket:%v header:%s", test.isWebsocket, test.connHeader)
|
||||||
request, err := http.NewRequestWithContext(ctx, request.MethodGet, dest.String(), nil)
|
t.Run(testName, func(t *testing.T) {
|
||||||
if err != nil {
|
ctx := t.Context()
|
||||||
t.Errorf("unexpected error: %v", err)
|
request, err := http.NewRequestWithContext(ctx, request.MethodGet, dest.String(), nil)
|
||||||
}
|
if err != nil {
|
||||||
request.Header.Set("Connection", item.connHeader)
|
t.Errorf("unexpected error: %v", err)
|
||||||
request.Header.Set("Upgrade", "websocket")
|
}
|
||||||
|
request.Header.Set("Connection", test.connHeader)
|
||||||
|
request.Header.Set("Upgrade", "websocket")
|
||||||
|
|
||||||
response, err := client.Do(request)
|
response, err := client.Do(request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("unexpected error: %v", err)
|
t.Errorf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// The requests recognized as websocket requests based on connection
|
// The requests recognized as websocket requests based on connection
|
||||||
// and upgrade headers will not also have the necessary Sec-Websocket-*
|
// and upgrade headers will not also have the necessary Sec-Websocket-*
|
||||||
// headers so it is expected to throw a 400
|
// headers so it is expected to throw a 400
|
||||||
if item.isWebsocket && response.StatusCode != http.StatusBadRequest {
|
if test.isWebsocket && response.StatusCode != http.StatusBadRequest {
|
||||||
t.Errorf("Unexpected response %#v", response)
|
t.Errorf("Unexpected response %#v", response)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !item.isWebsocket && response.StatusCode != http.StatusOK {
|
if !test.isWebsocket && response.StatusCode != http.StatusOK {
|
||||||
t.Errorf("Unexpected response %#v", response)
|
t.Errorf("Unexpected response %#v", response)
|
||||||
}
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue