From c329ccff7cebfe52b968f4b855e02cbddf6df8a5 Mon Sep 17 00:00:00 2001 From: Omer Aplatony Date: Sat, 5 Oct 2024 22:07:37 +0300 Subject: [PATCH] Update waitForDelete to use PollUntilContextTimeout Signed-off-by: Omer Aplatony Kubernetes-commit: bba055067e6283f94ee05cedeb33dacafe4a1094 --- pkg/drain/drain.go | 15 +++++---------- pkg/drain/drain_test.go | 5 ++--- 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/pkg/drain/drain.go b/pkg/drain/drain.go index 3633f3c4..348fd703 100644 --- a/pkg/drain/drain.go +++ b/pkg/drain/drain.go @@ -413,7 +413,10 @@ func (d *Helper) deletePods(pods []corev1.Pod, getPodFn func(namespace, name str func waitForDelete(params waitForDeleteParams) ([]corev1.Pod, error) { pods := params.pods - err := wait.PollImmediate(params.interval, params.timeout, func() (bool, error) { + if params.ctx == nil { + params.ctx = context.Background() + } + err := wait.PollUntilContextTimeout(params.ctx, params.interval, params.timeout, true, func(ctx context.Context) (done bool, err error) { pendingPods := []corev1.Pod{} for i, pod := range pods { p, err := params.getPodFn(pod.Namespace, pod.Name) @@ -440,15 +443,7 @@ func waitForDelete(params waitForDeleteParams) ([]corev1.Pod, error) { } } pods = pendingPods - if len(pendingPods) > 0 { - select { - case <-params.ctx.Done(): - return false, fmt.Errorf("global timeout reached: %v", params.globalTimeout) - default: - return false, nil - } - } - return true, nil + return len(pods) == 0, nil }) return pods, err } diff --git a/pkg/drain/drain_test.go b/pkg/drain/drain_test.go index 6ca4fe6d..6acf2489 100644 --- a/pkg/drain/drain_test.go +++ b/pkg/drain/drain_test.go @@ -36,7 +36,6 @@ import ( "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/types" - "k8s.io/apimachinery/pkg/util/wait" "k8s.io/client-go/kubernetes/fake" ktest "k8s.io/client-go/testing" ) @@ -101,7 +100,7 @@ func TestDeletePods(t *testing.T) { timeout: 3 * time.Second, expectPendingPods: true, expectError: true, - expectedError: &wait.ErrWaitTimeout, + expectedError: &context.DeadlineExceeded, getPodFn: func(namespace, name string) (*corev1.Pod, error) { oldPodMap, _ := createPods(false) if oldPod, found := oldPodMap[name]; found { @@ -117,7 +116,7 @@ func TestDeletePods(t *testing.T) { ctxTimeoutEarly: true, expectPendingPods: true, expectError: true, - expectedError: &wait.ErrWaitTimeout, + expectedError: &context.Canceled, getPodFn: func(namespace, name string) (*corev1.Pod, error) { oldPodMap, _ := createPods(false) if oldPod, found := oldPodMap[name]; found {