diff --git a/commands.go b/commands.go index 324afbba04..ddf9e852ff 100644 --- a/commands.go +++ b/commands.go @@ -9,7 +9,9 @@ import ( "path/filepath" "sort" "strings" + "sync" "text/tabwriter" + "time" log "github.com/Sirupsen/logrus" "github.com/codegangsta/cli" @@ -382,12 +384,6 @@ func cmdIp(c *cli.Context) { fmt.Println(ip) } -func cmdKill(c *cli.Context) { - if err := getHost(c).Driver.Kill(); err != nil { - log.Fatal(err) - } -} - func cmdLs(c *cli.Context) { quiet := c.Bool("quiet") store := NewStore(c.GlobalString("storage-path"), c.GlobalString("tls-ca-cert"), c.GlobalString("tls-ca-key")) @@ -456,12 +452,6 @@ func cmdLs(c *cli.Context) { w.Flush() } -func cmdRestart(c *cli.Context) { - if err := getHost(c).Driver.Restart(); err != nil { - log.Fatal(err) - } -} - func cmdRm(c *cli.Context) { if len(c.Args()) == 0 { cli.ShowCommandHelp(c, "rm") @@ -573,20 +563,79 @@ func cmdSsh(c *cli.Context) { } } +// machineCommand maps the command name to the corresponding machine command +// it is intended to be used by runCommand using a waitgroup and error channel +// to enable running commands across multiple machines asynchronously +func machineCommand(name string, machine *Host, wg *sync.WaitGroup, errorChan chan<- error) { + commands := map[string]interface{}{ + "start": machine.Start, + "stop": machine.Stop, + "restart": machine.Driver.Restart, + "kill": machine.Driver.Kill, + "upgrade": machine.Upgrade, + } + + log.Debugf("command=%s machine=%s", name, machine.Name) + + if err := commands[name].(func() error)(); err != nil { + errorChan <- err + } + + wg.Done() +} + +// runCommand will run the command across multiple machines +func runCommand(name string, c *cli.Context) error { + errorChan := make(chan error) + go func() { + err := <-errorChan + log.Errorf(err.Error()) + }() + + wg := &sync.WaitGroup{} + + machines, err := getHosts(c) + if err != nil { + return err + } + + for _, machine := range machines { + wg.Add(1) + go machineCommand(name, machine, wg, errorChan) + time.Sleep(1 * time.Second) + } + + wg.Wait() + + return nil +} + func cmdStart(c *cli.Context) { - if err := getHost(c).Start(); err != nil { + if err := runCommand("start", c); err != nil { log.Fatal(err) } } func cmdStop(c *cli.Context) { - if err := getHost(c).Stop(); err != nil { + if err := runCommand("stop", c); err != nil { + log.Fatal(err) + } +} + +func cmdRestart(c *cli.Context) { + if err := runCommand("restart", c); err != nil { + log.Fatal(err) + } +} + +func cmdKill(c *cli.Context) { + if err := runCommand("kill", c); err != nil { log.Fatal(err) } } func cmdUpgrade(c *cli.Context) { - if err := getHost(c).Upgrade(); err != nil { + if err := runCommand("upgrade", c); err != nil { log.Fatal(err) } } @@ -610,6 +659,31 @@ func cmdNotFound(c *cli.Context, command string) { ) } +func getHosts(c *cli.Context) ([]*Host, error) { + machines := []*Host{} + for _, n := range c.Args() { + machine, err := loadMachine(n, c) + if err != nil { + return nil, err + } + + machines = append(machines, machine) + } + + return machines, nil +} + +func loadMachine(name string, c *cli.Context) (*Host, error) { + store := NewStore(c.GlobalString("storage-path"), c.GlobalString("tls-ca-cert"), c.GlobalString("tls-ca-key")) + + machine, err := store.Load(name) + if err != nil { + return nil, err + } + + return machine, nil +} + func getHost(c *cli.Context) *Host { name := c.Args().First() store := NewStore(c.GlobalString("storage-path"), c.GlobalString("tls-ca-cert"), c.GlobalString("tls-ca-key")) diff --git a/commands_test.go b/commands_test.go index 0438847997..96d8163d9f 100644 --- a/commands_test.go +++ b/commands_test.go @@ -1,10 +1,13 @@ package main import ( + "flag" + "fmt" "io/ioutil" "os/exec" "testing" + "github.com/codegangsta/cli" drivers "github.com/docker/machine/drivers" "github.com/docker/machine/state" ) @@ -81,6 +84,49 @@ func (d *FakeDriver) GetSSHCommand(args ...string) (*exec.Cmd, error) { return &exec.Cmd{}, nil } +func TestGetHosts(t *testing.T) { + if err := clearHosts(); err != nil { + t.Fatal(err) + } + + flags := getDefaultTestDriverFlags() + + store := NewStore(TestStoreDir, "", "") + + hostA, hostAerr := store.Create("test-a", "none", flags) + if hostAerr != nil { + t.Fatal(hostAerr) + } + + hostB, hostBerr := store.Create("test-b", "none", flags) + if hostBerr != nil { + t.Fatal(hostBerr) + } + + set := flag.NewFlagSet("start", 0) + set.Parse([]string{"test-a", "test-b"}) + + c := cli.NewContext(nil, set, nil) + globalSet := flag.NewFlagSet("-d", 0) + globalSet.String("-d", "none", "driver") + globalSet.String("storage-path", TestStoreDir, "storage path") + globalSet.String("tls-ca-cert", "", "") + globalSet.String("tls-ca-key", "", "") + + hosts, err := getHosts(c) + if err != nil { + t.Fatal(err) + } + + fmt.Println(hosts) + fmt.Println(hostA) + fmt.Println(hostB) + + if err := clearHosts(); err != nil { + t.Fatal(err) + } +} + func TestGetHostState(t *testing.T) { storePath, err := ioutil.TempDir("", ".docker") if err != nil { diff --git a/store_test.go b/store_test.go index 15edf8a290..6e5dfbcf63 100644 --- a/store_test.go +++ b/store_test.go @@ -9,6 +9,15 @@ import ( "github.com/docker/machine/utils" ) +const ( + TestStoreDir = ".store-test" +) + +func init() { + + os.Setenv("MACHINE_STORAGE_PATH", TestStoreDir) +} + type DriverOptionsMock struct { Data map[string]interface{} } @@ -26,7 +35,7 @@ func (d DriverOptionsMock) Bool(key string) bool { } func clearHosts() error { - return os.RemoveAll(utils.GetMachineDir()) + return os.RemoveAll(TestStoreDir) } func getDefaultTestDriverFlags() *DriverOptionsMock {