diff --git a/webhook/webhook.go b/webhook/webhook.go index e29ea9611..ef7c238da 100644 --- a/webhook/webhook.go +++ b/webhook/webhook.go @@ -253,6 +253,7 @@ func (wh *Webhook) ServeHTTP(w http.ResponseWriter, r *http.Request) { default: w.WriteHeader(http.StatusOK) } + return } // Verify the content type is accurate. diff --git a/webhook/webhook_test.go b/webhook/webhook_test.go index bcdef62d9..19ce13a10 100644 --- a/webhook/webhook_test.go +++ b/webhook/webhook_test.go @@ -20,6 +20,8 @@ import ( "context" "fmt" "net" + "net/http" + "net/http/httptest" "testing" "time" @@ -96,3 +98,51 @@ func TestRegistrationStopChanFire(t *testing.T) { t.Errorf("Unexpected success to dial to port %d", opts.Port) } } + +func TestWebhookKubeletProbe(t *testing.T) { + opts := newDefaultOptions() + ctx, webhook, cancel := newNonRunningTestWebhook(t, opts) + defer cancel() + + recorder := bombRecorder{ResponseRecorder: httptest.NewRecorder()} + probeReq := httptest.NewRequest("GET", "/", nil) + probeReq.Header.Set("User-Agent", "kube-probe/1.16") + + webhook.ServeHTTP(&recorder, probeReq) + + if got, want := recorder.Code, http.StatusOK; got != want { + t.Fatalf("Probe got HTTP status %d - expected %d", got, want) + } + + if got, want := recorder.writeCount, 1; got != want { + t.Errorf("HTTP status was written %d times - expected only one write", got) + } + + // Stop the webhook - which means probes should fail + // + // The steps below aren't obvious and requires you to + // know the implementation details + cancel() + webhook.Run(ctx.Done()) + + recorder = bombRecorder{ResponseRecorder: httptest.NewRecorder()} + webhook.ServeHTTP(&recorder, probeReq) + + if got, want := recorder.Code, http.StatusInternalServerError; got != want { + t.Fatalf("Probe got HTTP status %d - expected %d", got, want) + } + + if got, want := recorder.writeCount, 1; got != want { + t.Errorf("HTTP status was written %d times - expected only one write", got) + } +} + +type bombRecorder struct { + *httptest.ResponseRecorder + writeCount int +} + +func (rw *bombRecorder) WriteHeader(code int) { + rw.writeCount += 1 + rw.ResponseRecorder.WriteHeader(code) +}