diff --git a/commands.go b/commands.go index 324afbba04..e749923df8 100644 --- a/commands.go +++ b/commands.go @@ -197,7 +197,7 @@ var Commands = []cli.Command{ { Name: "kill", Usage: "Kill a machine", - Description: "Argument is a machine name. Will use the active machine if none is provided.", + Description: "Argument(s) are one or more machine names. Will use the active machine if none is provided.", Action: cmdKill, }, { @@ -214,7 +214,7 @@ var Commands = []cli.Command{ { Name: "restart", Usage: "Restart a machine", - Description: "Argument is a machine name. Will use the active machine if none is provided.", + Description: "Argument(s) are one or more machine names. Will use the active machine if none is provided.", Action: cmdRestart, }, { @@ -226,7 +226,7 @@ var Commands = []cli.Command{ }, Name: "rm", Usage: "Remove a machine", - Description: "Argument is a machine name. Will use the active machine if none is provided.", + Description: "Argument(s) are one or more machine names. Will use the active machine if none is provided.", Action: cmdRm, }, { @@ -254,19 +254,19 @@ var Commands = []cli.Command{ { Name: "start", Usage: "Start a machine", - Description: "Argument is a machine name. Will use the active machine if none is provided.", + Description: "Argument(s) are one or more machine names. Will use the active machine if none is provided.", Action: cmdStart, }, { Name: "stop", Usage: "Stop a machine", - Description: "Argument is a machine name. Will use the active machine if none is provided.", + Description: "Argument(s) are one or more machine names. Will use the active machine if none is provided.", Action: cmdStop, }, { Name: "upgrade", Usage: "Upgrade a machine to the latest version of Docker", - Description: "Argument is a machine name. Will use the active machine if none is provided.", + Description: "Argument(s) are one or more machine names. Will use the active machine if none is provided.", Action: cmdUpgrade, }, { @@ -382,12 +382,6 @@ func cmdIp(c *cli.Context) { fmt.Println(ip) } -func cmdKill(c *cli.Context) { - if err := getHost(c).Driver.Kill(); err != nil { - log.Fatal(err) - } -} - func cmdLs(c *cli.Context) { quiet := c.Bool("quiet") store := NewStore(c.GlobalString("storage-path"), c.GlobalString("tls-ca-cert"), c.GlobalString("tls-ca-key")) @@ -456,12 +450,6 @@ func cmdLs(c *cli.Context) { w.Flush() } -func cmdRestart(c *cli.Context) { - if err := getHost(c).Driver.Restart(); err != nil { - log.Fatal(err) - } -} - func cmdRm(c *cli.Context) { if len(c.Args()) == 0 { cli.ShowCommandHelp(c, "rm") @@ -573,20 +561,110 @@ func cmdSsh(c *cli.Context) { } } +// machineCommand maps the command name to the corresponding machine command. +// We run commands concurrently and communicate back an error if there was one. +func machineCommand(actionName string, machine *Host, errorChan chan<- error) { + commands := map[string](func() error){ + "start": machine.Driver.Start, + "stop": machine.Driver.Stop, + "restart": machine.Driver.Restart, + "kill": machine.Driver.Kill, + "upgrade": machine.Driver.Upgrade, + } + + log.Debugf("command=%s machine=%s", actionName, machine.Name) + + if err := commands[actionName](); err != nil { + errorChan <- err + return + } + + errorChan <- nil +} + +// runActionForeachMachine will run the command across multiple machines +func runActionForeachMachine(actionName string, machines []*Host) { + var ( + numConcurrentActions = 0 + serialMachines = []*Host{} + errorChan = make(chan error) + ) + + for _, machine := range machines { + // Virtualbox is temperamental about doing things concurrently, + // so we schedule the actions in a "queue" to be executed serially + // after the concurrent actions are scheduled. + switch machine.DriverName { + case "virtualbox": + machine := machine + serialMachines = append(serialMachines, machine) + default: + numConcurrentActions++ + go machineCommand(actionName, machine, errorChan) + } + } + + // While the concurrent actions are running, + // do the serial actions. As the name implies, + // these run one at a time. + for _, machine := range serialMachines { + serialChan := make(chan error) + go machineCommand(actionName, machine, serialChan) + if err := <-serialChan; err != nil { + log.Errorln(err) + } + close(serialChan) + } + + // TODO: We should probably only do 5-10 of these + // at a time, since otherwise cloud providers might + // rate limit us. + for i := 0; i < numConcurrentActions; i++ { + if err := <-errorChan; err != nil { + log.Errorln(err) + } + } + + close(errorChan) +} + +func runActionWithContext(actionName string, c *cli.Context) error { + machines, err := getHosts(c) + if err != nil { + return err + } + + runActionForeachMachine(actionName, machines) + + return nil +} + func cmdStart(c *cli.Context) { - if err := getHost(c).Start(); err != nil { + if err := runActionWithContext("start", c); err != nil { log.Fatal(err) } } func cmdStop(c *cli.Context) { - if err := getHost(c).Stop(); err != nil { + if err := runActionWithContext("stop", c); err != nil { + log.Fatal(err) + } +} + +func cmdRestart(c *cli.Context) { + if err := runActionWithContext("restart", c); err != nil { + log.Fatal(err) + } +} + +func cmdKill(c *cli.Context) { + if err := runActionWithContext("kill", c); err != nil { log.Fatal(err) } } func cmdUpgrade(c *cli.Context) { - if err := getHost(c).Upgrade(); err != nil { + if err := runActionWithContext("upgrade", c); err != nil { log.Fatal(err) } } @@ -610,6 +688,31 @@ func cmdNotFound(c *cli.Context, command string) { ) } +func getHosts(c *cli.Context) ([]*Host, error) { + machines := []*Host{} + for _, n := range c.Args() { + machine, err := loadMachine(n, c) + if err != nil { + return nil, err + } + + machines = append(machines, machine) + } + + return machines, nil +} + +func loadMachine(name string, c *cli.Context) (*Host, error) { + store := NewStore(c.GlobalString("storage-path"), c.GlobalString("tls-ca-cert"), c.GlobalString("tls-ca-key")) + + machine, err := store.Load(name) + if err != nil { + return nil, err + } + + return machine, nil +} + func getHost(c *cli.Context) *Host { name := c.Args().First() store := NewStore(c.GlobalString("storage-path"), c.GlobalString("tls-ca-cert"), c.GlobalString("tls-ca-key")) diff --git a/commands_test.go b/commands_test.go index 0438847997..cf934fa463 100644 --- a/commands_test.go +++ b/commands_test.go @@ -1,10 +1,12 @@ package main import ( + "flag" "io/ioutil" "os/exec" "testing" + "github.com/codegangsta/cli" drivers "github.com/docker/machine/drivers" "github.com/docker/machine/state" ) @@ -46,10 +48,12 @@ func (d *FakeDriver) Remove() error { } func (d *FakeDriver) Start() error { + d.MockState = state.Running return nil } func (d *FakeDriver) Stop() error { + d.MockState = state.Stopped return nil } @@ -81,6 +85,46 @@ func (d *FakeDriver) GetSSHCommand(args ...string) (*exec.Cmd, error) { return &exec.Cmd{}, nil } +func TestGetHosts(t *testing.T) { + if err := clearHosts(); err != nil { + t.Fatal(err) + } + + flags := getDefaultTestDriverFlags() + + store := NewStore(TestStoreDir, "", "") + + _, hostAerr := store.Create("test-a", "none", flags) + if hostAerr != nil { + t.Fatal(hostAerr) + } + + _, hostBerr := store.Create("test-b", "none", flags) + if hostBerr != nil { + t.Fatal(hostBerr) + } + + set := flag.NewFlagSet("start", 0) + set.Parse([]string{"test-a", "test-b"}) + + globalSet := flag.NewFlagSet("-d", 0) + globalSet.String("-d", "none", "driver") + globalSet.String("storage-path", TestStoreDir, "storage path") + globalSet.String("tls-ca-cert", "", "") + globalSet.String("tls-ca-key", "", "") + + c := cli.NewContext(nil, set, globalSet) + + hosts, err := getHosts(c) + if err != nil { + t.Fatal(err) + } + + if len(hosts) != 2 { + t.Fatal("Expected %d hosts, got %d hosts", 2, len(hosts)) + } +} + func TestGetHostState(t *testing.T) { storePath, err := ioutil.TempDir("", ".docker") if err != nil { @@ -132,3 +176,104 @@ func TestGetHostState(t *testing.T) { } } } + +func TestRunActionForeachMachine(t *testing.T) { + storePath, err := ioutil.TempDir("", ".docker") + if err != nil { + t.Fatal("Error creating tmp dir:", err) + } + + // Assume a bunch of machines in randomly started or + // stopped states. + machines := []*Host{ + { + Name: "foo", + DriverName: "fakedriver", + Driver: &FakeDriver{ + MockState: state.Running, + }, + storePath: storePath, + }, + { + Name: "bar", + DriverName: "fakedriver", + Driver: &FakeDriver{ + MockState: state.Stopped, + }, + storePath: storePath, + }, + { + Name: "baz", + // Ssh, don't tell anyone but this + // driver only _thinks_ it's named + // virtualbox... (to test serial actions) + // It's actually FakeDriver! + DriverName: "virtualbox", + Driver: &FakeDriver{ + MockState: state.Stopped, + }, + storePath: storePath, + }, + { + Name: "spam", + DriverName: "virtualbox", + Driver: &FakeDriver{ + MockState: state.Running, + }, + storePath: storePath, + }, + { + Name: "eggs", + DriverName: "fakedriver", + Driver: &FakeDriver{ + MockState: state.Stopped, + }, + storePath: storePath, + }, + { + Name: "ham", + DriverName: "fakedriver", + Driver: &FakeDriver{ + MockState: state.Running, + }, + storePath: storePath, + }, + } + + runActionForeachMachine("start", machines) + + expected := map[string]state.State{ + "foo": state.Running, + "bar": state.Running, + "baz": state.Running, + "spam": state.Running, + "eggs": state.Running, + "ham": state.Running, + } + + for _, machine := range machines { + state, _ := machine.Driver.GetState() + if expected[machine.Name] != state { + t.Fatalf("Expected machine %s to have state %s, got state %s", machine.Name, state, expected[machine.Name]) + } + } + + // OK, now let's stop them all! + expected = map[string]state.State{ + "foo": state.Stopped, + "bar": state.Stopped, + "baz": state.Stopped, + "spam": state.Stopped, + "eggs": state.Stopped, + "ham": state.Stopped, + } + + runActionForeachMachine("stop", machines) + + for _, machine := range machines { + state, _ := machine.Driver.GetState() + if expected[machine.Name] != state { + t.Fatalf("Expected machine %s to have state %s, got state %s", machine.Name, state, expected[machine.Name]) + } + } +} diff --git a/host_test.go b/host_test.go index 427518ad4c..f1a7816646 100644 --- a/host_test.go +++ b/host_test.go @@ -28,7 +28,6 @@ func getTestStore() (*Store, error) { os.Exit(1) } - os.Setenv("MACHINE_DIR", tmpDir) return NewStore(tmpDir, hostTestCaCert, hostTestPrivateKey), nil } diff --git a/store_test.go b/store_test.go index 15edf8a290..524c0afa17 100644 --- a/store_test.go +++ b/store_test.go @@ -6,7 +6,10 @@ import ( "testing" _ "github.com/docker/machine/drivers/none" - "github.com/docker/machine/utils" +) + +const ( + TestStoreDir = ".store-test" ) type DriverOptionsMock struct { @@ -26,7 +29,7 @@ func (d DriverOptionsMock) Bool(key string) bool { } func clearHosts() error { - return os.RemoveAll(utils.GetMachineDir()) + return os.RemoveAll(TestStoreDir) } func getDefaultTestDriverFlags() *DriverOptionsMock { @@ -49,7 +52,7 @@ func TestStoreCreate(t *testing.T) { flags := getDefaultTestDriverFlags() - store := NewStore("", "", "") + store := NewStore(TestStoreDir, "", "") host, err := store.Create("test", "none", flags) if err != nil { @@ -58,7 +61,7 @@ func TestStoreCreate(t *testing.T) { if host.Name != "test" { t.Fatal("Host name is incorrect") } - path := filepath.Join(utils.GetMachineDir(), "test") + path := filepath.Join(TestStoreDir, "test") if _, err := os.Stat(path); os.IsNotExist(err) { t.Fatalf("Host path doesn't exist: %s", path) } @@ -71,12 +74,12 @@ func TestStoreRemove(t *testing.T) { flags := getDefaultTestDriverFlags() - store := NewStore("", "", "") + store := NewStore(TestStoreDir, "", "") _, err := store.Create("test", "none", flags) if err != nil { t.Fatal(err) } - path := filepath.Join(utils.GetMachineDir(), "test") + path := filepath.Join(TestStoreDir, "test") if _, err := os.Stat(path); os.IsNotExist(err) { t.Fatalf("Host path doesn't exist: %s", path) } @@ -96,7 +99,7 @@ func TestStoreList(t *testing.T) { flags := getDefaultTestDriverFlags() - store := NewStore("", "", "") + store := NewStore(TestStoreDir, "", "") _, err := store.Create("test", "none", flags) if err != nil { t.Fatal(err) @@ -117,7 +120,7 @@ func TestStoreExists(t *testing.T) { flags := getDefaultTestDriverFlags() - store := NewStore("", "", "") + store := NewStore(TestStoreDir, "", "") exists, err := store.Exists("test") if exists { t.Fatal("Exists returned true when it should have been false") @@ -144,13 +147,13 @@ func TestStoreLoad(t *testing.T) { flags := getDefaultTestDriverFlags() flags.Data["url"] = expectedURL - store := NewStore("", "", "") + store := NewStore(TestStoreDir, "", "") _, err := store.Create("test", "none", flags) if err != nil { t.Fatal(err) } - store = NewStore("", "", "") + store = NewStore(TestStoreDir, "", "") host, err := store.Load("test") if host.Name != "test" { t.Fatal("Host name is incorrect") @@ -171,7 +174,7 @@ func TestStoreGetSetActive(t *testing.T) { flags := getDefaultTestDriverFlags() - store := NewStore("", "", "") + store := NewStore(TestStoreDir, "", "") // No hosts set host, err := store.GetActive()