Add TaskContext to be passed between tasks

This commit is contained in:
Morten Torkildsen 2020-04-05 20:08:51 -07:00
parent b8c5a7c2b4
commit 260f419d8b
9 changed files with 91 additions and 68 deletions

View File

@ -198,8 +198,8 @@ func splitInfos(infos []*resource.Info) ([]*resource.Info, []*resource.Info) {
// buildTaskQueue takes the slice of infos and object identifiers, and
// builds a queue of tasks that needs to be executed.
func (a *Applier) buildTaskQueue(infos []*resource.Info, identifiers []object.ObjMetadata,
eventChannel chan event.Event) chan taskrunner.Task {
func (a *Applier) buildTaskQueue(infos []*resource.Info,
identifiers []object.ObjMetadata) chan taskrunner.Task {
tasks := []taskrunner.Task{
// This taks is responsible for applying all the resources
// in the infos slice.
@ -217,7 +217,6 @@ func (a *Applier) buildTaskQueue(infos []*resource.Info, identifiers []object.Ob
Type: event.ApplyEventCompleted,
},
},
EventChannel: eventChannel,
},
}
@ -237,7 +236,6 @@ func (a *Applier) buildTaskQueue(infos []*resource.Info, identifiers []object.Ob
EventType: pollevent.CompletedEvent,
},
},
EventChannel: eventChannel,
})
}
@ -248,7 +246,6 @@ func (a *Applier) buildTaskQueue(infos []*resource.Info, identifiers []object.Ob
&task.PruneTask{
Objects: infos,
PruneOptions: a.PruneOptions,
EventChannel: eventChannel,
},
// Once prune is completed, we send an event to notify
// the client.
@ -259,7 +256,6 @@ func (a *Applier) buildTaskQueue(infos []*resource.Info, identifiers []object.Ob
Type: event.PruneEventCompleted,
},
},
EventChannel: eventChannel,
})
}
@ -307,7 +303,7 @@ func (a *Applier) Run(ctx context.Context, options Options) <-chan event.Event {
identifiers := infosToObjMetas(infos)
// Fetch the queue (channel) of tasks that should be executed.
taskQueue := a.buildTaskQueue(infos, identifiers, eventChannel)
taskQueue := a.buildTaskQueue(infos, identifiers)
// Send event to inform the caller about the resources that
// will be applied/pruned.

View File

@ -20,11 +20,11 @@ type ApplyTask struct {
// the Run function on the ApplyOptions to update
// the cluster. It will push a TaskResult on the taskChannel
// to signal to the taskrunner that the task has completed (or failed).
func (a *ApplyTask) Start(taskChannel chan taskrunner.TaskResult) {
func (a *ApplyTask) Start(taskContext *taskrunner.TaskContext) {
go func() {
a.ApplyOptions.SetObjects(a.Objects)
err := a.ApplyOptions.Run()
taskChannel <- taskrunner.TaskResult{
taskContext.TaskChannel() <- taskrunner.TaskResult{
Err: err,
}
}()

View File

@ -5,7 +5,6 @@ package task
import (
"k8s.io/cli-runtime/pkg/resource"
"sigs.k8s.io/cli-utils/pkg/apply/event"
"sigs.k8s.io/cli-utils/pkg/apply/prune"
"sigs.k8s.io/cli-utils/pkg/apply/taskrunner"
)
@ -15,7 +14,6 @@ import (
// set of resources that have just been applied.
type PruneTask struct {
PruneOptions *prune.PruneOptions
EventChannel chan event.Event
Objects []*resource.Info
}
@ -23,10 +21,10 @@ type PruneTask struct {
// the Run function on the PruneOptions to update
// the cluster. It will push a TaskResult on the taskChannel
// to signal to the taskrunner that the task has completed (or failed).
func (p *PruneTask) Start(taskChannel chan taskrunner.TaskResult) {
func (p *PruneTask) Start(taskContext *taskrunner.TaskContext) {
go func() {
err := p.PruneOptions.Prune(p.Objects, p.EventChannel)
taskChannel <- taskrunner.TaskResult{
err := p.PruneOptions.Prune(p.Objects, taskContext.EventChannel())
taskContext.TaskChannel() <- taskrunner.TaskResult{
Err: err,
}
}()

View File

@ -12,17 +12,16 @@ import (
// that will send the provided event on the eventChannel when
// executed.
type SendEventTask struct {
EventChannel chan event.Event
Event event.Event
Event event.Event
}
// Start start a separate goroutine that will send the
// event and then push a TaskResult on the taskChannel to
// signal to the taskrunner that the task is completed.
func (s *SendEventTask) Start(taskChannel chan taskrunner.TaskResult) {
func (s *SendEventTask) Start(taskContext *taskrunner.TaskContext) {
go func() {
s.EventChannel <- s.Event
taskChannel <- taskrunner.TaskResult{}
taskContext.EventChannel() <- s.Event
taskContext.TaskChannel() <- taskrunner.TaskResult{}
}()
}

View File

@ -0,0 +1,32 @@
// Copyright 2020 The Kubernetes Authors.
// SPDX-License-Identifier: Apache-2.0
package taskrunner
import (
"sigs.k8s.io/cli-utils/pkg/apply/event"
)
// NewTaskContext returns a new TaskContext
func NewTaskContext(eventChannel chan event.Event) *TaskContext {
return &TaskContext{
taskChannel: make(chan TaskResult),
eventChannel: eventChannel,
}
}
// TaskContext defines a context that is passed between all
// the tasks that is in a taskqueue.
type TaskContext struct {
taskChannel chan TaskResult
eventChannel chan event.Event
}
func (tc *TaskContext) TaskChannel() chan TaskResult {
return tc.taskChannel
}
func (tc *TaskContext) EventChannel() chan event.Event {
return tc.eventChannel
}

View File

@ -124,13 +124,13 @@ type baseOptions struct {
func (b *baseRunner) run(ctx context.Context, taskQueue chan Task,
statusChannel <-chan pollevent.Event, eventChannel chan event.Event,
o baseOptions) error {
// taskChannel is used by tasks running in a separate goroutine
// to signal back to the main loop that the task is either finished
// or it has failed.
taskChannel := make(chan TaskResult)
// taskContext is passed into all tasks when they are started. It
// provides access to the eventChannel and the taskChannel, and
// also provides a way to pass data between tasks.
taskContext := NewTaskContext(eventChannel)
// Find and start the first task in the queue.
currentTask, done := b.nextTask(taskQueue, taskChannel)
currentTask, done := b.nextTask(taskQueue, taskContext)
if done {
return nil
}
@ -173,7 +173,7 @@ func (b *baseRunner) run(ctx context.Context, taskQueue chan Task,
statusEvent.Error)
// If the current task is a wait task, we just set it
// to complete so we can exit the loop as soon as possible.
completeIfWaitTask(currentTask, taskChannel)
completeIfWaitTask(currentTask, taskContext)
continue
}
@ -194,7 +194,7 @@ func (b *baseRunner) run(ctx context.Context, taskQueue chan Task,
// the condition has been met. If so, we complete the task.
if wt, ok := currentTask.(*WaitTask); ok {
if b.collector.conditionMet(wt.Identifiers, wt.Condition) {
completeIfWaitTask(currentTask, taskChannel)
completeIfWaitTask(currentTask, taskContext)
}
}
// A message on the taskChannel means that the current task
@ -203,7 +203,7 @@ func (b *baseRunner) run(ctx context.Context, taskQueue chan Task,
// else has gone wrong and we are waiting for the current task to
// finish, we exit.
// If everything is ok, we fetch and start the next task.
case msg := <-taskChannel:
case msg := <-taskContext.TaskChannel():
currentTask.ClearTimeout()
if msg.Err != nil {
return msg.Err
@ -211,7 +211,7 @@ func (b *baseRunner) run(ctx context.Context, taskQueue chan Task,
if abort {
return abortReason
}
currentTask, done = b.nextTask(taskQueue, taskChannel)
currentTask, done = b.nextTask(taskQueue, taskContext)
// If there are no more tasks, we are done. So just
// return.
if done {
@ -223,16 +223,16 @@ func (b *baseRunner) run(ctx context.Context, taskQueue chan Task,
case <-doneCh:
doneCh = nil // Set doneCh to nil so we don't enter a busy loop.
abort = true
completeIfWaitTask(currentTask, taskChannel)
completeIfWaitTask(currentTask, taskContext)
}
}
}
// completeIfWaitTask checks if the current task is a wait task. If so,
// we invoke the complete function to complete it.
func completeIfWaitTask(currentTask Task, taskChannel chan TaskResult) {
func completeIfWaitTask(currentTask Task, taskContext *TaskContext) {
if wt, ok := currentTask.(*WaitTask); ok {
wt.complete(taskChannel)
wt.complete(taskContext)
}
}
@ -240,7 +240,7 @@ func completeIfWaitTask(currentTask Task, taskChannel chan TaskResult) {
// starts it. If the taskQueue is empty, it the second
// return value will be true.
func (b *baseRunner) nextTask(taskQueue chan Task,
taskChannel chan TaskResult) (Task, bool) {
taskContext *TaskContext) (Task, bool) {
var tsk Task
select {
// If there is any tasks left in the queue, this
@ -259,12 +259,12 @@ func (b *baseRunner) nextTask(taskQueue chan Task,
// met. Without this check, a task might end up waiting for
// status events when the condition is in fact already met.
if b.collector.conditionMet(st.Identifiers, st.Condition) {
st.startAndComplete(taskChannel)
st.startAndComplete(taskContext)
} else {
tsk.Start(taskChannel)
tsk.Start(taskContext)
}
default:
tsk.Start(taskChannel)
tsk.Start(taskContext)
}
return tsk, false
}

View File

@ -131,9 +131,6 @@ func TestBaseRunner(t *testing.T) {
eventChannel := make(chan event.Event)
taskQueue := make(chan Task, len(tc.tasks))
for _, tsk := range tc.tasks {
if bt, ok := tsk.(*busyTask); ok {
bt.eventChannel = eventChannel
}
taskQueue <- tsk
}
@ -287,9 +284,6 @@ func TestBaseRunnerCancellation(t *testing.T) {
taskQueue := make(chan Task, len(tc.tasks))
for _, tsk := range tc.tasks {
if bt, ok := tsk.(*busyTask); ok {
bt.eventChannel = eventChannel
}
taskQueue <- tsk
}
@ -349,17 +343,16 @@ func TestBaseRunnerCancellation(t *testing.T) {
}
type busyTask struct {
eventChannel chan event.Event
resultEvent event.Event
duration time.Duration
err error
resultEvent event.Event
duration time.Duration
err error
}
func (b *busyTask) Start(taskChannel chan TaskResult) {
func (b *busyTask) Start(taskContext *TaskContext) {
go func() {
<-time.NewTimer(b.duration).C
b.eventChannel <- b.resultEvent
taskChannel <- TaskResult{
taskContext.EventChannel() <- b.resultEvent
taskContext.TaskChannel() <- TaskResult{
Err: b.err,
}
}()

View File

@ -13,7 +13,7 @@ import (
// Task is the interface that must be implemented by
// all tasks that will be executed by the taskrunner.
type Task interface {
Start(taskChannel chan TaskResult)
Start(taskContext *TaskContext)
ClearTimeout()
}
@ -64,7 +64,7 @@ type WaitTask struct {
// Start kicks off the task. For the wait task, this just means
// setting up the timeout timer.
func (w *WaitTask) Start(taskChannel chan TaskResult) {
func (w *WaitTask) Start(taskContext *TaskContext) {
timer := time.NewTimer(w.Timeout)
go func() {
//TODO(mortent): See if there is a better way to do this. This
@ -75,7 +75,7 @@ func (w *WaitTask) Start(taskChannel chan TaskResult) {
// We only send the taskResult if no one has gotten
// to the token first.
case <-w.token:
taskChannel <- TaskResult{
taskContext.TaskChannel() <- TaskResult{
Err: timeoutError{
message: fmt.Sprintf("timeout after %.0f seconds waiting for %d resources to reach condition %s",
w.Timeout.Seconds(), len(w.Identifiers), w.Condition),
@ -94,20 +94,20 @@ func (w *WaitTask) Start(taskChannel chan TaskResult) {
// met when the task should be started. In this case there is no
// need to start a timer. So it just sets the cancelFunc and then
// completes the task.
func (w *WaitTask) startAndComplete(taskChannel chan TaskResult) {
func (w *WaitTask) startAndComplete(taskContext *TaskContext) {
w.cancelFunc = func() {}
w.complete(taskChannel)
w.complete(taskContext)
}
// complete is invoked by the taskrunner when all the conditions
// for the task has been met, or something has failed so the task
// need to be stopped.
func (w *WaitTask) complete(taskChannel chan TaskResult) {
func (w *WaitTask) complete(taskContext *TaskContext) {
select {
// Only do something if we can get the token.
case <-w.token:
go func() {
taskChannel <- TaskResult{}
taskContext.TaskChannel() <- TaskResult{}
}()
default:
return

View File

@ -8,21 +8,23 @@ import (
"testing"
"time"
"sigs.k8s.io/cli-utils/pkg/apply/event"
"sigs.k8s.io/cli-utils/pkg/object"
)
func TestWaitTask_TimeoutTriggered(t *testing.T) {
task := NewWaitTask([]object.ObjMetadata{}, AllCurrent, 2*time.Second)
taskChannel := make(chan TaskResult)
defer close(taskChannel)
eventChannel := make(chan event.Event)
taskContext := NewTaskContext(eventChannel)
defer close(eventChannel)
task.Start(taskChannel)
task.Start(taskContext)
timer := time.NewTimer(3 * time.Second)
select {
case res := <-taskChannel:
case res := <-taskContext.TaskChannel():
if res.Err == nil || !IsTimeoutError(res.Err) {
t.Errorf("expected timeout error, but got %v", res.Err)
}
@ -35,15 +37,16 @@ func TestWaitTask_TimeoutTriggered(t *testing.T) {
func TestWaitTask_TimeoutCancelled(t *testing.T) {
task := NewWaitTask([]object.ObjMetadata{}, AllCurrent, 2*time.Second)
taskChannel := make(chan TaskResult)
defer close(taskChannel)
eventChannel := make(chan event.Event)
taskContext := NewTaskContext(eventChannel)
defer close(eventChannel)
task.Start(taskChannel)
task.Start(taskContext)
task.ClearTimeout()
timer := time.NewTimer(3 * time.Second)
select {
case res := <-taskChannel:
case res := <-taskContext.TaskChannel():
t.Errorf("didn't expect timeout error, but got %v", res.Err)
case <-timer.C:
return
@ -53,8 +56,10 @@ func TestWaitTask_TimeoutCancelled(t *testing.T) {
func TestWaitTask_SingleTaskResult(t *testing.T) {
task := NewWaitTask([]object.ObjMetadata{}, AllCurrent, 2*time.Second)
taskChannel := make(chan TaskResult, 10)
defer close(taskChannel)
eventChannel := make(chan event.Event)
taskContext := NewTaskContext(eventChannel)
taskContext.taskChannel = make(chan TaskResult, 10)
defer close(eventChannel)
var completeWg sync.WaitGroup
@ -62,17 +67,17 @@ func TestWaitTask_SingleTaskResult(t *testing.T) {
completeWg.Add(1)
go func() {
defer completeWg.Done()
task.complete(taskChannel)
task.complete(taskContext)
}()
}
completeWg.Wait()
<-taskChannel
<-taskContext.TaskChannel()
timer := time.NewTimer(4 * time.Second)
select {
case <-taskChannel:
case <-taskContext.TaskChannel():
t.Errorf("expected only one result on taskChannel, but got more")
case <-timer.C:
return