fix docker deps to standard rev

Signed-off-by: Evan Hazlett <ejhazlett@gmail.com>
This commit is contained in:
Evan Hazlett 2015-02-02 21:00:01 -05:00
parent 9f152143de
commit bcc40f8f85
No known key found for this signature in database
GPG Key ID: A519480096146526
26 changed files with 2110 additions and 617 deletions

78
Godeps/Godeps.json generated
View File

@ -1,6 +1,6 @@
{
"ImportPath": "github.com/docker/machine",
"GoVersion": "go1.3.3",
"GoVersion": "go1.4.1",
"Deps": [
{
"ImportPath": "code.google.com/p/goauth2/oauth",
@ -29,83 +29,83 @@
},
{
"ImportPath": "github.com/docker/docker/api",
"Comment": "v1.2.0-1619-g831d09f",
"Rev": "831d09f8b535a1e2939bdacb02032d238e8dd249"
"Comment": "v1.4.1",
"Rev": "5bc2ff8a36e9a768e8b479de4fe3ea9c9daf4121"
},
{
"ImportPath": "github.com/docker/docker/dockerversion",
"Comment": "v1.2.0-1619-g831d09f",
"Rev": "831d09f8b535a1e2939bdacb02032d238e8dd249"
"Comment": "v1.4.1",
"Rev": "5bc2ff8a36e9a768e8b479de4fe3ea9c9daf4121"
},
{
"ImportPath": "github.com/docker/docker/engine",
"Comment": "v1.2.0-1619-g831d09f",
"Rev": "831d09f8b535a1e2939bdacb02032d238e8dd249"
"Comment": "v1.4.1",
"Rev": "5bc2ff8a36e9a768e8b479de4fe3ea9c9daf4121"
},
{
"ImportPath": "github.com/docker/docker/pkg/archive",
"Comment": "v1.2.0-1619-g831d09f",
"Rev": "831d09f8b535a1e2939bdacb02032d238e8dd249"
"Comment": "v1.4.1",
"Rev": "5bc2ff8a36e9a768e8b479de4fe3ea9c9daf4121"
},
{
"ImportPath": "github.com/docker/docker/pkg/fileutils",
"Comment": "v1.2.0-1619-g831d09f",
"Rev": "831d09f8b535a1e2939bdacb02032d238e8dd249"
"Comment": "v1.4.1",
"Rev": "5bc2ff8a36e9a768e8b479de4fe3ea9c9daf4121"
},
{
"ImportPath": "github.com/docker/docker/pkg/ioutils",
"Comment": "v1.2.0-1619-g831d09f",
"Rev": "831d09f8b535a1e2939bdacb02032d238e8dd249"
"Comment": "v1.4.1",
"Rev": "5bc2ff8a36e9a768e8b479de4fe3ea9c9daf4121"
},
{
"ImportPath": "github.com/docker/docker/pkg/parsers",
"Comment": "v1.2.0-1619-g831d09f",
"Rev": "831d09f8b535a1e2939bdacb02032d238e8dd249"
"Comment": "v1.4.1",
"Rev": "5bc2ff8a36e9a768e8b479de4fe3ea9c9daf4121"
},
{
"ImportPath": "github.com/docker/docker/pkg/pools",
"Comment": "v1.2.0-1619-g831d09f",
"Rev": "831d09f8b535a1e2939bdacb02032d238e8dd249"
"Comment": "v1.4.1",
"Rev": "5bc2ff8a36e9a768e8b479de4fe3ea9c9daf4121"
},
{
"ImportPath": "github.com/docker/docker/pkg/promise",
"Comment": "v1.2.0-1619-g831d09f",
"Rev": "831d09f8b535a1e2939bdacb02032d238e8dd249"
"Comment": "v1.4.1",
"Rev": "5bc2ff8a36e9a768e8b479de4fe3ea9c9daf4121"
},
{
"ImportPath": "github.com/docker/docker/pkg/system",
"Comment": "v1.2.0-1619-g831d09f",
"Rev": "831d09f8b535a1e2939bdacb02032d238e8dd249"
"Comment": "v1.4.1",
"Rev": "5bc2ff8a36e9a768e8b479de4fe3ea9c9daf4121"
},
{
"ImportPath": "github.com/docker/docker/pkg/term",
"Comment": "v1.2.0-1619-g831d09f",
"Rev": "831d09f8b535a1e2939bdacb02032d238e8dd249"
"Comment": "v1.4.1",
"Rev": "5bc2ff8a36e9a768e8b479de4fe3ea9c9daf4121"
},
{
"ImportPath": "github.com/docker/docker/pkg/timeutils",
"Comment": "v1.2.0-1619-g831d09f",
"Rev": "831d09f8b535a1e2939bdacb02032d238e8dd249"
"Comment": "v1.4.1",
"Rev": "5bc2ff8a36e9a768e8b479de4fe3ea9c9daf4121"
},
{
"ImportPath": "github.com/docker/docker/pkg/units",
"Comment": "v1.2.0-1619-g831d09f",
"Rev": "831d09f8b535a1e2939bdacb02032d238e8dd249"
"Comment": "v1.4.1",
"Rev": "5bc2ff8a36e9a768e8b479de4fe3ea9c9daf4121"
},
{
"ImportPath": "github.com/docker/docker/pkg/version",
"Comment": "v1.2.0-1619-g831d09f",
"Rev": "831d09f8b535a1e2939bdacb02032d238e8dd249"
"Comment": "v1.4.1",
"Rev": "5bc2ff8a36e9a768e8b479de4fe3ea9c9daf4121"
},
{
"ImportPath": "github.com/docker/docker/utils",
"Comment": "v1.2.0-1619-g831d09f",
"Rev": "831d09f8b535a1e2939bdacb02032d238e8dd249"
"Comment": "v1.4.1",
"Rev": "5bc2ff8a36e9a768e8b479de4fe3ea9c9daf4121"
},
{
"ImportPath": "github.com/docker/docker/vendor/src/code.google.com/p/go/src/pkg/archive/tar",
"Comment": "v1.2.0-1619-g831d09f",
"Rev": "831d09f8b535a1e2939bdacb02032d238e8dd249"
"Comment": "v1.4.1",
"Rev": "5bc2ff8a36e9a768e8b479de4fe3ea9c9daf4121"
},
{
"ImportPath": "github.com/docker/libtrust",
@ -115,10 +115,6 @@
"ImportPath": "github.com/google/go-querystring/query",
"Rev": "30f7a39f4a218feb5325f3aebc60c32a572a8274"
},
{
"ImportPath": "github.com/smartystreets/go-aws-auth",
"Rev": "1f0db8c0ee6362470abe06a94e3385927ed72a4b"
},
{
"ImportPath": "github.com/mitchellh/mapstructure",
"Rev": "740c764bc6149d3f1806231418adb9f52c11bcbf"
@ -133,6 +129,10 @@
"Comment": "v1.0.0-232-g2e7ab37",
"Rev": "2e7ab378257b8723e02cbceac7410be4db286436"
},
{
"ImportPath": "github.com/smartystreets/go-aws-auth",
"Rev": "1f0db8c0ee6362470abe06a94e3385927ed72a4b"
},
{
"ImportPath": "github.com/tent/http-link-go",
"Rev": "ac974c61c2f990f4115b119354b5e0b47550e888"
@ -146,6 +146,10 @@
"ImportPath": "golang.org/x/crypto/ssh",
"Rev": "1fbbd62cfec66bd39d91e97749579579d4d3037e"
},
{
"ImportPath": "golang.org/x/net/context",
"Rev": "0fd82b9a2db248ee7d5a68b10a57ccf5d4e73d06"
},
{
"ImportPath": "google.golang.org/api/compute/v1",
"Rev": "aa91ac681e18e52b1a0dfe29b9d8354e88c0dcf5"

View File

@ -1,159 +0,0 @@
package client
import (
"bufio"
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
"log"
"os"
"strings"
"time"
"github.com/docker/libtrust"
)
// NewIdentityAuthTLSConfig creates a tls.Config for the client to use for
// libtrust identity authentication
func NewIdentityAuthTLSConfig(trustKey libtrust.PrivateKey, knownHostsPath, proto, addr string) (*tls.Config, error) {
tlsConfig := createTLSConfig()
// Load known hosts
knownHosts, err := libtrust.LoadKeySetFile(knownHostsPath)
if err != nil {
return nil, fmt.Errorf("Could not load trusted hosts file: %s", err)
}
// Generate CA pool from known hosts
allowedHosts, err := libtrust.FilterByHosts(knownHosts, addr, false)
if err != nil {
return nil, fmt.Errorf("Error filtering hosts: %s", err)
}
certPool, err := libtrust.GenerateCACertPool(trustKey, allowedHosts)
if err != nil {
return nil, fmt.Errorf("Could not create CA pool: %s", err)
}
tlsConfig.ServerName = "docker"
tlsConfig.RootCAs = certPool
// Generate client cert from trust key
x509Cert, err := libtrust.GenerateSelfSignedClientCert(trustKey)
if err != nil {
return nil, fmt.Errorf("Certificate generation error: %s", err)
}
tlsConfig.Certificates = []tls.Certificate{{
Certificate: [][]byte{x509Cert.Raw},
PrivateKey: trustKey.CryptoPrivateKey(),
Leaf: x509Cert,
}}
// Connect to server to see if it is a known host
tlsConfig.InsecureSkipVerify = true
testConn, err := tls.Dial(proto, addr, tlsConfig)
if err != nil {
return nil, fmt.Errorf("TLS Handshake error: %s", err)
}
opts := x509.VerifyOptions{
Roots: tlsConfig.RootCAs,
CurrentTime: time.Now(),
DNSName: tlsConfig.ServerName,
Intermediates: x509.NewCertPool(),
}
certs := testConn.ConnectionState().PeerCertificates
for i, cert := range certs {
if i == 0 {
continue
}
opts.Intermediates.AddCert(cert)
}
_, err = certs[0].Verify(opts)
if err != nil {
if _, ok := err.(x509.UnknownAuthorityError); ok {
pubKey, err := libtrust.FromCryptoPublicKey(certs[0].PublicKey)
if err != nil {
return nil, fmt.Errorf("Error extracting public key from certificate: %s", err)
}
// If server is not a known host, prompt user to ask whether it should
// be trusted and add to the known hosts file
if promptUnknownKey(pubKey, addr) {
pubKey.AddExtendedField("hosts", []string{addr})
err = libtrust.AddKeySetFile(knownHostsPath, pubKey)
if err != nil {
return nil, fmt.Errorf("Error saving updated host keys file: %s", err)
}
ca, err := libtrust.GenerateCACert(trustKey, pubKey)
if err != nil {
return nil, fmt.Errorf("Error generating CA: %s", err)
}
tlsConfig.RootCAs.AddCert(ca)
} else {
return nil, fmt.Errorf("Cancelling request due to invalid certificate")
}
} else {
return nil, fmt.Errorf("TLS verification error: %s", err)
}
}
testConn.Close()
tlsConfig.InsecureSkipVerify = false
return tlsConfig, nil
}
// NewCertAuthTLSConfig creates a tls.Config for the client to use for
// certificate authentication
func NewCertAuthTLSConfig(caPath, certPath, keyPath string) (*tls.Config, error) {
tlsConfig := createTLSConfig()
// Verify the server against a CA certificate?
if caPath != "" {
certPool := x509.NewCertPool()
file, err := ioutil.ReadFile(caPath)
if err != nil {
return nil, fmt.Errorf("Couldn't read ca cert %s: %s", caPath, err)
}
certPool.AppendCertsFromPEM(file)
tlsConfig.RootCAs = certPool
} else {
tlsConfig.InsecureSkipVerify = true
}
// Try to load and send client certificates
if certPath != "" && keyPath != "" {
_, errCert := os.Stat(certPath)
_, errKey := os.Stat(keyPath)
if errCert == nil && errKey == nil {
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
if err != nil {
return nil, fmt.Errorf("Couldn't load X509 key pair: %s. Key encrypted?", err)
}
tlsConfig.Certificates = []tls.Certificate{cert}
}
}
return tlsConfig, nil
}
// createTLSConfig creates the base tls.Config used by auth methods with some
// sensible defaults
func createTLSConfig() *tls.Config {
return &tls.Config{
// Avoid fallback to SSL protocols < TLS1.0
MinVersion: tls.VersionTLS10,
}
}
func promptUnknownKey(key libtrust.PublicKey, host string) bool {
fmt.Printf("The authenticity of host %q can't be established.\nRemote key ID %s\n", host, key.KeyID())
fmt.Printf("Are you sure you want to continue connecting (yes/no)? ")
reader := bufio.NewReader(os.Stdin)
line, _, err := reader.ReadLine()
if err != nil {
log.Fatalf("Error reading input: %s", err)
}
input := strings.TrimSpace(strings.ToLower(string(line)))
return input == "yes" || input == "y"
}

View File

@ -3,6 +3,7 @@ package client
import (
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"io"
"net"
@ -104,6 +105,16 @@ func (cli *DockerCli) LoadConfigFile() (err error) {
return err
}
func (cli *DockerCli) CheckTtyInput(attachStdin, ttyMode bool) error {
// In order to attach to a container tty, input stream for the client must
// be a tty itself: redirecting or piping the client standard input is
// incompatible with `docker run -t`, `docker exec -t` or `docker attach`.
if ttyMode && attachStdin && !cli.isTerminalIn {
return errors.New("cannot enable tty mode on non tty input")
}
return nil
}
func NewDockerCli(in io.ReadCloser, out, err io.Writer, key libtrust.PrivateKey, proto, addr string, tlsConfig *tls.Config) *DockerCli {
var (
inFd uintptr

View File

@ -38,6 +38,7 @@ import (
"github.com/docker/docker/pkg/term"
"github.com/docker/docker/pkg/timeutils"
"github.com/docker/docker/pkg/units"
"github.com/docker/docker/pkg/urlutil"
"github.com/docker/docker/registry"
"github.com/docker/docker/runconfig"
"github.com/docker/docker/utils"
@ -115,13 +116,13 @@ func (cli *DockerCli) CmdBuild(args ...string) error {
} else {
context = ioutil.NopCloser(buf)
}
} else if utils.IsURL(cmd.Arg(0)) && (!utils.IsGIT(cmd.Arg(0)) || !hasGit) {
} else if urlutil.IsURL(cmd.Arg(0)) && (!urlutil.IsGitURL(cmd.Arg(0)) || !hasGit) {
isRemote = true
} else {
root := cmd.Arg(0)
if utils.IsGIT(root) {
if urlutil.IsGitURL(root) {
remoteURL := cmd.Arg(0)
if !utils.ValidGitTransport(remoteURL) {
if !urlutil.IsGitTransport(remoteURL) {
remoteURL = "https://" + remoteURL
}
@ -179,7 +180,7 @@ func (cli *DockerCli) CmdBuild(args ...string) error {
// FIXME: ProgressReader shouldn't be this annoying to use
if context != nil {
sf := utils.NewStreamFormatter(false)
body = utils.ProgressReader(context, 0, cli.err, sf, true, "", "Sending build context to Docker daemon")
body = utils.ProgressReader(context, 0, cli.out, sf, true, "", "Sending build context to Docker daemon")
}
// Send the build context
v := &url.Values{}
@ -543,6 +544,9 @@ func (cli *DockerCli) CmdInfo(args ...string) error {
if initPath := remoteInfo.Get("InitPath"); initPath != "" {
fmt.Fprintf(cli.out, "Init Path: %s\n", initPath)
}
if root := remoteInfo.Get("DockerRootDir"); root != "" {
fmt.Fprintf(cli.out, "Docker Root Dir: %s\n", root)
}
}
if len(remoteInfo.GetList("IndexServerAddress")) != 0 {
@ -881,7 +885,7 @@ func (cli *DockerCli) CmdInspect(args ...string) error {
// Remove trailing ','
indented.Truncate(indented.Len() - 1)
}
indented.WriteByte(']')
indented.WriteString("]\n")
if tmpl == nil {
if _, err := io.Copy(cli.out, indented); err != nil {
@ -1327,7 +1331,7 @@ func (cli *DockerCli) CmdPull(args ...string) error {
}
func (cli *DockerCli) CmdImages(args ...string) error {
cmd := cli.Subcmd("images", "[NAME]", "List images")
cmd := cli.Subcmd("images", "[REPOSITORY]", "List images")
quiet := cmd.Bool([]string{"q", "-quiet"}, false, "Only show numeric IDs")
all := cmd.Bool([]string{"a", "-all"}, false, "Show all images (by default filter out the intermediate image layers)")
noTrunc := cmd.Bool([]string{"#notrunc", "-no-trunc"}, false, "Don't truncate output")
@ -1781,6 +1785,10 @@ func (cli *DockerCli) CmdEvents(args ...string) error {
cmd := cli.Subcmd("events", "", "Get real time events from the server")
since := cmd.String([]string{"#since", "-since"}, "", "Show all events created since timestamp")
until := cmd.String([]string{"-until"}, "", "Stream events until this timestamp")
flFilter := opts.NewListOpts(nil)
cmd.Var(&flFilter, []string{"f", "-filter"}, "Provide filter values (i.e. 'event=stop')")
if err := cmd.Parse(args); err != nil {
return nil
}
@ -1790,9 +1798,20 @@ func (cli *DockerCli) CmdEvents(args ...string) error {
return nil
}
var (
v = url.Values{}
loc = time.FixedZone(time.Now().Zone())
v = url.Values{}
loc = time.FixedZone(time.Now().Zone())
eventFilterArgs = filters.Args{}
)
// Consolidate all filter flags, and sanity check them early.
// They'll get process in the daemon/server.
for _, f := range flFilter.GetAll() {
var err error
eventFilterArgs, err = filters.ParseFlag(f, eventFilterArgs)
if err != nil {
return err
}
}
var setTime = func(key, value string) {
format := timeutils.RFC3339NanoFixed
if len(value) < len(format) {
@ -1810,6 +1829,13 @@ func (cli *DockerCli) CmdEvents(args ...string) error {
if *until != "" {
setTime("until", *until)
}
if len(eventFilterArgs) > 0 {
filterJson, err := filters.ToParam(eventFilterArgs)
if err != nil {
return err
}
v.Set("filters", filterJson)
}
if err := cli.stream("GET", "/events?"+v.Encode(), nil, cli.out, nil); err != nil {
return err
}
@ -1948,6 +1974,10 @@ func (cli *DockerCli) CmdAttach(args ...string) error {
tty = config.GetBool("Tty")
)
if err := cli.CheckTtyInput(!*noStdin, tty); err != nil {
return err
}
if tty && cli.isTerminalOut {
if err := cli.monitorTtySize(cmd.Arg(0), false); err != nil {
log.Debugf("Error monitoring TTY size: %s", err)
@ -2262,7 +2292,11 @@ func (cli *DockerCli) CmdRun(args ...string) error {
return nil
}
if *flDetach {
if !*flDetach {
if err := cli.CheckTtyInput(config.AttachStdin, config.Tty); err != nil {
return err
}
} else {
if fl := cmd.Lookup("attach"); fl != nil {
flAttach = fl.Value.(*opts.ListOpts)
if flAttach.Len() != 0 {
@ -2574,10 +2608,16 @@ func (cli *DockerCli) CmdExec(args ...string) error {
return nil
}
if execConfig.Detach {
if !execConfig.Detach {
if err := cli.CheckTtyInput(execConfig.AttachStdin, execConfig.Tty); err != nil {
return err
}
} else {
if _, _, err := readBody(cli.call("POST", "/exec/"+execID+"/start", execConfig, false)); err != nil {
return err
}
// For now don't print this - wait for when we support exec wait()
// fmt.Fprintf(cli.out, "%s\n", execID)
return nil
}
@ -2640,5 +2680,14 @@ func (cli *DockerCli) CmdExec(args ...string) error {
return err
}
var status int
if _, status, err = getExecExitCode(cli, execID); err != nil {
return err
}
if status != 0 {
return &utils.StatusError{StatusCode: status}
}
return nil
}

View File

@ -2,6 +2,7 @@ package client
import (
"crypto/tls"
"errors"
"fmt"
"io"
"net"
@ -10,6 +11,7 @@ import (
"os"
"runtime"
"strings"
"time"
log "github.com/Sirupsen/logrus"
"github.com/docker/docker/api"
@ -19,9 +21,99 @@ import (
"github.com/docker/docker/pkg/term"
)
type tlsClientCon struct {
*tls.Conn
rawConn net.Conn
}
func (c *tlsClientCon) CloseWrite() error {
// Go standard tls.Conn doesn't provide the CloseWrite() method so we do it
// on its underlying connection.
if cwc, ok := c.rawConn.(interface {
CloseWrite() error
}); ok {
return cwc.CloseWrite()
}
return nil
}
func tlsDial(network, addr string, config *tls.Config) (net.Conn, error) {
return tlsDialWithDialer(new(net.Dialer), network, addr, config)
}
// We need to copy Go's implementation of tls.Dial (pkg/cryptor/tls/tls.go) in
// order to return our custom tlsClientCon struct which holds both the tls.Conn
// object _and_ its underlying raw connection. The rationale for this is that
// we need to be able to close the write end of the connection when attaching,
// which tls.Conn does not provide.
func tlsDialWithDialer(dialer *net.Dialer, network, addr string, config *tls.Config) (net.Conn, error) {
// We want the Timeout and Deadline values from dialer to cover the
// whole process: TCP connection and TLS handshake. This means that we
// also need to start our own timers now.
timeout := dialer.Timeout
if !dialer.Deadline.IsZero() {
deadlineTimeout := dialer.Deadline.Sub(time.Now())
if timeout == 0 || deadlineTimeout < timeout {
timeout = deadlineTimeout
}
}
var errChannel chan error
if timeout != 0 {
errChannel = make(chan error, 2)
time.AfterFunc(timeout, func() {
errChannel <- errors.New("")
})
}
rawConn, err := dialer.Dial(network, addr)
if err != nil {
return nil, err
}
colonPos := strings.LastIndex(addr, ":")
if colonPos == -1 {
colonPos = len(addr)
}
hostname := addr[:colonPos]
// If no ServerName is set, infer the ServerName
// from the hostname we're connecting to.
if config.ServerName == "" {
// Make a copy to avoid polluting argument or default.
c := *config
c.ServerName = hostname
config = &c
}
conn := tls.Client(rawConn, config)
if timeout == 0 {
err = conn.Handshake()
} else {
go func() {
errChannel <- conn.Handshake()
}()
err = <-errChannel
}
if err != nil {
rawConn.Close()
return nil, err
}
// This is Docker difference with standard's crypto/tls package: returned a
// wrapper which holds both the TLS and raw connections.
return &tlsClientCon{conn, rawConn}, nil
}
func (cli *DockerCli) dial() (net.Conn, error) {
if cli.tlsConfig != nil && cli.proto != "unix" {
return tls.Dial(cli.proto, cli.addr, cli.tlsConfig)
// Notice this isn't Go standard's tls.Dial function
return tlsDial(cli.proto, cli.addr, cli.tlsConfig)
}
return net.Dial(cli.proto, cli.addr)
}
@ -109,12 +201,11 @@ func (cli *DockerCli) hijack(method, path string, setRawTerminal bool, in io.Rea
io.Copy(rwc, in)
log.Debugf("[hijack] End of stdin")
}
if tcpc, ok := rwc.(*net.TCPConn); ok {
if err := tcpc.CloseWrite(); err != nil {
log.Debugf("Couldn't send EOF: %s", err)
}
} else if unixc, ok := rwc.(*net.UnixConn); ok {
if err := unixc.CloseWrite(); err != nil {
if conn, ok := rwc.(interface {
CloseWrite() error
}); ok {
if err := conn.CloseWrite(); err != nil {
log.Debugf("Couldn't send EOF: %s", err)
}
}

View File

@ -234,6 +234,26 @@ func getExitCode(cli *DockerCli, containerId string) (bool, int, error) {
return state.GetBool("Running"), state.GetInt("ExitCode"), nil
}
// getExecExitCode perform an inspect on the exec command. It returns
// the running state and the exit code.
func getExecExitCode(cli *DockerCli, execId string) (bool, int, error) {
stream, _, err := cli.call("GET", "/exec/"+execId+"/json", nil, false)
if err != nil {
// If we can't connect, then the daemon probably died.
if err != ErrConnectionRefused {
return false, -1, err
}
return false, -1, nil
}
var result engine.Env
if err := result.Decode(stream); err != nil {
return false, -1, err
}
return result.GetBool("Running"), result.GetInt("ExitCode"), nil
}
func (cli *DockerCli) monitorTtySize(id string, isExec bool) error {
cli.resizeTty(id, isExec)

View File

@ -67,11 +67,6 @@ func LoadOrCreateTrustKey(trustKeyPath string) (libtrust.PrivateKey, error) {
if err := libtrust.SaveKey(trustKeyPath, trustKey); err != nil {
return nil, fmt.Errorf("Error saving key file: %s", err)
}
dir, file := path.Split(trustKeyPath)
// Save public key
if err := libtrust.SavePublicKey(path.Join(dir, "public-"+file), trustKey.PublicKey()); err != nil {
return nil, fmt.Errorf("Error saving public key file: %s", err)
}
} else if err != nil {
return nil, fmt.Errorf("Error loading key file: %s", err)
}

View File

@ -1,177 +0,0 @@
package server
import (
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
"net"
"os"
"path"
"sync"
"github.com/docker/libtrust"
)
// ClientKeyManager manages client keys on the filesystem
type ClientKeyManager struct {
key libtrust.PrivateKey
clientFile string
clientDir string
clientLock sync.RWMutex
clients []libtrust.PublicKey
configLock sync.Mutex
configs []*tls.Config
}
// NewClientKeyManager loads a new manager from a set of key files
// and managed by the given private key.
func NewClientKeyManager(trustKey libtrust.PrivateKey, clientFile, clientDir string) (*ClientKeyManager, error) {
m := &ClientKeyManager{
key: trustKey,
clientFile: clientFile,
clientDir: clientDir,
}
if err := m.loadKeys(); err != nil {
return nil, err
}
// TODO Start watching file and directory
return m, nil
}
func (c *ClientKeyManager) loadKeys() error {
// Load authorized keys file
var clients []libtrust.PublicKey
if c.clientFile != "" {
fileClients, err := libtrust.LoadKeySetFile(c.clientFile)
if err != nil {
return fmt.Errorf("unable to load authorized keys: %s", err)
}
clients = fileClients
}
// Add clients from authorized keys directory
files, err := ioutil.ReadDir(c.clientDir)
if err != nil && !os.IsNotExist(err) {
return fmt.Errorf("unable to open authorized keys directory: %s", err)
}
for _, f := range files {
if !f.IsDir() {
publicKey, err := libtrust.LoadPublicKeyFile(path.Join(c.clientDir, f.Name()))
if err != nil {
return fmt.Errorf("unable to load authorized key file: %s", err)
}
clients = append(clients, publicKey)
}
}
c.clientLock.Lock()
c.clients = clients
c.clientLock.Unlock()
return nil
}
// RegisterTLSConfig registers a tls configuration to manager
// such that any changes to the keys may be reflected in
// the tls client CA pool
func (c *ClientKeyManager) RegisterTLSConfig(tlsConfig *tls.Config) error {
c.clientLock.RLock()
certPool, err := libtrust.GenerateCACertPool(c.key, c.clients)
if err != nil {
return fmt.Errorf("CA pool generation error: %s", err)
}
c.clientLock.RUnlock()
tlsConfig.ClientCAs = certPool
c.configLock.Lock()
c.configs = append(c.configs, tlsConfig)
c.configLock.Unlock()
return nil
}
// NewIdentityAuthTLSConfig creates a tls.Config for the server to use for
// libtrust identity authentication
func NewIdentityAuthTLSConfig(trustKey libtrust.PrivateKey, clients *ClientKeyManager, addr string) (*tls.Config, error) {
tlsConfig := createTLSConfig()
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
if err := clients.RegisterTLSConfig(tlsConfig); err != nil {
return nil, err
}
// Generate cert
ips, domains, err := parseAddr(addr)
if err != nil {
return nil, err
}
// add default docker domain for docker clients to look for
domains = append(domains, "docker")
x509Cert, err := libtrust.GenerateSelfSignedServerCert(trustKey, domains, ips)
if err != nil {
return nil, fmt.Errorf("certificate generation error: %s", err)
}
tlsConfig.Certificates = []tls.Certificate{{
Certificate: [][]byte{x509Cert.Raw},
PrivateKey: trustKey.CryptoPrivateKey(),
Leaf: x509Cert,
}}
return tlsConfig, nil
}
// NewCertAuthTLSConfig creates a tls.Config for the server to use for
// certificate authentication
func NewCertAuthTLSConfig(caPath, certPath, keyPath string) (*tls.Config, error) {
tlsConfig := createTLSConfig()
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
if err != nil {
return nil, fmt.Errorf("Couldn't load X509 key pair (%s, %s): %s. Key encrypted?", certPath, keyPath, err)
}
tlsConfig.Certificates = []tls.Certificate{cert}
// Verify client certificates against a CA?
if caPath != "" {
certPool := x509.NewCertPool()
file, err := ioutil.ReadFile(caPath)
if err != nil {
return nil, fmt.Errorf("Couldn't read CA certificate: %s", err)
}
certPool.AppendCertsFromPEM(file)
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
tlsConfig.ClientCAs = certPool
}
return tlsConfig, nil
}
func createTLSConfig() *tls.Config {
return &tls.Config{
NextProtos: []string{"http/1.1"},
// Avoid fallback on insecure SSL protocols
MinVersion: tls.VersionTLS10,
}
}
// parseAddr parses an address into an array of IPs and domains
func parseAddr(addr string) ([]net.IP, []string, error) {
host, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, nil, err
}
var domains []string
var ips []net.IP
ip := net.ParseIP(host)
if ip != nil {
ips = []net.IP{ip}
} else {
domains = []string{host}
}
return ips, domains, nil
}

View File

@ -3,7 +3,7 @@ package server
import (
"bufio"
"bytes"
"crypto/tls"
"encoding/base64"
"encoding/json"
"expvar"
@ -18,6 +18,9 @@ import (
"strings"
"syscall"
"crypto/tls"
"crypto/x509"
"code.google.com/p/go.net/websocket"
"github.com/docker/libcontainer/user"
"github.com/gorilla/mux"
@ -62,6 +65,18 @@ func hijackServer(w http.ResponseWriter) (io.ReadCloser, io.Writer, error) {
return conn, conn, nil
}
func closeStreams(streams ...interface{}) {
for _, stream := range streams {
if tcpc, ok := stream.(interface {
CloseWrite() error
}); ok {
tcpc.CloseWrite()
} else if closer, ok := stream.(io.Closer); ok {
closer.Close()
}
}
}
// Check to make sure request's Content-Type is application/json
func checkForJson(r *http.Request) error {
ct := r.Header.Get("Content-Type")
@ -312,6 +327,7 @@ func getEvents(eng *engine.Engine, version version.Version, w http.ResponseWrite
streamJSON(job, w, true)
job.Setenv("since", r.Form.Get("since"))
job.Setenv("until", r.Form.Get("until"))
job.Setenv("filters", r.Form.Get("filters"))
return job.Run()
}
@ -867,20 +883,7 @@ func postContainersAttach(eng *engine.Engine, version version.Version, w http.Re
if err != nil {
return err
}
defer func() {
if tcpc, ok := inStream.(*net.TCPConn); ok {
tcpc.CloseWrite()
} else {
inStream.Close()
}
}()
defer func() {
if tcpc, ok := outStream.(*net.TCPConn); ok {
tcpc.CloseWrite()
} else if closer, ok := outStream.(io.Closer); ok {
closer.Close()
}
}()
defer closeStreams(inStream, outStream)
var errStream io.Writer
@ -953,6 +956,15 @@ func getContainersByName(eng *engine.Engine, version version.Version, w http.Res
return job.Run()
}
func getExecByID(eng *engine.Engine, version version.Version, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
if vars == nil {
return fmt.Errorf("Missing parameter 'id'")
}
var job = eng.Job("execInspect", vars["id"])
streamJSON(job, w, false)
return job.Run()
}
func getImagesByName(eng *engine.Engine, version version.Version, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
if vars == nil {
return fmt.Errorf("Missing parameter")
@ -1121,21 +1133,7 @@ func postContainerExecStart(eng *engine.Engine, version version.Version, w http.
if err != nil {
return err
}
defer func() {
if tcpc, ok := inStream.(*net.TCPConn); ok {
tcpc.CloseWrite()
} else {
inStream.Close()
}
}()
defer func() {
if tcpc, ok := outStream.(*net.TCPConn); ok {
tcpc.CloseWrite()
} else if closer, ok := outStream.(io.Closer); ok {
closer.Close()
}
}()
defer closeStreams(inStream, outStream)
var errStream io.Writer
@ -1246,6 +1244,7 @@ func AttachProfiler(router *mux.Router) {
router.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
router.HandleFunc("/debug/pprof/profile", pprof.Profile)
router.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
router.HandleFunc("/debug/pprof/block", pprof.Handler("block").ServeHTTP)
router.HandleFunc("/debug/pprof/heap", pprof.Handler("heap").ServeHTTP)
router.HandleFunc("/debug/pprof/goroutine", pprof.Handler("goroutine").ServeHTTP)
router.HandleFunc("/debug/pprof/threadcreate", pprof.Handler("threadcreate").ServeHTTP)
@ -1277,6 +1276,7 @@ func createRouter(eng *engine.Engine, logging, enableCors bool, dockerVersion st
"/containers/{name:.*}/top": getContainersTop,
"/containers/{name:.*}/logs": getContainersLogs,
"/containers/{name:.*}/attach/ws": wsContainersAttach,
"/exec/{id:.*}/json": getExecByID,
},
"POST": {
"/auth": postAuth,
@ -1405,6 +1405,33 @@ func lookupGidByName(nameOrGid string) (int, error) {
return -1, fmt.Errorf("Group %s not found", nameOrGid)
}
func setupTls(cert, key, ca string, l net.Listener) (net.Listener, error) {
tlsCert, err := tls.LoadX509KeyPair(cert, key)
if err != nil {
return nil, fmt.Errorf("Couldn't load X509 key pair (%s, %s): %s. Key encrypted?",
cert, key, err)
}
tlsConfig := &tls.Config{
NextProtos: []string{"http/1.1"},
Certificates: []tls.Certificate{tlsCert},
// Avoid fallback on insecure SSL protocols
MinVersion: tls.VersionTLS10,
}
if ca != "" {
certPool := x509.NewCertPool()
file, err := ioutil.ReadFile(ca)
if err != nil {
return nil, fmt.Errorf("Couldn't read CA certificate: %s", err)
}
certPool.AppendCertsFromPEM(file)
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
tlsConfig.ClientCAs = certPool
}
return tls.NewListener(l, tlsConfig), nil
}
func newListener(proto, addr string, bufferRequests bool) (net.Listener, error) {
if bufferRequests {
return listenbuffer.NewListenBuffer(proto, addr, activationLock)
@ -1467,6 +1494,10 @@ func setupUnixHttp(addr string, job *engine.Job) (*HttpServer, error) {
}
func setupTcpHttp(addr string, job *engine.Job) (*HttpServer, error) {
if !strings.HasPrefix(addr, "127.0.0.1") && !job.GetenvBool("TlsVerify") {
log.Infof("/!\\ DON'T BIND ON ANOTHER IP ADDRESS THAN 127.0.0.1 IF YOU DON'T KNOW WHAT YOU'RE DOING /!\\")
}
r, err := createRouter(job.Eng, job.GetenvBool("Logging"), job.GetenvBool("EnableCors"), job.Getenv("Version"))
if err != nil {
return nil, err
@ -1477,37 +1508,16 @@ func setupTcpHttp(addr string, job *engine.Job) (*HttpServer, error) {
return nil, err
}
var tlsConfig *tls.Config
switch job.Getenv("Auth") {
case "identity":
trustKey, err := api.LoadOrCreateTrustKey(job.Getenv("TrustKey"))
if job.GetenvBool("Tls") || job.GetenvBool("TlsVerify") {
var tlsCa string
if job.GetenvBool("TlsVerify") {
tlsCa = job.Getenv("TlsCa")
}
l, err = setupTls(job.Getenv("TlsCert"), job.Getenv("TlsKey"), tlsCa, l)
if err != nil {
return nil, err
}
manager, err := NewClientKeyManager(trustKey, job.Getenv("TrustClients"), job.Getenv("TrustDir"))
if err != nil {
return nil, err
}
if tlsConfig, err = NewIdentityAuthTLSConfig(trustKey, manager, addr); err != nil {
return nil, fmt.Errorf("Error creating TLS config: %s", err)
}
case "cert":
if tlsConfig, err = NewCertAuthTLSConfig(job.Getenv("AuthCa"), job.Getenv("AuthCert"), job.Getenv("AuthKey")); err != nil {
return nil, fmt.Errorf("Error creating TLS config: %s", err)
}
case "none":
tlsConfig = nil
default:
return nil, fmt.Errorf("Unknown auth method: %s", job.Getenv("Auth"))
}
if tlsConfig == nil {
log.Infof("/!\\ DON'T BIND INSECURELY ON A TCP ADDRESS IF YOU DON'T KNOW WHAT YOU'RE DOING /!\\")
} else {
l = tls.NewListener(l, tlsConfig)
}
return &HttpServer{&http.Server{Addr: addr, Handler: r}, l}, nil
}

View File

@ -11,7 +11,6 @@ import (
"time"
"github.com/docker/docker/pkg/ioutils"
"github.com/docker/docker/pkg/timeutils"
"github.com/docker/docker/utils"
)
@ -251,11 +250,3 @@ func (eng *Engine) ParseJob(input string) (*Job, error) {
job.Env().Init(&env)
return job, nil
}
func (eng *Engine) Logf(format string, args ...interface{}) (n int, err error) {
if !eng.Logging {
return 0, nil
}
prefixedFormat := fmt.Sprintf("[%s] [%s] %s\n", time.Now().Format(timeutils.RFC3339NanoFixed), eng, strings.TrimRight(format, "\n"))
return fmt.Fprintf(eng.Stderr, prefixedFormat, args...)
}

View File

@ -99,16 +99,6 @@ func TestEngineString(t *testing.T) {
}
}
func TestEngineLogf(t *testing.T) {
eng := New()
input := "Test log line"
if n, err := eng.Logf("%s\n", input); err != nil {
t.Fatal(err)
} else if n < len(input) {
t.Fatalf("Test: Logf() should print at least as much as the input\ninput=%d\nprinted=%d", len(input), n)
}
}
func TestParseJob(t *testing.T) {
eng := New()
// Verify that the resulting job calls to the right place

View File

@ -42,6 +42,11 @@ type (
Archiver struct {
Untar func(io.Reader, string, *TarOptions) error
}
// breakoutError is used to differentiate errors related to breaking out
// When testing archive breakout in the unit tests, this error is expected
// in order for the test to pass.
breakoutError error
)
var (
@ -287,11 +292,25 @@ func createTarFile(path, extractDir string, hdr *tar.Header, reader io.Reader, L
}
case tar.TypeLink:
if err := os.Link(filepath.Join(extractDir, hdr.Linkname), path); err != nil {
targetPath := filepath.Join(extractDir, hdr.Linkname)
// check for hardlink breakout
if !strings.HasPrefix(targetPath, extractDir) {
return breakoutError(fmt.Errorf("invalid hardlink %q -> %q", targetPath, hdr.Linkname))
}
if err := os.Link(targetPath, path); err != nil {
return err
}
case tar.TypeSymlink:
// path -> hdr.Linkname = targetPath
// e.g. /extractDir/path/to/symlink -> ../2/file = /extractDir/path/2/file
targetPath := filepath.Join(filepath.Dir(path), hdr.Linkname)
// the reason we don't need to check symlinks in the path (with FollowSymlinkInScope) is because
// that symlink would first have to be created, which would be caught earlier, at this very check:
if !strings.HasPrefix(targetPath, extractDir) {
return breakoutError(fmt.Errorf("invalid symlink %q -> %q", path, hdr.Linkname))
}
if err := os.Symlink(hdr.Linkname, path); err != nil {
return err
}
@ -445,30 +464,7 @@ func TarWithOptions(srcPath string, options *TarOptions) (io.ReadCloser, error)
return pipeReader, nil
}
// Untar reads a stream of bytes from `archive`, parses it as a tar archive,
// and unpacks it into the directory at `dest`.
// The archive may be compressed with one of the following algorithms:
// identity (uncompressed), gzip, bzip2, xz.
// FIXME: specify behavior when target path exists vs. doesn't exist.
func Untar(archive io.Reader, dest string, options *TarOptions) error {
if options == nil {
options = &TarOptions{}
}
if archive == nil {
return fmt.Errorf("Empty archive")
}
if options.Excludes == nil {
options.Excludes = []string{}
}
decompressedArchive, err := DecompressStream(archive)
if err != nil {
return err
}
defer decompressedArchive.Close()
func Unpack(decompressedArchive io.Reader, dest string, options *TarOptions) error {
tr := tar.NewReader(decompressedArchive)
trBuf := pools.BufioReader32KPool.Get(nil)
defer pools.BufioReader32KPool.Put(trBuf)
@ -488,6 +484,7 @@ loop:
}
// Normalize name, for safety and for a simple is-root check
// This keeps "../" as-is, but normalizes "/../" to "/"
hdr.Name = filepath.Clean(hdr.Name)
for _, exclude := range options.Excludes {
@ -509,6 +506,13 @@ loop:
}
path := filepath.Join(dest, hdr.Name)
rel, err := filepath.Rel(dest, path)
if err != nil {
return err
}
if strings.HasPrefix(rel, "..") {
return breakoutError(fmt.Errorf("%q is outside of %q", hdr.Name, dest))
}
// If path exits we almost always just want to remove and replace it
// The only exception is when it is a directory *and* the file from
@ -543,10 +547,33 @@ loop:
return err
}
}
return nil
}
// Untar reads a stream of bytes from `archive`, parses it as a tar archive,
// and unpacks it into the directory at `dest`.
// The archive may be compressed with one of the following algorithms:
// identity (uncompressed), gzip, bzip2, xz.
// FIXME: specify behavior when target path exists vs. doesn't exist.
func Untar(archive io.Reader, dest string, options *TarOptions) error {
if archive == nil {
return fmt.Errorf("Empty archive")
}
dest = filepath.Clean(dest)
if options == nil {
options = &TarOptions{}
}
if options.Excludes == nil {
options.Excludes = []string{}
}
decompressedArchive, err := DecompressStream(archive)
if err != nil {
return err
}
defer decompressedArchive.Close()
return Unpack(decompressedArchive, dest, options)
}
func (archiver *Archiver) TarUntar(src, dst string) error {
log.Debugf("TarUntar(%s %s)", src, dst)
archive, err := TarWithOptions(src, &TarOptions{Compression: Uncompressed})
@ -742,20 +769,33 @@ func NewTempArchive(src Archive, dir string) (*TempArchive, error) {
return nil, err
}
size := st.Size()
return &TempArchive{f, size, 0}, nil
return &TempArchive{File: f, Size: size}, nil
}
type TempArchive struct {
*os.File
Size int64 // Pre-computed from Stat().Size() as a convenience
read int64
Size int64 // Pre-computed from Stat().Size() as a convenience
read int64
closed bool
}
// Close closes the underlying file if it's still open, or does a no-op
// to allow callers to try to close the TempArchive multiple times safely.
func (archive *TempArchive) Close() error {
if archive.closed {
return nil
}
archive.closed = true
return archive.File.Close()
}
func (archive *TempArchive) Read(data []byte) (int, error) {
n, err := archive.File.Read(data)
archive.read += int64(n)
if err != nil || archive.read == archive.Size {
archive.File.Close()
archive.Close()
os.Remove(archive.File.Name())
}
return n, err

View File

@ -8,6 +8,8 @@ import (
"os"
"os/exec"
"path"
"path/filepath"
"strings"
"syscall"
"testing"
"time"
@ -214,7 +216,12 @@ func TestTarWithOptions(t *testing.T) {
// Failing prevents the archives from being uncompressed during ADD
func TestTypeXGlobalHeaderDoesNotFail(t *testing.T) {
hdr := tar.Header{Typeflag: tar.TypeXGlobalHeader}
err := createTarFile("pax_global_header", "some_dir", &hdr, nil, true)
tmpDir, err := ioutil.TempDir("", "docker-test-archive-pax-test")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpDir)
err = createTarFile(filepath.Join(tmpDir, "pax_global_header"), tmpDir, &hdr, nil, true)
if err != nil {
t.Fatal(err)
}
@ -403,3 +410,216 @@ func BenchmarkTarUntarWithLinks(b *testing.B) {
os.RemoveAll(target)
}
}
func TestUntarInvalidFilenames(t *testing.T) {
for i, headers := range [][]*tar.Header{
{
{
Name: "../victim/dotdot",
Typeflag: tar.TypeReg,
Mode: 0644,
},
},
{
{
// Note the leading slash
Name: "/../victim/slash-dotdot",
Typeflag: tar.TypeReg,
Mode: 0644,
},
},
} {
if err := testBreakout("untar", "docker-TestUntarInvalidFilenames", headers); err != nil {
t.Fatalf("i=%d. %v", i, err)
}
}
}
func TestUntarInvalidHardlink(t *testing.T) {
for i, headers := range [][]*tar.Header{
{ // try reading victim/hello (../)
{
Name: "dotdot",
Typeflag: tar.TypeLink,
Linkname: "../victim/hello",
Mode: 0644,
},
},
{ // try reading victim/hello (/../)
{
Name: "slash-dotdot",
Typeflag: tar.TypeLink,
// Note the leading slash
Linkname: "/../victim/hello",
Mode: 0644,
},
},
{ // try writing victim/file
{
Name: "loophole-victim",
Typeflag: tar.TypeLink,
Linkname: "../victim",
Mode: 0755,
},
{
Name: "loophole-victim/file",
Typeflag: tar.TypeReg,
Mode: 0644,
},
},
{ // try reading victim/hello (hardlink, symlink)
{
Name: "loophole-victim",
Typeflag: tar.TypeLink,
Linkname: "../victim",
Mode: 0755,
},
{
Name: "symlink",
Typeflag: tar.TypeSymlink,
Linkname: "loophole-victim/hello",
Mode: 0644,
},
},
{ // Try reading victim/hello (hardlink, hardlink)
{
Name: "loophole-victim",
Typeflag: tar.TypeLink,
Linkname: "../victim",
Mode: 0755,
},
{
Name: "hardlink",
Typeflag: tar.TypeLink,
Linkname: "loophole-victim/hello",
Mode: 0644,
},
},
{ // Try removing victim directory (hardlink)
{
Name: "loophole-victim",
Typeflag: tar.TypeLink,
Linkname: "../victim",
Mode: 0755,
},
{
Name: "loophole-victim",
Typeflag: tar.TypeReg,
Mode: 0644,
},
},
} {
if err := testBreakout("untar", "docker-TestUntarInvalidHardlink", headers); err != nil {
t.Fatalf("i=%d. %v", i, err)
}
}
}
func TestUntarInvalidSymlink(t *testing.T) {
for i, headers := range [][]*tar.Header{
{ // try reading victim/hello (../)
{
Name: "dotdot",
Typeflag: tar.TypeSymlink,
Linkname: "../victim/hello",
Mode: 0644,
},
},
{ // try reading victim/hello (/../)
{
Name: "slash-dotdot",
Typeflag: tar.TypeSymlink,
// Note the leading slash
Linkname: "/../victim/hello",
Mode: 0644,
},
},
{ // try writing victim/file
{
Name: "loophole-victim",
Typeflag: tar.TypeSymlink,
Linkname: "../victim",
Mode: 0755,
},
{
Name: "loophole-victim/file",
Typeflag: tar.TypeReg,
Mode: 0644,
},
},
{ // try reading victim/hello (symlink, symlink)
{
Name: "loophole-victim",
Typeflag: tar.TypeSymlink,
Linkname: "../victim",
Mode: 0755,
},
{
Name: "symlink",
Typeflag: tar.TypeSymlink,
Linkname: "loophole-victim/hello",
Mode: 0644,
},
},
{ // try reading victim/hello (symlink, hardlink)
{
Name: "loophole-victim",
Typeflag: tar.TypeSymlink,
Linkname: "../victim",
Mode: 0755,
},
{
Name: "hardlink",
Typeflag: tar.TypeLink,
Linkname: "loophole-victim/hello",
Mode: 0644,
},
},
{ // try removing victim directory (symlink)
{
Name: "loophole-victim",
Typeflag: tar.TypeSymlink,
Linkname: "../victim",
Mode: 0755,
},
{
Name: "loophole-victim",
Typeflag: tar.TypeReg,
Mode: 0644,
},
},
{ // try writing to victim/newdir/newfile with a symlink in the path
{
// this header needs to be before the next one, or else there is an error
Name: "dir/loophole",
Typeflag: tar.TypeSymlink,
Linkname: "../../victim",
Mode: 0755,
},
{
Name: "dir/loophole/newdir/newfile",
Typeflag: tar.TypeReg,
Mode: 0644,
},
},
} {
if err := testBreakout("untar", "docker-TestUntarInvalidSymlink", headers); err != nil {
t.Fatalf("i=%d. %v", i, err)
}
}
}
func TestTempArchiveCloseMultipleTimes(t *testing.T) {
reader := ioutil.NopCloser(strings.NewReader("hello"))
tempArchive, err := NewTempArchive(reader, "")
buf := make([]byte, 10)
n, err := tempArchive.Read(buf)
if n != 5 {
t.Fatalf("Expected to read 5 bytes. Read %d instead", n)
}
for i := 0; i < 3; i++ {
if err = tempArchive.Close(); err != nil {
t.Fatalf("i=%d. Unexpected error closing temp archive: %v", i, err)
}
}
}

View File

@ -15,22 +15,7 @@ import (
"github.com/docker/docker/pkg/system"
)
// ApplyLayer parses a diff in the standard layer format from `layer`, and
// applies it to the directory `dest`.
func ApplyLayer(dest string, layer ArchiveReader) error {
// We need to be able to set any perms
oldmask, err := system.Umask(0)
if err != nil {
return err
}
defer system.Umask(oldmask) // ignore err, ErrNotSupportedPlatform
layer, err = DecompressStream(layer)
if err != nil {
return err
}
func UnpackLayer(dest string, layer ArchiveReader) error {
tr := tar.NewReader(layer)
trBuf := pools.BufioReader32KPool.Get(tr)
defer pools.BufioReader32KPool.Put(trBuf)
@ -90,7 +75,15 @@ func ApplyLayer(dest string, layer ArchiveReader) error {
}
path := filepath.Join(dest, hdr.Name)
rel, err := filepath.Rel(dest, path)
if err != nil {
return err
}
if strings.HasPrefix(rel, "..") {
return breakoutError(fmt.Errorf("%q is outside of %q", hdr.Name, dest))
}
base := filepath.Base(path)
if strings.HasPrefix(base, ".wh.") {
originalBase := base[len(".wh."):]
originalPath := filepath.Join(filepath.Dir(path), originalBase)
@ -149,6 +142,24 @@ func ApplyLayer(dest string, layer ArchiveReader) error {
return err
}
}
return nil
}
// ApplyLayer parses a diff in the standard layer format from `layer`, and
// applies it to the directory `dest`.
func ApplyLayer(dest string, layer ArchiveReader) error {
dest = filepath.Clean(dest)
// We need to be able to set any perms
oldmask, err := system.Umask(0)
if err != nil {
return err
}
defer system.Umask(oldmask) // ignore err, ErrNotSupportedPlatform
layer, err = DecompressStream(layer)
if err != nil {
return err
}
return UnpackLayer(dest, layer)
}

View File

@ -0,0 +1,191 @@
package archive
import (
"testing"
"github.com/docker/docker/vendor/src/code.google.com/p/go/src/pkg/archive/tar"
)
func TestApplyLayerInvalidFilenames(t *testing.T) {
for i, headers := range [][]*tar.Header{
{
{
Name: "../victim/dotdot",
Typeflag: tar.TypeReg,
Mode: 0644,
},
},
{
{
// Note the leading slash
Name: "/../victim/slash-dotdot",
Typeflag: tar.TypeReg,
Mode: 0644,
},
},
} {
if err := testBreakout("applylayer", "docker-TestApplyLayerInvalidFilenames", headers); err != nil {
t.Fatalf("i=%d. %v", i, err)
}
}
}
func TestApplyLayerInvalidHardlink(t *testing.T) {
for i, headers := range [][]*tar.Header{
{ // try reading victim/hello (../)
{
Name: "dotdot",
Typeflag: tar.TypeLink,
Linkname: "../victim/hello",
Mode: 0644,
},
},
{ // try reading victim/hello (/../)
{
Name: "slash-dotdot",
Typeflag: tar.TypeLink,
// Note the leading slash
Linkname: "/../victim/hello",
Mode: 0644,
},
},
{ // try writing victim/file
{
Name: "loophole-victim",
Typeflag: tar.TypeLink,
Linkname: "../victim",
Mode: 0755,
},
{
Name: "loophole-victim/file",
Typeflag: tar.TypeReg,
Mode: 0644,
},
},
{ // try reading victim/hello (hardlink, symlink)
{
Name: "loophole-victim",
Typeflag: tar.TypeLink,
Linkname: "../victim",
Mode: 0755,
},
{
Name: "symlink",
Typeflag: tar.TypeSymlink,
Linkname: "loophole-victim/hello",
Mode: 0644,
},
},
{ // Try reading victim/hello (hardlink, hardlink)
{
Name: "loophole-victim",
Typeflag: tar.TypeLink,
Linkname: "../victim",
Mode: 0755,
},
{
Name: "hardlink",
Typeflag: tar.TypeLink,
Linkname: "loophole-victim/hello",
Mode: 0644,
},
},
{ // Try removing victim directory (hardlink)
{
Name: "loophole-victim",
Typeflag: tar.TypeLink,
Linkname: "../victim",
Mode: 0755,
},
{
Name: "loophole-victim",
Typeflag: tar.TypeReg,
Mode: 0644,
},
},
} {
if err := testBreakout("applylayer", "docker-TestApplyLayerInvalidHardlink", headers); err != nil {
t.Fatalf("i=%d. %v", i, err)
}
}
}
func TestApplyLayerInvalidSymlink(t *testing.T) {
for i, headers := range [][]*tar.Header{
{ // try reading victim/hello (../)
{
Name: "dotdot",
Typeflag: tar.TypeSymlink,
Linkname: "../victim/hello",
Mode: 0644,
},
},
{ // try reading victim/hello (/../)
{
Name: "slash-dotdot",
Typeflag: tar.TypeSymlink,
// Note the leading slash
Linkname: "/../victim/hello",
Mode: 0644,
},
},
{ // try writing victim/file
{
Name: "loophole-victim",
Typeflag: tar.TypeSymlink,
Linkname: "../victim",
Mode: 0755,
},
{
Name: "loophole-victim/file",
Typeflag: tar.TypeReg,
Mode: 0644,
},
},
{ // try reading victim/hello (symlink, symlink)
{
Name: "loophole-victim",
Typeflag: tar.TypeSymlink,
Linkname: "../victim",
Mode: 0755,
},
{
Name: "symlink",
Typeflag: tar.TypeSymlink,
Linkname: "loophole-victim/hello",
Mode: 0644,
},
},
{ // try reading victim/hello (symlink, hardlink)
{
Name: "loophole-victim",
Typeflag: tar.TypeSymlink,
Linkname: "../victim",
Mode: 0755,
},
{
Name: "hardlink",
Typeflag: tar.TypeLink,
Linkname: "loophole-victim/hello",
Mode: 0644,
},
},
{ // try removing victim directory (symlink)
{
Name: "loophole-victim",
Typeflag: tar.TypeSymlink,
Linkname: "../victim",
Mode: 0755,
},
{
Name: "loophole-victim",
Typeflag: tar.TypeReg,
Mode: 0644,
},
},
} {
if err := testBreakout("applylayer", "docker-TestApplyLayerInvalidSymlink", headers); err != nil {
t.Fatalf("i=%d. %v", i, err)
}
}
}

View File

@ -0,0 +1,166 @@
package archive
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"os"
"path/filepath"
"time"
"github.com/docker/docker/vendor/src/code.google.com/p/go/src/pkg/archive/tar"
)
var testUntarFns = map[string]func(string, io.Reader) error{
"untar": func(dest string, r io.Reader) error {
return Untar(r, dest, nil)
},
"applylayer": func(dest string, r io.Reader) error {
return ApplyLayer(dest, ArchiveReader(r))
},
}
// testBreakout is a helper function that, within the provided `tmpdir` directory,
// creates a `victim` folder with a generated `hello` file in it.
// `untar` extracts to a directory named `dest`, the tar file created from `headers`.
//
// Here are the tested scenarios:
// - removed `victim` folder (write)
// - removed files from `victim` folder (write)
// - new files in `victim` folder (write)
// - modified files in `victim` folder (write)
// - file in `dest` with same content as `victim/hello` (read)
//
// When using testBreakout make sure you cover one of the scenarios listed above.
func testBreakout(untarFn string, tmpdir string, headers []*tar.Header) error {
tmpdir, err := ioutil.TempDir("", tmpdir)
if err != nil {
return err
}
defer os.RemoveAll(tmpdir)
dest := filepath.Join(tmpdir, "dest")
if err := os.Mkdir(dest, 0755); err != nil {
return err
}
victim := filepath.Join(tmpdir, "victim")
if err := os.Mkdir(victim, 0755); err != nil {
return err
}
hello := filepath.Join(victim, "hello")
helloData, err := time.Now().MarshalText()
if err != nil {
return err
}
if err := ioutil.WriteFile(hello, helloData, 0644); err != nil {
return err
}
helloStat, err := os.Stat(hello)
if err != nil {
return err
}
reader, writer := io.Pipe()
go func() {
t := tar.NewWriter(writer)
for _, hdr := range headers {
t.WriteHeader(hdr)
}
t.Close()
}()
untar := testUntarFns[untarFn]
if untar == nil {
return fmt.Errorf("could not find untar function %q in testUntarFns", untarFn)
}
if err := untar(dest, reader); err != nil {
if _, ok := err.(breakoutError); !ok {
// If untar returns an error unrelated to an archive breakout,
// then consider this an unexpected error and abort.
return err
}
// Here, untar detected the breakout.
// Let's move on verifying that indeed there was no breakout.
fmt.Printf("breakoutError: %v\n", err)
}
// Check victim folder
f, err := os.Open(victim)
if err != nil {
// codepath taken if victim folder was removed
return fmt.Errorf("archive breakout: error reading %q: %v", victim, err)
}
defer f.Close()
// Check contents of victim folder
//
// We are only interested in getting 2 files from the victim folder, because if all is well
// we expect only one result, the `hello` file. If there is a second result, it cannot
// hold the same name `hello` and we assume that a new file got created in the victim folder.
// That is enough to detect an archive breakout.
names, err := f.Readdirnames(2)
if err != nil {
// codepath taken if victim is not a folder
return fmt.Errorf("archive breakout: error reading directory content of %q: %v", victim, err)
}
for _, name := range names {
if name != "hello" {
// codepath taken if new file was created in victim folder
return fmt.Errorf("archive breakout: new file %q", name)
}
}
// Check victim/hello
f, err = os.Open(hello)
if err != nil {
// codepath taken if read permissions were removed
return fmt.Errorf("archive breakout: could not lstat %q: %v", hello, err)
}
defer f.Close()
b, err := ioutil.ReadAll(f)
if err != nil {
return err
}
fi, err := f.Stat()
if err != nil {
return err
}
if helloStat.IsDir() != fi.IsDir() ||
// TODO: cannot check for fi.ModTime() change
helloStat.Mode() != fi.Mode() ||
helloStat.Size() != fi.Size() ||
!bytes.Equal(helloData, b) {
// codepath taken if hello has been modified
return fmt.Errorf("archive breakout: file %q has been modified. Contents: expected=%q, got=%q. FileInfo: expected=%#v, got=%#v.", hello, helloData, b, helloStat, fi)
}
// Check that nothing in dest/ has the same content as victim/hello.
// Since victim/hello was generated with time.Now(), it is safe to assume
// that any file whose content matches exactly victim/hello, managed somehow
// to access victim/hello.
return filepath.Walk(dest, func(path string, info os.FileInfo, err error) error {
if info.IsDir() {
if err != nil {
// skip directory if error
return filepath.SkipDir
}
// enter directory
return nil
}
if err != nil {
// skip file if error
return nil
}
b, err := ioutil.ReadFile(path)
if err != nil {
// Houston, we have a problem. Aborting (space)walk.
return err
}
if bytes.Equal(helloData, b) {
return fmt.Errorf("archive breakout: file %q has been accessed via %q", hello, path)
}
return nil
})
}

View File

@ -7,63 +7,59 @@ import (
)
// FIXME: Change this not to receive default value as parameter
func ParseHost(defaultHost string, defaultUnix, addr string) (string, error) {
var (
proto string
host string
port int
)
func ParseHost(defaultTCPAddr, defaultUnixAddr, addr string) (string, error) {
addr = strings.TrimSpace(addr)
switch {
case addr == "tcp://":
return "", fmt.Errorf("Invalid bind address format: %s", addr)
case strings.HasPrefix(addr, "unix://"):
proto = "unix"
addr = strings.TrimPrefix(addr, "unix://")
if addr == "" {
addr = defaultUnix
}
case strings.HasPrefix(addr, "tcp://"):
proto = "tcp"
addr = strings.TrimPrefix(addr, "tcp://")
case strings.HasPrefix(addr, "fd://"):
if addr == "" {
addr = fmt.Sprintf("unix://%s", defaultUnixAddr)
}
addrParts := strings.Split(addr, "://")
if len(addrParts) == 1 {
addrParts = []string{"tcp", addrParts[0]}
}
switch addrParts[0] {
case "tcp":
return ParseTCPAddr(addrParts[1], defaultTCPAddr)
case "unix":
return ParseUnixAddr(addrParts[1], defaultUnixAddr)
case "fd":
return addr, nil
case addr == "":
proto = "unix"
addr = defaultUnix
default:
if strings.Contains(addr, "://") {
return "", fmt.Errorf("Invalid bind address protocol: %s", addr)
}
proto = "tcp"
}
if proto != "unix" && strings.Contains(addr, ":") {
hostParts := strings.Split(addr, ":")
if len(hostParts) != 2 {
return "", fmt.Errorf("Invalid bind address format: %s", addr)
}
if hostParts[0] != "" {
host = hostParts[0]
} else {
host = defaultHost
}
if p, err := strconv.Atoi(hostParts[1]); err == nil && p != 0 {
port = p
} else {
return "", fmt.Errorf("Invalid bind address format: %s", addr)
}
} else if proto == "tcp" && !strings.Contains(addr, ":") {
return "", fmt.Errorf("Invalid bind address format: %s", addr)
} else {
host = addr
}
if proto == "unix" {
return fmt.Sprintf("%s://%s", proto, host), nil
}
func ParseUnixAddr(addr string, defaultAddr string) (string, error) {
addr = strings.TrimPrefix(addr, "unix://")
if strings.Contains(addr, "://") {
return "", fmt.Errorf("Invalid proto, expected unix: %s", addr)
}
return fmt.Sprintf("%s://%s:%d", proto, host, port), nil
if addr == "" {
addr = defaultAddr
}
return fmt.Sprintf("unix://%s", addr), nil
}
func ParseTCPAddr(addr string, defaultAddr string) (string, error) {
addr = strings.TrimPrefix(addr, "tcp://")
if strings.Contains(addr, "://") || addr == "" {
return "", fmt.Errorf("Invalid proto, expected tcp: %s", addr)
}
hostParts := strings.Split(addr, ":")
if len(hostParts) != 2 {
return "", fmt.Errorf("Invalid bind address format: %s", addr)
}
host := hostParts[0]
if host == "" {
host = defaultAddr
}
p, err := strconv.Atoi(hostParts[1])
if err != nil && p == 0 {
return "", fmt.Errorf("Invalid bind address format: %s", addr)
}
return fmt.Sprintf("tcp://%s:%d", host, p), nil
}
// Get a repos name and returns the right reposName + tag

View File

@ -0,0 +1,47 @@
// +build linux,cgo
package term
import (
"syscall"
"unsafe"
)
// #include <termios.h>
import "C"
type Termios syscall.Termios
// MakeRaw put the terminal connected to the given file descriptor into raw
// mode and returns the previous state of the terminal so that it can be
// restored.
func MakeRaw(fd uintptr) (*State, error) {
var oldState State
if err := tcget(fd, &oldState.termios); err != 0 {
return nil, err
}
newState := oldState.termios
C.cfmakeraw((*C.struct_termios)(unsafe.Pointer(&newState)))
if err := tcset(fd, &newState); err != 0 {
return nil, err
}
return &oldState, nil
}
func tcget(fd uintptr, p *Termios) syscall.Errno {
ret, err := C.tcgetattr(C.int(fd), (*C.struct_termios)(unsafe.Pointer(p)))
if ret != 0 {
return err.(syscall.Errno)
}
return 0
}
func tcset(fd uintptr, p *Termios) syscall.Errno {
ret, err := C.tcsetattr(C.int(fd), C.TCSANOW, (*C.struct_termios)(unsafe.Pointer(p)))
if ret != 0 {
return err.(syscall.Errno)
}
return 0
}

View File

@ -0,0 +1,19 @@
// +build !windows
// +build !linux !cgo
package term
import (
"syscall"
"unsafe"
)
func tcget(fd uintptr, p *Termios) syscall.Errno {
_, _, err := syscall.Syscall(syscall.SYS_IOCTL, fd, uintptr(getTermios), uintptr(unsafe.Pointer(p)))
return err
}
func tcset(fd uintptr, p *Termios) syscall.Errno {
_, _, err := syscall.Syscall(syscall.SYS_IOCTL, fd, setTermios, uintptr(unsafe.Pointer(p)))
return err
}

View File

@ -47,8 +47,7 @@ func SetWinsize(fd uintptr, ws *Winsize) error {
// IsTerminal returns true if the given file descriptor is a terminal.
func IsTerminal(fd uintptr) bool {
var termios Termios
_, _, err := syscall.Syscall(syscall.SYS_IOCTL, fd, uintptr(getTermios), uintptr(unsafe.Pointer(&termios)))
return err == 0
return tcget(fd, &termios) == 0
}
// Restore restores the terminal connected to the given file descriptor to a
@ -57,8 +56,7 @@ func RestoreTerminal(fd uintptr, state *State) error {
if state == nil {
return ErrInvalidState
}
_, _, err := syscall.Syscall(syscall.SYS_IOCTL, fd, uintptr(setTermios), uintptr(unsafe.Pointer(&state.termios)))
if err != 0 {
if err := tcset(fd, &state.termios); err != 0 {
return err
}
return nil
@ -66,7 +64,7 @@ func RestoreTerminal(fd uintptr, state *State) error {
func SaveState(fd uintptr) (*State, error) {
var oldState State
if _, _, err := syscall.Syscall(syscall.SYS_IOCTL, fd, getTermios, uintptr(unsafe.Pointer(&oldState.termios))); err != 0 {
if err := tcget(fd, &oldState.termios); err != 0 {
return nil, err
}
@ -77,7 +75,7 @@ func DisableEcho(fd uintptr, state *State) error {
newState := state.termios
newState.Lflag &^= syscall.ECHO
if _, _, err := syscall.Syscall(syscall.SYS_IOCTL, fd, setTermios, uintptr(unsafe.Pointer(&newState))); err != 0 {
if err := tcset(fd, &newState); err != 0 {
return err
}
handleInterrupt(fd, state)

View File

@ -1,3 +1,5 @@
// +build !cgo
package term
import (

View File

@ -31,6 +31,10 @@ type KeyValuePair struct {
Value string
}
var (
validHex = regexp.MustCompile(`^([a-f0-9]{64})$`)
)
// Request a given URL and return an io.Reader
func Download(url string) (resp *http.Response, err error) {
if resp, err = http.Get(url); err != nil {
@ -190,11 +194,9 @@ func GenerateRandomID() string {
}
func ValidateID(id string) error {
if id == "" {
return fmt.Errorf("Id can't be empty")
}
if strings.Contains(id, ":") {
return fmt.Errorf("Invalid character in id: ':'")
if ok := validHex.MatchString(id); !ok {
err := fmt.Errorf("image ID '%s' is invalid", id)
return err
}
return nil
}
@ -288,21 +290,7 @@ func NewHTTPRequestError(msg string, res *http.Response) error {
}
}
func IsURL(str string) bool {
return strings.HasPrefix(str, "http://") || strings.HasPrefix(str, "https://")
}
func IsGIT(str string) bool {
return strings.HasPrefix(str, "git://") || strings.HasPrefix(str, "github.com/") || strings.HasPrefix(str, "git@github.com:") || (strings.HasSuffix(str, ".git") && IsURL(str))
}
func ValidGitTransport(str string) bool {
return strings.HasPrefix(str, "git://") || strings.HasPrefix(str, "git@") || IsURL(str)
}
var (
localHostRx = regexp.MustCompile(`(?m)^nameserver 127[^\n]+\n*`)
)
var localHostRx = regexp.MustCompile(`(?m)^nameserver 127[^\n]+\n*`)
// RemoveLocalDns looks into the /etc/resolv.conf,
// and removes any local nameserver entries.

View File

@ -97,24 +97,3 @@ func TestReadSymlinkedDirectoryToFile(t *testing.T) {
t.Errorf("failed to remove symlink: %s", err)
}
}
func TestValidGitTransport(t *testing.T) {
for _, url := range []string{
"git://github.com/docker/docker",
"git@github.com:docker/docker.git",
"https://github.com/docker/docker.git",
"http://github.com/docker/docker.git",
} {
if ValidGitTransport(url) == false {
t.Fatalf("%q should be detected as valid Git prefix", url)
}
}
for _, url := range []string{
"github.com/docker/docker",
} {
if ValidGitTransport(url) == true {
t.Fatalf("%q should not be detected as valid Git prefix", url)
}
}
}

View File

@ -0,0 +1,432 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package context defines the Context type, which carries deadlines,
// cancelation signals, and other request-scoped values across API boundaries
// and between processes.
//
// Incoming requests to a server should create a Context, and outgoing calls to
// servers should accept a Context. The chain of function calls between must
// propagate the Context, optionally replacing it with a modified copy created
// using WithDeadline, WithTimeout, WithCancel, or WithValue.
//
// Programs that use Contexts should follow these rules to keep interfaces
// consistent across packages and enable static analysis tools to check context
// propagation:
//
// Do not store Contexts inside a struct type; instead, pass a Context
// explicitly to each function that needs it. The Context should be the first
// parameter, typically named ctx:
//
// func DoSomething(ctx context.Context, arg Arg) error {
// // ... use ctx ...
// }
//
// Do not pass a nil Context, even if a function permits it. Pass context.TODO
// if you are unsure about which Context to use.
//
// Use context Values only for request-scoped data that transits processes and
// APIs, not for passing optional parameters to functions.
//
// The same Context may be passed to functions running in different goroutines;
// Contexts are safe for simultaneous use by multiple goroutines.
//
// See http://blog.golang.org/context for example code for a server that uses
// Contexts.
package context
import (
"errors"
"fmt"
"sync"
"time"
)
// A Context carries a deadline, a cancelation signal, and other values across
// API boundaries.
//
// Context's methods may be called by multiple goroutines simultaneously.
type Context interface {
// Deadline returns the time when work done on behalf of this context
// should be canceled. Deadline returns ok==false when no deadline is
// set. Successive calls to Deadline return the same results.
Deadline() (deadline time.Time, ok bool)
// Done returns a channel that's closed when work done on behalf of this
// context should be canceled. Done may return nil if this context can
// never be canceled. Successive calls to Done return the same value.
//
// WithCancel arranges for Done to be closed when cancel is called;
// WithDeadline arranges for Done to be closed when the deadline
// expires; WithTimeout arranges for Done to be closed when the timeout
// elapses.
//
// Done is provided for use in select statements:
//
// // DoSomething calls DoSomethingSlow and returns as soon as
// // it returns or ctx.Done is closed.
// func DoSomething(ctx context.Context) (Result, error) {
// c := make(chan Result, 1)
// go func() { c <- DoSomethingSlow(ctx) }()
// select {
// case res := <-c:
// return res, nil
// case <-ctx.Done():
// return nil, ctx.Err()
// }
// }
//
// See http://blog.golang.org/pipelines for more examples of how to use
// a Done channel for cancelation.
Done() <-chan struct{}
// Err returns a non-nil error value after Done is closed. Err returns
// Canceled if the context was canceled or DeadlineExceeded if the
// context's deadline passed. No other values for Err are defined.
// After Done is closed, successive calls to Err return the same value.
Err() error
// Value returns the value associated with this context for key, or nil
// if no value is associated with key. Successive calls to Value with
// the same key returns the same result.
//
// Use context values only for request-scoped data that transits
// processes and API boundaries, not for passing optional parameters to
// functions.
//
// A key identifies a specific value in a Context. Functions that wish
// to store values in Context typically allocate a key in a global
// variable then use that key as the argument to context.WithValue and
// Context.Value. A key can be any type that supports equality;
// packages should define keys as an unexported type to avoid
// collisions.
//
// Packages that define a Context key should provide type-safe accessors
// for the values stores using that key:
//
// // Package user defines a User type that's stored in Contexts.
// package user
//
// import "golang.org/x/net/context"
//
// // User is the type of value stored in the Contexts.
// type User struct {...}
//
// // key is an unexported type for keys defined in this package.
// // This prevents collisions with keys defined in other packages.
// type key int
//
// // userKey is the key for user.User values in Contexts. It is
// // unexported; clients use user.NewContext and user.FromContext
// // instead of using this key directly.
// var userKey key = 0
//
// // NewContext returns a new Context that carries value u.
// func NewContext(ctx context.Context, u *User) context.Context {
// return context.WithValue(ctx, userKey, u)
// }
//
// // FromContext returns the User value stored in ctx, if any.
// func FromContext(ctx context.Context) (*User, bool) {
// u, ok := ctx.Value(userKey).(*User)
// return u, ok
// }
Value(key interface{}) interface{}
}
// Canceled is the error returned by Context.Err when the context is canceled.
var Canceled = errors.New("context canceled")
// DeadlineExceeded is the error returned by Context.Err when the context's
// deadline passes.
var DeadlineExceeded = errors.New("context deadline exceeded")
// An emptyCtx is never canceled, has no values, and has no deadline. It is not
// struct{}, since vars of this type must have distinct addresses.
type emptyCtx int
func (*emptyCtx) Deadline() (deadline time.Time, ok bool) {
return
}
func (*emptyCtx) Done() <-chan struct{} {
return nil
}
func (*emptyCtx) Err() error {
return nil
}
func (*emptyCtx) Value(key interface{}) interface{} {
return nil
}
func (e *emptyCtx) String() string {
switch e {
case background:
return "context.Background"
case todo:
return "context.TODO"
}
return "unknown empty Context"
}
var (
background = new(emptyCtx)
todo = new(emptyCtx)
)
// Background returns a non-nil, empty Context. It is never canceled, has no
// values, and has no deadline. It is typically used by the main function,
// initialization, and tests, and as the top-level Context for incoming
// requests.
func Background() Context {
return background
}
// TODO returns a non-nil, empty Context. Code should use context.TODO when
// it's unclear which Context to use or it's is not yet available (because the
// surrounding function has not yet been extended to accept a Context
// parameter). TODO is recognized by static analysis tools that determine
// whether Contexts are propagated correctly in a program.
func TODO() Context {
return todo
}
// A CancelFunc tells an operation to abandon its work.
// A CancelFunc does not wait for the work to stop.
// After the first call, subsequent calls to a CancelFunc do nothing.
type CancelFunc func()
// WithCancel returns a copy of parent with a new Done channel. The returned
// context's Done channel is closed when the returned cancel function is called
// or when the parent context's Done channel is closed, whichever happens first.
func WithCancel(parent Context) (ctx Context, cancel CancelFunc) {
c := newCancelCtx(parent)
propagateCancel(parent, &c)
return &c, func() { c.cancel(true, Canceled) }
}
// newCancelCtx returns an initialized cancelCtx.
func newCancelCtx(parent Context) cancelCtx {
return cancelCtx{
Context: parent,
done: make(chan struct{}),
}
}
// propagateCancel arranges for child to be canceled when parent is.
func propagateCancel(parent Context, child canceler) {
if parent.Done() == nil {
return // parent is never canceled
}
if p, ok := parentCancelCtx(parent); ok {
p.mu.Lock()
if p.err != nil {
// parent has already been canceled
child.cancel(false, p.err)
} else {
if p.children == nil {
p.children = make(map[canceler]bool)
}
p.children[child] = true
}
p.mu.Unlock()
} else {
go func() {
select {
case <-parent.Done():
child.cancel(false, parent.Err())
case <-child.Done():
}
}()
}
}
// parentCancelCtx follows a chain of parent references until it finds a
// *cancelCtx. This function understands how each of the concrete types in this
// package represents its parent.
func parentCancelCtx(parent Context) (*cancelCtx, bool) {
for {
switch c := parent.(type) {
case *cancelCtx:
return c, true
case *timerCtx:
return &c.cancelCtx, true
case *valueCtx:
parent = c.Context
default:
return nil, false
}
}
}
// A canceler is a context type that can be canceled directly. The
// implementations are *cancelCtx and *timerCtx.
type canceler interface {
cancel(removeFromParent bool, err error)
Done() <-chan struct{}
}
// A cancelCtx can be canceled. When canceled, it also cancels any children
// that implement canceler.
type cancelCtx struct {
Context
done chan struct{} // closed by the first cancel call.
mu sync.Mutex
children map[canceler]bool // set to nil by the first cancel call
err error // set to non-nil by the first cancel call
}
func (c *cancelCtx) Done() <-chan struct{} {
return c.done
}
func (c *cancelCtx) Err() error {
c.mu.Lock()
defer c.mu.Unlock()
return c.err
}
func (c *cancelCtx) String() string {
return fmt.Sprintf("%v.WithCancel", c.Context)
}
// cancel closes c.done, cancels each of c's children, and, if
// removeFromParent is true, removes c from its parent's children.
func (c *cancelCtx) cancel(removeFromParent bool, err error) {
if err == nil {
panic("context: internal error: missing cancel error")
}
c.mu.Lock()
if c.err != nil {
c.mu.Unlock()
return // already canceled
}
c.err = err
close(c.done)
for child := range c.children {
// NOTE: acquiring the child's lock while holding parent's lock.
child.cancel(false, err)
}
c.children = nil
c.mu.Unlock()
if removeFromParent {
if p, ok := parentCancelCtx(c.Context); ok {
p.mu.Lock()
if p.children != nil {
delete(p.children, c)
}
p.mu.Unlock()
}
}
}
// WithDeadline returns a copy of the parent context with the deadline adjusted
// to be no later than d. If the parent's deadline is already earlier than d,
// WithDeadline(parent, d) is semantically equivalent to parent. The returned
// context's Done channel is closed when the deadline expires, when the returned
// cancel function is called, or when the parent context's Done channel is
// closed, whichever happens first.
//
// Canceling this context releases resources associated with the deadline
// timer, so code should call cancel as soon as the operations running in this
// Context complete.
func WithDeadline(parent Context, deadline time.Time) (Context, CancelFunc) {
if cur, ok := parent.Deadline(); ok && cur.Before(deadline) {
// The current deadline is already sooner than the new one.
return WithCancel(parent)
}
c := &timerCtx{
cancelCtx: newCancelCtx(parent),
deadline: deadline,
}
propagateCancel(parent, c)
d := deadline.Sub(time.Now())
if d <= 0 {
c.cancel(true, DeadlineExceeded) // deadline has already passed
return c, func() { c.cancel(true, Canceled) }
}
c.mu.Lock()
defer c.mu.Unlock()
if c.err == nil {
c.timer = time.AfterFunc(d, func() {
c.cancel(true, DeadlineExceeded)
})
}
return c, func() { c.cancel(true, Canceled) }
}
// A timerCtx carries a timer and a deadline. It embeds a cancelCtx to
// implement Done and Err. It implements cancel by stopping its timer then
// delegating to cancelCtx.cancel.
type timerCtx struct {
cancelCtx
timer *time.Timer // Under cancelCtx.mu.
deadline time.Time
}
func (c *timerCtx) Deadline() (deadline time.Time, ok bool) {
return c.deadline, true
}
func (c *timerCtx) String() string {
return fmt.Sprintf("%v.WithDeadline(%s [%s])", c.cancelCtx.Context, c.deadline, c.deadline.Sub(time.Now()))
}
func (c *timerCtx) cancel(removeFromParent bool, err error) {
c.cancelCtx.cancel(removeFromParent, err)
c.mu.Lock()
if c.timer != nil {
c.timer.Stop()
c.timer = nil
}
c.mu.Unlock()
}
// WithTimeout returns WithDeadline(parent, time.Now().Add(timeout)).
//
// Canceling this context releases resources associated with the deadline
// timer, so code should call cancel as soon as the operations running in this
// Context complete:
//
// func slowOperationWithTimeout(ctx context.Context) (Result, error) {
// ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond)
// defer cancel() // releases resources if slowOperation completes before timeout elapses
// return slowOperation(ctx)
// }
func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc) {
return WithDeadline(parent, time.Now().Add(timeout))
}
// WithValue returns a copy of parent in which the value associated with key is
// val.
//
// Use context Values only for request-scoped data that transits processes and
// APIs, not for passing optional parameters to functions.
func WithValue(parent Context, key interface{}, val interface{}) Context {
return &valueCtx{parent, key, val}
}
// A valueCtx carries a key-value pair. It implements Value for that key and
// delegates all other calls to the embedded Context.
type valueCtx struct {
Context
key, val interface{}
}
func (c *valueCtx) String() string {
return fmt.Sprintf("%v.WithValue(%#v, %#v)", c.Context, c.key, c.val)
}
func (c *valueCtx) Value(key interface{}) interface{} {
if c.key == key {
return c.val
}
return c.Context.Value(key)
}

View File

@ -0,0 +1,553 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package context
import (
"fmt"
"math/rand"
"runtime"
"strings"
"sync"
"testing"
"time"
)
// otherContext is a Context that's not one of the types defined in context.go.
// This lets us test code paths that differ based on the underlying type of the
// Context.
type otherContext struct {
Context
}
func TestBackground(t *testing.T) {
c := Background()
if c == nil {
t.Fatalf("Background returned nil")
}
select {
case x := <-c.Done():
t.Errorf("<-c.Done() == %v want nothing (it should block)", x)
default:
}
if got, want := fmt.Sprint(c), "context.Background"; got != want {
t.Errorf("Background().String() = %q want %q", got, want)
}
}
func TestTODO(t *testing.T) {
c := TODO()
if c == nil {
t.Fatalf("TODO returned nil")
}
select {
case x := <-c.Done():
t.Errorf("<-c.Done() == %v want nothing (it should block)", x)
default:
}
if got, want := fmt.Sprint(c), "context.TODO"; got != want {
t.Errorf("TODO().String() = %q want %q", got, want)
}
}
func TestWithCancel(t *testing.T) {
c1, cancel := WithCancel(Background())
if got, want := fmt.Sprint(c1), "context.Background.WithCancel"; got != want {
t.Errorf("c1.String() = %q want %q", got, want)
}
o := otherContext{c1}
c2, _ := WithCancel(o)
contexts := []Context{c1, o, c2}
for i, c := range contexts {
if d := c.Done(); d == nil {
t.Errorf("c[%d].Done() == %v want non-nil", i, d)
}
if e := c.Err(); e != nil {
t.Errorf("c[%d].Err() == %v want nil", i, e)
}
select {
case x := <-c.Done():
t.Errorf("<-c.Done() == %v want nothing (it should block)", x)
default:
}
}
cancel()
time.Sleep(100 * time.Millisecond) // let cancelation propagate
for i, c := range contexts {
select {
case <-c.Done():
default:
t.Errorf("<-c[%d].Done() blocked, but shouldn't have", i)
}
if e := c.Err(); e != Canceled {
t.Errorf("c[%d].Err() == %v want %v", i, e, Canceled)
}
}
}
func TestParentFinishesChild(t *testing.T) {
// Context tree:
// parent -> cancelChild
// parent -> valueChild -> timerChild
parent, cancel := WithCancel(Background())
cancelChild, stop := WithCancel(parent)
defer stop()
valueChild := WithValue(parent, "key", "value")
timerChild, stop := WithTimeout(valueChild, 10000*time.Hour)
defer stop()
select {
case x := <-parent.Done():
t.Errorf("<-parent.Done() == %v want nothing (it should block)", x)
case x := <-cancelChild.Done():
t.Errorf("<-cancelChild.Done() == %v want nothing (it should block)", x)
case x := <-timerChild.Done():
t.Errorf("<-timerChild.Done() == %v want nothing (it should block)", x)
case x := <-valueChild.Done():
t.Errorf("<-valueChild.Done() == %v want nothing (it should block)", x)
default:
}
// The parent's children should contain the two cancelable children.
pc := parent.(*cancelCtx)
cc := cancelChild.(*cancelCtx)
tc := timerChild.(*timerCtx)
pc.mu.Lock()
if len(pc.children) != 2 || !pc.children[cc] || !pc.children[tc] {
t.Errorf("bad linkage: pc.children = %v, want %v and %v",
pc.children, cc, tc)
}
pc.mu.Unlock()
if p, ok := parentCancelCtx(cc.Context); !ok || p != pc {
t.Errorf("bad linkage: parentCancelCtx(cancelChild.Context) = %v, %v want %v, true", p, ok, pc)
}
if p, ok := parentCancelCtx(tc.Context); !ok || p != pc {
t.Errorf("bad linkage: parentCancelCtx(timerChild.Context) = %v, %v want %v, true", p, ok, pc)
}
cancel()
pc.mu.Lock()
if len(pc.children) != 0 {
t.Errorf("pc.cancel didn't clear pc.children = %v", pc.children)
}
pc.mu.Unlock()
// parent and children should all be finished.
check := func(ctx Context, name string) {
select {
case <-ctx.Done():
default:
t.Errorf("<-%s.Done() blocked, but shouldn't have", name)
}
if e := ctx.Err(); e != Canceled {
t.Errorf("%s.Err() == %v want %v", name, e, Canceled)
}
}
check(parent, "parent")
check(cancelChild, "cancelChild")
check(valueChild, "valueChild")
check(timerChild, "timerChild")
// WithCancel should return a canceled context on a canceled parent.
precanceledChild := WithValue(parent, "key", "value")
select {
case <-precanceledChild.Done():
default:
t.Errorf("<-precanceledChild.Done() blocked, but shouldn't have")
}
if e := precanceledChild.Err(); e != Canceled {
t.Errorf("precanceledChild.Err() == %v want %v", e, Canceled)
}
}
func TestChildFinishesFirst(t *testing.T) {
cancelable, stop := WithCancel(Background())
defer stop()
for _, parent := range []Context{Background(), cancelable} {
child, cancel := WithCancel(parent)
select {
case x := <-parent.Done():
t.Errorf("<-parent.Done() == %v want nothing (it should block)", x)
case x := <-child.Done():
t.Errorf("<-child.Done() == %v want nothing (it should block)", x)
default:
}
cc := child.(*cancelCtx)
pc, pcok := parent.(*cancelCtx) // pcok == false when parent == Background()
if p, ok := parentCancelCtx(cc.Context); ok != pcok || (ok && pc != p) {
t.Errorf("bad linkage: parentCancelCtx(cc.Context) = %v, %v want %v, %v", p, ok, pc, pcok)
}
if pcok {
pc.mu.Lock()
if len(pc.children) != 1 || !pc.children[cc] {
t.Errorf("bad linkage: pc.children = %v, cc = %v", pc.children, cc)
}
pc.mu.Unlock()
}
cancel()
if pcok {
pc.mu.Lock()
if len(pc.children) != 0 {
t.Errorf("child's cancel didn't remove self from pc.children = %v", pc.children)
}
pc.mu.Unlock()
}
// child should be finished.
select {
case <-child.Done():
default:
t.Errorf("<-child.Done() blocked, but shouldn't have")
}
if e := child.Err(); e != Canceled {
t.Errorf("child.Err() == %v want %v", e, Canceled)
}
// parent should not be finished.
select {
case x := <-parent.Done():
t.Errorf("<-parent.Done() == %v want nothing (it should block)", x)
default:
}
if e := parent.Err(); e != nil {
t.Errorf("parent.Err() == %v want nil", e)
}
}
}
func testDeadline(c Context, wait time.Duration, t *testing.T) {
select {
case <-time.After(wait):
t.Fatalf("context should have timed out")
case <-c.Done():
}
if e := c.Err(); e != DeadlineExceeded {
t.Errorf("c.Err() == %v want %v", e, DeadlineExceeded)
}
}
func TestDeadline(t *testing.T) {
c, _ := WithDeadline(Background(), time.Now().Add(100*time.Millisecond))
if got, prefix := fmt.Sprint(c), "context.Background.WithDeadline("; !strings.HasPrefix(got, prefix) {
t.Errorf("c.String() = %q want prefix %q", got, prefix)
}
testDeadline(c, 200*time.Millisecond, t)
c, _ = WithDeadline(Background(), time.Now().Add(100*time.Millisecond))
o := otherContext{c}
testDeadline(o, 200*time.Millisecond, t)
c, _ = WithDeadline(Background(), time.Now().Add(100*time.Millisecond))
o = otherContext{c}
c, _ = WithDeadline(o, time.Now().Add(300*time.Millisecond))
testDeadline(c, 200*time.Millisecond, t)
}
func TestTimeout(t *testing.T) {
c, _ := WithTimeout(Background(), 100*time.Millisecond)
if got, prefix := fmt.Sprint(c), "context.Background.WithDeadline("; !strings.HasPrefix(got, prefix) {
t.Errorf("c.String() = %q want prefix %q", got, prefix)
}
testDeadline(c, 200*time.Millisecond, t)
c, _ = WithTimeout(Background(), 100*time.Millisecond)
o := otherContext{c}
testDeadline(o, 200*time.Millisecond, t)
c, _ = WithTimeout(Background(), 100*time.Millisecond)
o = otherContext{c}
c, _ = WithTimeout(o, 300*time.Millisecond)
testDeadline(c, 200*time.Millisecond, t)
}
func TestCanceledTimeout(t *testing.T) {
c, _ := WithTimeout(Background(), 200*time.Millisecond)
o := otherContext{c}
c, cancel := WithTimeout(o, 400*time.Millisecond)
cancel()
time.Sleep(100 * time.Millisecond) // let cancelation propagate
select {
case <-c.Done():
default:
t.Errorf("<-c.Done() blocked, but shouldn't have")
}
if e := c.Err(); e != Canceled {
t.Errorf("c.Err() == %v want %v", e, Canceled)
}
}
type key1 int
type key2 int
var k1 = key1(1)
var k2 = key2(1) // same int as k1, different type
var k3 = key2(3) // same type as k2, different int
func TestValues(t *testing.T) {
check := func(c Context, nm, v1, v2, v3 string) {
if v, ok := c.Value(k1).(string); ok == (len(v1) == 0) || v != v1 {
t.Errorf(`%s.Value(k1).(string) = %q, %t want %q, %t`, nm, v, ok, v1, len(v1) != 0)
}
if v, ok := c.Value(k2).(string); ok == (len(v2) == 0) || v != v2 {
t.Errorf(`%s.Value(k2).(string) = %q, %t want %q, %t`, nm, v, ok, v2, len(v2) != 0)
}
if v, ok := c.Value(k3).(string); ok == (len(v3) == 0) || v != v3 {
t.Errorf(`%s.Value(k3).(string) = %q, %t want %q, %t`, nm, v, ok, v3, len(v3) != 0)
}
}
c0 := Background()
check(c0, "c0", "", "", "")
c1 := WithValue(Background(), k1, "c1k1")
check(c1, "c1", "c1k1", "", "")
if got, want := fmt.Sprint(c1), `context.Background.WithValue(1, "c1k1")`; got != want {
t.Errorf("c.String() = %q want %q", got, want)
}
c2 := WithValue(c1, k2, "c2k2")
check(c2, "c2", "c1k1", "c2k2", "")
c3 := WithValue(c2, k3, "c3k3")
check(c3, "c2", "c1k1", "c2k2", "c3k3")
c4 := WithValue(c3, k1, nil)
check(c4, "c4", "", "c2k2", "c3k3")
o0 := otherContext{Background()}
check(o0, "o0", "", "", "")
o1 := otherContext{WithValue(Background(), k1, "c1k1")}
check(o1, "o1", "c1k1", "", "")
o2 := WithValue(o1, k2, "o2k2")
check(o2, "o2", "c1k1", "o2k2", "")
o3 := otherContext{c4}
check(o3, "o3", "", "c2k2", "c3k3")
o4 := WithValue(o3, k3, nil)
check(o4, "o4", "", "c2k2", "")
}
func TestAllocs(t *testing.T) {
bg := Background()
for _, test := range []struct {
desc string
f func()
limit float64
gccgoLimit float64
}{
{
desc: "Background()",
f: func() { Background() },
limit: 0,
gccgoLimit: 0,
},
{
desc: fmt.Sprintf("WithValue(bg, %v, nil)", k1),
f: func() {
c := WithValue(bg, k1, nil)
c.Value(k1)
},
limit: 3,
gccgoLimit: 3,
},
{
desc: "WithTimeout(bg, 15*time.Millisecond)",
f: func() {
c, _ := WithTimeout(bg, 15*time.Millisecond)
<-c.Done()
},
limit: 8,
gccgoLimit: 13,
},
{
desc: "WithCancel(bg)",
f: func() {
c, cancel := WithCancel(bg)
cancel()
<-c.Done()
},
limit: 5,
gccgoLimit: 8,
},
{
desc: "WithTimeout(bg, 100*time.Millisecond)",
f: func() {
c, cancel := WithTimeout(bg, 100*time.Millisecond)
cancel()
<-c.Done()
},
limit: 8,
gccgoLimit: 25,
},
} {
limit := test.limit
if runtime.Compiler == "gccgo" {
// gccgo does not yet do escape analysis.
// TOOD(iant): Remove this when gccgo does do escape analysis.
limit = test.gccgoLimit
}
if n := testing.AllocsPerRun(100, test.f); n > limit {
t.Errorf("%s allocs = %f want %d", test.desc, n, int(limit))
}
}
}
func TestSimultaneousCancels(t *testing.T) {
root, cancel := WithCancel(Background())
m := map[Context]CancelFunc{root: cancel}
q := []Context{root}
// Create a tree of contexts.
for len(q) != 0 && len(m) < 100 {
parent := q[0]
q = q[1:]
for i := 0; i < 4; i++ {
ctx, cancel := WithCancel(parent)
m[ctx] = cancel
q = append(q, ctx)
}
}
// Start all the cancels in a random order.
var wg sync.WaitGroup
wg.Add(len(m))
for _, cancel := range m {
go func(cancel CancelFunc) {
cancel()
wg.Done()
}(cancel)
}
// Wait on all the contexts in a random order.
for ctx := range m {
select {
case <-ctx.Done():
case <-time.After(1 * time.Second):
buf := make([]byte, 10<<10)
n := runtime.Stack(buf, true)
t.Fatalf("timed out waiting for <-ctx.Done(); stacks:\n%s", buf[:n])
}
}
// Wait for all the cancel functions to return.
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
select {
case <-done:
case <-time.After(1 * time.Second):
buf := make([]byte, 10<<10)
n := runtime.Stack(buf, true)
t.Fatalf("timed out waiting for cancel functions; stacks:\n%s", buf[:n])
}
}
func TestInterlockedCancels(t *testing.T) {
parent, cancelParent := WithCancel(Background())
child, cancelChild := WithCancel(parent)
go func() {
parent.Done()
cancelChild()
}()
cancelParent()
select {
case <-child.Done():
case <-time.After(1 * time.Second):
buf := make([]byte, 10<<10)
n := runtime.Stack(buf, true)
t.Fatalf("timed out waiting for child.Done(); stacks:\n%s", buf[:n])
}
}
func TestLayersCancel(t *testing.T) {
testLayers(t, time.Now().UnixNano(), false)
}
func TestLayersTimeout(t *testing.T) {
testLayers(t, time.Now().UnixNano(), true)
}
func testLayers(t *testing.T, seed int64, testTimeout bool) {
rand.Seed(seed)
errorf := func(format string, a ...interface{}) {
t.Errorf(fmt.Sprintf("seed=%d: %s", seed, format), a...)
}
const (
timeout = 200 * time.Millisecond
minLayers = 30
)
type value int
var (
vals []*value
cancels []CancelFunc
numTimers int
ctx = Background()
)
for i := 0; i < minLayers || numTimers == 0 || len(cancels) == 0 || len(vals) == 0; i++ {
switch rand.Intn(3) {
case 0:
v := new(value)
ctx = WithValue(ctx, v, v)
vals = append(vals, v)
case 1:
var cancel CancelFunc
ctx, cancel = WithCancel(ctx)
cancels = append(cancels, cancel)
case 2:
var cancel CancelFunc
ctx, cancel = WithTimeout(ctx, timeout)
cancels = append(cancels, cancel)
numTimers++
}
}
checkValues := func(when string) {
for _, key := range vals {
if val := ctx.Value(key).(*value); key != val {
errorf("%s: ctx.Value(%p) = %p want %p", when, key, val, key)
}
}
}
select {
case <-ctx.Done():
errorf("ctx should not be canceled yet")
default:
}
if s, prefix := fmt.Sprint(ctx), "context.Background."; !strings.HasPrefix(s, prefix) {
t.Errorf("ctx.String() = %q want prefix %q", s, prefix)
}
t.Log(ctx)
checkValues("before cancel")
if testTimeout {
select {
case <-ctx.Done():
case <-time.After(timeout + timeout/10):
errorf("ctx should have timed out")
}
checkValues("after timeout")
} else {
cancel := cancels[rand.Intn(len(cancels))]
cancel()
select {
case <-ctx.Done():
default:
errorf("ctx should be canceled")
}
checkValues("after cancel")
}
}

View File

@ -0,0 +1,26 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package context_test
import (
"fmt"
"time"
"golang.org/x/net/context"
)
func ExampleWithTimeout() {
// Pass a context with a timeout to tell a blocking function that it
// should abandon its work after the timeout elapses.
ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond)
select {
case <-time.After(200 * time.Millisecond):
fmt.Println("overslept")
case <-ctx.Done():
fmt.Println(ctx.Err()) // prints "context deadline exceeded"
}
// Output:
// context deadline exceeded
}