diff --git a/configmap/informer/informed_watcher.go b/configmap/informer/informed_watcher.go index 2231ee52f..1ae7ef8d4 100644 --- a/configmap/informer/informed_watcher.go +++ b/configmap/informer/informed_watcher.go @@ -107,7 +107,8 @@ var _ configmap.Watcher = (*InformedWatcher)(nil) // Asserts that InformedWatcher implements DefaultingWatcher. var _ configmap.DefaultingWatcher = (*InformedWatcher)(nil) -// WatchWithDefault implements DefaultingWatcher. +// WatchWithDefault implements DefaultingWatcher. Adding a default for the configMap being watched means that when +// Start is called, Start will not wait for the add event from the API server. func (i *InformedWatcher) WatchWithDefault(cm corev1.ConfigMap, o ...configmap.Observer) { i.defaults[cm.Name] = &cm @@ -126,31 +127,59 @@ func (i *InformedWatcher) WatchWithDefault(cm corev1.ConfigMap, o ...configmap.O i.Watch(cm.Name, o...) } -// Start implements Watcher. -func (i *InformedWatcher) Start(stopCh <-chan struct{}) error { - // Pretend that all the defaulted ConfigMaps were just created. This is done before we start - // the informer to ensure that if a defaulted ConfigMap does exist, then the real value is - // processed after the default one. +func (i *InformedWatcher) triggerAddEventForDefaultedConfigMaps(addConfigMapEvent func(obj interface{})) { i.ForEach(func(k string, _ []configmap.Observer) error { if def, ok := i.defaults[k]; ok { - i.addConfigMapEvent(def) + addConfigMapEvent(def) } return nil }) +} - if err := i.registerCallbackAndStartInformer(stopCh); err != nil { +func (i *InformedWatcher) getConfigMapNames() []string { + var configMaps []string + i.ForEach(func(k string, _ []configmap.Observer) error { + configMaps = append(configMaps, k) + return nil + }) + return configMaps +} + +// Start implements Watcher. Start will wait for all watched resources to exist and for the add event handler to be +// invoked at least once for each before continuing or for the stopCh to be signalled, whichever happens first. If +// the watched resource is defaulted, Start will invoke the add event handler directly and will not wait for a further +// add event from the API server. +func (i *InformedWatcher) Start(stopCh <-chan struct{}) error { + // using the synced callback wrapper around the add event handler will allow the caller + // to wait for the add event to be processed for all configmaps + s := newSyncedCallback(i.getConfigMapNames(), i.addConfigMapEvent) + addConfigMapEvent := func(obj interface{}) { + configMap := obj.(*corev1.ConfigMap) + s.Call(obj, configMap.Name) + } + // Pretend that all the defaulted ConfigMaps were just created. This is done before we start + // the informer to ensure that if a defaulted ConfigMap does exist, then the real value is + // processed after the default one. + i.triggerAddEventForDefaultedConfigMaps(addConfigMapEvent) + + if err := i.registerCallbackAndStartInformer(addConfigMapEvent, stopCh); err != nil { return err } - // Wait until it has been synced (WITHOUT holing the mutex, so callbacks happen) + // Wait until the shared informer has been synced (WITHOUT holing the mutex, so callbacks happen) if ok := cache.WaitForCacheSync(stopCh, i.informer.Informer().HasSynced); !ok { return errors.New("error waiting for ConfigMap informer to sync") } - return i.checkObservedResourcesExist() + if err := i.checkObservedResourcesExist(); err != nil { + return err + } + + // Wait until all config maps have been at least initially processed + return s.WaitForAllKeys(stopCh) } -func (i *InformedWatcher) registerCallbackAndStartInformer(stopCh <-chan struct{}) error { +func (i *InformedWatcher) registerCallbackAndStartInformer(addConfigMapEvent func(obj interface{}), stopCh <-chan struct{}) error { i.Lock() defer i.Unlock() if i.started { @@ -159,13 +188,14 @@ func (i *InformedWatcher) registerCallbackAndStartInformer(stopCh <-chan struct{ i.started = true i.informer.Informer().AddEventHandler(cache.ResourceEventHandlerFuncs{ - AddFunc: i.addConfigMapEvent, + AddFunc: addConfigMapEvent, UpdateFunc: i.updateConfigMapEvent, DeleteFunc: i.deleteConfigMapEvent, }) // Start the shared informer factory (non-blocking). i.sif.Start(stopCh) + return nil } diff --git a/configmap/informer/informed_watcher_test.go b/configmap/informer/informed_watcher_test.go index c25194122..5a812bb69 100644 --- a/configmap/informer/informed_watcher_test.go +++ b/configmap/informer/informed_watcher_test.go @@ -20,7 +20,6 @@ import ( "context" "sync" "testing" - "time" corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/equality" @@ -28,7 +27,6 @@ import ( "k8s.io/apimachinery/pkg/labels" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/selection" - "k8s.io/apimachinery/pkg/util/wait" fakekubeclientset "k8s.io/client-go/kubernetes/fake" ) @@ -42,6 +40,7 @@ type counter struct { func (c *counter) callback(cm *corev1.ConfigMap) { c.mu.Lock() defer c.mu.Unlock() + c.cfg = append(c.cfg, cm) if c.wg != nil { c.wg.Done() @@ -54,26 +53,6 @@ func (c *counter) count() int { return len(c.cfg) } -func (c *counter) eventuallyEquals(t *testing.T, want int) { - got := 0 - - err := wait.Poll( - // interval - 100*time.Millisecond, - - // timeout - 5*time.Second, - func() (done bool, err error) { - got = c.count() - return got == want, nil - }, - ) - - if err != nil { - t.Errorf("%v.count = %d, want %d", c.name, got, want) - } -} - func TestInformedWatcher(t *testing.T) { fooCM := &corev1.ConfigMap{ ObjectMeta: metav1.ObjectMeta{ @@ -105,10 +84,12 @@ func TestInformedWatcher(t *testing.T) { t.Fatal("cm.Start() =", err) } - // When Start returns the callbacks will eventually be called with the + // When Start returns the callbacks should have been called with the // version of the objects that is available. - for _, count := range []*counter{foo1, foo2, bar} { - count.eventuallyEquals(t, 1) + for _, obj := range []*counter{foo1, foo2, bar} { + if got, want := obj.count(), 1; got != want { + t.Errorf("%v.count = %d, want %d", obj.name, got, want) + } } // After a "foo" event, the "foo" watchers should have 2, diff --git a/configmap/informer/synced_callback.go b/configmap/informer/synced_callback.go new file mode 100644 index 000000000..355c512b5 --- /dev/null +++ b/configmap/informer/synced_callback.go @@ -0,0 +1,112 @@ +/* +Copyright 2021 The Knative Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package informer + +import ( + "sync" + + "k8s.io/apimachinery/pkg/util/sets" + "k8s.io/apimachinery/pkg/util/wait" +) + +// namedWaitGroup is used to increment and decrement a WaitGroup by name +type namedWaitGroup struct { + waitGroup sync.WaitGroup + keys sets.String + mu sync.Mutex +} + +// newNamedWaitGroup returns an instantiated namedWaitGroup. +func newNamedWaitGroup() *namedWaitGroup { + return &namedWaitGroup{ + keys: sets.NewString(), + } +} + +// Add will add the key to the list of keys being tracked and increment the wait group. +// If the key has already been added, the wait group will not be incremented again. +func (n *namedWaitGroup) Add(key string) { + n.mu.Lock() + defer n.mu.Unlock() + + if !n.keys.Has(key) { + n.keys.Insert(key) + n.waitGroup.Add(1) + } +} + +// Done will decrement the counter if the key is present in the tracked keys. If it is not present +// it will be ignored. +func (n *namedWaitGroup) Done(key string) { + n.mu.Lock() + defer n.mu.Unlock() + + if n.keys.Has(key) { + n.keys.Delete(key) + n.waitGroup.Done() + } +} + +// Wait will wait for the underlying waitGroup to complete. +func (n *namedWaitGroup) Wait() { + n.waitGroup.Wait() +} + +// syncedCallback can be used to wait for a callback to be called at least once for a list of keys. +type syncedCallback struct { + // namedWaitGroup will block until the callback has been called for all tracked entities + namedWaitGroup *namedWaitGroup + + // callback is the callback that is intended to be called at least once for each key + // being tracked via WaitGroup + callback func(obj interface{}) +} + +// newSyncedCallback will return a syncedCallback that will track the provided keys. +func newSyncedCallback(keys []string, callback func(obj interface{})) *syncedCallback { + s := &syncedCallback{ + callback: callback, + namedWaitGroup: newNamedWaitGroup(), + } + for _, key := range keys { + s.namedWaitGroup.Add(key) + } + return s +} + +// Event is intended to be a wrapper for the actual event handler; this wrapper will signal via +// the wait group that the event handler has been called at least once for the key. +func (s *syncedCallback) Call(obj interface{}, key string) { + s.callback(obj) + s.namedWaitGroup.Done(key) +} + +// WaitForAllKeys will block until s.Call has been called for all the keys we are tracking or the stop signal is +// received. +func (s *syncedCallback) WaitForAllKeys(stopCh <-chan struct{}) error { + c := make(chan struct{}) + go func() { + defer close(c) + s.namedWaitGroup.Wait() + }() + select { + case <-c: + return nil + case <-stopCh: + return wait.ErrWaitTimeout + } +} diff --git a/configmap/informer/synced_callback_test.go b/configmap/informer/synced_callback_test.go new file mode 100644 index 000000000..b4798fe4a --- /dev/null +++ b/configmap/informer/synced_callback_test.go @@ -0,0 +1,151 @@ +/* +Copyright 2021 The Knative Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package informer + +import ( + "testing" + "time" +) + +func TestNamedWaitGroup(t *testing.T) { + nwg := newNamedWaitGroup() + + // nothing has been added so wait returns immediately + initiallyDone := make(chan struct{}) + go func() { + defer close(initiallyDone) + nwg.Wait() + }() + select { + case <-time.After(1 * time.Second): + t.Fatalf("Wait should have returned immediately but still hadn't after timeout elapsed") + case <-initiallyDone: + // the Wait returned as expected since nothing was tracked + } + + // Add some keys to track + nwg.Add("foo") + nwg.Add("bar") + // Adding keys multiple times shouldn't increment the counter again + nwg.Add("bar") + + // Now that we've added keys, when we Wait, it should block + done := make(chan struct{}) + go func() { + defer close(done) + nwg.Wait() + }() + + // Indicate that this key is done + nwg.Done("foo") + // Indicating done on a key that doesn't exist should do nothing + nwg.Done("doesnt exist") + + // Only one of the tracked keys has completed, so the channel should not yet have closed + select { + case <-done: + t.Fatalf("Wait returned before all keys were done") + default: + // as expected, the channel is still open (waiting for the final key to be done) + } + + // Indicate the final key is done + nwg.Done("bar") + + // Now that all keys are done, the Wait should return + select { + case <-time.After(1 * time.Second): + t.Fatalf("Wait should have returned immediately but still hadn't after timeout elapsed") + case <-done: + // completed successfully + } +} + +func TestSyncedCallback(t *testing.T) { + keys := []string{"foo", "bar"} + objs := []interface{}{"fooobj", "barobj"} + var seen []interface{} + callback := func(obj interface{}) { + seen = append(seen, obj) + } + sc := newSyncedCallback(keys, callback) + + // Wait for the callback to be called for all of the keys + stopCh := make(chan struct{}) + done := make(chan struct{}) + go func() { + defer close(done) + sc.WaitForAllKeys(stopCh) + }() + + // Call the callback for one of the keys + sc.Call(objs[0], "foo") + + // Only one of the tracked keys has been synced so we should still be waiting + select { + case <-done: + t.Fatalf("Wait returned before all keys were done") + default: + // as expected, the channel is still open (waiting for the final key to be done) + } + + // Call the callback for the other key + sc.Call(objs[1], "bar") + + // Now that all keys are done, the Wait should return + select { + case <-time.After(1 * time.Second): + t.Fatalf("WaitForAllKeys should have returned but still hadn't after timeout elapsed") + case <-done: + // completed successfully + } + + if len(seen) != 2 || seen[0] != objs[0] || seen[1] != objs[1] { + t.Errorf("callback wasn't called as expected, expected to see %v but saw %v", objs, seen) + } +} + +func TestSyncedCallbackStops(t *testing.T) { + sc := newSyncedCallback([]string{"somekey"}, func(obj interface{}) {}) + + // Wait for the callback to be called - which it won't be! + stopCh := make(chan struct{}) + done := make(chan struct{}) + go func() { + defer close(done) + sc.WaitForAllKeys(stopCh) + }() + + // Nothing has been synced so we should still be waiting + select { + case <-done: + t.Fatalf("Wait returned before all keys were done") + default: + // as expected, the channel is still open + } + + // signal to stop via the stop channel + close(stopCh) + + // Even though the callback wasn't called, the Wait should return b/c of the stop channel + select { + case <-time.After(1 * time.Second): + t.Fatalf("WaitForAllKeys should have returned because of the stop channel but still hadn't after timeout elapsed") + case <-done: + // stopped successfully + } +}