diff --git a/drivers/virtualbox/virtualbox.go b/drivers/virtualbox/virtualbox.go index 8e7005d72d..4e42202548 100644 --- a/drivers/virtualbox/virtualbox.go +++ b/drivers/virtualbox/virtualbox.go @@ -177,17 +177,14 @@ func (d *Driver) Create() error { if err := os.Mkdir(imgPath, 0700); err != nil { return err } - } if d.Boot2DockerURL != "" { isoURL = d.Boot2DockerURL log.Infof("Downloading %s from %s...", isoFilename, isoURL) - if err := b2dutils.DownloadISO(commonIsoPath, isoFilename, isoURL); err != nil { + if err := b2dutils.DownloadISO(imgPath, isoFilename, isoURL); err != nil { return err - } - } else { // todo: check latest release URL, download if it's new // until then always use "latest" @@ -202,11 +199,11 @@ func (d *Driver) Create() error { return err } } + } - isoDest := filepath.Join(d.storePath, isoFilename) - if err := utils.CopyFile(commonIsoPath, isoDest); err != nil { - return err - } + isoDest := filepath.Join(d.storePath, isoFilename) + if err := utils.CopyFile(commonIsoPath, isoDest); err != nil { + return err } log.Infof("Creating SSH key...") diff --git a/utils/b2d.go b/utils/b2d.go index 369159de29..f7fed62958 100644 --- a/utils/b2d.go +++ b/utils/b2d.go @@ -7,6 +7,7 @@ import ( "io/ioutil" "net" "net/http" + "net/url" "os" "path/filepath" "time" @@ -84,29 +85,46 @@ func (b *B2dUtils) GetLatestBoot2DockerReleaseURL() (string, error) { } // Download boot2docker ISO image for the given tag and save it at dest. -func (b *B2dUtils) DownloadISO(dir, file, url string) error { - client := getClient() - rsp, err := client.Get(url) - if err != nil { - return err +func (b *B2dUtils) DownloadISO(dir, file, isoUrl string) error { + u, err := url.Parse(isoUrl) + var src io.ReadCloser + if u.Scheme == "file" { + s, err := os.Open(u.Path) + if err != nil { + return err + } + src = s + } else { + client := getClient() + s, err := client.Get(isoUrl) + if err != nil { + return err + } + src = s.Body } - defer rsp.Body.Close() + + defer src.Close() // Download to a temp file first then rename it to avoid partial download. f, err := ioutil.TempFile(dir, file+".tmp") if err != nil { return err } + defer os.Remove(f.Name()) - if _, err := io.Copy(f, rsp.Body); err != nil { + + if _, err := io.Copy(f, src); err != nil { // TODO: display download progress? return err } + if err := f.Close(); err != nil { return err } + if err := os.Rename(f.Name(), filepath.Join(dir, file)); err != nil { return err } + return nil }