From 47065b904559f36e83b924007c375f2de0f31af0 Mon Sep 17 00:00:00 2001 From: Alexandr Morozov Date: Fri, 6 Jun 2014 15:28:12 +0400 Subject: [PATCH] State refactoring and add waiting functions Docker-DCO-1.1-Signed-off-by: Alexandr Morozov (github: LK4D4) --- daemon/state.go | 115 ++++++++++++++++++++++++++++++++++--------- daemon/state_test.go | 102 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 194 insertions(+), 23 deletions(-) create mode 100644 daemon/state_test.go diff --git a/daemon/state.go b/daemon/state.go index 7ee8fc48c3..b9ef350568 100644 --- a/daemon/state.go +++ b/daemon/state.go @@ -16,6 +16,13 @@ type State struct { ExitCode int StartedAt time.Time FinishedAt time.Time + waitChan chan struct{} +} + +func NewState() *State { + return &State{ + waitChan: make(chan struct{}), + } } // String returns a human-readable description of the state @@ -35,56 +42,118 @@ func (s *State) String() string { return fmt.Sprintf("Exited (%d) %s ago", s.ExitCode, units.HumanDuration(time.Now().UTC().Sub(s.FinishedAt))) } +func wait(waitChan <-chan struct{}, timeout time.Duration) error { + if timeout < 0 { + <-waitChan + return nil + } + select { + case <-time.After(timeout): + return fmt.Errorf("Timed out: %v", timeout) + case <-waitChan: + return nil + } +} + +// WaitRunning waits until state is running. If state already running it returns +// immediatly. If you want wait forever you must supply negative timeout. +// Returns pid, that was passed to SetRunning +func (s *State) WaitRunning(timeout time.Duration) (int, error) { + s.RLock() + if s.IsRunning() { + pid := s.Pid + s.RUnlock() + return pid, nil + } + waitChan := s.waitChan + s.RUnlock() + if err := wait(waitChan, timeout); err != nil { + return -1, err + } + return s.GetPid(), nil +} + +// WaitStop waits until state is stopped. If state already stopped it returns +// immediatly. If you want wait forever you must supply negative timeout. +// Returns exit code, that was passed to SetRunning +func (s *State) WaitStop(timeout time.Duration) (int, error) { + s.RLock() + if !s.Running { + exitCode := s.ExitCode + s.RUnlock() + return exitCode, nil + } + waitChan := s.waitChan + s.RUnlock() + if err := wait(waitChan, timeout); err != nil { + return -1, err + } + return s.GetExitCode(), nil +} + func (s *State) IsRunning() bool { s.RLock() - defer s.RUnlock() + res := s.Running + s.RUnlock() + return res +} - return s.Running +func (s *State) GetPid() int { + s.RLock() + res := s.Pid + s.RUnlock() + return res } func (s *State) GetExitCode() int { s.RLock() - defer s.RUnlock() - - return s.ExitCode + res := s.ExitCode + s.RUnlock() + return res } func (s *State) SetRunning(pid int) { s.Lock() - defer s.Unlock() - - s.Running = true - s.Paused = false - s.ExitCode = 0 - s.Pid = pid - s.StartedAt = time.Now().UTC() + if !s.Running { + s.Running = true + s.Paused = false + s.ExitCode = 0 + s.Pid = pid + s.StartedAt = time.Now().UTC() + close(s.waitChan) // fire waiters for start + s.waitChan = make(chan struct{}) + } + s.Unlock() } func (s *State) SetStopped(exitCode int) { s.Lock() - defer s.Unlock() - - s.Running = false - s.Pid = 0 - s.FinishedAt = time.Now().UTC() - s.ExitCode = exitCode + if s.Running { + s.Running = false + s.Pid = 0 + s.FinishedAt = time.Now().UTC() + s.ExitCode = exitCode + close(s.waitChan) // fire waiters for stop + s.waitChan = make(chan struct{}) + } + s.Unlock() } func (s *State) SetPaused() { s.Lock() - defer s.Unlock() s.Paused = true + s.Unlock() } func (s *State) SetUnpaused() { s.Lock() - defer s.Unlock() s.Paused = false + s.Unlock() } func (s *State) IsPaused() bool { s.RLock() - defer s.RUnlock() - - return s.Paused + res := s.Paused + s.RUnlock() + return res } diff --git a/daemon/state_test.go b/daemon/state_test.go new file mode 100644 index 0000000000..7b02f3aeac --- /dev/null +++ b/daemon/state_test.go @@ -0,0 +1,102 @@ +package daemon + +import ( + "sync/atomic" + "testing" + "time" +) + +func TestStateRunStop(t *testing.T) { + s := NewState() + for i := 1; i < 3; i++ { // full lifecycle two times + started := make(chan struct{}) + var pid int64 + go func() { + runPid, _ := s.WaitRunning(-1 * time.Second) + atomic.StoreInt64(&pid, int64(runPid)) + close(started) + }() + s.SetRunning(i + 100) + if !s.IsRunning() { + t.Fatal("State not running") + } + if s.Pid != i+100 { + t.Fatalf("Pid %v, expected %v", s.Pid, i+100) + } + if s.ExitCode != 0 { + t.Fatalf("ExitCode %v, expected 0", s.ExitCode) + } + select { + case <-time.After(100 * time.Millisecond): + t.Fatal("Start callback doesn't fire in 100 milliseconds") + case <-started: + t.Log("Start callback fired") + } + runPid := int(atomic.LoadInt64(&pid)) + if runPid != i+100 { + t.Fatalf("Pid %v, expected %v", runPid, i+100) + } + if pid, err := s.WaitRunning(-1 * time.Second); err != nil || pid != i+100 { + t.Fatal("WaitRunning returned pid: %v, err: %v, expected pid: %v, err: %v", pid, err, i+100, nil) + } + + stopped := make(chan struct{}) + var exit int64 + go func() { + exitCode, _ := s.WaitStop(-1 * time.Second) + atomic.StoreInt64(&exit, int64(exitCode)) + close(stopped) + }() + s.SetStopped(i) + if s.IsRunning() { + t.Fatal("State is running") + } + if s.ExitCode != i { + t.Fatalf("ExitCode %v, expected %v", s.ExitCode, i) + } + if s.Pid != 0 { + t.Fatalf("Pid %v, expected 0", s.Pid) + } + select { + case <-time.After(100 * time.Millisecond): + t.Fatal("Stop callback doesn't fire in 100 milliseconds") + case <-stopped: + t.Log("Stop callback fired") + } + exitCode := int(atomic.LoadInt64(&exit)) + if exitCode != i { + t.Fatalf("ExitCode %v, expected %v", exitCode, i) + } + if exitCode, err := s.WaitStop(-1 * time.Second); err != nil || exitCode != i { + t.Fatal("WaitStop returned exitCode: %v, err: %v, expected exitCode: %v, err: %v", exitCode, err, i, nil) + } + } +} + +func TestStateTimeoutWait(t *testing.T) { + s := NewState() + started := make(chan struct{}) + go func() { + s.WaitRunning(100 * time.Millisecond) + close(started) + }() + select { + case <-time.After(200 * time.Millisecond): + t.Fatal("Start callback doesn't fire in 100 milliseconds") + case <-started: + t.Log("Start callback fired") + } + s.SetRunning(42) + stopped := make(chan struct{}) + go func() { + s.WaitRunning(100 * time.Millisecond) + close(stopped) + }() + select { + case <-time.After(200 * time.Millisecond): + t.Fatal("Start callback doesn't fire in 100 milliseconds") + case <-stopped: + t.Log("Start callback fired") + } + +}