mirror of https://github.com/docker/docs.git
Merge pull request #821 from ehazlett/godep-update
godep: fix upstream errors; update naturalsort
This commit is contained in:
commit
e4a37a8e2a
|
@ -1,6 +1,6 @@
|
||||||
{
|
{
|
||||||
"ImportPath": "github.com/docker/machine",
|
"ImportPath": "github.com/docker/machine",
|
||||||
"GoVersion": "go1.4.1",
|
"GoVersion": "go1.4.2",
|
||||||
"Deps": [
|
"Deps": [
|
||||||
{
|
{
|
||||||
"ImportPath": "code.google.com/p/goauth2/oauth",
|
"ImportPath": "code.google.com/p/goauth2/oauth",
|
||||||
|
@ -9,7 +9,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"ImportPath": "github.com/MSOpenTech/azure-sdk-for-go",
|
"ImportPath": "github.com/MSOpenTech/azure-sdk-for-go",
|
||||||
"Comment": "v1.1-14-g814812a",
|
"Comment": "v1.1-17-g515f3ec",
|
||||||
"Rev": "515f3ec74ce6a5b31e934cefae997c97bd0a1b1e"
|
"Rev": "515f3ec74ce6a5b31e934cefae997c97bd0a1b1e"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -21,10 +21,6 @@
|
||||||
"ImportPath": "github.com/cenkalti/backoff",
|
"ImportPath": "github.com/cenkalti/backoff",
|
||||||
"Rev": "9831e1e25c874e0a0601b6dc43641071414eec7a"
|
"Rev": "9831e1e25c874e0a0601b6dc43641071414eec7a"
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"ImportPath": "github.com/skarademir/naturalsort",
|
|
||||||
"Rev": "9688a08870fba63de22ec85d167598bb677820ee"
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"ImportPath": "github.com/codegangsta/cli",
|
"ImportPath": "github.com/codegangsta/cli",
|
||||||
"Comment": "1.2.0-64-ge1712f3",
|
"Comment": "1.2.0-64-ge1712f3",
|
||||||
|
@ -37,87 +33,92 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"ImportPath": "github.com/docker/docker/api",
|
"ImportPath": "github.com/docker/docker/api",
|
||||||
"Comment": "v1.4.1",
|
"Comment": "v1.5.0",
|
||||||
"Rev": "5bc2ff8a36e9a768e8b479de4fe3ea9c9daf4121"
|
"Rev": "a8a31eff10544860d2188dddabdee4d727545796"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"ImportPath": "github.com/docker/docker/dockerversion",
|
"ImportPath": "github.com/docker/docker/dockerversion",
|
||||||
"Comment": "v1.4.1",
|
"Comment": "v1.5.0",
|
||||||
"Rev": "5bc2ff8a36e9a768e8b479de4fe3ea9c9daf4121"
|
"Rev": "a8a31eff10544860d2188dddabdee4d727545796"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"ImportPath": "github.com/docker/docker/engine",
|
"ImportPath": "github.com/docker/docker/engine",
|
||||||
"Comment": "v1.4.1",
|
"Comment": "v1.5.0",
|
||||||
"Rev": "5bc2ff8a36e9a768e8b479de4fe3ea9c9daf4121"
|
"Rev": "a8a31eff10544860d2188dddabdee4d727545796"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"ImportPath": "github.com/docker/docker/pkg/archive",
|
"ImportPath": "github.com/docker/docker/pkg/archive",
|
||||||
"Comment": "v1.4.1",
|
"Comment": "v1.5.0",
|
||||||
"Rev": "5bc2ff8a36e9a768e8b479de4fe3ea9c9daf4121"
|
"Rev": "a8a31eff10544860d2188dddabdee4d727545796"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"ImportPath": "github.com/docker/docker/pkg/fileutils",
|
"ImportPath": "github.com/docker/docker/pkg/fileutils",
|
||||||
"Comment": "v1.4.1",
|
"Comment": "v1.5.0",
|
||||||
"Rev": "5bc2ff8a36e9a768e8b479de4fe3ea9c9daf4121"
|
"Rev": "a8a31eff10544860d2188dddabdee4d727545796"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"ImportPath": "github.com/docker/docker/pkg/ioutils",
|
"ImportPath": "github.com/docker/docker/pkg/ioutils",
|
||||||
"Comment": "v1.4.1",
|
"Comment": "v1.5.0",
|
||||||
"Rev": "5bc2ff8a36e9a768e8b479de4fe3ea9c9daf4121"
|
"Rev": "a8a31eff10544860d2188dddabdee4d727545796"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"ImportPath": "github.com/docker/docker/pkg/mflag",
|
||||||
|
"Comment": "v1.5.0",
|
||||||
|
"Rev": "a8a31eff10544860d2188dddabdee4d727545796"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"ImportPath": "github.com/docker/docker/pkg/parsers",
|
"ImportPath": "github.com/docker/docker/pkg/parsers",
|
||||||
"Comment": "v1.4.1",
|
"Comment": "v1.5.0",
|
||||||
"Rev": "5bc2ff8a36e9a768e8b479de4fe3ea9c9daf4121"
|
"Rev": "a8a31eff10544860d2188dddabdee4d727545796"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"ImportPath": "github.com/docker/docker/pkg/pools",
|
"ImportPath": "github.com/docker/docker/pkg/pools",
|
||||||
"Comment": "v1.4.1",
|
"Comment": "v1.5.0",
|
||||||
"Rev": "5bc2ff8a36e9a768e8b479de4fe3ea9c9daf4121"
|
"Rev": "a8a31eff10544860d2188dddabdee4d727545796"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"ImportPath": "github.com/docker/docker/pkg/promise",
|
"ImportPath": "github.com/docker/docker/pkg/promise",
|
||||||
"Comment": "v1.4.1",
|
"Comment": "v1.5.0",
|
||||||
"Rev": "5bc2ff8a36e9a768e8b479de4fe3ea9c9daf4121"
|
"Rev": "a8a31eff10544860d2188dddabdee4d727545796"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"ImportPath": "github.com/docker/docker/pkg/system",
|
"ImportPath": "github.com/docker/docker/pkg/system",
|
||||||
"Comment": "v1.4.1",
|
"Comment": "v1.5.0",
|
||||||
"Rev": "5bc2ff8a36e9a768e8b479de4fe3ea9c9daf4121"
|
"Rev": "a8a31eff10544860d2188dddabdee4d727545796"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"ImportPath": "github.com/docker/docker/pkg/term",
|
"ImportPath": "github.com/docker/docker/pkg/term",
|
||||||
"Comment": "v1.4.1",
|
"Comment": "v1.5.0",
|
||||||
"Rev": "5bc2ff8a36e9a768e8b479de4fe3ea9c9daf4121"
|
"Rev": "a8a31eff10544860d2188dddabdee4d727545796"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"ImportPath": "github.com/docker/docker/pkg/timeutils",
|
"ImportPath": "github.com/docker/docker/pkg/timeutils",
|
||||||
"Comment": "v1.4.1",
|
"Comment": "v1.5.0",
|
||||||
"Rev": "5bc2ff8a36e9a768e8b479de4fe3ea9c9daf4121"
|
"Rev": "a8a31eff10544860d2188dddabdee4d727545796"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"ImportPath": "github.com/docker/docker/pkg/units",
|
"ImportPath": "github.com/docker/docker/pkg/units",
|
||||||
"Comment": "v1.4.1",
|
"Comment": "v1.5.0",
|
||||||
"Rev": "5bc2ff8a36e9a768e8b479de4fe3ea9c9daf4121"
|
"Rev": "a8a31eff10544860d2188dddabdee4d727545796"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"ImportPath": "github.com/docker/docker/pkg/version",
|
"ImportPath": "github.com/docker/docker/pkg/version",
|
||||||
"Comment": "v1.4.1",
|
"Comment": "v1.5.0",
|
||||||
"Rev": "5bc2ff8a36e9a768e8b479de4fe3ea9c9daf4121"
|
"Rev": "a8a31eff10544860d2188dddabdee4d727545796"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"ImportPath": "github.com/docker/docker/utils",
|
"ImportPath": "github.com/docker/docker/utils",
|
||||||
"Comment": "v1.4.1",
|
"Comment": "v1.5.0",
|
||||||
"Rev": "5bc2ff8a36e9a768e8b479de4fe3ea9c9daf4121"
|
"Rev": "a8a31eff10544860d2188dddabdee4d727545796"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"ImportPath": "github.com/docker/docker/vendor/src/code.google.com/p/go/src/pkg/archive/tar",
|
"ImportPath": "github.com/docker/docker/vendor/src/code.google.com/p/go/src/pkg/archive/tar",
|
||||||
"Comment": "v1.4.1",
|
"Comment": "v1.5.0",
|
||||||
"Rev": "5bc2ff8a36e9a768e8b479de4fe3ea9c9daf4121"
|
"Rev": "a8a31eff10544860d2188dddabdee4d727545796"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"ImportPath": "github.com/docker/libtrust",
|
"ImportPath": "github.com/docker/libtrust",
|
||||||
"Rev": "6b7834910dcbb3021adc193411d01f65595445fb"
|
"Rev": "c54fbb67c1f1e68d7d6f8d2ad7c9360404616a41"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"ImportPath": "github.com/google/go-querystring/query",
|
"ImportPath": "github.com/google/go-querystring/query",
|
||||||
|
@ -132,6 +133,10 @@
|
||||||
"Comment": "v1.0.0-473-g7ca169d",
|
"Comment": "v1.0.0-473-g7ca169d",
|
||||||
"Rev": "7ca169d371b29e3dbab9e631c3a6151896b06330"
|
"Rev": "7ca169d371b29e3dbab9e631c3a6151896b06330"
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"ImportPath": "github.com/skarademir/naturalsort",
|
||||||
|
"Rev": "983d4d86054d80f91fd04dd62ec52c1d078ce403"
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"ImportPath": "github.com/smartystreets/go-aws-auth",
|
"ImportPath": "github.com/smartystreets/go-aws-auth",
|
||||||
"Rev": "1f0db8c0ee6362470abe06a94e3385927ed72a4b"
|
"Rev": "1f0db8c0ee6362470abe06a94e3385927ed72a4b"
|
||||||
|
@ -146,8 +151,8 @@
|
||||||
"Rev": "66a23eaabc61518f91769939ff541886fe1dceef"
|
"Rev": "66a23eaabc61518f91769939ff541886fe1dceef"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"ImportPath": "golang.org/x/crypto/ssh",
|
"ImportPath": "golang.org/x/net/context",
|
||||||
"Rev": "1fbbd62cfec66bd39d91e97749579579d4d3037e"
|
"Rev": "97d8e4e174133a4d1d2171380e510eb4dea8f5ea"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"ImportPath": "google.golang.org/api/compute/v1",
|
"ImportPath": "google.golang.org/api/compute/v1",
|
||||||
|
|
|
@ -17,7 +17,6 @@ import (
|
||||||
flag "github.com/docker/docker/pkg/mflag"
|
flag "github.com/docker/docker/pkg/mflag"
|
||||||
"github.com/docker/docker/pkg/term"
|
"github.com/docker/docker/pkg/term"
|
||||||
"github.com/docker/docker/registry"
|
"github.com/docker/docker/registry"
|
||||||
"github.com/docker/libtrust"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type DockerCli struct {
|
type DockerCli struct {
|
||||||
|
@ -27,7 +26,7 @@ type DockerCli struct {
|
||||||
in io.ReadCloser
|
in io.ReadCloser
|
||||||
out io.Writer
|
out io.Writer
|
||||||
err io.Writer
|
err io.Writer
|
||||||
key libtrust.PrivateKey
|
keyFile string
|
||||||
tlsConfig *tls.Config
|
tlsConfig *tls.Config
|
||||||
scheme string
|
scheme string
|
||||||
// inFd holds file descriptor of the client's STDIN, if it's a valid file
|
// inFd holds file descriptor of the client's STDIN, if it's a valid file
|
||||||
|
@ -75,24 +74,31 @@ func (cli *DockerCli) Cmd(args ...string) error {
|
||||||
if len(args) > 0 {
|
if len(args) > 0 {
|
||||||
method, exists := cli.getMethod(args[0])
|
method, exists := cli.getMethod(args[0])
|
||||||
if !exists {
|
if !exists {
|
||||||
fmt.Println("Error: Command not found:", args[0])
|
fmt.Fprintf(cli.err, "docker: '%s' is not a docker command. See 'docker --help'.\n", args[0])
|
||||||
return cli.CmdHelp()
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
return method(args[1:]...)
|
return method(args[1:]...)
|
||||||
}
|
}
|
||||||
return cli.CmdHelp()
|
return cli.CmdHelp()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cli *DockerCli) Subcmd(name, signature, description string) *flag.FlagSet {
|
func (cli *DockerCli) Subcmd(name, signature, description string, exitOnError bool) *flag.FlagSet {
|
||||||
flags := flag.NewFlagSet(name, flag.ContinueOnError)
|
var errorHandling flag.ErrorHandling
|
||||||
|
if exitOnError {
|
||||||
|
errorHandling = flag.ExitOnError
|
||||||
|
} else {
|
||||||
|
errorHandling = flag.ContinueOnError
|
||||||
|
}
|
||||||
|
flags := flag.NewFlagSet(name, errorHandling)
|
||||||
flags.Usage = func() {
|
flags.Usage = func() {
|
||||||
options := ""
|
options := ""
|
||||||
if flags.FlagCountUndeprecated() > 0 {
|
if flags.FlagCountUndeprecated() > 0 {
|
||||||
options = "[OPTIONS] "
|
options = "[OPTIONS] "
|
||||||
}
|
}
|
||||||
fmt.Fprintf(cli.err, "\nUsage: docker %s %s%s\n\n%s\n\n", name, options, signature, description)
|
fmt.Fprintf(cli.out, "\nUsage: docker %s %s%s\n\n%s\n\n", name, options, signature, description)
|
||||||
|
flags.SetOutput(cli.out)
|
||||||
flags.PrintDefaults()
|
flags.PrintDefaults()
|
||||||
os.Exit(2)
|
os.Exit(0)
|
||||||
}
|
}
|
||||||
return flags
|
return flags
|
||||||
}
|
}
|
||||||
|
@ -115,7 +121,7 @@ func (cli *DockerCli) CheckTtyInput(attachStdin, ttyMode bool) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDockerCli(in io.ReadCloser, out, err io.Writer, key libtrust.PrivateKey, proto, addr string, tlsConfig *tls.Config) *DockerCli {
|
func NewDockerCli(in io.ReadCloser, out, err io.Writer, keyFile string, proto, addr string, tlsConfig *tls.Config) *DockerCli {
|
||||||
var (
|
var (
|
||||||
inFd uintptr
|
inFd uintptr
|
||||||
outFd uintptr
|
outFd uintptr
|
||||||
|
@ -148,6 +154,7 @@ func NewDockerCli(in io.ReadCloser, out, err io.Writer, key libtrust.PrivateKey,
|
||||||
|
|
||||||
// The transport is created here for reuse during the client session
|
// The transport is created here for reuse during the client session
|
||||||
tr := &http.Transport{
|
tr := &http.Transport{
|
||||||
|
Proxy: http.ProxyFromEnvironment,
|
||||||
TLSClientConfig: tlsConfig,
|
TLSClientConfig: tlsConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -169,7 +176,7 @@ func NewDockerCli(in io.ReadCloser, out, err io.Writer, key libtrust.PrivateKey,
|
||||||
in: in,
|
in: in,
|
||||||
out: out,
|
out: out,
|
||||||
err: err,
|
err: err,
|
||||||
key: key,
|
keyFile: keyFile,
|
||||||
inFd: inFd,
|
inFd: inFd,
|
||||||
outFd: outFd,
|
outFd: outFd,
|
||||||
isTerminalIn: isTerminalIn,
|
isTerminalIn: isTerminalIn,
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -72,6 +72,15 @@ func tlsDialWithDialer(dialer *net.Dialer, network, addr string, config *tls.Con
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
// When we set up a TCP connection for hijack, there could be long periods
|
||||||
|
// of inactivity (a long running command with no output) that in certain
|
||||||
|
// network setups may cause ECONNTIMEOUT, leaving the client in an unknown
|
||||||
|
// state. Setting TCP KeepAlive on the socket connection will prohibit
|
||||||
|
// ECONNTIMEOUT unless the socket connection truly is broken
|
||||||
|
if tcpConn, ok := rawConn.(*net.TCPConn); ok {
|
||||||
|
tcpConn.SetKeepAlive(true)
|
||||||
|
tcpConn.SetKeepAlivePeriod(30 * time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
colonPos := strings.LastIndex(addr, ":")
|
colonPos := strings.LastIndex(addr, ":")
|
||||||
if colonPos == -1 {
|
if colonPos == -1 {
|
||||||
|
@ -134,10 +143,21 @@ func (cli *DockerCli) hijack(method, path string, setRawTerminal bool, in io.Rea
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
req.Header.Set("User-Agent", "Docker-Client/"+dockerversion.VERSION)
|
req.Header.Set("User-Agent", "Docker-Client/"+dockerversion.VERSION)
|
||||||
req.Header.Set("Content-Type", "plain/text")
|
req.Header.Set("Content-Type", "text/plain")
|
||||||
|
req.Header.Set("Connection", "Upgrade")
|
||||||
|
req.Header.Set("Upgrade", "tcp")
|
||||||
req.Host = cli.addr
|
req.Host = cli.addr
|
||||||
|
|
||||||
dial, err := cli.dial()
|
dial, err := cli.dial()
|
||||||
|
// When we set up a TCP connection for hijack, there could be long periods
|
||||||
|
// of inactivity (a long running command with no output) that in certain
|
||||||
|
// network setups may cause ECONNTIMEOUT, leaving the client in an unknown
|
||||||
|
// state. Setting TCP KeepAlive on the socket connection will prohibit
|
||||||
|
// ECONNTIMEOUT unless the socket connection truly is broken
|
||||||
|
if tcpConn, ok := dial.(*net.TCPConn); ok {
|
||||||
|
tcpConn.SetKeepAlive(true)
|
||||||
|
tcpConn.SetKeepAlivePeriod(30 * time.Second)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if strings.Contains(err.Error(), "connection refused") {
|
if strings.Contains(err.Error(), "connection refused") {
|
||||||
return fmt.Errorf("Cannot connect to the Docker daemon. Is 'docker -d' running on this host?")
|
return fmt.Errorf("Cannot connect to the Docker daemon. Is 'docker -d' running on this host?")
|
||||||
|
|
|
@ -66,7 +66,7 @@ func (cli *DockerCli) call(method, path string, data interface{}, passAuthInfo b
|
||||||
if passAuthInfo {
|
if passAuthInfo {
|
||||||
cli.LoadConfigFile()
|
cli.LoadConfigFile()
|
||||||
// Resolve the Auth config relevant for this server
|
// Resolve the Auth config relevant for this server
|
||||||
authConfig := cli.configFile.ResolveAuthConfig(registry.IndexServerAddress())
|
authConfig := cli.configFile.Configs[registry.IndexServerAddress()]
|
||||||
getHeaders := func(authConfig registry.AuthConfig) (map[string][]string, error) {
|
getHeaders := func(authConfig registry.AuthConfig) (map[string][]string, error) {
|
||||||
buf, err := json.Marshal(authConfig)
|
buf, err := json.Marshal(authConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -89,7 +89,7 @@ func (cli *DockerCli) call(method, path string, data interface{}, passAuthInfo b
|
||||||
if data != nil {
|
if data != nil {
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
} else if method == "POST" {
|
} else if method == "POST" {
|
||||||
req.Header.Set("Content-Type", "plain/text")
|
req.Header.Set("Content-Type", "text/plain")
|
||||||
}
|
}
|
||||||
resp, err := cli.HTTPClient().Do(req)
|
resp, err := cli.HTTPClient().Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -135,7 +135,7 @@ func (cli *DockerCli) streamHelper(method, path string, setRawTerminal bool, in
|
||||||
req.URL.Host = cli.addr
|
req.URL.Host = cli.addr
|
||||||
req.URL.Scheme = cli.scheme
|
req.URL.Scheme = cli.scheme
|
||||||
if method == "POST" {
|
if method == "POST" {
|
||||||
req.Header.Set("Content-Type", "plain/text")
|
req.Header.Set("Content-Type", "text/plain")
|
||||||
}
|
}
|
||||||
|
|
||||||
if headers != nil {
|
if headers != nil {
|
||||||
|
|
|
@ -4,7 +4,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"mime"
|
"mime"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
log "github.com/Sirupsen/logrus"
|
log "github.com/Sirupsen/logrus"
|
||||||
|
@ -15,9 +15,10 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
APIVERSION version.Version = "1.16"
|
APIVERSION version.Version = "1.17"
|
||||||
DEFAULTHTTPHOST = "127.0.0.1"
|
DEFAULTHTTPHOST = "127.0.0.1"
|
||||||
DEFAULTUNIXSOCKET = "/var/run/docker.sock"
|
DEFAULTUNIXSOCKET = "/var/run/docker.sock"
|
||||||
|
DefaultDockerfileName string = "Dockerfile"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ValidateHost(val string) (string, error) {
|
func ValidateHost(val string) (string, error) {
|
||||||
|
@ -54,7 +55,7 @@ func MatchesContentType(contentType, expectedType string) bool {
|
||||||
// LoadOrCreateTrustKey attempts to load the libtrust key at the given path,
|
// LoadOrCreateTrustKey attempts to load the libtrust key at the given path,
|
||||||
// otherwise generates a new one
|
// otherwise generates a new one
|
||||||
func LoadOrCreateTrustKey(trustKeyPath string) (libtrust.PrivateKey, error) {
|
func LoadOrCreateTrustKey(trustKeyPath string) (libtrust.PrivateKey, error) {
|
||||||
err := os.MkdirAll(path.Dir(trustKeyPath), 0700)
|
err := os.MkdirAll(filepath.Dir(trustKeyPath), 0700)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -68,7 +69,7 @@ func LoadOrCreateTrustKey(trustKeyPath string) (libtrust.PrivateKey, error) {
|
||||||
return nil, fmt.Errorf("Error saving key file: %s", err)
|
return nil, fmt.Errorf("Error saving key file: %s", err)
|
||||||
}
|
}
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
return nil, fmt.Errorf("Error loading key file: %s", err)
|
return nil, fmt.Errorf("Error loading key file %s: %s", trustKeyPath, err)
|
||||||
}
|
}
|
||||||
return trustKey, nil
|
return trustKey, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,2 +1,2 @@
|
||||||
Victor Vieux <vieux@docker.com> (@vieux)
|
Victor Vieux <vieux@docker.com> (@vieux)
|
||||||
Johan Euphrosine <proppy@google.com> (@proppy)
|
# Johan Euphrosine <proppy@google.com> (@proppy)
|
||||||
|
|
|
@ -27,6 +27,7 @@ import (
|
||||||
|
|
||||||
log "github.com/Sirupsen/logrus"
|
log "github.com/Sirupsen/logrus"
|
||||||
"github.com/docker/docker/api"
|
"github.com/docker/docker/api"
|
||||||
|
"github.com/docker/docker/daemon/networkdriver/portallocator"
|
||||||
"github.com/docker/docker/engine"
|
"github.com/docker/docker/engine"
|
||||||
"github.com/docker/docker/pkg/listenbuffer"
|
"github.com/docker/docker/pkg/listenbuffer"
|
||||||
"github.com/docker/docker/pkg/parsers"
|
"github.com/docker/docker/pkg/parsers"
|
||||||
|
@ -410,6 +411,19 @@ func getContainersJSON(eng *engine.Engine, version version.Version, w http.Respo
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getContainersStats(eng *engine.Engine, version version.Version, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
|
||||||
|
if err := parseForm(r); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if vars == nil {
|
||||||
|
return fmt.Errorf("Missing parameter")
|
||||||
|
}
|
||||||
|
name := vars["name"]
|
||||||
|
job := eng.Job("container_stats", name)
|
||||||
|
streamJSON(job, w, true)
|
||||||
|
return job.Run()
|
||||||
|
}
|
||||||
|
|
||||||
func getContainersLogs(eng *engine.Engine, version version.Version, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
|
func getContainersLogs(eng *engine.Engine, version version.Version, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
|
||||||
if err := parseForm(r); err != nil {
|
if err := parseForm(r); err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -738,6 +752,24 @@ func postContainersRestart(eng *engine.Engine, version version.Version, w http.R
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func postContainerRename(eng *engine.Engine, version version.Version, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
|
||||||
|
if err := parseForm(r); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if vars == nil {
|
||||||
|
return fmt.Errorf("Missing parameter")
|
||||||
|
}
|
||||||
|
|
||||||
|
newName := r.URL.Query().Get("name")
|
||||||
|
job := eng.Job("container_rename", vars["name"], newName)
|
||||||
|
job.Setenv("t", r.Form.Get("t"))
|
||||||
|
if err := job.Run(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func deleteContainers(eng *engine.Engine, version version.Version, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
|
func deleteContainers(eng *engine.Engine, version version.Version, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
|
||||||
if err := parseForm(r); err != nil {
|
if err := parseForm(r); err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -887,7 +919,11 @@ func postContainersAttach(eng *engine.Engine, version version.Version, w http.Re
|
||||||
|
|
||||||
var errStream io.Writer
|
var errStream io.Writer
|
||||||
|
|
||||||
|
if _, ok := r.Header["Upgrade"]; ok {
|
||||||
|
fmt.Fprintf(outStream, "HTTP/1.1 101 UPGRADED\r\nContent-Type: application/vnd.docker.raw-stream\r\nConnection: Upgrade\r\nUpgrade: tcp\r\n\r\n")
|
||||||
|
} else {
|
||||||
fmt.Fprintf(outStream, "HTTP/1.1 200 OK\r\nContent-Type: application/vnd.docker.raw-stream\r\n\r\n")
|
fmt.Fprintf(outStream, "HTTP/1.1 200 OK\r\nContent-Type: application/vnd.docker.raw-stream\r\n\r\n")
|
||||||
|
}
|
||||||
|
|
||||||
if c.GetSubEnv("Config") != nil && !c.GetSubEnv("Config").GetBool("Tty") && version.GreaterThanOrEqualTo("1.6") {
|
if c.GetSubEnv("Config") != nil && !c.GetSubEnv("Config").GetBool("Tty") && version.GreaterThanOrEqualTo("1.6") {
|
||||||
errStream = stdcopy.NewStdWriter(outStream, stdcopy.Stderr)
|
errStream = stdcopy.NewStdWriter(outStream, stdcopy.Stderr)
|
||||||
|
@ -1030,6 +1066,7 @@ func postBuild(eng *engine.Engine, version version.Version, w http.ResponseWrite
|
||||||
}
|
}
|
||||||
job.Stdin.Add(r.Body)
|
job.Stdin.Add(r.Body)
|
||||||
job.Setenv("remote", r.FormValue("remote"))
|
job.Setenv("remote", r.FormValue("remote"))
|
||||||
|
job.Setenv("dockerfile", r.FormValue("dockerfile"))
|
||||||
job.Setenv("t", r.FormValue("t"))
|
job.Setenv("t", r.FormValue("t"))
|
||||||
job.Setenv("q", r.FormValue("q"))
|
job.Setenv("q", r.FormValue("q"))
|
||||||
job.Setenv("nocache", r.FormValue("nocache"))
|
job.Setenv("nocache", r.FormValue("nocache"))
|
||||||
|
@ -1137,7 +1174,12 @@ func postContainerExecStart(eng *engine.Engine, version version.Version, w http.
|
||||||
|
|
||||||
var errStream io.Writer
|
var errStream io.Writer
|
||||||
|
|
||||||
|
if _, ok := r.Header["Upgrade"]; ok {
|
||||||
|
fmt.Fprintf(outStream, "HTTP/1.1 101 UPGRADED\r\nContent-Type: application/vnd.docker.raw-stream\r\nConnection: Upgrade\r\nUpgrade: tcp\r\n\r\n")
|
||||||
|
} else {
|
||||||
fmt.Fprintf(outStream, "HTTP/1.1 200 OK\r\nContent-Type: application/vnd.docker.raw-stream\r\n\r\n")
|
fmt.Fprintf(outStream, "HTTP/1.1 200 OK\r\nContent-Type: application/vnd.docker.raw-stream\r\n\r\n")
|
||||||
|
}
|
||||||
|
|
||||||
if !job.GetenvBool("Tty") && version.GreaterThanOrEqualTo("1.6") {
|
if !job.GetenvBool("Tty") && version.GreaterThanOrEqualTo("1.6") {
|
||||||
errStream = stdcopy.NewStdWriter(outStream, stdcopy.Stderr)
|
errStream = stdcopy.NewStdWriter(outStream, stdcopy.Stderr)
|
||||||
outStream = stdcopy.NewStdWriter(outStream, stdcopy.Stdout)
|
outStream = stdcopy.NewStdWriter(outStream, stdcopy.Stdout)
|
||||||
|
@ -1250,7 +1292,7 @@ func AttachProfiler(router *mux.Router) {
|
||||||
router.HandleFunc("/debug/pprof/threadcreate", pprof.Handler("threadcreate").ServeHTTP)
|
router.HandleFunc("/debug/pprof/threadcreate", pprof.Handler("threadcreate").ServeHTTP)
|
||||||
}
|
}
|
||||||
|
|
||||||
func createRouter(eng *engine.Engine, logging, enableCors bool, dockerVersion string) (*mux.Router, error) {
|
func createRouter(eng *engine.Engine, logging, enableCors bool, dockerVersion string) *mux.Router {
|
||||||
r := mux.NewRouter()
|
r := mux.NewRouter()
|
||||||
if os.Getenv("DEBUG") != "" {
|
if os.Getenv("DEBUG") != "" {
|
||||||
AttachProfiler(r)
|
AttachProfiler(r)
|
||||||
|
@ -1275,6 +1317,7 @@ func createRouter(eng *engine.Engine, logging, enableCors bool, dockerVersion st
|
||||||
"/containers/{name:.*}/json": getContainersByName,
|
"/containers/{name:.*}/json": getContainersByName,
|
||||||
"/containers/{name:.*}/top": getContainersTop,
|
"/containers/{name:.*}/top": getContainersTop,
|
||||||
"/containers/{name:.*}/logs": getContainersLogs,
|
"/containers/{name:.*}/logs": getContainersLogs,
|
||||||
|
"/containers/{name:.*}/stats": getContainersStats,
|
||||||
"/containers/{name:.*}/attach/ws": wsContainersAttach,
|
"/containers/{name:.*}/attach/ws": wsContainersAttach,
|
||||||
"/exec/{id:.*}/json": getExecByID,
|
"/exec/{id:.*}/json": getExecByID,
|
||||||
},
|
},
|
||||||
|
@ -1300,6 +1343,7 @@ func createRouter(eng *engine.Engine, logging, enableCors bool, dockerVersion st
|
||||||
"/containers/{name:.*}/exec": postContainerExecCreate,
|
"/containers/{name:.*}/exec": postContainerExecCreate,
|
||||||
"/exec/{name:.*}/start": postContainerExecStart,
|
"/exec/{name:.*}/start": postContainerExecStart,
|
||||||
"/exec/{name:.*}/resize": postContainerExecResize,
|
"/exec/{name:.*}/resize": postContainerExecResize,
|
||||||
|
"/containers/{name:.*}/rename": postContainerRename,
|
||||||
},
|
},
|
||||||
"DELETE": {
|
"DELETE": {
|
||||||
"/containers/{name:.*}": deleteContainers,
|
"/containers/{name:.*}": deleteContainers,
|
||||||
|
@ -1331,30 +1375,23 @@ func createRouter(eng *engine.Engine, logging, enableCors bool, dockerVersion st
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return r, nil
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServeRequest processes a single http request to the docker remote api.
|
// ServeRequest processes a single http request to the docker remote api.
|
||||||
// FIXME: refactor this to be part of Server and not require re-creating a new
|
// FIXME: refactor this to be part of Server and not require re-creating a new
|
||||||
// router each time. This requires first moving ListenAndServe into Server.
|
// router each time. This requires first moving ListenAndServe into Server.
|
||||||
func ServeRequest(eng *engine.Engine, apiversion version.Version, w http.ResponseWriter, req *http.Request) error {
|
func ServeRequest(eng *engine.Engine, apiversion version.Version, w http.ResponseWriter, req *http.Request) {
|
||||||
router, err := createRouter(eng, false, true, "")
|
router := createRouter(eng, false, true, "")
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
// Insert APIVERSION into the request as a convenience
|
// Insert APIVERSION into the request as a convenience
|
||||||
req.URL.Path = fmt.Sprintf("/v%s%s", apiversion, req.URL.Path)
|
req.URL.Path = fmt.Sprintf("/v%s%s", apiversion, req.URL.Path)
|
||||||
router.ServeHTTP(w, req)
|
router.ServeHTTP(w, req)
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// serveFd creates an http.Server and sets it up to serve given a socket activated
|
// serveFd creates an http.Server and sets it up to serve given a socket activated
|
||||||
// argument.
|
// argument.
|
||||||
func serveFd(addr string, job *engine.Job) error {
|
func serveFd(addr string, job *engine.Job) error {
|
||||||
r, err := createRouter(job.Eng, job.GetenvBool("Logging"), job.GetenvBool("EnableCors"), job.Getenv("Version"))
|
r := createRouter(job.Eng, job.GetenvBool("Logging"), job.GetenvBool("EnableCors"), job.Getenv("Version"))
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
ls, e := systemd.ListenFD(addr)
|
ls, e := systemd.ListenFD(addr)
|
||||||
if e != nil {
|
if e != nil {
|
||||||
|
@ -1389,7 +1426,7 @@ func serveFd(addr string, job *engine.Job) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func lookupGidByName(nameOrGid string) (int, error) {
|
func lookupGidByName(nameOrGid string) (int, error) {
|
||||||
groupFile, err := user.GetGroupFile()
|
groupFile, err := user.GetGroupPath()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return -1, err
|
return -1, err
|
||||||
}
|
}
|
||||||
|
@ -1466,10 +1503,7 @@ func setSocketGroup(addr, group string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupUnixHttp(addr string, job *engine.Job) (*HttpServer, error) {
|
func setupUnixHttp(addr string, job *engine.Job) (*HttpServer, error) {
|
||||||
r, err := createRouter(job.Eng, job.GetenvBool("Logging"), job.GetenvBool("EnableCors"), job.Getenv("Version"))
|
r := createRouter(job.Eng, job.GetenvBool("Logging"), job.GetenvBool("EnableCors"), job.Getenv("Version"))
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := syscall.Unlink(addr); err != nil && !os.IsNotExist(err) {
|
if err := syscall.Unlink(addr); err != nil && !os.IsNotExist(err) {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -1493,18 +1527,45 @@ func setupUnixHttp(addr string, job *engine.Job) (*HttpServer, error) {
|
||||||
return &HttpServer{&http.Server{Addr: addr, Handler: r}, l}, nil
|
return &HttpServer{&http.Server{Addr: addr, Handler: r}, l}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func allocateDaemonPort(addr string) error {
|
||||||
|
host, port, err := net.SplitHostPort(addr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
intPort, err := strconv.Atoi(port)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var hostIPs []net.IP
|
||||||
|
if parsedIP := net.ParseIP(host); parsedIP != nil {
|
||||||
|
hostIPs = append(hostIPs, parsedIP)
|
||||||
|
} else if hostIPs, err = net.LookupIP(host); err != nil {
|
||||||
|
return fmt.Errorf("failed to lookup %s address in host specification", host)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, hostIP := range hostIPs {
|
||||||
|
if _, err := portallocator.RequestPort(hostIP, "tcp", intPort); err != nil {
|
||||||
|
return fmt.Errorf("failed to allocate daemon listening port %d (err: %v)", intPort, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func setupTcpHttp(addr string, job *engine.Job) (*HttpServer, error) {
|
func setupTcpHttp(addr string, job *engine.Job) (*HttpServer, error) {
|
||||||
if !strings.HasPrefix(addr, "127.0.0.1") && !job.GetenvBool("TlsVerify") {
|
if !strings.HasPrefix(addr, "127.0.0.1") && !job.GetenvBool("TlsVerify") {
|
||||||
log.Infof("/!\\ DON'T BIND ON ANOTHER IP ADDRESS THAN 127.0.0.1 IF YOU DON'T KNOW WHAT YOU'RE DOING /!\\")
|
log.Infof("/!\\ DON'T BIND ON ANOTHER IP ADDRESS THAN 127.0.0.1 IF YOU DON'T KNOW WHAT YOU'RE DOING /!\\")
|
||||||
}
|
}
|
||||||
|
|
||||||
r, err := createRouter(job.Eng, job.GetenvBool("Logging"), job.GetenvBool("EnableCors"), job.Getenv("Version"))
|
r := createRouter(job.Eng, job.GetenvBool("Logging"), job.GetenvBool("EnableCors"), job.Getenv("Version"))
|
||||||
|
|
||||||
|
l, err := newListener("tcp", addr, job.GetenvBool("BufferRequests"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
l, err := newListener("tcp", addr, job.GetenvBool("BufferRequests"))
|
if err := allocateDaemonPort(addr); err != nil {
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -484,9 +484,7 @@ func serveRequestUsingVersion(method, target string, version version.Version, bo
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if err := ServeRequest(eng, version, r, req); err != nil {
|
ServeRequest(eng, version, r, req)
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,87 @@
|
||||||
|
// This package is used for API stability in the types and response to the
|
||||||
|
// consumers of the API stats endpoint.
|
||||||
|
package stats
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
type ThrottlingData struct {
|
||||||
|
// Number of periods with throttling active
|
||||||
|
Periods uint64 `json:"periods"`
|
||||||
|
// Number of periods when the container hit its throttling limit.
|
||||||
|
ThrottledPeriods uint64 `json:"throttled_periods"`
|
||||||
|
// Aggregate time the container was throttled for in nanoseconds.
|
||||||
|
ThrottledTime uint64 `json:"throttled_time"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// All CPU stats are aggregated since container inception.
|
||||||
|
type CpuUsage struct {
|
||||||
|
// Total CPU time consumed.
|
||||||
|
// Units: nanoseconds.
|
||||||
|
TotalUsage uint64 `json:"total_usage"`
|
||||||
|
// Total CPU time consumed per core.
|
||||||
|
// Units: nanoseconds.
|
||||||
|
PercpuUsage []uint64 `json:"percpu_usage"`
|
||||||
|
// Time spent by tasks of the cgroup in kernel mode.
|
||||||
|
// Units: nanoseconds.
|
||||||
|
UsageInKernelmode uint64 `json:"usage_in_kernelmode"`
|
||||||
|
// Time spent by tasks of the cgroup in user mode.
|
||||||
|
// Units: nanoseconds.
|
||||||
|
UsageInUsermode uint64 `json:"usage_in_usermode"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type CpuStats struct {
|
||||||
|
CpuUsage CpuUsage `json:"cpu_usage"`
|
||||||
|
SystemUsage uint64 `json:"system_cpu_usage"`
|
||||||
|
ThrottlingData ThrottlingData `json:"throttling_data,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type MemoryStats struct {
|
||||||
|
// current res_counter usage for memory
|
||||||
|
Usage uint64 `json:"usage"`
|
||||||
|
// maximum usage ever recorded.
|
||||||
|
MaxUsage uint64 `json:"max_usage"`
|
||||||
|
// TODO(vishh): Export these as stronger types.
|
||||||
|
// all the stats exported via memory.stat.
|
||||||
|
Stats map[string]uint64 `json:"stats"`
|
||||||
|
// number of times memory usage hits limits.
|
||||||
|
Failcnt uint64 `json:"failcnt"`
|
||||||
|
Limit uint64 `json:"limit"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type BlkioStatEntry struct {
|
||||||
|
Major uint64 `json:"major"`
|
||||||
|
Minor uint64 `json:"minor"`
|
||||||
|
Op string `json:"op"`
|
||||||
|
Value uint64 `json:"value"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type BlkioStats struct {
|
||||||
|
// number of bytes tranferred to and from the block device
|
||||||
|
IoServiceBytesRecursive []BlkioStatEntry `json:"io_service_bytes_recursive"`
|
||||||
|
IoServicedRecursive []BlkioStatEntry `json:"io_serviced_recursive"`
|
||||||
|
IoQueuedRecursive []BlkioStatEntry `json:"io_queue_recursive"`
|
||||||
|
IoServiceTimeRecursive []BlkioStatEntry `json:"io_service_time_recursive"`
|
||||||
|
IoWaitTimeRecursive []BlkioStatEntry `json:"io_wait_time_recursive"`
|
||||||
|
IoMergedRecursive []BlkioStatEntry `json:"io_merged_recursive"`
|
||||||
|
IoTimeRecursive []BlkioStatEntry `json:"io_time_recursive"`
|
||||||
|
SectorsRecursive []BlkioStatEntry `json:"sectors_recursive"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Network struct {
|
||||||
|
RxBytes uint64 `json:"rx_bytes"`
|
||||||
|
RxPackets uint64 `json:"rx_packets"`
|
||||||
|
RxErrors uint64 `json:"rx_errors"`
|
||||||
|
RxDropped uint64 `json:"rx_dropped"`
|
||||||
|
TxBytes uint64 `json:"tx_bytes"`
|
||||||
|
TxPackets uint64 `json:"tx_packets"`
|
||||||
|
TxErrors uint64 `json:"tx_errors"`
|
||||||
|
TxDropped uint64 `json:"tx_dropped"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Stats struct {
|
||||||
|
Read time.Time `json:"read"`
|
||||||
|
Network Network `json:"network,omitempty"`
|
||||||
|
CpuStats CpuStats `json:"cpu_stats,omitempty"`
|
||||||
|
MemoryStats MemoryStats `json:"memory_stats,omitempty"`
|
||||||
|
BlkioStats BlkioStats `json:"blkio_stats,omitempty"`
|
||||||
|
}
|
|
@ -9,7 +9,7 @@ var (
|
||||||
GITCOMMIT string
|
GITCOMMIT string
|
||||||
VERSION string
|
VERSION string
|
||||||
|
|
||||||
IAMSTATIC bool // whether or not Docker itself was compiled statically via ./hack/make.sh binary
|
IAMSTATIC string // whether or not Docker itself was compiled statically via ./hack/make.sh binary ("true" or not "true")
|
||||||
INITSHA1 string // sha1sum of separate static dockerinit, if Docker itself was compiled dynamically via ./hack/make.sh dynbinary
|
INITSHA1 string // sha1sum of separate static dockerinit, if Docker itself was compiled dynamically via ./hack/make.sh dynbinary
|
||||||
INITPATH string // custom location to search for a valid dockerinit binary (available for packagers as a last resort escape hatch)
|
INITPATH string // custom location to search for a valid dockerinit binary (available for packagers as a last resort escape hatch)
|
||||||
)
|
)
|
||||||
|
|
|
@ -4,6 +4,8 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/docker/docker/pkg/ioutils"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestRegister(t *testing.T) {
|
func TestRegister(t *testing.T) {
|
||||||
|
@ -150,3 +152,85 @@ func TestCatchallEmptyName(t *testing.T) {
|
||||||
t.Fatalf("Engine.Job(\"\").Run() should return an error")
|
t.Fatalf("Engine.Job(\"\").Run() should return an error")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Ensure that a job within a job both using the same underlying standard
|
||||||
|
// output writer does not close the output of the outer job when the inner
|
||||||
|
// job's stdout is wrapped with a NopCloser. When not wrapped, it should
|
||||||
|
// close the outer job's output.
|
||||||
|
func TestNestedJobSharedOutput(t *testing.T) {
|
||||||
|
var (
|
||||||
|
outerHandler Handler
|
||||||
|
innerHandler Handler
|
||||||
|
wrapOutput bool
|
||||||
|
)
|
||||||
|
|
||||||
|
outerHandler = func(job *Job) Status {
|
||||||
|
job.Stdout.Write([]byte("outer1"))
|
||||||
|
|
||||||
|
innerJob := job.Eng.Job("innerJob")
|
||||||
|
|
||||||
|
if wrapOutput {
|
||||||
|
innerJob.Stdout.Add(ioutils.NopWriteCloser(job.Stdout))
|
||||||
|
} else {
|
||||||
|
innerJob.Stdout.Add(job.Stdout)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := innerJob.Run(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If wrapOutput was *false* this write will do nothing.
|
||||||
|
// FIXME (jlhawn): It should cause an error to write to
|
||||||
|
// closed output.
|
||||||
|
job.Stdout.Write([]byte(" outer2"))
|
||||||
|
|
||||||
|
return StatusOK
|
||||||
|
}
|
||||||
|
|
||||||
|
innerHandler = func(job *Job) Status {
|
||||||
|
job.Stdout.Write([]byte(" inner"))
|
||||||
|
|
||||||
|
return StatusOK
|
||||||
|
}
|
||||||
|
|
||||||
|
eng := New()
|
||||||
|
eng.Register("outerJob", outerHandler)
|
||||||
|
eng.Register("innerJob", innerHandler)
|
||||||
|
|
||||||
|
// wrapOutput starts *false* so the expected
|
||||||
|
// output of running the outer job will be:
|
||||||
|
//
|
||||||
|
// "outer1 inner"
|
||||||
|
//
|
||||||
|
outBuf := new(bytes.Buffer)
|
||||||
|
outerJob := eng.Job("outerJob")
|
||||||
|
outerJob.Stdout.Add(outBuf)
|
||||||
|
|
||||||
|
if err := outerJob.Run(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedOutput := "outer1 inner"
|
||||||
|
if outBuf.String() != expectedOutput {
|
||||||
|
t.Fatalf("expected job output to be %q, got %q", expectedOutput, outBuf.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set wrapOutput to true so that the expected
|
||||||
|
// output of running the outer job will be:
|
||||||
|
//
|
||||||
|
// "outer1 inner outer2"
|
||||||
|
//
|
||||||
|
wrapOutput = true
|
||||||
|
outBuf.Reset()
|
||||||
|
outerJob = eng.Job("outerJob")
|
||||||
|
outerJob.Stdout.Add(outBuf)
|
||||||
|
|
||||||
|
if err := outerJob.Run(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedOutput = "outer1 inner outer2"
|
||||||
|
if outBuf.String() != expectedOutput {
|
||||||
|
t.Fatalf("expected job output to be %q, got %q", expectedOutput, outBuf.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -111,6 +111,7 @@ func (o *Output) Close() error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
o.tasks.Wait()
|
o.tasks.Wait()
|
||||||
|
o.dests = nil
|
||||||
return firstErr
|
return firstErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -30,8 +30,8 @@ type (
|
||||||
ArchiveReader io.Reader
|
ArchiveReader io.Reader
|
||||||
Compression int
|
Compression int
|
||||||
TarOptions struct {
|
TarOptions struct {
|
||||||
Includes []string
|
IncludeFiles []string
|
||||||
Excludes []string
|
ExcludePatterns []string
|
||||||
Compression Compression
|
Compression Compression
|
||||||
NoLchown bool
|
NoLchown bool
|
||||||
Name string
|
Name string
|
||||||
|
@ -101,7 +101,6 @@ func DecompressStream(archive io.Reader) (io.ReadCloser, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
log.Debugf("[tar autodetect] n: %v", bs)
|
|
||||||
|
|
||||||
compression := DetectCompression(bs)
|
compression := DetectCompression(bs)
|
||||||
switch compression {
|
switch compression {
|
||||||
|
@ -378,7 +377,7 @@ func escapeName(name string) string {
|
||||||
}
|
}
|
||||||
|
|
||||||
// TarWithOptions creates an archive from the directory at `path`, only including files whose relative
|
// TarWithOptions creates an archive from the directory at `path`, only including files whose relative
|
||||||
// paths are included in `options.Includes` (if non-nil) or not in `options.Excludes`.
|
// paths are included in `options.IncludeFiles` (if non-nil) or not in `options.ExcludePatterns`.
|
||||||
func TarWithOptions(srcPath string, options *TarOptions) (io.ReadCloser, error) {
|
func TarWithOptions(srcPath string, options *TarOptions) (io.ReadCloser, error) {
|
||||||
pipeReader, pipeWriter := io.Pipe()
|
pipeReader, pipeWriter := io.Pipe()
|
||||||
|
|
||||||
|
@ -401,12 +400,14 @@ func TarWithOptions(srcPath string, options *TarOptions) (io.ReadCloser, error)
|
||||||
// mutating the filesystem and we can see transient errors
|
// mutating the filesystem and we can see transient errors
|
||||||
// from this
|
// from this
|
||||||
|
|
||||||
if options.Includes == nil {
|
if options.IncludeFiles == nil {
|
||||||
options.Includes = []string{"."}
|
options.IncludeFiles = []string{"."}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
seen := make(map[string]bool)
|
||||||
|
|
||||||
var renamedRelFilePath string // For when tar.Options.Name is set
|
var renamedRelFilePath string // For when tar.Options.Name is set
|
||||||
for _, include := range options.Includes {
|
for _, include := range options.IncludeFiles {
|
||||||
filepath.Walk(filepath.Join(srcPath, include), func(filePath string, f os.FileInfo, err error) error {
|
filepath.Walk(filepath.Join(srcPath, include), func(filePath string, f os.FileInfo, err error) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("Tar: Can't stat file %s to tar: %s", srcPath, err)
|
log.Debugf("Tar: Can't stat file %s to tar: %s", srcPath, err)
|
||||||
|
@ -420,11 +421,20 @@ func TarWithOptions(srcPath string, options *TarOptions) (io.ReadCloser, error)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
skip, err := fileutils.Matches(relFilePath, options.Excludes)
|
skip := false
|
||||||
|
|
||||||
|
// If "include" is an exact match for the current file
|
||||||
|
// then even if there's an "excludePatterns" pattern that
|
||||||
|
// matches it, don't skip it. IOW, assume an explicit 'include'
|
||||||
|
// is asking for that file no matter what - which is true
|
||||||
|
// for some files, like .dockerignore and Dockerfile (sometimes)
|
||||||
|
if include != relFilePath {
|
||||||
|
skip, err = fileutils.Matches(relFilePath, options.ExcludePatterns)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("Error matching %s", relFilePath, err)
|
log.Debugf("Error matching %s", relFilePath, err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if skip {
|
if skip {
|
||||||
if f.IsDir() {
|
if f.IsDir() {
|
||||||
|
@ -433,6 +443,11 @@ func TarWithOptions(srcPath string, options *TarOptions) (io.ReadCloser, error)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if seen[relFilePath] {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
seen[relFilePath] = true
|
||||||
|
|
||||||
// Rename the base resource
|
// Rename the base resource
|
||||||
if options.Name != "" && filePath == srcPath+"/"+filepath.Base(relFilePath) {
|
if options.Name != "" && filePath == srcPath+"/"+filepath.Base(relFilePath) {
|
||||||
renamedRelFilePath = relFilePath
|
renamedRelFilePath = relFilePath
|
||||||
|
@ -487,7 +502,7 @@ loop:
|
||||||
// This keeps "../" as-is, but normalizes "/../" to "/"
|
// This keeps "../" as-is, but normalizes "/../" to "/"
|
||||||
hdr.Name = filepath.Clean(hdr.Name)
|
hdr.Name = filepath.Clean(hdr.Name)
|
||||||
|
|
||||||
for _, exclude := range options.Excludes {
|
for _, exclude := range options.ExcludePatterns {
|
||||||
if strings.HasPrefix(hdr.Name, exclude) {
|
if strings.HasPrefix(hdr.Name, exclude) {
|
||||||
continue loop
|
continue loop
|
||||||
}
|
}
|
||||||
|
@ -563,8 +578,8 @@ func Untar(archive io.Reader, dest string, options *TarOptions) error {
|
||||||
if options == nil {
|
if options == nil {
|
||||||
options = &TarOptions{}
|
options = &TarOptions{}
|
||||||
}
|
}
|
||||||
if options.Excludes == nil {
|
if options.ExcludePatterns == nil {
|
||||||
options.Excludes = []string{}
|
options.ExcludePatterns = []string{}
|
||||||
}
|
}
|
||||||
decompressedArchive, err := DecompressStream(archive)
|
decompressedArchive, err := DecompressStream(archive)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -166,7 +166,7 @@ func TestTarUntar(t *testing.T) {
|
||||||
} {
|
} {
|
||||||
changes, err := tarUntar(t, origin, &TarOptions{
|
changes, err := tarUntar(t, origin, &TarOptions{
|
||||||
Compression: c,
|
Compression: c,
|
||||||
Excludes: []string{"3"},
|
ExcludePatterns: []string{"3"},
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -196,8 +196,8 @@ func TestTarWithOptions(t *testing.T) {
|
||||||
opts *TarOptions
|
opts *TarOptions
|
||||||
numChanges int
|
numChanges int
|
||||||
}{
|
}{
|
||||||
{&TarOptions{Includes: []string{"1"}}, 1},
|
{&TarOptions{IncludeFiles: []string{"1"}}, 1},
|
||||||
{&TarOptions{Excludes: []string{"2"}}, 1},
|
{&TarOptions{ExcludePatterns: []string{"2"}}, 1},
|
||||||
}
|
}
|
||||||
for _, testCase := range cases {
|
for _, testCase := range cases {
|
||||||
changes, err := tarUntar(t, origin, testCase.opts)
|
changes, err := tarUntar(t, origin, testCase.opts)
|
||||||
|
|
|
@ -286,7 +286,7 @@ func TestApplyLayer(t *testing.T) {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := ApplyLayer(src, layerCopy); err != nil {
|
if _, err := ApplyLayer(src, layerCopy); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -15,7 +15,7 @@ import (
|
||||||
"github.com/docker/docker/pkg/system"
|
"github.com/docker/docker/pkg/system"
|
||||||
)
|
)
|
||||||
|
|
||||||
func UnpackLayer(dest string, layer ArchiveReader) error {
|
func UnpackLayer(dest string, layer ArchiveReader) (size int64, err error) {
|
||||||
tr := tar.NewReader(layer)
|
tr := tar.NewReader(layer)
|
||||||
trBuf := pools.BufioReader32KPool.Get(tr)
|
trBuf := pools.BufioReader32KPool.Get(tr)
|
||||||
defer pools.BufioReader32KPool.Put(trBuf)
|
defer pools.BufioReader32KPool.Put(trBuf)
|
||||||
|
@ -33,9 +33,11 @@ func UnpackLayer(dest string, layer ArchiveReader) error {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size += hdr.Size
|
||||||
|
|
||||||
// Normalize name, for safety and for a simple is-root check
|
// Normalize name, for safety and for a simple is-root check
|
||||||
hdr.Name = filepath.Clean(hdr.Name)
|
hdr.Name = filepath.Clean(hdr.Name)
|
||||||
|
|
||||||
|
@ -48,7 +50,7 @@ func UnpackLayer(dest string, layer ArchiveReader) error {
|
||||||
if _, err := os.Lstat(parentPath); err != nil && os.IsNotExist(err) {
|
if _, err := os.Lstat(parentPath); err != nil && os.IsNotExist(err) {
|
||||||
err = os.MkdirAll(parentPath, 0600)
|
err = os.MkdirAll(parentPath, 0600)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return 0, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -63,12 +65,12 @@ func UnpackLayer(dest string, layer ArchiveReader) error {
|
||||||
aufsHardlinks[basename] = hdr
|
aufsHardlinks[basename] = hdr
|
||||||
if aufsTempdir == "" {
|
if aufsTempdir == "" {
|
||||||
if aufsTempdir, err = ioutil.TempDir("", "dockerplnk"); err != nil {
|
if aufsTempdir, err = ioutil.TempDir("", "dockerplnk"); err != nil {
|
||||||
return err
|
return 0, err
|
||||||
}
|
}
|
||||||
defer os.RemoveAll(aufsTempdir)
|
defer os.RemoveAll(aufsTempdir)
|
||||||
}
|
}
|
||||||
if err := createTarFile(filepath.Join(aufsTempdir, basename), dest, hdr, tr, true); err != nil {
|
if err := createTarFile(filepath.Join(aufsTempdir, basename), dest, hdr, tr, true); err != nil {
|
||||||
return err
|
return 0, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
|
@ -77,10 +79,10 @@ func UnpackLayer(dest string, layer ArchiveReader) error {
|
||||||
path := filepath.Join(dest, hdr.Name)
|
path := filepath.Join(dest, hdr.Name)
|
||||||
rel, err := filepath.Rel(dest, path)
|
rel, err := filepath.Rel(dest, path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return 0, err
|
||||||
}
|
}
|
||||||
if strings.HasPrefix(rel, "..") {
|
if strings.HasPrefix(rel, "..") {
|
||||||
return breakoutError(fmt.Errorf("%q is outside of %q", hdr.Name, dest))
|
return 0, breakoutError(fmt.Errorf("%q is outside of %q", hdr.Name, dest))
|
||||||
}
|
}
|
||||||
base := filepath.Base(path)
|
base := filepath.Base(path)
|
||||||
|
|
||||||
|
@ -88,7 +90,7 @@ func UnpackLayer(dest string, layer ArchiveReader) error {
|
||||||
originalBase := base[len(".wh."):]
|
originalBase := base[len(".wh."):]
|
||||||
originalPath := filepath.Join(filepath.Dir(path), originalBase)
|
originalPath := filepath.Join(filepath.Dir(path), originalBase)
|
||||||
if err := os.RemoveAll(originalPath); err != nil {
|
if err := os.RemoveAll(originalPath); err != nil {
|
||||||
return err
|
return 0, err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// If path exits we almost always just want to remove and replace it.
|
// If path exits we almost always just want to remove and replace it.
|
||||||
|
@ -98,7 +100,7 @@ func UnpackLayer(dest string, layer ArchiveReader) error {
|
||||||
if fi, err := os.Lstat(path); err == nil {
|
if fi, err := os.Lstat(path); err == nil {
|
||||||
if !(fi.IsDir() && hdr.Typeflag == tar.TypeDir) {
|
if !(fi.IsDir() && hdr.Typeflag == tar.TypeDir) {
|
||||||
if err := os.RemoveAll(path); err != nil {
|
if err := os.RemoveAll(path); err != nil {
|
||||||
return err
|
return 0, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -113,18 +115,18 @@ func UnpackLayer(dest string, layer ArchiveReader) error {
|
||||||
linkBasename := filepath.Base(hdr.Linkname)
|
linkBasename := filepath.Base(hdr.Linkname)
|
||||||
srcHdr = aufsHardlinks[linkBasename]
|
srcHdr = aufsHardlinks[linkBasename]
|
||||||
if srcHdr == nil {
|
if srcHdr == nil {
|
||||||
return fmt.Errorf("Invalid aufs hardlink")
|
return 0, fmt.Errorf("Invalid aufs hardlink")
|
||||||
}
|
}
|
||||||
tmpFile, err := os.Open(filepath.Join(aufsTempdir, linkBasename))
|
tmpFile, err := os.Open(filepath.Join(aufsTempdir, linkBasename))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return 0, err
|
||||||
}
|
}
|
||||||
defer tmpFile.Close()
|
defer tmpFile.Close()
|
||||||
srcData = tmpFile
|
srcData = tmpFile
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := createTarFile(path, dest, srcHdr, srcData, true); err != nil {
|
if err := createTarFile(path, dest, srcHdr, srcData, true); err != nil {
|
||||||
return err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Directory mtimes must be handled at the end to avoid further
|
// Directory mtimes must be handled at the end to avoid further
|
||||||
|
@ -139,27 +141,29 @@ func UnpackLayer(dest string, layer ArchiveReader) error {
|
||||||
path := filepath.Join(dest, hdr.Name)
|
path := filepath.Join(dest, hdr.Name)
|
||||||
ts := []syscall.Timespec{timeToTimespec(hdr.AccessTime), timeToTimespec(hdr.ModTime)}
|
ts := []syscall.Timespec{timeToTimespec(hdr.AccessTime), timeToTimespec(hdr.ModTime)}
|
||||||
if err := syscall.UtimesNano(path, ts); err != nil {
|
if err := syscall.UtimesNano(path, ts); err != nil {
|
||||||
return err
|
return 0, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
|
return size, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ApplyLayer parses a diff in the standard layer format from `layer`, and
|
// ApplyLayer parses a diff in the standard layer format from `layer`, and
|
||||||
// applies it to the directory `dest`.
|
// applies it to the directory `dest`. Returns the size in bytes of the
|
||||||
func ApplyLayer(dest string, layer ArchiveReader) error {
|
// contents of the layer.
|
||||||
|
func ApplyLayer(dest string, layer ArchiveReader) (int64, error) {
|
||||||
dest = filepath.Clean(dest)
|
dest = filepath.Clean(dest)
|
||||||
|
|
||||||
// We need to be able to set any perms
|
// We need to be able to set any perms
|
||||||
oldmask, err := system.Umask(0)
|
oldmask, err := system.Umask(0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return 0, err
|
||||||
}
|
}
|
||||||
defer system.Umask(oldmask) // ignore err, ErrNotSupportedPlatform
|
defer system.Umask(oldmask) // ignore err, ErrNotSupportedPlatform
|
||||||
|
|
||||||
layer, err = DecompressStream(layer)
|
layer, err = DecompressStream(layer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return 0, err
|
||||||
}
|
}
|
||||||
return UnpackLayer(dest, layer)
|
return UnpackLayer(dest, layer)
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,7 +17,8 @@ var testUntarFns = map[string]func(string, io.Reader) error{
|
||||||
return Untar(r, dest, nil)
|
return Untar(r, dest, nil)
|
||||||
},
|
},
|
||||||
"applylayer": func(dest string, r io.Reader) error {
|
"applylayer": func(dest string, r io.Reader) error {
|
||||||
return ApplyLayer(dest, ArchiveReader(r))
|
_, err := ApplyLayer(dest, ArchiveReader(r))
|
||||||
|
return err
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,27 @@
|
||||||
|
Copyright (c) 2014-2015 The Docker & Go Authors. All rights reserved.
|
||||||
|
|
||||||
|
Redistribution and use in source and binary forms, with or without
|
||||||
|
modification, are permitted provided that the following conditions are
|
||||||
|
met:
|
||||||
|
|
||||||
|
* Redistributions of source code must retain the above copyright
|
||||||
|
notice, this list of conditions and the following disclaimer.
|
||||||
|
* Redistributions in binary form must reproduce the above
|
||||||
|
copyright notice, this list of conditions and the following disclaimer
|
||||||
|
in the documentation and/or other materials provided with the
|
||||||
|
distribution.
|
||||||
|
* Neither the name of Google Inc. nor the names of its
|
||||||
|
contributors may be used to endorse or promote products derived from
|
||||||
|
this software without specific prior written permission.
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||||
|
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||||
|
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||||
|
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||||
|
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||||
|
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||||
|
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||||
|
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||||
|
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
|
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
1
Godeps/_workspace/src/github.com/docker/docker/pkg/mflag/MAINTAINERS
generated
vendored
Normal file
1
Godeps/_workspace/src/github.com/docker/docker/pkg/mflag/MAINTAINERS
generated
vendored
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Victor Vieux <vieux@docker.com> (@vieux)
|
40
Godeps/_workspace/src/github.com/docker/docker/pkg/mflag/README.md
generated
vendored
Normal file
40
Godeps/_workspace/src/github.com/docker/docker/pkg/mflag/README.md
generated
vendored
Normal file
|
@ -0,0 +1,40 @@
|
||||||
|
Package mflag (aka multiple-flag) implements command-line flag parsing.
|
||||||
|
It's an **hacky** fork of the [official golang package](http://golang.org/pkg/flag/)
|
||||||
|
|
||||||
|
It adds:
|
||||||
|
|
||||||
|
* both short and long flag version
|
||||||
|
`./example -s red` `./example --string blue`
|
||||||
|
|
||||||
|
* multiple names for the same option
|
||||||
|
```
|
||||||
|
$>./example -h
|
||||||
|
Usage of example:
|
||||||
|
-s, --string="": a simple string
|
||||||
|
```
|
||||||
|
|
||||||
|
___
|
||||||
|
It is very flexible on purpose, so you can do things like:
|
||||||
|
```
|
||||||
|
$>./example -h
|
||||||
|
Usage of example:
|
||||||
|
-s, -string, --string="": a simple string
|
||||||
|
```
|
||||||
|
|
||||||
|
Or:
|
||||||
|
```
|
||||||
|
$>./example -h
|
||||||
|
Usage of example:
|
||||||
|
-oldflag, --newflag="": a simple string
|
||||||
|
```
|
||||||
|
|
||||||
|
You can also hide some flags from the usage, so if we want only `--newflag`:
|
||||||
|
```
|
||||||
|
$>./example -h
|
||||||
|
Usage of example:
|
||||||
|
--newflag="": a simple string
|
||||||
|
$>./example -oldflag str
|
||||||
|
str
|
||||||
|
```
|
||||||
|
|
||||||
|
See [example.go](example/example.go) for more details.
|
36
Godeps/_workspace/src/github.com/docker/docker/pkg/mflag/example/example.go
generated
vendored
Normal file
36
Godeps/_workspace/src/github.com/docker/docker/pkg/mflag/example/example.go
generated
vendored
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
flag "github.com/docker/docker/pkg/mflag"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
i int
|
||||||
|
str string
|
||||||
|
b, b2, h bool
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
flag.Bool([]string{"#hp", "#-halp"}, false, "display the halp")
|
||||||
|
flag.BoolVar(&b, []string{"b", "#bal", "#bol", "-bal"}, false, "a simple bool")
|
||||||
|
flag.BoolVar(&b, []string{"g", "#gil"}, false, "a simple bool")
|
||||||
|
flag.BoolVar(&b2, []string{"#-bool"}, false, "a simple bool")
|
||||||
|
flag.IntVar(&i, []string{"-integer", "-number"}, -1, "a simple integer")
|
||||||
|
flag.StringVar(&str, []string{"s", "#hidden", "-string"}, "", "a simple string") //-s -hidden and --string will work, but -hidden won't be in the usage
|
||||||
|
flag.BoolVar(&h, []string{"h", "#help", "-help"}, false, "display the help")
|
||||||
|
flag.StringVar(&str, []string{"mode"}, "mode1", "set the mode\nmode1: use the mode1\nmode2: use the mode2\nmode3: use the mode3")
|
||||||
|
flag.Parse()
|
||||||
|
}
|
||||||
|
func main() {
|
||||||
|
if h {
|
||||||
|
flag.PrintDefaults()
|
||||||
|
} else {
|
||||||
|
fmt.Printf("s/#hidden/-string: %s\n", str)
|
||||||
|
fmt.Printf("b: %t\n", b)
|
||||||
|
fmt.Printf("-bool: %t\n", b2)
|
||||||
|
fmt.Printf("s/#hidden/-string(via lookup): %s\n", flag.Lookup("s").Value.String())
|
||||||
|
fmt.Printf("ARGS: %v\n", flag.Args())
|
||||||
|
}
|
||||||
|
}
|
1084
Godeps/_workspace/src/github.com/docker/docker/pkg/mflag/flag.go
generated
vendored
Normal file
1084
Godeps/_workspace/src/github.com/docker/docker/pkg/mflag/flag.go
generated
vendored
Normal file
File diff suppressed because it is too large
Load Diff
516
Godeps/_workspace/src/github.com/docker/docker/pkg/mflag/flag_test.go
generated
vendored
Normal file
516
Godeps/_workspace/src/github.com/docker/docker/pkg/mflag/flag_test.go
generated
vendored
Normal file
|
@ -0,0 +1,516 @@
|
||||||
|
// Copyright 2014-2015 The Docker & Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package mflag
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ResetForTesting clears all flag state and sets the usage function as directed.
|
||||||
|
// After calling ResetForTesting, parse errors in flag handling will not
|
||||||
|
// exit the program.
|
||||||
|
func ResetForTesting(usage func()) {
|
||||||
|
CommandLine = NewFlagSet(os.Args[0], ContinueOnError)
|
||||||
|
Usage = usage
|
||||||
|
}
|
||||||
|
func boolString(s string) string {
|
||||||
|
if s == "0" {
|
||||||
|
return "false"
|
||||||
|
}
|
||||||
|
return "true"
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEverything(t *testing.T) {
|
||||||
|
ResetForTesting(nil)
|
||||||
|
Bool([]string{"test_bool"}, false, "bool value")
|
||||||
|
Int([]string{"test_int"}, 0, "int value")
|
||||||
|
Int64([]string{"test_int64"}, 0, "int64 value")
|
||||||
|
Uint([]string{"test_uint"}, 0, "uint value")
|
||||||
|
Uint64([]string{"test_uint64"}, 0, "uint64 value")
|
||||||
|
String([]string{"test_string"}, "0", "string value")
|
||||||
|
Float64([]string{"test_float64"}, 0, "float64 value")
|
||||||
|
Duration([]string{"test_duration"}, 0, "time.Duration value")
|
||||||
|
|
||||||
|
m := make(map[string]*Flag)
|
||||||
|
desired := "0"
|
||||||
|
visitor := func(f *Flag) {
|
||||||
|
for _, name := range f.Names {
|
||||||
|
if len(name) > 5 && name[0:5] == "test_" {
|
||||||
|
m[name] = f
|
||||||
|
ok := false
|
||||||
|
switch {
|
||||||
|
case f.Value.String() == desired:
|
||||||
|
ok = true
|
||||||
|
case name == "test_bool" && f.Value.String() == boolString(desired):
|
||||||
|
ok = true
|
||||||
|
case name == "test_duration" && f.Value.String() == desired+"s":
|
||||||
|
ok = true
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
t.Error("Visit: bad value", f.Value.String(), "for", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
VisitAll(visitor)
|
||||||
|
if len(m) != 8 {
|
||||||
|
t.Error("VisitAll misses some flags")
|
||||||
|
for k, v := range m {
|
||||||
|
t.Log(k, *v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
m = make(map[string]*Flag)
|
||||||
|
Visit(visitor)
|
||||||
|
if len(m) != 0 {
|
||||||
|
t.Errorf("Visit sees unset flags")
|
||||||
|
for k, v := range m {
|
||||||
|
t.Log(k, *v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Now set all flags
|
||||||
|
Set("test_bool", "true")
|
||||||
|
Set("test_int", "1")
|
||||||
|
Set("test_int64", "1")
|
||||||
|
Set("test_uint", "1")
|
||||||
|
Set("test_uint64", "1")
|
||||||
|
Set("test_string", "1")
|
||||||
|
Set("test_float64", "1")
|
||||||
|
Set("test_duration", "1s")
|
||||||
|
desired = "1"
|
||||||
|
Visit(visitor)
|
||||||
|
if len(m) != 8 {
|
||||||
|
t.Error("Visit fails after set")
|
||||||
|
for k, v := range m {
|
||||||
|
t.Log(k, *v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Now test they're visited in sort order.
|
||||||
|
var flagNames []string
|
||||||
|
Visit(func(f *Flag) {
|
||||||
|
for _, name := range f.Names {
|
||||||
|
flagNames = append(flagNames, name)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
if !sort.StringsAreSorted(flagNames) {
|
||||||
|
t.Errorf("flag names not sorted: %v", flagNames)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGet(t *testing.T) {
|
||||||
|
ResetForTesting(nil)
|
||||||
|
Bool([]string{"test_bool"}, true, "bool value")
|
||||||
|
Int([]string{"test_int"}, 1, "int value")
|
||||||
|
Int64([]string{"test_int64"}, 2, "int64 value")
|
||||||
|
Uint([]string{"test_uint"}, 3, "uint value")
|
||||||
|
Uint64([]string{"test_uint64"}, 4, "uint64 value")
|
||||||
|
String([]string{"test_string"}, "5", "string value")
|
||||||
|
Float64([]string{"test_float64"}, 6, "float64 value")
|
||||||
|
Duration([]string{"test_duration"}, 7, "time.Duration value")
|
||||||
|
|
||||||
|
visitor := func(f *Flag) {
|
||||||
|
for _, name := range f.Names {
|
||||||
|
if len(name) > 5 && name[0:5] == "test_" {
|
||||||
|
g, ok := f.Value.(Getter)
|
||||||
|
if !ok {
|
||||||
|
t.Errorf("Visit: value does not satisfy Getter: %T", f.Value)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switch name {
|
||||||
|
case "test_bool":
|
||||||
|
ok = g.Get() == true
|
||||||
|
case "test_int":
|
||||||
|
ok = g.Get() == int(1)
|
||||||
|
case "test_int64":
|
||||||
|
ok = g.Get() == int64(2)
|
||||||
|
case "test_uint":
|
||||||
|
ok = g.Get() == uint(3)
|
||||||
|
case "test_uint64":
|
||||||
|
ok = g.Get() == uint64(4)
|
||||||
|
case "test_string":
|
||||||
|
ok = g.Get() == "5"
|
||||||
|
case "test_float64":
|
||||||
|
ok = g.Get() == float64(6)
|
||||||
|
case "test_duration":
|
||||||
|
ok = g.Get() == time.Duration(7)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
t.Errorf("Visit: bad value %T(%v) for %s", g.Get(), g.Get(), name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
VisitAll(visitor)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testParse(f *FlagSet, t *testing.T) {
|
||||||
|
if f.Parsed() {
|
||||||
|
t.Error("f.Parse() = true before Parse")
|
||||||
|
}
|
||||||
|
boolFlag := f.Bool([]string{"bool"}, false, "bool value")
|
||||||
|
bool2Flag := f.Bool([]string{"bool2"}, false, "bool2 value")
|
||||||
|
f.Bool([]string{"bool3"}, false, "bool3 value")
|
||||||
|
bool4Flag := f.Bool([]string{"bool4"}, false, "bool4 value")
|
||||||
|
intFlag := f.Int([]string{"-int"}, 0, "int value")
|
||||||
|
int64Flag := f.Int64([]string{"-int64"}, 0, "int64 value")
|
||||||
|
uintFlag := f.Uint([]string{"uint"}, 0, "uint value")
|
||||||
|
uint64Flag := f.Uint64([]string{"-uint64"}, 0, "uint64 value")
|
||||||
|
stringFlag := f.String([]string{"string"}, "0", "string value")
|
||||||
|
f.String([]string{"string2"}, "0", "string2 value")
|
||||||
|
singleQuoteFlag := f.String([]string{"squote"}, "", "single quoted value")
|
||||||
|
doubleQuoteFlag := f.String([]string{"dquote"}, "", "double quoted value")
|
||||||
|
mixedQuoteFlag := f.String([]string{"mquote"}, "", "mixed quoted value")
|
||||||
|
mixed2QuoteFlag := f.String([]string{"mquote2"}, "", "mixed2 quoted value")
|
||||||
|
nestedQuoteFlag := f.String([]string{"nquote"}, "", "nested quoted value")
|
||||||
|
nested2QuoteFlag := f.String([]string{"nquote2"}, "", "nested2 quoted value")
|
||||||
|
float64Flag := f.Float64([]string{"float64"}, 0, "float64 value")
|
||||||
|
durationFlag := f.Duration([]string{"duration"}, 5*time.Second, "time.Duration value")
|
||||||
|
extra := "one-extra-argument"
|
||||||
|
args := []string{
|
||||||
|
"-bool",
|
||||||
|
"-bool2=true",
|
||||||
|
"-bool4=false",
|
||||||
|
"--int", "22",
|
||||||
|
"--int64", "0x23",
|
||||||
|
"-uint", "24",
|
||||||
|
"--uint64", "25",
|
||||||
|
"-string", "hello",
|
||||||
|
"-squote='single'",
|
||||||
|
`-dquote="double"`,
|
||||||
|
`-mquote='mixed"`,
|
||||||
|
`-mquote2="mixed2'`,
|
||||||
|
`-nquote="'single nested'"`,
|
||||||
|
`-nquote2='"double nested"'`,
|
||||||
|
"-float64", "2718e28",
|
||||||
|
"-duration", "2m",
|
||||||
|
extra,
|
||||||
|
}
|
||||||
|
if err := f.Parse(args); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !f.Parsed() {
|
||||||
|
t.Error("f.Parse() = false after Parse")
|
||||||
|
}
|
||||||
|
if *boolFlag != true {
|
||||||
|
t.Error("bool flag should be true, is ", *boolFlag)
|
||||||
|
}
|
||||||
|
if *bool2Flag != true {
|
||||||
|
t.Error("bool2 flag should be true, is ", *bool2Flag)
|
||||||
|
}
|
||||||
|
if !f.IsSet("bool2") {
|
||||||
|
t.Error("bool2 should be marked as set")
|
||||||
|
}
|
||||||
|
if f.IsSet("bool3") {
|
||||||
|
t.Error("bool3 should not be marked as set")
|
||||||
|
}
|
||||||
|
if !f.IsSet("bool4") {
|
||||||
|
t.Error("bool4 should be marked as set")
|
||||||
|
}
|
||||||
|
if *bool4Flag != false {
|
||||||
|
t.Error("bool4 flag should be false, is ", *bool4Flag)
|
||||||
|
}
|
||||||
|
if *intFlag != 22 {
|
||||||
|
t.Error("int flag should be 22, is ", *intFlag)
|
||||||
|
}
|
||||||
|
if *int64Flag != 0x23 {
|
||||||
|
t.Error("int64 flag should be 0x23, is ", *int64Flag)
|
||||||
|
}
|
||||||
|
if *uintFlag != 24 {
|
||||||
|
t.Error("uint flag should be 24, is ", *uintFlag)
|
||||||
|
}
|
||||||
|
if *uint64Flag != 25 {
|
||||||
|
t.Error("uint64 flag should be 25, is ", *uint64Flag)
|
||||||
|
}
|
||||||
|
if *stringFlag != "hello" {
|
||||||
|
t.Error("string flag should be `hello`, is ", *stringFlag)
|
||||||
|
}
|
||||||
|
if !f.IsSet("string") {
|
||||||
|
t.Error("string flag should be marked as set")
|
||||||
|
}
|
||||||
|
if f.IsSet("string2") {
|
||||||
|
t.Error("string2 flag should not be marked as set")
|
||||||
|
}
|
||||||
|
if *singleQuoteFlag != "single" {
|
||||||
|
t.Error("single quote string flag should be `single`, is ", *singleQuoteFlag)
|
||||||
|
}
|
||||||
|
if *doubleQuoteFlag != "double" {
|
||||||
|
t.Error("double quote string flag should be `double`, is ", *doubleQuoteFlag)
|
||||||
|
}
|
||||||
|
if *mixedQuoteFlag != `'mixed"` {
|
||||||
|
t.Error("mixed quote string flag should be `'mixed\"`, is ", *mixedQuoteFlag)
|
||||||
|
}
|
||||||
|
if *mixed2QuoteFlag != `"mixed2'` {
|
||||||
|
t.Error("mixed2 quote string flag should be `\"mixed2'`, is ", *mixed2QuoteFlag)
|
||||||
|
}
|
||||||
|
if *nestedQuoteFlag != "'single nested'" {
|
||||||
|
t.Error("nested quote string flag should be `'single nested'`, is ", *nestedQuoteFlag)
|
||||||
|
}
|
||||||
|
if *nested2QuoteFlag != `"double nested"` {
|
||||||
|
t.Error("double quote string flag should be `\"double nested\"`, is ", *nested2QuoteFlag)
|
||||||
|
}
|
||||||
|
if *float64Flag != 2718e28 {
|
||||||
|
t.Error("float64 flag should be 2718e28, is ", *float64Flag)
|
||||||
|
}
|
||||||
|
if *durationFlag != 2*time.Minute {
|
||||||
|
t.Error("duration flag should be 2m, is ", *durationFlag)
|
||||||
|
}
|
||||||
|
if len(f.Args()) != 1 {
|
||||||
|
t.Error("expected one argument, got", len(f.Args()))
|
||||||
|
} else if f.Args()[0] != extra {
|
||||||
|
t.Errorf("expected argument %q got %q", extra, f.Args()[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testPanic(f *FlagSet, t *testing.T) {
|
||||||
|
f.Int([]string{"-int"}, 0, "int value")
|
||||||
|
if f.Parsed() {
|
||||||
|
t.Error("f.Parse() = true before Parse")
|
||||||
|
}
|
||||||
|
args := []string{
|
||||||
|
"-int", "21",
|
||||||
|
}
|
||||||
|
f.Parse(args)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParsePanic(t *testing.T) {
|
||||||
|
ResetForTesting(func() {})
|
||||||
|
testPanic(CommandLine, t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParse(t *testing.T) {
|
||||||
|
ResetForTesting(func() { t.Error("bad parse") })
|
||||||
|
testParse(CommandLine, t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFlagSetParse(t *testing.T) {
|
||||||
|
testParse(NewFlagSet("test", ContinueOnError), t)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Declare a user-defined flag type.
|
||||||
|
type flagVar []string
|
||||||
|
|
||||||
|
func (f *flagVar) String() string {
|
||||||
|
return fmt.Sprint([]string(*f))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *flagVar) Set(value string) error {
|
||||||
|
*f = append(*f, value)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUserDefined(t *testing.T) {
|
||||||
|
var flags FlagSet
|
||||||
|
flags.Init("test", ContinueOnError)
|
||||||
|
var v flagVar
|
||||||
|
flags.Var(&v, []string{"v"}, "usage")
|
||||||
|
if err := flags.Parse([]string{"-v", "1", "-v", "2", "-v=3"}); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
if len(v) != 3 {
|
||||||
|
t.Fatal("expected 3 args; got ", len(v))
|
||||||
|
}
|
||||||
|
expect := "[1 2 3]"
|
||||||
|
if v.String() != expect {
|
||||||
|
t.Errorf("expected value %q got %q", expect, v.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Declare a user-defined boolean flag type.
|
||||||
|
type boolFlagVar struct {
|
||||||
|
count int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *boolFlagVar) String() string {
|
||||||
|
return fmt.Sprintf("%d", b.count)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *boolFlagVar) Set(value string) error {
|
||||||
|
if value == "true" {
|
||||||
|
b.count++
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *boolFlagVar) IsBoolFlag() bool {
|
||||||
|
return b.count < 4
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUserDefinedBool(t *testing.T) {
|
||||||
|
var flags FlagSet
|
||||||
|
flags.Init("test", ContinueOnError)
|
||||||
|
var b boolFlagVar
|
||||||
|
var err error
|
||||||
|
flags.Var(&b, []string{"b"}, "usage")
|
||||||
|
if err = flags.Parse([]string{"-b", "-b", "-b", "-b=true", "-b=false", "-b", "barg", "-b"}); err != nil {
|
||||||
|
if b.count < 4 {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if b.count != 4 {
|
||||||
|
t.Errorf("want: %d; got: %d", 4, b.count)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error; got none")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetOutput(t *testing.T) {
|
||||||
|
var flags FlagSet
|
||||||
|
var buf bytes.Buffer
|
||||||
|
flags.SetOutput(&buf)
|
||||||
|
flags.Init("test", ContinueOnError)
|
||||||
|
flags.Parse([]string{"-unknown"})
|
||||||
|
if out := buf.String(); !strings.Contains(out, "-unknown") {
|
||||||
|
t.Logf("expected output mentioning unknown; got %q", out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// This tests that one can reset the flags. This still works but not well, and is
|
||||||
|
// superseded by FlagSet.
|
||||||
|
func TestChangingArgs(t *testing.T) {
|
||||||
|
ResetForTesting(func() { t.Fatal("bad parse") })
|
||||||
|
oldArgs := os.Args
|
||||||
|
defer func() { os.Args = oldArgs }()
|
||||||
|
os.Args = []string{"cmd", "-before", "subcmd", "-after", "args"}
|
||||||
|
before := Bool([]string{"before"}, false, "")
|
||||||
|
if err := CommandLine.Parse(os.Args[1:]); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
cmd := Arg(0)
|
||||||
|
os.Args = Args()
|
||||||
|
after := Bool([]string{"after"}, false, "")
|
||||||
|
Parse()
|
||||||
|
args := Args()
|
||||||
|
|
||||||
|
if !*before || cmd != "subcmd" || !*after || len(args) != 1 || args[0] != "args" {
|
||||||
|
t.Fatalf("expected true subcmd true [args] got %v %v %v %v", *before, cmd, *after, args)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that -help invokes the usage message and returns ErrHelp.
|
||||||
|
func TestHelp(t *testing.T) {
|
||||||
|
var helpCalled = false
|
||||||
|
fs := NewFlagSet("help test", ContinueOnError)
|
||||||
|
fs.Usage = func() { helpCalled = true }
|
||||||
|
var flag bool
|
||||||
|
fs.BoolVar(&flag, []string{"flag"}, false, "regular flag")
|
||||||
|
// Regular flag invocation should work
|
||||||
|
err := fs.Parse([]string{"-flag=true"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("expected no error; got ", err)
|
||||||
|
}
|
||||||
|
if !flag {
|
||||||
|
t.Error("flag was not set by -flag")
|
||||||
|
}
|
||||||
|
if helpCalled {
|
||||||
|
t.Error("help called for regular flag")
|
||||||
|
helpCalled = false // reset for next test
|
||||||
|
}
|
||||||
|
// Help flag should work as expected.
|
||||||
|
err = fs.Parse([]string{"-help"})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("error expected")
|
||||||
|
}
|
||||||
|
if err != ErrHelp {
|
||||||
|
t.Fatal("expected ErrHelp; got ", err)
|
||||||
|
}
|
||||||
|
if !helpCalled {
|
||||||
|
t.Fatal("help was not called")
|
||||||
|
}
|
||||||
|
// If we define a help flag, that should override.
|
||||||
|
var help bool
|
||||||
|
fs.BoolVar(&help, []string{"help"}, false, "help flag")
|
||||||
|
helpCalled = false
|
||||||
|
err = fs.Parse([]string{"-help"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("expected no error for defined -help; got ", err)
|
||||||
|
}
|
||||||
|
if helpCalled {
|
||||||
|
t.Fatal("help was called; should not have been for defined help flag")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test the flag count functions.
|
||||||
|
func TestFlagCounts(t *testing.T) {
|
||||||
|
fs := NewFlagSet("help test", ContinueOnError)
|
||||||
|
var flag bool
|
||||||
|
fs.BoolVar(&flag, []string{"flag1"}, false, "regular flag")
|
||||||
|
fs.BoolVar(&flag, []string{"#deprecated1"}, false, "regular flag")
|
||||||
|
fs.BoolVar(&flag, []string{"f", "flag2"}, false, "regular flag")
|
||||||
|
fs.BoolVar(&flag, []string{"#d", "#deprecated2"}, false, "regular flag")
|
||||||
|
fs.BoolVar(&flag, []string{"flag3"}, false, "regular flag")
|
||||||
|
fs.BoolVar(&flag, []string{"g", "#flag4", "-flag4"}, false, "regular flag")
|
||||||
|
|
||||||
|
if fs.FlagCount() != 6 {
|
||||||
|
t.Fatal("FlagCount wrong. ", fs.FlagCount())
|
||||||
|
}
|
||||||
|
if fs.FlagCountUndeprecated() != 4 {
|
||||||
|
t.Fatal("FlagCountUndeprecated wrong. ", fs.FlagCountUndeprecated())
|
||||||
|
}
|
||||||
|
if fs.NFlag() != 0 {
|
||||||
|
t.Fatal("NFlag wrong. ", fs.NFlag())
|
||||||
|
}
|
||||||
|
err := fs.Parse([]string{"-fd", "-g", "-flag4"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("expected no error for defined -help; got ", err)
|
||||||
|
}
|
||||||
|
if fs.NFlag() != 4 {
|
||||||
|
t.Fatal("NFlag wrong. ", fs.NFlag())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Show up bug in sortFlags
|
||||||
|
func TestSortFlags(t *testing.T) {
|
||||||
|
fs := NewFlagSet("help TestSortFlags", ContinueOnError)
|
||||||
|
|
||||||
|
var err error
|
||||||
|
|
||||||
|
var b bool
|
||||||
|
fs.BoolVar(&b, []string{"b", "-banana"}, false, "usage")
|
||||||
|
|
||||||
|
err = fs.Parse([]string{"--banana=true"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("expected no error; got ", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
count := 0
|
||||||
|
|
||||||
|
fs.VisitAll(func(flag *Flag) {
|
||||||
|
count++
|
||||||
|
if flag == nil {
|
||||||
|
t.Fatal("VisitAll should not return a nil flag")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
flagcount := fs.FlagCount()
|
||||||
|
if flagcount != count {
|
||||||
|
t.Fatalf("FlagCount (%d) != number (%d) of elements visited", flagcount, count)
|
||||||
|
}
|
||||||
|
// Make sure its idempotent
|
||||||
|
if flagcount != fs.FlagCount() {
|
||||||
|
t.Fatalf("FlagCount (%d) != fs.FlagCount() (%d) of elements visited", flagcount, fs.FlagCount())
|
||||||
|
}
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
fs.Visit(func(flag *Flag) {
|
||||||
|
count++
|
||||||
|
if flag == nil {
|
||||||
|
t.Fatal("Visit should not return a nil flag")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
nflag := fs.NFlag()
|
||||||
|
if nflag != count {
|
||||||
|
t.Fatalf("NFlag (%d) != number (%d) of elements visited", nflag, count)
|
||||||
|
}
|
||||||
|
if nflag != fs.NFlag() {
|
||||||
|
t.Fatalf("NFlag (%d) != fs.NFlag() (%d) of elements visited", nflag, fs.NFlag())
|
||||||
|
}
|
||||||
|
}
|
|
@ -104,3 +104,28 @@ func ParseKeyValueOpt(opt string) (string, string, error) {
|
||||||
}
|
}
|
||||||
return strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]), nil
|
return strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ParsePortRange(ports string) (uint64, uint64, error) {
|
||||||
|
if ports == "" {
|
||||||
|
return 0, 0, fmt.Errorf("Empty string specified for ports.")
|
||||||
|
}
|
||||||
|
if !strings.Contains(ports, "-") {
|
||||||
|
start, err := strconv.ParseUint(ports, 10, 16)
|
||||||
|
end := start
|
||||||
|
return start, end, err
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := strings.Split(ports, "-")
|
||||||
|
start, err := strconv.ParseUint(parts[0], 10, 16)
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, err
|
||||||
|
}
|
||||||
|
end, err := strconv.ParseUint(parts[1], 10, 16)
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, err
|
||||||
|
}
|
||||||
|
if end < start {
|
||||||
|
return 0, 0, fmt.Errorf("Invalid range specified for the Port: %s", ports)
|
||||||
|
}
|
||||||
|
return start, end, nil
|
||||||
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package parsers
|
package parsers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -81,3 +82,35 @@ func TestParsePortMapping(t *testing.T) {
|
||||||
t.Fail()
|
t.Fail()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestParsePortRange(t *testing.T) {
|
||||||
|
if start, end, err := ParsePortRange("8000-8080"); err != nil || start != 8000 || end != 8080 {
|
||||||
|
t.Fatalf("Error: %s or Expecting {start,end} values {8000,8080} but found {%d,%d}.", err, start, end)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParsePortRangeIncorrectRange(t *testing.T) {
|
||||||
|
if _, _, err := ParsePortRange("9000-8080"); err == nil || !strings.Contains(err.Error(), "Invalid range specified for the Port") {
|
||||||
|
t.Fatalf("Expecting error 'Invalid range specified for the Port' but received %s.", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParsePortRangeIncorrectEndRange(t *testing.T) {
|
||||||
|
if _, _, err := ParsePortRange("8000-a"); err == nil || !strings.Contains(err.Error(), "invalid syntax") {
|
||||||
|
t.Fatalf("Expecting error 'Invalid range specified for the Port' but received %s.", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, _, err := ParsePortRange("8000-30a"); err == nil || !strings.Contains(err.Error(), "invalid syntax") {
|
||||||
|
t.Fatalf("Expecting error 'Invalid range specified for the Port' but received %s.", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParsePortRangeIncorrectStartRange(t *testing.T) {
|
||||||
|
if _, _, err := ParsePortRange("a-8000"); err == nil || !strings.Contains(err.Error(), "invalid syntax") {
|
||||||
|
t.Fatalf("Expecting error 'Invalid range specified for the Port' but received %s.", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, _, err := ParsePortRange("30a-8000"); err == nil || !strings.Contains(err.Error(), "invalid syntax") {
|
||||||
|
t.Fatalf("Expecting error 'Invalid range specified for the Port' but received %s.", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -39,7 +39,7 @@ var binaryAbbrs = []string{"B", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB", "ZiB",
|
||||||
|
|
||||||
// HumanSize returns a human-readable approximation of a size
|
// HumanSize returns a human-readable approximation of a size
|
||||||
// using SI standard (eg. "44kB", "17MB")
|
// using SI standard (eg. "44kB", "17MB")
|
||||||
func HumanSize(size int64) string {
|
func HumanSize(size float64) string {
|
||||||
return intToString(float64(size), 1000.0, decimapAbbrs)
|
return intToString(float64(size), 1000.0, decimapAbbrs)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -23,9 +23,9 @@ func TestHumanSize(t *testing.T) {
|
||||||
assertEquals(t, "1 MB", HumanSize(1000000))
|
assertEquals(t, "1 MB", HumanSize(1000000))
|
||||||
assertEquals(t, "1.049 MB", HumanSize(1048576))
|
assertEquals(t, "1.049 MB", HumanSize(1048576))
|
||||||
assertEquals(t, "2 MB", HumanSize(2*MB))
|
assertEquals(t, "2 MB", HumanSize(2*MB))
|
||||||
assertEquals(t, "3.42 GB", HumanSize(3.42*GB))
|
assertEquals(t, "3.42 GB", HumanSize(float64(3.42*GB)))
|
||||||
assertEquals(t, "5.372 TB", HumanSize(5.372*TB))
|
assertEquals(t, "5.372 TB", HumanSize(float64(5.372*TB)))
|
||||||
assertEquals(t, "2.22 PB", HumanSize(2.22*PB))
|
assertEquals(t, "2.22 PB", HumanSize(float64(2.22*PB)))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFromHumanSize(t *testing.T) {
|
func TestFromHumanSize(t *testing.T) {
|
||||||
|
|
|
@ -0,0 +1,45 @@
|
||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
flag "github.com/docker/docker/pkg/mflag"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ParseFlags is a utility function that adds a help flag if withHelp is true,
|
||||||
|
// calls cmd.Parse(args) and prints a relevant error message if there are
|
||||||
|
// incorrect number of arguments. It returns error only if error handling is
|
||||||
|
// set to ContinueOnError and parsing fails. If error handling is set to
|
||||||
|
// ExitOnError, it's safe to ignore the return value.
|
||||||
|
// TODO: move this to a better package than utils
|
||||||
|
func ParseFlags(cmd *flag.FlagSet, args []string, withHelp bool) error {
|
||||||
|
var help *bool
|
||||||
|
if withHelp {
|
||||||
|
help = cmd.Bool([]string{"#help", "-help"}, false, "Print usage")
|
||||||
|
}
|
||||||
|
if err := cmd.Parse(args); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if help != nil && *help {
|
||||||
|
cmd.Usage()
|
||||||
|
// just in case Usage does not exit
|
||||||
|
os.Exit(0)
|
||||||
|
}
|
||||||
|
if str := cmd.CheckArgs(); str != "" {
|
||||||
|
ReportError(cmd, str, withHelp)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ReportError(cmd *flag.FlagSet, str string, withHelp bool) {
|
||||||
|
if withHelp {
|
||||||
|
if os.Args[0] == cmd.Name() {
|
||||||
|
str += ". See '" + os.Args[0] + " --help'"
|
||||||
|
} else {
|
||||||
|
str += ". See '" + os.Args[0] + " " + cmd.Name() + " --help'"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fmt.Fprintf(cmd.Out(), "docker: %s.\n", str)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
|
@ -134,6 +134,10 @@ func (self *HTTPRequestFactory) AddDecorator(d ...HTTPRequestDecorator) {
|
||||||
self.decorators = append(self.decorators, d...)
|
self.decorators = append(self.decorators, d...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (self *HTTPRequestFactory) GetDecorators() []HTTPRequestDecorator {
|
||||||
|
return self.decorators
|
||||||
|
}
|
||||||
|
|
||||||
// NewRequest() creates a new *http.Request,
|
// NewRequest() creates a new *http.Request,
|
||||||
// applies all decorators in the HTTPRequestFactory on the request,
|
// applies all decorators in the HTTPRequestFactory on the request,
|
||||||
// then applies decorators provided by d on the request.
|
// then applies decorators provided by d on the request.
|
||||||
|
|
|
@ -44,12 +44,15 @@ func (p *JSONProgress) String() string {
|
||||||
if p.Current <= 0 && p.Total <= 0 {
|
if p.Current <= 0 && p.Total <= 0 {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
current := units.HumanSize(int64(p.Current))
|
current := units.HumanSize(float64(p.Current))
|
||||||
if p.Total <= 0 {
|
if p.Total <= 0 {
|
||||||
return fmt.Sprintf("%8v", current)
|
return fmt.Sprintf("%8v", current)
|
||||||
}
|
}
|
||||||
total := units.HumanSize(int64(p.Total))
|
total := units.HumanSize(float64(p.Total))
|
||||||
percentage := int(float64(p.Current)/float64(p.Total)*100) / 2
|
percentage := int(float64(p.Current)/float64(p.Total)*100) / 2
|
||||||
|
if percentage > 50 {
|
||||||
|
percentage = 50
|
||||||
|
}
|
||||||
if width > 110 {
|
if width > 110 {
|
||||||
// this number can't be negetive gh#7136
|
// this number can't be negetive gh#7136
|
||||||
numSpaces := 0
|
numSpaces := 0
|
||||||
|
|
|
@ -30,7 +30,7 @@ func TestProgress(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// this number can't be negetive gh#7136
|
// this number can't be negetive gh#7136
|
||||||
expected = "[==============================================================>] 50 B/40 B"
|
expected = "[==================================================>] 50 B/40 B"
|
||||||
jp4 := JSONProgress{Current: 50, Total: 40}
|
jp4 := JSONProgress{Current: 50, Total: 40}
|
||||||
if jp4.String() != expected {
|
if jp4.String() != expected {
|
||||||
t.Fatalf("Expected %q, got %q", expected, jp4.String())
|
t.Fatalf("Expected %q, got %q", expected, jp4.String())
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package utils
|
package utils
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/sha1"
|
"crypto/sha1"
|
||||||
|
@ -93,7 +94,7 @@ func isValidDockerInitPath(target string, selfPath string) bool { // target and
|
||||||
if target == "" {
|
if target == "" {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if dockerversion.IAMSTATIC {
|
if dockerversion.IAMSTATIC == "true" {
|
||||||
if selfPath == "" {
|
if selfPath == "" {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
@ -290,14 +291,6 @@ func NewHTTPRequestError(msg string, res *http.Response) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var localHostRx = regexp.MustCompile(`(?m)^nameserver 127[^\n]+\n*`)
|
|
||||||
|
|
||||||
// RemoveLocalDns looks into the /etc/resolv.conf,
|
|
||||||
// and removes any local nameserver entries.
|
|
||||||
func RemoveLocalDns(resolvConf []byte) []byte {
|
|
||||||
return localHostRx.ReplaceAll(resolvConf, []byte{})
|
|
||||||
}
|
|
||||||
|
|
||||||
// An StatusError reports an unsuccessful exit by a command.
|
// An StatusError reports an unsuccessful exit by a command.
|
||||||
type StatusError struct {
|
type StatusError struct {
|
||||||
Status string
|
Status string
|
||||||
|
@ -408,7 +401,17 @@ func ReplaceOrAppendEnvValues(defaults, overrides []string) []string {
|
||||||
parts := strings.SplitN(e, "=", 2)
|
parts := strings.SplitN(e, "=", 2)
|
||||||
cache[parts[0]] = i
|
cache[parts[0]] = i
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, value := range overrides {
|
for _, value := range overrides {
|
||||||
|
// Values w/o = means they want this env to be removed/unset.
|
||||||
|
if !strings.Contains(value, "=") {
|
||||||
|
if i, exists := cache[value]; exists {
|
||||||
|
defaults[i] = "" // Used to indicate it should be removed
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Just do a normal set/update
|
||||||
parts := strings.SplitN(value, "=", 2)
|
parts := strings.SplitN(value, "=", 2)
|
||||||
if i, exists := cache[parts[0]]; exists {
|
if i, exists := cache[parts[0]]; exists {
|
||||||
defaults[i] = value
|
defaults[i] = value
|
||||||
|
@ -416,9 +419,28 @@ func ReplaceOrAppendEnvValues(defaults, overrides []string) []string {
|
||||||
defaults = append(defaults, value)
|
defaults = append(defaults, value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Now remove all entries that we want to "unset"
|
||||||
|
for i := 0; i < len(defaults); i++ {
|
||||||
|
if defaults[i] == "" {
|
||||||
|
defaults = append(defaults[:i], defaults[i+1:]...)
|
||||||
|
i--
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return defaults
|
return defaults
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func DoesEnvExist(name string) bool {
|
||||||
|
for _, entry := range os.Environ() {
|
||||||
|
parts := strings.SplitN(entry, "=", 2)
|
||||||
|
if parts[0] == name {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// ReadSymlinkedDirectory returns the target directory of a symlink.
|
// ReadSymlinkedDirectory returns the target directory of a symlink.
|
||||||
// The target of the symbolic link may not be a file.
|
// The target of the symbolic link may not be a file.
|
||||||
func ReadSymlinkedDirectory(path string) (string, error) {
|
func ReadSymlinkedDirectory(path string) (string, error) {
|
||||||
|
@ -492,3 +514,34 @@ func StringsContainsNoCase(slice []string, s string) bool {
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Reads a .dockerignore file and returns the list of file patterns
|
||||||
|
// to ignore. Note this will trim whitespace from each line as well
|
||||||
|
// as use GO's "clean" func to get the shortest/cleanest path for each.
|
||||||
|
func ReadDockerIgnore(path string) ([]string, error) {
|
||||||
|
// Note that a missing .dockerignore file isn't treated as an error
|
||||||
|
reader, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
if !os.IsNotExist(err) {
|
||||||
|
return nil, fmt.Errorf("Error reading '%s': %v", path, err)
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
defer reader.Close()
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(reader)
|
||||||
|
var excludes []string
|
||||||
|
|
||||||
|
for scanner.Scan() {
|
||||||
|
pattern := strings.TrimSpace(scanner.Text())
|
||||||
|
if pattern == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
pattern = filepath.Clean(pattern)
|
||||||
|
excludes = append(excludes, pattern)
|
||||||
|
}
|
||||||
|
if err = scanner.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("Error reading '%s': %v", path, err)
|
||||||
|
}
|
||||||
|
return excludes, nil
|
||||||
|
}
|
||||||
|
|
|
@ -37,3 +37,13 @@ func TreeSize(dir string) (size int64, err error) {
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsFileOwner checks whether the current user is the owner of the given file.
|
||||||
|
func IsFileOwner(f string) bool {
|
||||||
|
if fileInfo, err := os.Stat(f); err == nil && fileInfo != nil {
|
||||||
|
if stat, ok := fileInfo.Sys().(*syscall.Stat_t); ok && int(stat.Uid) == os.Getuid() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"sort"
|
||||||
"time"
|
"time"
|
||||||
"unicode"
|
"unicode"
|
||||||
)
|
)
|
||||||
|
@ -31,11 +32,27 @@ type jsHeader struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type jsSignature struct {
|
type jsSignature struct {
|
||||||
Header *jsHeader `json:"header"`
|
Header jsHeader `json:"header"`
|
||||||
Signature string `json:"signature"`
|
Signature string `json:"signature"`
|
||||||
Protected string `json:"protected,omitempty"`
|
Protected string `json:"protected,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type jsSignaturesSorted []jsSignature
|
||||||
|
|
||||||
|
func (jsbkid jsSignaturesSorted) Swap(i, j int) { jsbkid[i], jsbkid[j] = jsbkid[j], jsbkid[i] }
|
||||||
|
func (jsbkid jsSignaturesSorted) Len() int { return len(jsbkid) }
|
||||||
|
|
||||||
|
func (jsbkid jsSignaturesSorted) Less(i, j int) bool {
|
||||||
|
ki, kj := jsbkid[i].Header.JWK.KeyID(), jsbkid[j].Header.JWK.KeyID()
|
||||||
|
si, sj := jsbkid[i].Signature, jsbkid[j].Signature
|
||||||
|
|
||||||
|
if ki == kj {
|
||||||
|
return si < sj
|
||||||
|
}
|
||||||
|
|
||||||
|
return ki < kj
|
||||||
|
}
|
||||||
|
|
||||||
type signKey struct {
|
type signKey struct {
|
||||||
PrivateKey
|
PrivateKey
|
||||||
Chain []*x509.Certificate
|
Chain []*x509.Certificate
|
||||||
|
@ -44,7 +61,7 @@ type signKey struct {
|
||||||
// JSONSignature represents a signature of a json object.
|
// JSONSignature represents a signature of a json object.
|
||||||
type JSONSignature struct {
|
type JSONSignature struct {
|
||||||
payload string
|
payload string
|
||||||
signatures []*jsSignature
|
signatures []jsSignature
|
||||||
indent string
|
indent string
|
||||||
formatLength int
|
formatLength int
|
||||||
formatTail []byte
|
formatTail []byte
|
||||||
|
@ -52,7 +69,7 @@ type JSONSignature struct {
|
||||||
|
|
||||||
func newJSONSignature() *JSONSignature {
|
func newJSONSignature() *JSONSignature {
|
||||||
return &JSONSignature{
|
return &JSONSignature{
|
||||||
signatures: make([]*jsSignature, 0, 1),
|
signatures: make([]jsSignature, 0, 1),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -99,17 +116,14 @@ func (js *JSONSignature) Sign(key PrivateKey) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
header := &jsHeader{
|
js.signatures = append(js.signatures, jsSignature{
|
||||||
|
Header: jsHeader{
|
||||||
JWK: key.PublicKey(),
|
JWK: key.PublicKey(),
|
||||||
Algorithm: algorithm,
|
Algorithm: algorithm,
|
||||||
}
|
},
|
||||||
sig := &jsSignature{
|
|
||||||
Header: header,
|
|
||||||
Signature: joseBase64UrlEncode(sigBytes),
|
Signature: joseBase64UrlEncode(sigBytes),
|
||||||
Protected: protected,
|
Protected: protected,
|
||||||
}
|
})
|
||||||
|
|
||||||
js.signatures = append(js.signatures, sig)
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -136,7 +150,7 @@ func (js *JSONSignature) SignWithChain(key PrivateKey, chain []*x509.Certificate
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
header := &jsHeader{
|
header := jsHeader{
|
||||||
Chain: make([]string, len(chain)),
|
Chain: make([]string, len(chain)),
|
||||||
Algorithm: algorithm,
|
Algorithm: algorithm,
|
||||||
}
|
}
|
||||||
|
@ -145,13 +159,11 @@ func (js *JSONSignature) SignWithChain(key PrivateKey, chain []*x509.Certificate
|
||||||
header.Chain[i] = base64.StdEncoding.EncodeToString(cert.Raw)
|
header.Chain[i] = base64.StdEncoding.EncodeToString(cert.Raw)
|
||||||
}
|
}
|
||||||
|
|
||||||
sig := &jsSignature{
|
js.signatures = append(js.signatures, jsSignature{
|
||||||
Header: header,
|
Header: header,
|
||||||
Signature: joseBase64UrlEncode(sigBytes),
|
Signature: joseBase64UrlEncode(sigBytes),
|
||||||
Protected: protected,
|
Protected: protected,
|
||||||
}
|
})
|
||||||
|
|
||||||
js.signatures = append(js.signatures, sig)
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -272,6 +284,9 @@ func (js *JSONSignature) JWS() ([]byte, error) {
|
||||||
if len(js.signatures) == 0 {
|
if len(js.signatures) == 0 {
|
||||||
return nil, errors.New("missing signature")
|
return nil, errors.New("missing signature")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sort.Sort(jsSignaturesSorted(js.signatures))
|
||||||
|
|
||||||
jsonMap := map[string]interface{}{
|
jsonMap := map[string]interface{}{
|
||||||
"payload": js.payload,
|
"payload": js.payload,
|
||||||
"signatures": js.signatures,
|
"signatures": js.signatures,
|
||||||
|
@ -301,7 +316,7 @@ type jsParsedHeader struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type jsParsedSignature struct {
|
type jsParsedSignature struct {
|
||||||
Header *jsParsedHeader `json:"header"`
|
Header jsParsedHeader `json:"header"`
|
||||||
Signature string `json:"signature"`
|
Signature string `json:"signature"`
|
||||||
Protected string `json:"protected"`
|
Protected string `json:"protected"`
|
||||||
}
|
}
|
||||||
|
@ -310,7 +325,7 @@ type jsParsedSignature struct {
|
||||||
func ParseJWS(content []byte) (*JSONSignature, error) {
|
func ParseJWS(content []byte) (*JSONSignature, error) {
|
||||||
type jsParsed struct {
|
type jsParsed struct {
|
||||||
Payload string `json:"payload"`
|
Payload string `json:"payload"`
|
||||||
Signatures []*jsParsedSignature `json:"signatures"`
|
Signatures []jsParsedSignature `json:"signatures"`
|
||||||
}
|
}
|
||||||
parsed := &jsParsed{}
|
parsed := &jsParsed{}
|
||||||
err := json.Unmarshal(content, parsed)
|
err := json.Unmarshal(content, parsed)
|
||||||
|
@ -329,9 +344,9 @@ func ParseJWS(content []byte) (*JSONSignature, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
js.signatures = make([]*jsSignature, len(parsed.Signatures))
|
js.signatures = make([]jsSignature, len(parsed.Signatures))
|
||||||
for i, signature := range parsed.Signatures {
|
for i, signature := range parsed.Signatures {
|
||||||
header := &jsHeader{
|
header := jsHeader{
|
||||||
Algorithm: signature.Header.Algorithm,
|
Algorithm: signature.Header.Algorithm,
|
||||||
}
|
}
|
||||||
if signature.Header.Chain != nil {
|
if signature.Header.Chain != nil {
|
||||||
|
@ -344,7 +359,7 @@ func ParseJWS(content []byte) (*JSONSignature, error) {
|
||||||
}
|
}
|
||||||
header.JWK = publicKey
|
header.JWK = publicKey
|
||||||
}
|
}
|
||||||
js.signatures[i] = &jsSignature{
|
js.signatures[i] = jsSignature{
|
||||||
Header: header,
|
Header: header,
|
||||||
Signature: signature.Signature,
|
Signature: signature.Signature,
|
||||||
Protected: signature.Protected,
|
Protected: signature.Protected,
|
||||||
|
@ -356,7 +371,11 @@ func ParseJWS(content []byte) (*JSONSignature, error) {
|
||||||
|
|
||||||
// NewJSONSignature returns a new unsigned JWS from a json byte array.
|
// NewJSONSignature returns a new unsigned JWS from a json byte array.
|
||||||
// JSONSignature will need to be signed before serializing or storing.
|
// JSONSignature will need to be signed before serializing or storing.
|
||||||
func NewJSONSignature(content []byte) (*JSONSignature, error) {
|
// Optionally, one or more signatures can be provided as byte buffers,
|
||||||
|
// containing serialized JWS signatures, to assemble a fully signed JWS
|
||||||
|
// package. It is the callers responsibility to ensure uniqueness of the
|
||||||
|
// provided signatures.
|
||||||
|
func NewJSONSignature(content []byte, signatures ...[]byte) (*JSONSignature, error) {
|
||||||
var dataMap map[string]interface{}
|
var dataMap map[string]interface{}
|
||||||
err := json.Unmarshal(content, &dataMap)
|
err := json.Unmarshal(content, &dataMap)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -380,6 +399,40 @@ func NewJSONSignature(content []byte) (*JSONSignature, error) {
|
||||||
js.formatLength = lastRuneIndex + 1
|
js.formatLength = lastRuneIndex + 1
|
||||||
js.formatTail = content[js.formatLength:]
|
js.formatTail = content[js.formatLength:]
|
||||||
|
|
||||||
|
if len(signatures) > 0 {
|
||||||
|
for _, signature := range signatures {
|
||||||
|
var parsedJSig jsParsedSignature
|
||||||
|
|
||||||
|
if err := json.Unmarshal(signature, &parsedJSig); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(stevvooe): A lot of the code below is repeated in
|
||||||
|
// ParseJWS. It will require more refactoring to fix that.
|
||||||
|
jsig := jsSignature{
|
||||||
|
Header: jsHeader{
|
||||||
|
Algorithm: parsedJSig.Header.Algorithm,
|
||||||
|
},
|
||||||
|
Signature: parsedJSig.Signature,
|
||||||
|
Protected: parsedJSig.Protected,
|
||||||
|
}
|
||||||
|
|
||||||
|
if parsedJSig.Header.Chain != nil {
|
||||||
|
jsig.Header.Chain = parsedJSig.Header.Chain
|
||||||
|
}
|
||||||
|
|
||||||
|
if parsedJSig.Header.JWK != nil {
|
||||||
|
publicKey, err := UnmarshalPublicKeyJWK([]byte(parsedJSig.Header.JWK))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
jsig.Header.JWK = publicKey
|
||||||
|
}
|
||||||
|
|
||||||
|
js.signatures = append(js.signatures, jsig)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return js, nil
|
return js, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -455,7 +508,7 @@ func ParsePrettySignature(content []byte, signatureKey string) (*JSONSignature,
|
||||||
}
|
}
|
||||||
|
|
||||||
js := newJSONSignature()
|
js := newJSONSignature()
|
||||||
js.signatures = make([]*jsSignature, len(signatureBlocks))
|
js.signatures = make([]jsSignature, len(signatureBlocks))
|
||||||
|
|
||||||
for i, signatureBlock := range signatureBlocks {
|
for i, signatureBlock := range signatureBlocks {
|
||||||
protectedBytes, err := joseBase64UrlDecode(signatureBlock.Protected)
|
protectedBytes, err := joseBase64UrlDecode(signatureBlock.Protected)
|
||||||
|
@ -491,7 +544,7 @@ func ParsePrettySignature(content []byte, signatureKey string) (*JSONSignature,
|
||||||
return nil, errors.New("conflicting format tail")
|
return nil, errors.New("conflicting format tail")
|
||||||
}
|
}
|
||||||
|
|
||||||
header := &jsHeader{
|
header := jsHeader{
|
||||||
Algorithm: signatureBlock.Header.Algorithm,
|
Algorithm: signatureBlock.Header.Algorithm,
|
||||||
Chain: signatureBlock.Header.Chain,
|
Chain: signatureBlock.Header.Chain,
|
||||||
}
|
}
|
||||||
|
@ -502,7 +555,7 @@ func ParsePrettySignature(content []byte, signatureKey string) (*JSONSignature,
|
||||||
}
|
}
|
||||||
header.JWK = publicKey
|
header.JWK = publicKey
|
||||||
}
|
}
|
||||||
js.signatures[i] = &jsSignature{
|
js.signatures[i] = jsSignature{
|
||||||
Header: header,
|
Header: header,
|
||||||
Signature: signatureBlock.Signature,
|
Signature: signatureBlock.Signature,
|
||||||
Protected: signatureBlock.Protected,
|
Protected: signatureBlock.Protected,
|
||||||
|
@ -532,6 +585,8 @@ func (js *JSONSignature) PrettySignature(signatureKey string) ([]byte, error) {
|
||||||
}
|
}
|
||||||
payload = payload[:js.formatLength]
|
payload = payload[:js.formatLength]
|
||||||
|
|
||||||
|
sort.Sort(jsSignaturesSorted(js.signatures))
|
||||||
|
|
||||||
var marshalled []byte
|
var marshalled []byte
|
||||||
var marshallErr error
|
var marshallErr error
|
||||||
if js.indent != "" {
|
if js.indent != "" {
|
||||||
|
@ -564,3 +619,39 @@ func (js *JSONSignature) PrettySignature(signatureKey string) ([]byte, error) {
|
||||||
|
|
||||||
return buf.Bytes(), nil
|
return buf.Bytes(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Signatures provides the signatures on this JWS as opaque blobs, sorted by
|
||||||
|
// keyID. These blobs can be stored and reassembled with payloads. Internally,
|
||||||
|
// they are simply marshaled json web signatures but implementations should
|
||||||
|
// not rely on this.
|
||||||
|
func (js *JSONSignature) Signatures() ([][]byte, error) {
|
||||||
|
sort.Sort(jsSignaturesSorted(js.signatures))
|
||||||
|
|
||||||
|
var sb [][]byte
|
||||||
|
for _, jsig := range js.signatures {
|
||||||
|
p, err := json.Marshal(jsig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
sb = append(sb, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
return sb, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge combines the signatures from one or more other signatures into the
|
||||||
|
// method receiver. If the payloads differ for any argument, an error will be
|
||||||
|
// returned and the receiver will not be modified.
|
||||||
|
func (js *JSONSignature) Merge(others ...*JSONSignature) error {
|
||||||
|
merged := js.signatures
|
||||||
|
for _, other := range others {
|
||||||
|
if js.payload != other.payload {
|
||||||
|
return fmt.Errorf("payloads differ from merge target")
|
||||||
|
}
|
||||||
|
merged = append(merged, other.signatures...)
|
||||||
|
}
|
||||||
|
|
||||||
|
js.signatures = merged
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
@ -2,9 +2,11 @@ package libtrust
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"crypto/rand"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/docker/libtrust/testutil"
|
"github.com/docker/libtrust/testutil"
|
||||||
|
@ -295,3 +297,84 @@ func TestInvalidChain(t *testing.T) {
|
||||||
t.Fatalf("Unexpected chains returned from invalid verify")
|
t.Fatalf("Unexpected chains returned from invalid verify")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMergeSignatures(t *testing.T) {
|
||||||
|
pk1, err := GenerateECP256PrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error generating private key 1: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pk2, err := GenerateECP256PrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error generating private key 2: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := make([]byte, 1<<10)
|
||||||
|
if _, err = io.ReadFull(rand.Reader, payload); err != nil {
|
||||||
|
t.Fatalf("error generating payload: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
payload, _ = json.Marshal(map[string]interface{}{"data": payload})
|
||||||
|
|
||||||
|
sig1, err := NewJSONSignature(payload)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error creating signature 1: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := sig1.Sign(pk1); err != nil {
|
||||||
|
t.Fatalf("unexpected error signing with pk1: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sig2, err := NewJSONSignature(payload)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error creating signature 2: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := sig2.Sign(pk2); err != nil {
|
||||||
|
t.Fatalf("unexpected error signing with pk2: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now, we actually merge into sig1
|
||||||
|
if err := sig1.Merge(sig2); err != nil {
|
||||||
|
t.Fatalf("unexpected error merging: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the new signature package
|
||||||
|
pubkeys, err := sig1.Verify()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error during verify: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure the pubkeys match the two private keys from before
|
||||||
|
privkeys := map[string]PrivateKey{
|
||||||
|
pk1.KeyID(): pk1,
|
||||||
|
pk2.KeyID(): pk2,
|
||||||
|
}
|
||||||
|
|
||||||
|
found := map[string]struct{}{}
|
||||||
|
|
||||||
|
for _, pubkey := range pubkeys {
|
||||||
|
if _, ok := privkeys[pubkey.KeyID()]; !ok {
|
||||||
|
t.Fatalf("unexpected public key found during verification: %v", pubkey)
|
||||||
|
}
|
||||||
|
|
||||||
|
found[pubkey.KeyID()] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure we've found all the private keys from verification
|
||||||
|
for keyid, _ := range privkeys {
|
||||||
|
if _, ok := found[keyid]; !ok {
|
||||||
|
t.Fatalf("public key %v not found during verification", keyid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create another signature, with a different payload, and ensure we get an error.
|
||||||
|
sig3, err := NewJSONSignature([]byte("{}"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error making signature for sig3: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := sig1.Merge(sig3); err == nil {
|
||||||
|
t.Fatalf("error expected during invalid merge with different payload")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,175 @@
|
||||||
|
package libtrust
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"path"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ClientKeyManager manages client keys on the filesystem
|
||||||
|
type ClientKeyManager struct {
|
||||||
|
key PrivateKey
|
||||||
|
clientFile string
|
||||||
|
clientDir string
|
||||||
|
|
||||||
|
clientLock sync.RWMutex
|
||||||
|
clients []PublicKey
|
||||||
|
|
||||||
|
configLock sync.Mutex
|
||||||
|
configs []*tls.Config
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClientKeyManager loads a new manager from a set of key files
|
||||||
|
// and managed by the given private key.
|
||||||
|
func NewClientKeyManager(trustKey PrivateKey, clientFile, clientDir string) (*ClientKeyManager, error) {
|
||||||
|
m := &ClientKeyManager{
|
||||||
|
key: trustKey,
|
||||||
|
clientFile: clientFile,
|
||||||
|
clientDir: clientDir,
|
||||||
|
}
|
||||||
|
if err := m.loadKeys(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// TODO Start watching file and directory
|
||||||
|
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientKeyManager) loadKeys() (err error) {
|
||||||
|
// Load authorized keys file
|
||||||
|
var clients []PublicKey
|
||||||
|
if c.clientFile != "" {
|
||||||
|
clients, err = LoadKeySetFile(c.clientFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to load authorized keys: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add clients from authorized keys directory
|
||||||
|
files, err := ioutil.ReadDir(c.clientDir)
|
||||||
|
if err != nil && !os.IsNotExist(err) {
|
||||||
|
return fmt.Errorf("unable to open authorized keys directory: %s", err)
|
||||||
|
}
|
||||||
|
for _, f := range files {
|
||||||
|
if !f.IsDir() {
|
||||||
|
publicKey, err := LoadPublicKeyFile(path.Join(c.clientDir, f.Name()))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to load authorized key file: %s", err)
|
||||||
|
}
|
||||||
|
clients = append(clients, publicKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.clientLock.Lock()
|
||||||
|
c.clients = clients
|
||||||
|
c.clientLock.Unlock()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterTLSConfig registers a tls configuration to manager
|
||||||
|
// such that any changes to the keys may be reflected in
|
||||||
|
// the tls client CA pool
|
||||||
|
func (c *ClientKeyManager) RegisterTLSConfig(tlsConfig *tls.Config) error {
|
||||||
|
c.clientLock.RLock()
|
||||||
|
certPool, err := GenerateCACertPool(c.key, c.clients)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("CA pool generation error: %s", err)
|
||||||
|
}
|
||||||
|
c.clientLock.RUnlock()
|
||||||
|
|
||||||
|
tlsConfig.ClientCAs = certPool
|
||||||
|
|
||||||
|
c.configLock.Lock()
|
||||||
|
c.configs = append(c.configs, tlsConfig)
|
||||||
|
c.configLock.Unlock()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewIdentityAuthTLSConfig creates a tls.Config for the server to use for
|
||||||
|
// libtrust identity authentication for the domain specified
|
||||||
|
func NewIdentityAuthTLSConfig(trustKey PrivateKey, clients *ClientKeyManager, addr string, domain string) (*tls.Config, error) {
|
||||||
|
tlsConfig := newTLSConfig()
|
||||||
|
|
||||||
|
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
||||||
|
if err := clients.RegisterTLSConfig(tlsConfig); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate cert
|
||||||
|
ips, domains, err := parseAddr(addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// add domain that it expects clients to use
|
||||||
|
domains = append(domains, domain)
|
||||||
|
x509Cert, err := GenerateSelfSignedServerCert(trustKey, domains, ips)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("certificate generation error: %s", err)
|
||||||
|
}
|
||||||
|
tlsConfig.Certificates = []tls.Certificate{{
|
||||||
|
Certificate: [][]byte{x509Cert.Raw},
|
||||||
|
PrivateKey: trustKey.CryptoPrivateKey(),
|
||||||
|
Leaf: x509Cert,
|
||||||
|
}}
|
||||||
|
|
||||||
|
return tlsConfig, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCertAuthTLSConfig creates a tls.Config for the server to use for
|
||||||
|
// certificate authentication
|
||||||
|
func NewCertAuthTLSConfig(caPath, certPath, keyPath string) (*tls.Config, error) {
|
||||||
|
tlsConfig := newTLSConfig()
|
||||||
|
|
||||||
|
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Couldn't load X509 key pair (%s, %s): %s. Key encrypted?", certPath, keyPath, err)
|
||||||
|
}
|
||||||
|
tlsConfig.Certificates = []tls.Certificate{cert}
|
||||||
|
|
||||||
|
// Verify client certificates against a CA?
|
||||||
|
if caPath != "" {
|
||||||
|
certPool := x509.NewCertPool()
|
||||||
|
file, err := ioutil.ReadFile(caPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Couldn't read CA certificate: %s", err)
|
||||||
|
}
|
||||||
|
certPool.AppendCertsFromPEM(file)
|
||||||
|
|
||||||
|
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
||||||
|
tlsConfig.ClientCAs = certPool
|
||||||
|
}
|
||||||
|
|
||||||
|
return tlsConfig, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTLSConfig() *tls.Config {
|
||||||
|
return &tls.Config{
|
||||||
|
NextProtos: []string{"http/1.1"},
|
||||||
|
// Avoid fallback on insecure SSL protocols
|
||||||
|
MinVersion: tls.VersionTLS10,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseAddr parses an address into an array of IPs and domains
|
||||||
|
func parseAddr(addr string) ([]net.IP, []string, error) {
|
||||||
|
host, _, err := net.SplitHostPort(addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
var domains []string
|
||||||
|
var ips []net.IP
|
||||||
|
ip := net.ParseIP(host)
|
||||||
|
if ip != nil {
|
||||||
|
ips = []net.IP{ip}
|
||||||
|
} else {
|
||||||
|
domains = []string{host}
|
||||||
|
}
|
||||||
|
return ips, domains, nil
|
||||||
|
}
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"crypto"
|
"crypto"
|
||||||
"crypto/elliptic"
|
"crypto/elliptic"
|
||||||
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/base32"
|
"encoding/base32"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
@ -12,9 +13,144 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/big"
|
"math/big"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// LoadOrCreateTrustKey will load a PrivateKey from the specified path
|
||||||
|
func LoadOrCreateTrustKey(trustKeyPath string) (PrivateKey, error) {
|
||||||
|
if err := os.MkdirAll(filepath.Dir(trustKeyPath), 0700); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
trustKey, err := LoadKeyFile(trustKeyPath)
|
||||||
|
if err == ErrKeyFileDoesNotExist {
|
||||||
|
trustKey, err = GenerateECP256PrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error generating key: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := SaveKey(trustKeyPath, trustKey); err != nil {
|
||||||
|
return nil, fmt.Errorf("error saving key file: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
dir, file := filepath.Split(trustKeyPath)
|
||||||
|
if err := SavePublicKey(filepath.Join(dir, "public-"+file), trustKey.PublicKey()); err != nil {
|
||||||
|
return nil, fmt.Errorf("error saving public key file: %s", err)
|
||||||
|
}
|
||||||
|
} else if err != nil {
|
||||||
|
return nil, fmt.Errorf("error loading key file: %s", err)
|
||||||
|
}
|
||||||
|
return trustKey, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewIdentityAuthTLSClientConfig returns a tls.Config configured to use identity
|
||||||
|
// based authentication from the specified dockerUrl, the rootConfigPath and
|
||||||
|
// the server name to which it is connecting.
|
||||||
|
// If trustUnknownHosts is true it will automatically add the host to the
|
||||||
|
// known-hosts.json in rootConfigPath.
|
||||||
|
func NewIdentityAuthTLSClientConfig(dockerUrl string, trustUnknownHosts bool, rootConfigPath string, serverName string) (*tls.Config, error) {
|
||||||
|
tlsConfig := newTLSConfig()
|
||||||
|
|
||||||
|
trustKeyPath := filepath.Join(rootConfigPath, "key.json")
|
||||||
|
knownHostsPath := filepath.Join(rootConfigPath, "known-hosts.json")
|
||||||
|
|
||||||
|
u, err := url.Parse(dockerUrl)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to parse machine url")
|
||||||
|
}
|
||||||
|
|
||||||
|
if u.Scheme == "unix" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
addr := u.Host
|
||||||
|
proto := "tcp"
|
||||||
|
|
||||||
|
trustKey, err := LoadOrCreateTrustKey(trustKeyPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to load trust key: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
knownHosts, err := LoadKeySetFile(knownHostsPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("could not load trusted hosts file: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
allowedHosts, err := FilterByHosts(knownHosts, addr, false)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error filtering hosts: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
certPool, err := GenerateCACertPool(trustKey, allowedHosts)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Could not create CA pool: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConfig.ServerName = serverName
|
||||||
|
tlsConfig.RootCAs = certPool
|
||||||
|
|
||||||
|
x509Cert, err := GenerateSelfSignedClientCert(trustKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("certificate generation error: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConfig.Certificates = []tls.Certificate{{
|
||||||
|
Certificate: [][]byte{x509Cert.Raw},
|
||||||
|
PrivateKey: trustKey.CryptoPrivateKey(),
|
||||||
|
Leaf: x509Cert,
|
||||||
|
}}
|
||||||
|
|
||||||
|
tlsConfig.InsecureSkipVerify = true
|
||||||
|
|
||||||
|
testConn, err := tls.Dial(proto, addr, tlsConfig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("tls Handshake error: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
opts := x509.VerifyOptions{
|
||||||
|
Roots: tlsConfig.RootCAs,
|
||||||
|
CurrentTime: time.Now(),
|
||||||
|
DNSName: tlsConfig.ServerName,
|
||||||
|
Intermediates: x509.NewCertPool(),
|
||||||
|
}
|
||||||
|
|
||||||
|
certs := testConn.ConnectionState().PeerCertificates
|
||||||
|
for i, cert := range certs {
|
||||||
|
if i == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
opts.Intermediates.AddCert(cert)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := certs[0].Verify(opts); err != nil {
|
||||||
|
if _, ok := err.(x509.UnknownAuthorityError); ok {
|
||||||
|
if trustUnknownHosts {
|
||||||
|
pubKey, err := FromCryptoPublicKey(certs[0].PublicKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error extracting public key from cert: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pubKey.AddExtendedField("hosts", []string{addr})
|
||||||
|
|
||||||
|
if err := AddKeySetFile(knownHostsPath, pubKey); err != nil {
|
||||||
|
return nil, fmt.Errorf("error adding machine to known hosts: %s", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("unable to connect. unknown host: %s", addr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
testConn.Close()
|
||||||
|
tlsConfig.InsecureSkipVerify = false
|
||||||
|
|
||||||
|
return tlsConfig, nil
|
||||||
|
}
|
||||||
|
|
||||||
// joseBase64UrlEncode encodes the given data using the standard base64 url
|
// joseBase64UrlEncode encodes the given data using the standard base64 url
|
||||||
// encoding format but with all trailing '=' characters ommitted in accordance
|
// encoding format but with all trailing '=' characters ommitted in accordance
|
||||||
// with the jose specification.
|
// with the jose specification.
|
||||||
|
|
|
@ -15,11 +15,16 @@ func (s NaturalSort) Swap(i, j int) {
|
||||||
s[i], s[j] = s[j], s[i]
|
s[i], s[j] = s[j], s[i]
|
||||||
}
|
}
|
||||||
func (s NaturalSort) Less(i, j int) bool {
|
func (s NaturalSort) Less(i, j int) bool {
|
||||||
r1 := regexp.MustCompilePOSIX(`^([^0-9]*)+|[0-9]+`)
|
r := regexp.MustCompilePOSIX(`^([^0-9]*)+|[0-9]+`)
|
||||||
|
|
||||||
spliti := r1.FindAllString(strings.Replace(s[i], " ", "", -1), -1)
|
spliti := r.FindAllString(strings.Replace(s[i], " ", "", -1), -1)
|
||||||
splitj := r1.FindAllString(strings.Replace(s[j], " ", "", -1), -1)
|
splitj := r.FindAllString(strings.Replace(s[j], " ", "", -1), -1)
|
||||||
for index := range spliti {
|
|
||||||
|
splitshortest := len(spliti)
|
||||||
|
if len(spliti) > len(splitj) {
|
||||||
|
splitshortest = len(splitj)
|
||||||
|
}
|
||||||
|
for index := 0; index < splitshortest; index ++{
|
||||||
if spliti[index] != splitj[index] {
|
if spliti[index] != splitj[index] {
|
||||||
inti, ei := strconv.Atoi(spliti[index])
|
inti, ei := strconv.Atoi(spliti[index])
|
||||||
intj, ej := strconv.Atoi(splitj[index])
|
intj, ej := strconv.Atoi(splitj[index])
|
||||||
|
|
|
@ -26,6 +26,10 @@ func TestSortValid(t *testing.T) {
|
||||||
[]string{"0"},
|
[]string{"0"},
|
||||||
[]string{"0"},
|
[]string{"0"},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
[]string{"data","data20","data3"},
|
||||||
|
[]string{"data","data3","data20"},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
[]string{"1", "2", "30", "22", "0", "00", "3"},
|
[]string{"1", "2", "30", "22", "0", "00", "3"},
|
||||||
[]string{"0", "00", "1", "2", "3", "22", "30"},
|
[]string{"0", "00", "1", "2", "3", "22", "30"},
|
||||||
|
|
|
@ -1,563 +0,0 @@
|
||||||
// Copyright 2012 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
/*
|
|
||||||
Package agent implements a client to an ssh-agent daemon.
|
|
||||||
|
|
||||||
References:
|
|
||||||
[PROTOCOL.agent]: http://www.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.agent
|
|
||||||
*/
|
|
||||||
package agent
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"crypto/dsa"
|
|
||||||
"crypto/ecdsa"
|
|
||||||
"crypto/elliptic"
|
|
||||||
"crypto/rsa"
|
|
||||||
"encoding/base64"
|
|
||||||
"encoding/binary"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"math/big"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Agent represents the capabilities of an ssh-agent.
|
|
||||||
type Agent interface {
|
|
||||||
// List returns the identities known to the agent.
|
|
||||||
List() ([]*Key, error)
|
|
||||||
|
|
||||||
// Sign has the agent sign the data using a protocol 2 key as defined
|
|
||||||
// in [PROTOCOL.agent] section 2.6.2.
|
|
||||||
Sign(key ssh.PublicKey, data []byte) (*ssh.Signature, error)
|
|
||||||
|
|
||||||
// Insert adds a private key to the agent. If a certificate
|
|
||||||
// is given, that certificate is added as public key.
|
|
||||||
Add(s interface{}, cert *ssh.Certificate, comment string) error
|
|
||||||
|
|
||||||
// Remove removes all identities with the given public key.
|
|
||||||
Remove(key ssh.PublicKey) error
|
|
||||||
|
|
||||||
// RemoveAll removes all identities.
|
|
||||||
RemoveAll() error
|
|
||||||
|
|
||||||
// Lock locks the agent. Sign and Remove will fail, and List will empty an empty list.
|
|
||||||
Lock(passphrase []byte) error
|
|
||||||
|
|
||||||
// Unlock undoes the effect of Lock
|
|
||||||
Unlock(passphrase []byte) error
|
|
||||||
|
|
||||||
// Signers returns signers for all the known keys.
|
|
||||||
Signers() ([]ssh.Signer, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// See [PROTOCOL.agent], section 3.
|
|
||||||
const (
|
|
||||||
agentRequestV1Identities = 1
|
|
||||||
|
|
||||||
// 3.2 Requests from client to agent for protocol 2 key operations
|
|
||||||
agentAddIdentity = 17
|
|
||||||
agentRemoveIdentity = 18
|
|
||||||
agentRemoveAllIdentities = 19
|
|
||||||
agentAddIdConstrained = 25
|
|
||||||
|
|
||||||
// 3.3 Key-type independent requests from client to agent
|
|
||||||
agentAddSmartcardKey = 20
|
|
||||||
agentRemoveSmartcardKey = 21
|
|
||||||
agentLock = 22
|
|
||||||
agentUnlock = 23
|
|
||||||
agentAddSmartcardKeyConstrained = 26
|
|
||||||
|
|
||||||
// 3.7 Key constraint identifiers
|
|
||||||
agentConstrainLifetime = 1
|
|
||||||
agentConstrainConfirm = 2
|
|
||||||
)
|
|
||||||
|
|
||||||
// maxAgentResponseBytes is the maximum agent reply size that is accepted. This
|
|
||||||
// is a sanity check, not a limit in the spec.
|
|
||||||
const maxAgentResponseBytes = 16 << 20
|
|
||||||
|
|
||||||
// Agent messages:
|
|
||||||
// These structures mirror the wire format of the corresponding ssh agent
|
|
||||||
// messages found in [PROTOCOL.agent].
|
|
||||||
|
|
||||||
// 3.4 Generic replies from agent to client
|
|
||||||
const agentFailure = 5
|
|
||||||
|
|
||||||
type failureAgentMsg struct{}
|
|
||||||
|
|
||||||
const agentSuccess = 6
|
|
||||||
|
|
||||||
type successAgentMsg struct{}
|
|
||||||
|
|
||||||
// See [PROTOCOL.agent], section 2.5.2.
|
|
||||||
const agentRequestIdentities = 11
|
|
||||||
|
|
||||||
type requestIdentitiesAgentMsg struct{}
|
|
||||||
|
|
||||||
// See [PROTOCOL.agent], section 2.5.2.
|
|
||||||
const agentIdentitiesAnswer = 12
|
|
||||||
|
|
||||||
type identitiesAnswerAgentMsg struct {
|
|
||||||
NumKeys uint32 `sshtype:"12"`
|
|
||||||
Keys []byte `ssh:"rest"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// See [PROTOCOL.agent], section 2.6.2.
|
|
||||||
const agentSignRequest = 13
|
|
||||||
|
|
||||||
type signRequestAgentMsg struct {
|
|
||||||
KeyBlob []byte `sshtype:"13"`
|
|
||||||
Data []byte
|
|
||||||
Flags uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
// See [PROTOCOL.agent], section 2.6.2.
|
|
||||||
|
|
||||||
// 3.6 Replies from agent to client for protocol 2 key operations
|
|
||||||
const agentSignResponse = 14
|
|
||||||
|
|
||||||
type signResponseAgentMsg struct {
|
|
||||||
SigBlob []byte `sshtype:"14"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type publicKey struct {
|
|
||||||
Format string
|
|
||||||
Rest []byte `ssh:"rest"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Key represents a protocol 2 public key as defined in
|
|
||||||
// [PROTOCOL.agent], section 2.5.2.
|
|
||||||
type Key struct {
|
|
||||||
Format string
|
|
||||||
Blob []byte
|
|
||||||
Comment string
|
|
||||||
}
|
|
||||||
|
|
||||||
func clientErr(err error) error {
|
|
||||||
return fmt.Errorf("agent: client error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// String returns the storage form of an agent key with the format, base64
|
|
||||||
// encoded serialized key, and the comment if it is not empty.
|
|
||||||
func (k *Key) String() string {
|
|
||||||
s := string(k.Format) + " " + base64.StdEncoding.EncodeToString(k.Blob)
|
|
||||||
|
|
||||||
if k.Comment != "" {
|
|
||||||
s += " " + k.Comment
|
|
||||||
}
|
|
||||||
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// Type returns the public key type.
|
|
||||||
func (k *Key) Type() string {
|
|
||||||
return k.Format
|
|
||||||
}
|
|
||||||
|
|
||||||
// Marshal returns key blob to satisfy the ssh.PublicKey interface.
|
|
||||||
func (k *Key) Marshal() []byte {
|
|
||||||
return k.Blob
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify satisfies the ssh.PublicKey interface, but is not
|
|
||||||
// implemented for agent keys.
|
|
||||||
func (k *Key) Verify(data []byte, sig *ssh.Signature) error {
|
|
||||||
return errors.New("agent: agent key does not know how to verify")
|
|
||||||
}
|
|
||||||
|
|
||||||
type wireKey struct {
|
|
||||||
Format string
|
|
||||||
Rest []byte `ssh:"rest"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseKey(in []byte) (out *Key, rest []byte, err error) {
|
|
||||||
var record struct {
|
|
||||||
Blob []byte
|
|
||||||
Comment string
|
|
||||||
Rest []byte `ssh:"rest"`
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := ssh.Unmarshal(in, &record); err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var wk wireKey
|
|
||||||
if err := ssh.Unmarshal(record.Blob, &wk); err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &Key{
|
|
||||||
Format: wk.Format,
|
|
||||||
Blob: record.Blob,
|
|
||||||
Comment: record.Comment,
|
|
||||||
}, record.Rest, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// client is a client for an ssh-agent process.
|
|
||||||
type client struct {
|
|
||||||
// conn is typically a *net.UnixConn
|
|
||||||
conn io.ReadWriter
|
|
||||||
// mu is used to prevent concurrent access to the agent
|
|
||||||
mu sync.Mutex
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewClient returns an Agent that talks to an ssh-agent process over
|
|
||||||
// the given connection.
|
|
||||||
func NewClient(rw io.ReadWriter) Agent {
|
|
||||||
return &client{conn: rw}
|
|
||||||
}
|
|
||||||
|
|
||||||
// call sends an RPC to the agent. On success, the reply is
|
|
||||||
// unmarshaled into reply and replyType is set to the first byte of
|
|
||||||
// the reply, which contains the type of the message.
|
|
||||||
func (c *client) call(req []byte) (reply interface{}, err error) {
|
|
||||||
c.mu.Lock()
|
|
||||||
defer c.mu.Unlock()
|
|
||||||
|
|
||||||
msg := make([]byte, 4+len(req))
|
|
||||||
binary.BigEndian.PutUint32(msg, uint32(len(req)))
|
|
||||||
copy(msg[4:], req)
|
|
||||||
if _, err = c.conn.Write(msg); err != nil {
|
|
||||||
return nil, clientErr(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var respSizeBuf [4]byte
|
|
||||||
if _, err = io.ReadFull(c.conn, respSizeBuf[:]); err != nil {
|
|
||||||
return nil, clientErr(err)
|
|
||||||
}
|
|
||||||
respSize := binary.BigEndian.Uint32(respSizeBuf[:])
|
|
||||||
if respSize > maxAgentResponseBytes {
|
|
||||||
return nil, clientErr(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
buf := make([]byte, respSize)
|
|
||||||
if _, err = io.ReadFull(c.conn, buf); err != nil {
|
|
||||||
return nil, clientErr(err)
|
|
||||||
}
|
|
||||||
reply, err = unmarshal(buf)
|
|
||||||
if err != nil {
|
|
||||||
return nil, clientErr(err)
|
|
||||||
}
|
|
||||||
return reply, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *client) simpleCall(req []byte) error {
|
|
||||||
resp, err := c.call(req)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if _, ok := resp.(*successAgentMsg); ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return errors.New("agent: failure")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *client) RemoveAll() error {
|
|
||||||
return c.simpleCall([]byte{agentRemoveAllIdentities})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *client) Remove(key ssh.PublicKey) error {
|
|
||||||
req := ssh.Marshal(&agentRemoveIdentityMsg{
|
|
||||||
KeyBlob: key.Marshal(),
|
|
||||||
})
|
|
||||||
return c.simpleCall(req)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *client) Lock(passphrase []byte) error {
|
|
||||||
req := ssh.Marshal(&agentLockMsg{
|
|
||||||
Passphrase: passphrase,
|
|
||||||
})
|
|
||||||
return c.simpleCall(req)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *client) Unlock(passphrase []byte) error {
|
|
||||||
req := ssh.Marshal(&agentUnlockMsg{
|
|
||||||
Passphrase: passphrase,
|
|
||||||
})
|
|
||||||
return c.simpleCall(req)
|
|
||||||
}
|
|
||||||
|
|
||||||
// List returns the identities known to the agent.
|
|
||||||
func (c *client) List() ([]*Key, error) {
|
|
||||||
// see [PROTOCOL.agent] section 2.5.2.
|
|
||||||
req := []byte{agentRequestIdentities}
|
|
||||||
|
|
||||||
msg, err := c.call(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
switch msg := msg.(type) {
|
|
||||||
case *identitiesAnswerAgentMsg:
|
|
||||||
if msg.NumKeys > maxAgentResponseBytes/8 {
|
|
||||||
return nil, errors.New("agent: too many keys in agent reply")
|
|
||||||
}
|
|
||||||
keys := make([]*Key, msg.NumKeys)
|
|
||||||
data := msg.Keys
|
|
||||||
for i := uint32(0); i < msg.NumKeys; i++ {
|
|
||||||
var key *Key
|
|
||||||
var err error
|
|
||||||
if key, data, err = parseKey(data); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
keys[i] = key
|
|
||||||
}
|
|
||||||
return keys, nil
|
|
||||||
case *failureAgentMsg:
|
|
||||||
return nil, errors.New("agent: failed to list keys")
|
|
||||||
}
|
|
||||||
panic("unreachable")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sign has the agent sign the data using a protocol 2 key as defined
|
|
||||||
// in [PROTOCOL.agent] section 2.6.2.
|
|
||||||
func (c *client) Sign(key ssh.PublicKey, data []byte) (*ssh.Signature, error) {
|
|
||||||
req := ssh.Marshal(signRequestAgentMsg{
|
|
||||||
KeyBlob: key.Marshal(),
|
|
||||||
Data: data,
|
|
||||||
})
|
|
||||||
|
|
||||||
msg, err := c.call(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
switch msg := msg.(type) {
|
|
||||||
case *signResponseAgentMsg:
|
|
||||||
var sig ssh.Signature
|
|
||||||
if err := ssh.Unmarshal(msg.SigBlob, &sig); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &sig, nil
|
|
||||||
case *failureAgentMsg:
|
|
||||||
return nil, errors.New("agent: failed to sign challenge")
|
|
||||||
}
|
|
||||||
panic("unreachable")
|
|
||||||
}
|
|
||||||
|
|
||||||
// unmarshal parses an agent message in packet, returning the parsed
|
|
||||||
// form and the message type of packet.
|
|
||||||
func unmarshal(packet []byte) (interface{}, error) {
|
|
||||||
if len(packet) < 1 {
|
|
||||||
return nil, errors.New("agent: empty packet")
|
|
||||||
}
|
|
||||||
var msg interface{}
|
|
||||||
switch packet[0] {
|
|
||||||
case agentFailure:
|
|
||||||
return new(failureAgentMsg), nil
|
|
||||||
case agentSuccess:
|
|
||||||
return new(successAgentMsg), nil
|
|
||||||
case agentIdentitiesAnswer:
|
|
||||||
msg = new(identitiesAnswerAgentMsg)
|
|
||||||
case agentSignResponse:
|
|
||||||
msg = new(signResponseAgentMsg)
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("agent: unknown type tag %d", packet[0])
|
|
||||||
}
|
|
||||||
if err := ssh.Unmarshal(packet, msg); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return msg, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type rsaKeyMsg struct {
|
|
||||||
Type string `sshtype:"17"`
|
|
||||||
N *big.Int
|
|
||||||
E *big.Int
|
|
||||||
D *big.Int
|
|
||||||
Iqmp *big.Int // IQMP = Inverse Q Mod P
|
|
||||||
P *big.Int
|
|
||||||
Q *big.Int
|
|
||||||
Comments string
|
|
||||||
}
|
|
||||||
|
|
||||||
type dsaKeyMsg struct {
|
|
||||||
Type string `sshtype:"17"`
|
|
||||||
P *big.Int
|
|
||||||
Q *big.Int
|
|
||||||
G *big.Int
|
|
||||||
Y *big.Int
|
|
||||||
X *big.Int
|
|
||||||
Comments string
|
|
||||||
}
|
|
||||||
|
|
||||||
type ecdsaKeyMsg struct {
|
|
||||||
Type string `sshtype:"17"`
|
|
||||||
Curve string
|
|
||||||
KeyBytes []byte
|
|
||||||
D *big.Int
|
|
||||||
Comments string
|
|
||||||
}
|
|
||||||
|
|
||||||
// Insert adds a private key to the agent.
|
|
||||||
func (c *client) insertKey(s interface{}, comment string) error {
|
|
||||||
var req []byte
|
|
||||||
switch k := s.(type) {
|
|
||||||
case *rsa.PrivateKey:
|
|
||||||
if len(k.Primes) != 2 {
|
|
||||||
return fmt.Errorf("agent: unsupported RSA key with %d primes", len(k.Primes))
|
|
||||||
}
|
|
||||||
k.Precompute()
|
|
||||||
req = ssh.Marshal(rsaKeyMsg{
|
|
||||||
Type: ssh.KeyAlgoRSA,
|
|
||||||
N: k.N,
|
|
||||||
E: big.NewInt(int64(k.E)),
|
|
||||||
D: k.D,
|
|
||||||
Iqmp: k.Precomputed.Qinv,
|
|
||||||
P: k.Primes[0],
|
|
||||||
Q: k.Primes[1],
|
|
||||||
Comments: comment,
|
|
||||||
})
|
|
||||||
case *dsa.PrivateKey:
|
|
||||||
req = ssh.Marshal(dsaKeyMsg{
|
|
||||||
Type: ssh.KeyAlgoDSA,
|
|
||||||
P: k.P,
|
|
||||||
Q: k.Q,
|
|
||||||
G: k.G,
|
|
||||||
Y: k.Y,
|
|
||||||
X: k.X,
|
|
||||||
Comments: comment,
|
|
||||||
})
|
|
||||||
case *ecdsa.PrivateKey:
|
|
||||||
nistID := fmt.Sprintf("nistp%d", k.Params().BitSize)
|
|
||||||
req = ssh.Marshal(ecdsaKeyMsg{
|
|
||||||
Type: "ecdsa-sha2-" + nistID,
|
|
||||||
Curve: nistID,
|
|
||||||
KeyBytes: elliptic.Marshal(k.Curve, k.X, k.Y),
|
|
||||||
D: k.D,
|
|
||||||
Comments: comment,
|
|
||||||
})
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("agent: unsupported key type %T", s)
|
|
||||||
}
|
|
||||||
resp, err := c.call(req)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if _, ok := resp.(*successAgentMsg); ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return errors.New("agent: failure")
|
|
||||||
}
|
|
||||||
|
|
||||||
type rsaCertMsg struct {
|
|
||||||
Type string `sshtype:"17"`
|
|
||||||
CertBytes []byte
|
|
||||||
D *big.Int
|
|
||||||
Iqmp *big.Int // IQMP = Inverse Q Mod P
|
|
||||||
P *big.Int
|
|
||||||
Q *big.Int
|
|
||||||
Comments string
|
|
||||||
}
|
|
||||||
|
|
||||||
type dsaCertMsg struct {
|
|
||||||
Type string `sshtype:"17"`
|
|
||||||
CertBytes []byte
|
|
||||||
X *big.Int
|
|
||||||
Comments string
|
|
||||||
}
|
|
||||||
|
|
||||||
type ecdsaCertMsg struct {
|
|
||||||
Type string `sshtype:"17"`
|
|
||||||
CertBytes []byte
|
|
||||||
D *big.Int
|
|
||||||
Comments string
|
|
||||||
}
|
|
||||||
|
|
||||||
// Insert adds a private key to the agent. If a certificate is given,
|
|
||||||
// that certificate is added instead as public key.
|
|
||||||
func (c *client) Add(s interface{}, cert *ssh.Certificate, comment string) error {
|
|
||||||
if cert == nil {
|
|
||||||
return c.insertKey(s, comment)
|
|
||||||
} else {
|
|
||||||
return c.insertCert(s, cert, comment)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *client) insertCert(s interface{}, cert *ssh.Certificate, comment string) error {
|
|
||||||
var req []byte
|
|
||||||
switch k := s.(type) {
|
|
||||||
case *rsa.PrivateKey:
|
|
||||||
if len(k.Primes) != 2 {
|
|
||||||
return fmt.Errorf("agent: unsupported RSA key with %d primes", len(k.Primes))
|
|
||||||
}
|
|
||||||
k.Precompute()
|
|
||||||
req = ssh.Marshal(rsaCertMsg{
|
|
||||||
Type: cert.Type(),
|
|
||||||
CertBytes: cert.Marshal(),
|
|
||||||
D: k.D,
|
|
||||||
Iqmp: k.Precomputed.Qinv,
|
|
||||||
P: k.Primes[0],
|
|
||||||
Q: k.Primes[1],
|
|
||||||
Comments: comment,
|
|
||||||
})
|
|
||||||
case *dsa.PrivateKey:
|
|
||||||
req = ssh.Marshal(dsaCertMsg{
|
|
||||||
Type: cert.Type(),
|
|
||||||
CertBytes: cert.Marshal(),
|
|
||||||
X: k.X,
|
|
||||||
Comments: comment,
|
|
||||||
})
|
|
||||||
case *ecdsa.PrivateKey:
|
|
||||||
req = ssh.Marshal(ecdsaCertMsg{
|
|
||||||
Type: cert.Type(),
|
|
||||||
CertBytes: cert.Marshal(),
|
|
||||||
D: k.D,
|
|
||||||
Comments: comment,
|
|
||||||
})
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("agent: unsupported key type %T", s)
|
|
||||||
}
|
|
||||||
|
|
||||||
signer, err := ssh.NewSignerFromKey(s)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if bytes.Compare(cert.Key.Marshal(), signer.PublicKey().Marshal()) != 0 {
|
|
||||||
return errors.New("agent: signer and cert have different public key")
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := c.call(req)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if _, ok := resp.(*successAgentMsg); ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return errors.New("agent: failure")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Signers provides a callback for client authentication.
|
|
||||||
func (c *client) Signers() ([]ssh.Signer, error) {
|
|
||||||
keys, err := c.List()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var result []ssh.Signer
|
|
||||||
for _, k := range keys {
|
|
||||||
result = append(result, &agentKeyringSigner{c, k})
|
|
||||||
}
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type agentKeyringSigner struct {
|
|
||||||
agent *client
|
|
||||||
pub ssh.PublicKey
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *agentKeyringSigner) PublicKey() ssh.PublicKey {
|
|
||||||
return s.pub
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *agentKeyringSigner) Sign(rand io.Reader, data []byte) (*ssh.Signature, error) {
|
|
||||||
// The agent has its own entropy source, so the rand argument is ignored.
|
|
||||||
return s.agent.Sign(s.pub, data)
|
|
||||||
}
|
|
|
@ -1,278 +0,0 @@
|
||||||
// Copyright 2012 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package agent
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"crypto/rand"
|
|
||||||
"errors"
|
|
||||||
"net"
|
|
||||||
"os"
|
|
||||||
"os/exec"
|
|
||||||
"path/filepath"
|
|
||||||
"strconv"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
|
||||||
)
|
|
||||||
|
|
||||||
// startAgent executes ssh-agent, and returns a Agent interface to it.
|
|
||||||
func startAgent(t *testing.T) (client Agent, socket string, cleanup func()) {
|
|
||||||
if testing.Short() {
|
|
||||||
// ssh-agent is not always available, and the key
|
|
||||||
// types supported vary by platform.
|
|
||||||
t.Skip("skipping test due to -short")
|
|
||||||
}
|
|
||||||
|
|
||||||
bin, err := exec.LookPath("ssh-agent")
|
|
||||||
if err != nil {
|
|
||||||
t.Skip("could not find ssh-agent")
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd := exec.Command(bin, "-s")
|
|
||||||
out, err := cmd.Output()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("cmd.Output: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Output looks like:
|
|
||||||
|
|
||||||
SSH_AUTH_SOCK=/tmp/ssh-P65gpcqArqvH/agent.15541; export SSH_AUTH_SOCK;
|
|
||||||
SSH_AGENT_PID=15542; export SSH_AGENT_PID;
|
|
||||||
echo Agent pid 15542;
|
|
||||||
*/
|
|
||||||
fields := bytes.Split(out, []byte(";"))
|
|
||||||
line := bytes.SplitN(fields[0], []byte("="), 2)
|
|
||||||
line[0] = bytes.TrimLeft(line[0], "\n")
|
|
||||||
if string(line[0]) != "SSH_AUTH_SOCK" {
|
|
||||||
t.Fatalf("could not find key SSH_AUTH_SOCK in %q", fields[0])
|
|
||||||
}
|
|
||||||
socket = string(line[1])
|
|
||||||
|
|
||||||
line = bytes.SplitN(fields[2], []byte("="), 2)
|
|
||||||
line[0] = bytes.TrimLeft(line[0], "\n")
|
|
||||||
if string(line[0]) != "SSH_AGENT_PID" {
|
|
||||||
t.Fatalf("could not find key SSH_AGENT_PID in %q", fields[2])
|
|
||||||
}
|
|
||||||
pidStr := line[1]
|
|
||||||
pid, err := strconv.Atoi(string(pidStr))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Atoi(%q): %v", pidStr, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err := net.Dial("unix", string(socket))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("net.Dial: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ac := NewClient(conn)
|
|
||||||
return ac, socket, func() {
|
|
||||||
proc, _ := os.FindProcess(pid)
|
|
||||||
if proc != nil {
|
|
||||||
proc.Kill()
|
|
||||||
}
|
|
||||||
conn.Close()
|
|
||||||
os.RemoveAll(filepath.Dir(socket))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func testAgent(t *testing.T, key interface{}, cert *ssh.Certificate) {
|
|
||||||
agent, _, cleanup := startAgent(t)
|
|
||||||
defer cleanup()
|
|
||||||
|
|
||||||
testAgentInterface(t, agent, key, cert)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testAgentInterface(t *testing.T, agent Agent, key interface{}, cert *ssh.Certificate) {
|
|
||||||
signer, err := ssh.NewSignerFromKey(key)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("NewSignerFromKey(%T): %v", key, err)
|
|
||||||
}
|
|
||||||
// The agent should start up empty.
|
|
||||||
if keys, err := agent.List(); err != nil {
|
|
||||||
t.Fatalf("RequestIdentities: %v", err)
|
|
||||||
} else if len(keys) > 0 {
|
|
||||||
t.Fatalf("got %d keys, want 0: %v", len(keys), keys)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Attempt to insert the key, with certificate if specified.
|
|
||||||
var pubKey ssh.PublicKey
|
|
||||||
if cert != nil {
|
|
||||||
err = agent.Add(key, cert, "comment")
|
|
||||||
pubKey = cert
|
|
||||||
} else {
|
|
||||||
err = agent.Add(key, nil, "comment")
|
|
||||||
pubKey = signer.PublicKey()
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("insert(%T): %v", key, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Did the key get inserted successfully?
|
|
||||||
if keys, err := agent.List(); err != nil {
|
|
||||||
t.Fatalf("List: %v", err)
|
|
||||||
} else if len(keys) != 1 {
|
|
||||||
t.Fatalf("got %v, want 1 key", keys)
|
|
||||||
} else if keys[0].Comment != "comment" {
|
|
||||||
t.Fatalf("key comment: got %v, want %v", keys[0].Comment, "comment")
|
|
||||||
} else if !bytes.Equal(keys[0].Blob, pubKey.Marshal()) {
|
|
||||||
t.Fatalf("key mismatch")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Can the agent make a valid signature?
|
|
||||||
data := []byte("hello")
|
|
||||||
sig, err := agent.Sign(pubKey, data)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Sign(%s): %v", pubKey.Type(), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := pubKey.Verify(data, sig); err != nil {
|
|
||||||
t.Fatalf("Verify(%s): %v", pubKey.Type(), err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAgent(t *testing.T) {
|
|
||||||
for _, keyType := range []string{"rsa", "dsa", "ecdsa"} {
|
|
||||||
testAgent(t, testPrivateKeys[keyType], nil)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCert(t *testing.T) {
|
|
||||||
cert := &ssh.Certificate{
|
|
||||||
Key: testPublicKeys["rsa"],
|
|
||||||
ValidBefore: ssh.CertTimeInfinity,
|
|
||||||
CertType: ssh.UserCert,
|
|
||||||
}
|
|
||||||
cert.SignCert(rand.Reader, testSigners["ecdsa"])
|
|
||||||
|
|
||||||
testAgent(t, testPrivateKeys["rsa"], cert)
|
|
||||||
}
|
|
||||||
|
|
||||||
// netPipe is analogous to net.Pipe, but it uses a real net.Conn, and
|
|
||||||
// therefore is buffered (net.Pipe deadlocks if both sides start with
|
|
||||||
// a write.)
|
|
||||||
func netPipe() (net.Conn, net.Conn, error) {
|
|
||||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
defer listener.Close()
|
|
||||||
c1, err := net.Dial("tcp", listener.Addr().String())
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
c2, err := listener.Accept()
|
|
||||||
if err != nil {
|
|
||||||
c1.Close()
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return c1, c2, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAuth(t *testing.T) {
|
|
||||||
a, b, err := netPipe()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("netPipe: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
defer a.Close()
|
|
||||||
defer b.Close()
|
|
||||||
|
|
||||||
agent, _, cleanup := startAgent(t)
|
|
||||||
defer cleanup()
|
|
||||||
|
|
||||||
if err := agent.Add(testPrivateKeys["rsa"], nil, "comment"); err != nil {
|
|
||||||
t.Errorf("Add: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
serverConf := ssh.ServerConfig{}
|
|
||||||
serverConf.AddHostKey(testSigners["rsa"])
|
|
||||||
serverConf.PublicKeyCallback = func(c ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
|
|
||||||
if bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, errors.New("pubkey rejected")
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
conn, _, _, err := ssh.NewServerConn(a, &serverConf)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Server: %v", err)
|
|
||||||
}
|
|
||||||
conn.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
conf := ssh.ClientConfig{}
|
|
||||||
conf.Auth = append(conf.Auth, ssh.PublicKeysCallback(agent.Signers))
|
|
||||||
conn, _, _, err := ssh.NewClientConn(b, "", &conf)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("NewClientConn: %v", err)
|
|
||||||
}
|
|
||||||
conn.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLockClient(t *testing.T) {
|
|
||||||
agent, _, cleanup := startAgent(t)
|
|
||||||
defer cleanup()
|
|
||||||
testLockAgent(agent, t)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testLockAgent(agent Agent, t *testing.T) {
|
|
||||||
if err := agent.Add(testPrivateKeys["rsa"], nil, "comment 1"); err != nil {
|
|
||||||
t.Errorf("Add: %v", err)
|
|
||||||
}
|
|
||||||
if err := agent.Add(testPrivateKeys["dsa"], nil, "comment dsa"); err != nil {
|
|
||||||
t.Errorf("Add: %v", err)
|
|
||||||
}
|
|
||||||
if keys, err := agent.List(); err != nil {
|
|
||||||
t.Errorf("List: %v", err)
|
|
||||||
} else if len(keys) != 2 {
|
|
||||||
t.Errorf("Want 2 keys, got %v", keys)
|
|
||||||
}
|
|
||||||
|
|
||||||
passphrase := []byte("secret")
|
|
||||||
if err := agent.Lock(passphrase); err != nil {
|
|
||||||
t.Errorf("Lock: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if keys, err := agent.List(); err != nil {
|
|
||||||
t.Errorf("List: %v", err)
|
|
||||||
} else if len(keys) != 0 {
|
|
||||||
t.Errorf("Want 0 keys, got %v", keys)
|
|
||||||
}
|
|
||||||
|
|
||||||
signer, _ := ssh.NewSignerFromKey(testPrivateKeys["rsa"])
|
|
||||||
if _, err := agent.Sign(signer.PublicKey(), []byte("hello")); err == nil {
|
|
||||||
t.Fatalf("Sign did not fail")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := agent.Remove(signer.PublicKey()); err == nil {
|
|
||||||
t.Fatalf("Remove did not fail")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := agent.RemoveAll(); err == nil {
|
|
||||||
t.Fatalf("RemoveAll did not fail")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := agent.Unlock(nil); err == nil {
|
|
||||||
t.Errorf("Unlock with wrong passphrase succeeded")
|
|
||||||
}
|
|
||||||
if err := agent.Unlock(passphrase); err != nil {
|
|
||||||
t.Errorf("Unlock: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := agent.Remove(signer.PublicKey()); err != nil {
|
|
||||||
t.Fatalf("Remove: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if keys, err := agent.List(); err != nil {
|
|
||||||
t.Errorf("List: %v", err)
|
|
||||||
} else if len(keys) != 1 {
|
|
||||||
t.Errorf("Want 1 keys, got %v", keys)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,103 +0,0 @@
|
||||||
// Copyright 2014 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package agent
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
|
||||||
)
|
|
||||||
|
|
||||||
// RequestAgentForwarding sets up agent forwarding for the session.
|
|
||||||
// ForwardToAgent or ForwardToRemote should be called to route
|
|
||||||
// the authentication requests.
|
|
||||||
func RequestAgentForwarding(session *ssh.Session) error {
|
|
||||||
ok, err := session.SendRequest("auth-agent-req@openssh.com", true, nil)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if !ok {
|
|
||||||
return errors.New("forwarding request denied")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ForwardToAgent routes authentication requests to the given keyring.
|
|
||||||
func ForwardToAgent(client *ssh.Client, keyring Agent) error {
|
|
||||||
channels := client.HandleChannelOpen(channelType)
|
|
||||||
if channels == nil {
|
|
||||||
return errors.New("agent: already have handler for " + channelType)
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
for ch := range channels {
|
|
||||||
channel, reqs, err := ch.Accept()
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
go ssh.DiscardRequests(reqs)
|
|
||||||
go func() {
|
|
||||||
ServeAgent(keyring, channel)
|
|
||||||
channel.Close()
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
const channelType = "auth-agent@openssh.com"
|
|
||||||
|
|
||||||
// ForwardToRemote routes authentication requests to the ssh-agent
|
|
||||||
// process serving on the given unix socket.
|
|
||||||
func ForwardToRemote(client *ssh.Client, addr string) error {
|
|
||||||
channels := client.HandleChannelOpen(channelType)
|
|
||||||
if channels == nil {
|
|
||||||
return errors.New("agent: already have handler for " + channelType)
|
|
||||||
}
|
|
||||||
conn, err := net.Dial("unix", addr)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
conn.Close()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
for ch := range channels {
|
|
||||||
channel, reqs, err := ch.Accept()
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
go ssh.DiscardRequests(reqs)
|
|
||||||
go forwardUnixSocket(channel, addr)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func forwardUnixSocket(channel ssh.Channel, addr string) {
|
|
||||||
conn, err := net.Dial("unix", addr)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
wg.Add(2)
|
|
||||||
go func() {
|
|
||||||
io.Copy(conn, channel)
|
|
||||||
conn.(*net.UnixConn).CloseWrite()
|
|
||||||
wg.Done()
|
|
||||||
}()
|
|
||||||
go func() {
|
|
||||||
io.Copy(channel, conn)
|
|
||||||
channel.CloseWrite()
|
|
||||||
wg.Done()
|
|
||||||
}()
|
|
||||||
|
|
||||||
wg.Wait()
|
|
||||||
conn.Close()
|
|
||||||
channel.Close()
|
|
||||||
}
|
|
|
@ -1,183 +0,0 @@
|
||||||
// Copyright 2014 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package agent
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/subtle"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
|
||||||
)
|
|
||||||
|
|
||||||
type privKey struct {
|
|
||||||
signer ssh.Signer
|
|
||||||
comment string
|
|
||||||
}
|
|
||||||
|
|
||||||
type keyring struct {
|
|
||||||
mu sync.Mutex
|
|
||||||
keys []privKey
|
|
||||||
|
|
||||||
locked bool
|
|
||||||
passphrase []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
var errLocked = errors.New("agent: locked")
|
|
||||||
|
|
||||||
// NewKeyring returns an Agent that holds keys in memory. It is safe
|
|
||||||
// for concurrent use by multiple goroutines.
|
|
||||||
func NewKeyring() Agent {
|
|
||||||
return &keyring{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveAll removes all identities.
|
|
||||||
func (r *keyring) RemoveAll() error {
|
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
if r.locked {
|
|
||||||
return errLocked
|
|
||||||
}
|
|
||||||
|
|
||||||
r.keys = nil
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove removes all identities with the given public key.
|
|
||||||
func (r *keyring) Remove(key ssh.PublicKey) error {
|
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
if r.locked {
|
|
||||||
return errLocked
|
|
||||||
}
|
|
||||||
|
|
||||||
want := key.Marshal()
|
|
||||||
found := false
|
|
||||||
for i := 0; i < len(r.keys); {
|
|
||||||
if bytes.Equal(r.keys[i].signer.PublicKey().Marshal(), want) {
|
|
||||||
found = true
|
|
||||||
r.keys[i] = r.keys[len(r.keys)-1]
|
|
||||||
r.keys = r.keys[len(r.keys)-1:]
|
|
||||||
continue
|
|
||||||
} else {
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !found {
|
|
||||||
return errors.New("agent: key not found")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Lock locks the agent. Sign and Remove will fail, and List will empty an empty list.
|
|
||||||
func (r *keyring) Lock(passphrase []byte) error {
|
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
if r.locked {
|
|
||||||
return errLocked
|
|
||||||
}
|
|
||||||
|
|
||||||
r.locked = true
|
|
||||||
r.passphrase = passphrase
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Unlock undoes the effect of Lock
|
|
||||||
func (r *keyring) Unlock(passphrase []byte) error {
|
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
if !r.locked {
|
|
||||||
return errors.New("agent: not locked")
|
|
||||||
}
|
|
||||||
if len(passphrase) != len(r.passphrase) || 1 != subtle.ConstantTimeCompare(passphrase, r.passphrase) {
|
|
||||||
return fmt.Errorf("agent: incorrect passphrase")
|
|
||||||
}
|
|
||||||
|
|
||||||
r.locked = false
|
|
||||||
r.passphrase = nil
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// List returns the identities known to the agent.
|
|
||||||
func (r *keyring) List() ([]*Key, error) {
|
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
if r.locked {
|
|
||||||
// section 2.7: locked agents return empty.
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var ids []*Key
|
|
||||||
for _, k := range r.keys {
|
|
||||||
pub := k.signer.PublicKey()
|
|
||||||
ids = append(ids, &Key{
|
|
||||||
Format: pub.Type(),
|
|
||||||
Blob: pub.Marshal(),
|
|
||||||
Comment: k.comment})
|
|
||||||
}
|
|
||||||
return ids, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Insert adds a private key to the keyring. If a certificate
|
|
||||||
// is given, that certificate is added as public key.
|
|
||||||
func (r *keyring) Add(priv interface{}, cert *ssh.Certificate, comment string) error {
|
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
if r.locked {
|
|
||||||
return errLocked
|
|
||||||
}
|
|
||||||
signer, err := ssh.NewSignerFromKey(priv)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if cert != nil {
|
|
||||||
signer, err = ssh.NewCertSigner(cert, signer)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
r.keys = append(r.keys, privKey{signer, comment})
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sign returns a signature for the data.
|
|
||||||
func (r *keyring) Sign(key ssh.PublicKey, data []byte) (*ssh.Signature, error) {
|
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
if r.locked {
|
|
||||||
return nil, errLocked
|
|
||||||
}
|
|
||||||
|
|
||||||
wanted := key.Marshal()
|
|
||||||
for _, k := range r.keys {
|
|
||||||
if bytes.Equal(k.signer.PublicKey().Marshal(), wanted) {
|
|
||||||
return k.signer.Sign(rand.Reader, data)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil, errors.New("not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Signers returns signers for all the known keys.
|
|
||||||
func (r *keyring) Signers() ([]ssh.Signer, error) {
|
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
if r.locked {
|
|
||||||
return nil, errLocked
|
|
||||||
}
|
|
||||||
|
|
||||||
s := make([]ssh.Signer, len(r.keys))
|
|
||||||
for _, k := range r.keys {
|
|
||||||
s = append(s, k.signer)
|
|
||||||
}
|
|
||||||
return s, nil
|
|
||||||
}
|
|
|
@ -1,209 +0,0 @@
|
||||||
// Copyright 2012 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package agent
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/rsa"
|
|
||||||
"encoding/binary"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"log"
|
|
||||||
"math/big"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Server wraps an Agent and uses it to implement the agent side of
|
|
||||||
// the SSH-agent, wire protocol.
|
|
||||||
type server struct {
|
|
||||||
agent Agent
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *server) processRequestBytes(reqData []byte) []byte {
|
|
||||||
rep, err := s.processRequest(reqData)
|
|
||||||
if err != nil {
|
|
||||||
if err != errLocked {
|
|
||||||
// TODO(hanwen): provide better logging interface?
|
|
||||||
log.Printf("agent %d: %v", reqData[0], err)
|
|
||||||
}
|
|
||||||
return []byte{agentFailure}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err == nil && rep == nil {
|
|
||||||
return []byte{agentSuccess}
|
|
||||||
}
|
|
||||||
|
|
||||||
return ssh.Marshal(rep)
|
|
||||||
}
|
|
||||||
|
|
||||||
func marshalKey(k *Key) []byte {
|
|
||||||
var record struct {
|
|
||||||
Blob []byte
|
|
||||||
Comment string
|
|
||||||
}
|
|
||||||
record.Blob = k.Marshal()
|
|
||||||
record.Comment = k.Comment
|
|
||||||
|
|
||||||
return ssh.Marshal(&record)
|
|
||||||
}
|
|
||||||
|
|
||||||
type agentV1IdentityMsg struct {
|
|
||||||
Numkeys uint32 `sshtype:"2"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type agentRemoveIdentityMsg struct {
|
|
||||||
KeyBlob []byte `sshtype:"18"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type agentLockMsg struct {
|
|
||||||
Passphrase []byte `sshtype:"22"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type agentUnlockMsg struct {
|
|
||||||
Passphrase []byte `sshtype:"23"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *server) processRequest(data []byte) (interface{}, error) {
|
|
||||||
switch data[0] {
|
|
||||||
case agentRequestV1Identities:
|
|
||||||
return &agentV1IdentityMsg{0}, nil
|
|
||||||
case agentRemoveIdentity:
|
|
||||||
var req agentRemoveIdentityMsg
|
|
||||||
if err := ssh.Unmarshal(data, &req); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var wk wireKey
|
|
||||||
if err := ssh.Unmarshal(req.KeyBlob, &wk); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, s.agent.Remove(&Key{Format: wk.Format, Blob: req.KeyBlob})
|
|
||||||
|
|
||||||
case agentRemoveAllIdentities:
|
|
||||||
return nil, s.agent.RemoveAll()
|
|
||||||
|
|
||||||
case agentLock:
|
|
||||||
var req agentLockMsg
|
|
||||||
if err := ssh.Unmarshal(data, &req); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, s.agent.Lock(req.Passphrase)
|
|
||||||
|
|
||||||
case agentUnlock:
|
|
||||||
var req agentLockMsg
|
|
||||||
if err := ssh.Unmarshal(data, &req); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return nil, s.agent.Unlock(req.Passphrase)
|
|
||||||
|
|
||||||
case agentSignRequest:
|
|
||||||
var req signRequestAgentMsg
|
|
||||||
if err := ssh.Unmarshal(data, &req); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var wk wireKey
|
|
||||||
if err := ssh.Unmarshal(req.KeyBlob, &wk); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
k := &Key{
|
|
||||||
Format: wk.Format,
|
|
||||||
Blob: req.KeyBlob,
|
|
||||||
}
|
|
||||||
|
|
||||||
sig, err := s.agent.Sign(k, req.Data) // TODO(hanwen): flags.
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &signResponseAgentMsg{SigBlob: ssh.Marshal(sig)}, nil
|
|
||||||
case agentRequestIdentities:
|
|
||||||
keys, err := s.agent.List()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
rep := identitiesAnswerAgentMsg{
|
|
||||||
NumKeys: uint32(len(keys)),
|
|
||||||
}
|
|
||||||
for _, k := range keys {
|
|
||||||
rep.Keys = append(rep.Keys, marshalKey(k)...)
|
|
||||||
}
|
|
||||||
return rep, nil
|
|
||||||
case agentAddIdentity:
|
|
||||||
return nil, s.insertIdentity(data)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, fmt.Errorf("unknown opcode %d", data[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *server) insertIdentity(req []byte) error {
|
|
||||||
var record struct {
|
|
||||||
Type string `sshtype:"17"`
|
|
||||||
Rest []byte `ssh:"rest"`
|
|
||||||
}
|
|
||||||
if err := ssh.Unmarshal(req, &record); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
switch record.Type {
|
|
||||||
case ssh.KeyAlgoRSA:
|
|
||||||
var k rsaKeyMsg
|
|
||||||
if err := ssh.Unmarshal(req, &k); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
priv := rsa.PrivateKey{
|
|
||||||
PublicKey: rsa.PublicKey{
|
|
||||||
E: int(k.E.Int64()),
|
|
||||||
N: k.N,
|
|
||||||
},
|
|
||||||
D: k.D,
|
|
||||||
Primes: []*big.Int{k.P, k.Q},
|
|
||||||
}
|
|
||||||
priv.Precompute()
|
|
||||||
|
|
||||||
return s.agent.Add(&priv, nil, k.Comments)
|
|
||||||
}
|
|
||||||
return fmt.Errorf("not implemented: %s", record.Type)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ServeAgent serves the agent protocol on the given connection. It
|
|
||||||
// returns when an I/O error occurs.
|
|
||||||
func ServeAgent(agent Agent, c io.ReadWriter) error {
|
|
||||||
s := &server{agent}
|
|
||||||
|
|
||||||
var length [4]byte
|
|
||||||
for {
|
|
||||||
if _, err := io.ReadFull(c, length[:]); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
l := binary.BigEndian.Uint32(length[:])
|
|
||||||
if l > maxAgentResponseBytes {
|
|
||||||
// We also cap requests.
|
|
||||||
return fmt.Errorf("agent: request too large: %d", l)
|
|
||||||
}
|
|
||||||
|
|
||||||
req := make([]byte, l)
|
|
||||||
if _, err := io.ReadFull(c, req); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
repData := s.processRequestBytes(req)
|
|
||||||
if len(repData) > maxAgentResponseBytes {
|
|
||||||
return fmt.Errorf("agent: reply too large: %d bytes", len(repData))
|
|
||||||
}
|
|
||||||
|
|
||||||
binary.BigEndian.PutUint32(length[:], uint32(len(repData)))
|
|
||||||
if _, err := c.Write(length[:]); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if _, err := c.Write(repData); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,77 +0,0 @@
|
||||||
// Copyright 2012 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package agent
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestServer(t *testing.T) {
|
|
||||||
c1, c2, err := netPipe()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("netPipe: %v", err)
|
|
||||||
}
|
|
||||||
defer c1.Close()
|
|
||||||
defer c2.Close()
|
|
||||||
client := NewClient(c1)
|
|
||||||
|
|
||||||
go ServeAgent(NewKeyring(), c2)
|
|
||||||
|
|
||||||
testAgentInterface(t, client, testPrivateKeys["rsa"], nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLockServer(t *testing.T) {
|
|
||||||
testLockAgent(NewKeyring(), t)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSetupForwardAgent(t *testing.T) {
|
|
||||||
a, b, err := netPipe()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("netPipe: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
defer a.Close()
|
|
||||||
defer b.Close()
|
|
||||||
|
|
||||||
_, socket, cleanup := startAgent(t)
|
|
||||||
defer cleanup()
|
|
||||||
|
|
||||||
serverConf := ssh.ServerConfig{
|
|
||||||
NoClientAuth: true,
|
|
||||||
}
|
|
||||||
serverConf.AddHostKey(testSigners["rsa"])
|
|
||||||
incoming := make(chan *ssh.ServerConn, 1)
|
|
||||||
go func() {
|
|
||||||
conn, _, _, err := ssh.NewServerConn(a, &serverConf)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Server: %v", err)
|
|
||||||
}
|
|
||||||
incoming <- conn
|
|
||||||
}()
|
|
||||||
|
|
||||||
conf := ssh.ClientConfig{}
|
|
||||||
conn, chans, reqs, err := ssh.NewClientConn(b, "", &conf)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("NewClientConn: %v", err)
|
|
||||||
}
|
|
||||||
client := ssh.NewClient(conn, chans, reqs)
|
|
||||||
|
|
||||||
if err := ForwardToRemote(client, socket); err != nil {
|
|
||||||
t.Fatalf("SetupForwardAgent: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
server := <-incoming
|
|
||||||
ch, reqs, err := server.OpenChannel(channelType, nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("OpenChannel(%q): %v", channelType, err)
|
|
||||||
}
|
|
||||||
go ssh.DiscardRequests(reqs)
|
|
||||||
|
|
||||||
agentClient := NewClient(ch)
|
|
||||||
testAgentInterface(t, agentClient, testPrivateKeys["rsa"], nil)
|
|
||||||
conn.Close()
|
|
||||||
}
|
|
|
@ -1,64 +0,0 @@
|
||||||
// Copyright 2014 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
// IMPLEMENTOR NOTE: To avoid a package loop, this file is in three places:
|
|
||||||
// ssh/, ssh/agent, and ssh/test/. It should be kept in sync across all three
|
|
||||||
// instances.
|
|
||||||
|
|
||||||
package agent
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/rand"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
|
||||||
"golang.org/x/crypto/ssh/testdata"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
testPrivateKeys map[string]interface{}
|
|
||||||
testSigners map[string]ssh.Signer
|
|
||||||
testPublicKeys map[string]ssh.PublicKey
|
|
||||||
)
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
var err error
|
|
||||||
|
|
||||||
n := len(testdata.PEMBytes)
|
|
||||||
testPrivateKeys = make(map[string]interface{}, n)
|
|
||||||
testSigners = make(map[string]ssh.Signer, n)
|
|
||||||
testPublicKeys = make(map[string]ssh.PublicKey, n)
|
|
||||||
for t, k := range testdata.PEMBytes {
|
|
||||||
testPrivateKeys[t], err = ssh.ParseRawPrivateKey(k)
|
|
||||||
if err != nil {
|
|
||||||
panic(fmt.Sprintf("Unable to parse test key %s: %v", t, err))
|
|
||||||
}
|
|
||||||
testSigners[t], err = ssh.NewSignerFromKey(testPrivateKeys[t])
|
|
||||||
if err != nil {
|
|
||||||
panic(fmt.Sprintf("Unable to create signer for test key %s: %v", t, err))
|
|
||||||
}
|
|
||||||
testPublicKeys[t] = testSigners[t].PublicKey()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a cert and sign it for use in tests.
|
|
||||||
testCert := &ssh.Certificate{
|
|
||||||
Nonce: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil
|
|
||||||
ValidPrincipals: []string{"gopher1", "gopher2"}, // increases test coverage
|
|
||||||
ValidAfter: 0, // unix epoch
|
|
||||||
ValidBefore: ssh.CertTimeInfinity, // The end of currently representable time.
|
|
||||||
Reserved: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil
|
|
||||||
Key: testPublicKeys["ecdsa"],
|
|
||||||
SignatureKey: testPublicKeys["rsa"],
|
|
||||||
Permissions: ssh.Permissions{
|
|
||||||
CriticalOptions: map[string]string{},
|
|
||||||
Extensions: map[string]string{},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
testCert.SignCert(rand.Reader, testSigners["rsa"])
|
|
||||||
testPrivateKeys["cert"] = testPrivateKeys["ecdsa"]
|
|
||||||
testSigners["cert"], err = ssh.NewCertSigner(testCert, testSigners["ecdsa"])
|
|
||||||
if err != nil {
|
|
||||||
panic(fmt.Sprintf("Unable to create certificate signer: %v", err))
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,122 +0,0 @@
|
||||||
// Copyright 2013 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package ssh
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
type server struct {
|
|
||||||
*ServerConn
|
|
||||||
chans <-chan NewChannel
|
|
||||||
}
|
|
||||||
|
|
||||||
func newServer(c net.Conn, conf *ServerConfig) (*server, error) {
|
|
||||||
sconn, chans, reqs, err := NewServerConn(c, conf)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
go DiscardRequests(reqs)
|
|
||||||
return &server{sconn, chans}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *server) Accept() (NewChannel, error) {
|
|
||||||
n, ok := <-s.chans
|
|
||||||
if !ok {
|
|
||||||
return nil, io.EOF
|
|
||||||
}
|
|
||||||
return n, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func sshPipe() (Conn, *server, error) {
|
|
||||||
c1, c2, err := netPipe()
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
clientConf := ClientConfig{
|
|
||||||
User: "user",
|
|
||||||
}
|
|
||||||
serverConf := ServerConfig{
|
|
||||||
NoClientAuth: true,
|
|
||||||
}
|
|
||||||
serverConf.AddHostKey(testSigners["ecdsa"])
|
|
||||||
done := make(chan *server, 1)
|
|
||||||
go func() {
|
|
||||||
server, err := newServer(c2, &serverConf)
|
|
||||||
if err != nil {
|
|
||||||
done <- nil
|
|
||||||
}
|
|
||||||
done <- server
|
|
||||||
}()
|
|
||||||
|
|
||||||
client, _, reqs, err := NewClientConn(c1, "", &clientConf)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
server := <-done
|
|
||||||
if server == nil {
|
|
||||||
return nil, nil, errors.New("server handshake failed.")
|
|
||||||
}
|
|
||||||
go DiscardRequests(reqs)
|
|
||||||
|
|
||||||
return client, server, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkEndToEnd(b *testing.B) {
|
|
||||||
b.StopTimer()
|
|
||||||
|
|
||||||
client, server, err := sshPipe()
|
|
||||||
if err != nil {
|
|
||||||
b.Fatalf("sshPipe: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
defer client.Close()
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
size := (1 << 20)
|
|
||||||
input := make([]byte, size)
|
|
||||||
output := make([]byte, size)
|
|
||||||
b.SetBytes(int64(size))
|
|
||||||
done := make(chan int, 1)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
newCh, err := server.Accept()
|
|
||||||
if err != nil {
|
|
||||||
b.Fatalf("Client: %v", err)
|
|
||||||
}
|
|
||||||
ch, incoming, err := newCh.Accept()
|
|
||||||
go DiscardRequests(incoming)
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
if _, err := io.ReadFull(ch, output); err != nil {
|
|
||||||
b.Fatalf("ReadFull: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ch.Close()
|
|
||||||
done <- 1
|
|
||||||
}()
|
|
||||||
|
|
||||||
ch, in, err := client.OpenChannel("speed", nil)
|
|
||||||
if err != nil {
|
|
||||||
b.Fatalf("OpenChannel: %v", err)
|
|
||||||
}
|
|
||||||
go DiscardRequests(in)
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
b.StartTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
if _, err := ch.Write(input); err != nil {
|
|
||||||
b.Fatalf("WriteFull: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ch.Close()
|
|
||||||
b.StopTimer()
|
|
||||||
|
|
||||||
<-done
|
|
||||||
}
|
|
|
@ -1,98 +0,0 @@
|
||||||
// Copyright 2012 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package ssh
|
|
||||||
|
|
||||||
import (
|
|
||||||
"io"
|
|
||||||
"sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
// buffer provides a linked list buffer for data exchange
|
|
||||||
// between producer and consumer. Theoretically the buffer is
|
|
||||||
// of unlimited capacity as it does no allocation of its own.
|
|
||||||
type buffer struct {
|
|
||||||
// protects concurrent access to head, tail and closed
|
|
||||||
*sync.Cond
|
|
||||||
|
|
||||||
head *element // the buffer that will be read first
|
|
||||||
tail *element // the buffer that will be read last
|
|
||||||
|
|
||||||
closed bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// An element represents a single link in a linked list.
|
|
||||||
type element struct {
|
|
||||||
buf []byte
|
|
||||||
next *element
|
|
||||||
}
|
|
||||||
|
|
||||||
// newBuffer returns an empty buffer that is not closed.
|
|
||||||
func newBuffer() *buffer {
|
|
||||||
e := new(element)
|
|
||||||
b := &buffer{
|
|
||||||
Cond: newCond(),
|
|
||||||
head: e,
|
|
||||||
tail: e,
|
|
||||||
}
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
|
|
||||||
// write makes buf available for Read to receive.
|
|
||||||
// buf must not be modified after the call to write.
|
|
||||||
func (b *buffer) write(buf []byte) {
|
|
||||||
b.Cond.L.Lock()
|
|
||||||
e := &element{buf: buf}
|
|
||||||
b.tail.next = e
|
|
||||||
b.tail = e
|
|
||||||
b.Cond.Signal()
|
|
||||||
b.Cond.L.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// eof closes the buffer. Reads from the buffer once all
|
|
||||||
// the data has been consumed will receive os.EOF.
|
|
||||||
func (b *buffer) eof() error {
|
|
||||||
b.Cond.L.Lock()
|
|
||||||
b.closed = true
|
|
||||||
b.Cond.Signal()
|
|
||||||
b.Cond.L.Unlock()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read reads data from the internal buffer in buf. Reads will block
|
|
||||||
// if no data is available, or until the buffer is closed.
|
|
||||||
func (b *buffer) Read(buf []byte) (n int, err error) {
|
|
||||||
b.Cond.L.Lock()
|
|
||||||
defer b.Cond.L.Unlock()
|
|
||||||
|
|
||||||
for len(buf) > 0 {
|
|
||||||
// if there is data in b.head, copy it
|
|
||||||
if len(b.head.buf) > 0 {
|
|
||||||
r := copy(buf, b.head.buf)
|
|
||||||
buf, b.head.buf = buf[r:], b.head.buf[r:]
|
|
||||||
n += r
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// if there is a next buffer, make it the head
|
|
||||||
if len(b.head.buf) == 0 && b.head != b.tail {
|
|
||||||
b.head = b.head.next
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// if at least one byte has been copied, return
|
|
||||||
if n > 0 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
// if nothing was read, and there is nothing outstanding
|
|
||||||
// check to see if the buffer is closed.
|
|
||||||
if b.closed {
|
|
||||||
err = io.EOF
|
|
||||||
break
|
|
||||||
}
|
|
||||||
// out of buffers, wait for producer
|
|
||||||
b.Cond.Wait()
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
|
@ -1,87 +0,0 @@
|
||||||
// Copyright 2011 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package ssh
|
|
||||||
|
|
||||||
import (
|
|
||||||
"io"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
var alphabet = []byte("abcdefghijklmnopqrstuvwxyz")
|
|
||||||
|
|
||||||
func TestBufferReadwrite(t *testing.T) {
|
|
||||||
b := newBuffer()
|
|
||||||
b.write(alphabet[:10])
|
|
||||||
r, _ := b.Read(make([]byte, 10))
|
|
||||||
if r != 10 {
|
|
||||||
t.Fatalf("Expected written == read == 10, written: 10, read %d", r)
|
|
||||||
}
|
|
||||||
|
|
||||||
b = newBuffer()
|
|
||||||
b.write(alphabet[:5])
|
|
||||||
r, _ = b.Read(make([]byte, 10))
|
|
||||||
if r != 5 {
|
|
||||||
t.Fatalf("Expected written == read == 5, written: 5, read %d", r)
|
|
||||||
}
|
|
||||||
|
|
||||||
b = newBuffer()
|
|
||||||
b.write(alphabet[:10])
|
|
||||||
r, _ = b.Read(make([]byte, 5))
|
|
||||||
if r != 5 {
|
|
||||||
t.Fatalf("Expected written == 10, read == 5, written: 10, read %d", r)
|
|
||||||
}
|
|
||||||
|
|
||||||
b = newBuffer()
|
|
||||||
b.write(alphabet[:5])
|
|
||||||
b.write(alphabet[5:15])
|
|
||||||
r, _ = b.Read(make([]byte, 10))
|
|
||||||
r2, _ := b.Read(make([]byte, 10))
|
|
||||||
if r != 10 || r2 != 5 || 15 != r+r2 {
|
|
||||||
t.Fatal("Expected written == read == 15")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBufferClose(t *testing.T) {
|
|
||||||
b := newBuffer()
|
|
||||||
b.write(alphabet[:10])
|
|
||||||
b.eof()
|
|
||||||
_, err := b.Read(make([]byte, 5))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal("expected read of 5 to not return EOF")
|
|
||||||
}
|
|
||||||
b = newBuffer()
|
|
||||||
b.write(alphabet[:10])
|
|
||||||
b.eof()
|
|
||||||
r, err := b.Read(make([]byte, 5))
|
|
||||||
r2, err2 := b.Read(make([]byte, 10))
|
|
||||||
if r != 5 || r2 != 5 || err != nil || err2 != nil {
|
|
||||||
t.Fatal("expected reads of 5 and 5")
|
|
||||||
}
|
|
||||||
|
|
||||||
b = newBuffer()
|
|
||||||
b.write(alphabet[:10])
|
|
||||||
b.eof()
|
|
||||||
r, err = b.Read(make([]byte, 5))
|
|
||||||
r2, err2 = b.Read(make([]byte, 10))
|
|
||||||
r3, err3 := b.Read(make([]byte, 10))
|
|
||||||
if r != 5 || r2 != 5 || r3 != 0 || err != nil || err2 != nil || err3 != io.EOF {
|
|
||||||
t.Fatal("expected reads of 5 and 5 and 0, with EOF")
|
|
||||||
}
|
|
||||||
|
|
||||||
b = newBuffer()
|
|
||||||
b.write(make([]byte, 5))
|
|
||||||
b.write(make([]byte, 10))
|
|
||||||
b.eof()
|
|
||||||
r, err = b.Read(make([]byte, 9))
|
|
||||||
r2, err2 = b.Read(make([]byte, 3))
|
|
||||||
r3, err3 = b.Read(make([]byte, 3))
|
|
||||||
r4, err4 := b.Read(make([]byte, 10))
|
|
||||||
if err != nil || err2 != nil || err3 != nil || err4 != io.EOF {
|
|
||||||
t.Fatalf("Expected EOF on forth read only, err=%v, err2=%v, err3=%v, err4=%v", err, err2, err3, err4)
|
|
||||||
}
|
|
||||||
if r != 9 || r2 != 3 || r3 != 3 || r4 != 0 {
|
|
||||||
t.Fatal("Expected written == read == 15", r, r2, r3, r4)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,474 +0,0 @@
|
||||||
// Copyright 2012 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package ssh
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"sort"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// These constants from [PROTOCOL.certkeys] represent the algorithm names
|
|
||||||
// for certificate types supported by this package.
|
|
||||||
const (
|
|
||||||
CertAlgoRSAv01 = "ssh-rsa-cert-v01@openssh.com"
|
|
||||||
CertAlgoDSAv01 = "ssh-dss-cert-v01@openssh.com"
|
|
||||||
CertAlgoECDSA256v01 = "ecdsa-sha2-nistp256-cert-v01@openssh.com"
|
|
||||||
CertAlgoECDSA384v01 = "ecdsa-sha2-nistp384-cert-v01@openssh.com"
|
|
||||||
CertAlgoECDSA521v01 = "ecdsa-sha2-nistp521-cert-v01@openssh.com"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Certificate types distinguish between host and user
|
|
||||||
// certificates. The values can be set in the CertType field of
|
|
||||||
// Certificate.
|
|
||||||
const (
|
|
||||||
UserCert = 1
|
|
||||||
HostCert = 2
|
|
||||||
)
|
|
||||||
|
|
||||||
// Signature represents a cryptographic signature.
|
|
||||||
type Signature struct {
|
|
||||||
Format string
|
|
||||||
Blob []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
// CertTimeInfinity can be used for OpenSSHCertV01.ValidBefore to indicate that
|
|
||||||
// a certificate does not expire.
|
|
||||||
const CertTimeInfinity = 1<<64 - 1
|
|
||||||
|
|
||||||
// An Certificate represents an OpenSSH certificate as defined in
|
|
||||||
// [PROTOCOL.certkeys]?rev=1.8.
|
|
||||||
type Certificate struct {
|
|
||||||
Nonce []byte
|
|
||||||
Key PublicKey
|
|
||||||
Serial uint64
|
|
||||||
CertType uint32
|
|
||||||
KeyId string
|
|
||||||
ValidPrincipals []string
|
|
||||||
ValidAfter uint64
|
|
||||||
ValidBefore uint64
|
|
||||||
Permissions
|
|
||||||
Reserved []byte
|
|
||||||
SignatureKey PublicKey
|
|
||||||
Signature *Signature
|
|
||||||
}
|
|
||||||
|
|
||||||
// genericCertData holds the key-independent part of the certificate data.
|
|
||||||
// Overall, certificates contain an nonce, public key fields and
|
|
||||||
// key-independent fields.
|
|
||||||
type genericCertData struct {
|
|
||||||
Serial uint64
|
|
||||||
CertType uint32
|
|
||||||
KeyId string
|
|
||||||
ValidPrincipals []byte
|
|
||||||
ValidAfter uint64
|
|
||||||
ValidBefore uint64
|
|
||||||
CriticalOptions []byte
|
|
||||||
Extensions []byte
|
|
||||||
Reserved []byte
|
|
||||||
SignatureKey []byte
|
|
||||||
Signature []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func marshalStringList(namelist []string) []byte {
|
|
||||||
var to []byte
|
|
||||||
for _, name := range namelist {
|
|
||||||
s := struct{ N string }{name}
|
|
||||||
to = append(to, Marshal(&s)...)
|
|
||||||
}
|
|
||||||
return to
|
|
||||||
}
|
|
||||||
|
|
||||||
func marshalTuples(tups map[string]string) []byte {
|
|
||||||
keys := make([]string, 0, len(tups))
|
|
||||||
for k := range tups {
|
|
||||||
keys = append(keys, k)
|
|
||||||
}
|
|
||||||
sort.Strings(keys)
|
|
||||||
|
|
||||||
var r []byte
|
|
||||||
for _, k := range keys {
|
|
||||||
s := struct{ K, V string }{k, tups[k]}
|
|
||||||
r = append(r, Marshal(&s)...)
|
|
||||||
}
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseTuples(in []byte) (map[string]string, error) {
|
|
||||||
tups := map[string]string{}
|
|
||||||
var lastKey string
|
|
||||||
var haveLastKey bool
|
|
||||||
|
|
||||||
for len(in) > 0 {
|
|
||||||
nameBytes, rest, ok := parseString(in)
|
|
||||||
if !ok {
|
|
||||||
return nil, errShortRead
|
|
||||||
}
|
|
||||||
data, rest, ok := parseString(rest)
|
|
||||||
if !ok {
|
|
||||||
return nil, errShortRead
|
|
||||||
}
|
|
||||||
name := string(nameBytes)
|
|
||||||
|
|
||||||
// according to [PROTOCOL.certkeys], the names must be in
|
|
||||||
// lexical order.
|
|
||||||
if haveLastKey && name <= lastKey {
|
|
||||||
return nil, fmt.Errorf("ssh: certificate options are not in lexical order")
|
|
||||||
}
|
|
||||||
lastKey, haveLastKey = name, true
|
|
||||||
|
|
||||||
tups[name] = string(data)
|
|
||||||
in = rest
|
|
||||||
}
|
|
||||||
return tups, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseCert(in []byte, privAlgo string) (*Certificate, error) {
|
|
||||||
nonce, rest, ok := parseString(in)
|
|
||||||
if !ok {
|
|
||||||
return nil, errShortRead
|
|
||||||
}
|
|
||||||
|
|
||||||
key, rest, err := parsePubKey(rest, privAlgo)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var g genericCertData
|
|
||||||
if err := Unmarshal(rest, &g); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
c := &Certificate{
|
|
||||||
Nonce: nonce,
|
|
||||||
Key: key,
|
|
||||||
Serial: g.Serial,
|
|
||||||
CertType: g.CertType,
|
|
||||||
KeyId: g.KeyId,
|
|
||||||
ValidAfter: g.ValidAfter,
|
|
||||||
ValidBefore: g.ValidBefore,
|
|
||||||
}
|
|
||||||
|
|
||||||
for principals := g.ValidPrincipals; len(principals) > 0; {
|
|
||||||
principal, rest, ok := parseString(principals)
|
|
||||||
if !ok {
|
|
||||||
return nil, errShortRead
|
|
||||||
}
|
|
||||||
c.ValidPrincipals = append(c.ValidPrincipals, string(principal))
|
|
||||||
principals = rest
|
|
||||||
}
|
|
||||||
|
|
||||||
c.CriticalOptions, err = parseTuples(g.CriticalOptions)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
c.Extensions, err = parseTuples(g.Extensions)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
c.Reserved = g.Reserved
|
|
||||||
k, err := ParsePublicKey(g.SignatureKey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
c.SignatureKey = k
|
|
||||||
c.Signature, rest, ok = parseSignatureBody(g.Signature)
|
|
||||||
if !ok || len(rest) > 0 {
|
|
||||||
return nil, errors.New("ssh: signature parse error")
|
|
||||||
}
|
|
||||||
|
|
||||||
return c, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type openSSHCertSigner struct {
|
|
||||||
pub *Certificate
|
|
||||||
signer Signer
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewCertSigner returns a Signer that signs with the given Certificate, whose
|
|
||||||
// private key is held by signer. It returns an error if the public key in cert
|
|
||||||
// doesn't match the key used by signer.
|
|
||||||
func NewCertSigner(cert *Certificate, signer Signer) (Signer, error) {
|
|
||||||
if bytes.Compare(cert.Key.Marshal(), signer.PublicKey().Marshal()) != 0 {
|
|
||||||
return nil, errors.New("ssh: signer and cert have different public key")
|
|
||||||
}
|
|
||||||
|
|
||||||
return &openSSHCertSigner{cert, signer}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *openSSHCertSigner) Sign(rand io.Reader, data []byte) (*Signature, error) {
|
|
||||||
return s.signer.Sign(rand, data)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *openSSHCertSigner) PublicKey() PublicKey {
|
|
||||||
return s.pub
|
|
||||||
}
|
|
||||||
|
|
||||||
const sourceAddressCriticalOption = "source-address"
|
|
||||||
|
|
||||||
// CertChecker does the work of verifying a certificate. Its methods
|
|
||||||
// can be plugged into ClientConfig.HostKeyCallback and
|
|
||||||
// ServerConfig.PublicKeyCallback. For the CertChecker to work,
|
|
||||||
// minimally, the IsAuthority callback should be set.
|
|
||||||
type CertChecker struct {
|
|
||||||
// SupportedCriticalOptions lists the CriticalOptions that the
|
|
||||||
// server application layer understands. These are only used
|
|
||||||
// for user certificates.
|
|
||||||
SupportedCriticalOptions []string
|
|
||||||
|
|
||||||
// IsAuthority should return true if the key is recognized as
|
|
||||||
// an authority. This allows for certificates to be signed by other
|
|
||||||
// certificates.
|
|
||||||
IsAuthority func(auth PublicKey) bool
|
|
||||||
|
|
||||||
// Clock is used for verifying time stamps. If nil, time.Now
|
|
||||||
// is used.
|
|
||||||
Clock func() time.Time
|
|
||||||
|
|
||||||
// UserKeyFallback is called when CertChecker.Authenticate encounters a
|
|
||||||
// public key that is not a certificate. It must implement validation
|
|
||||||
// of user keys or else, if nil, all such keys are rejected.
|
|
||||||
UserKeyFallback func(conn ConnMetadata, key PublicKey) (*Permissions, error)
|
|
||||||
|
|
||||||
// HostKeyFallback is called when CertChecker.CheckHostKey encounters a
|
|
||||||
// public key that is not a certificate. It must implement host key
|
|
||||||
// validation or else, if nil, all such keys are rejected.
|
|
||||||
HostKeyFallback func(addr string, remote net.Addr, key PublicKey) error
|
|
||||||
|
|
||||||
// IsRevoked is called for each certificate so that revocation checking
|
|
||||||
// can be implemented. It should return true if the given certificate
|
|
||||||
// is revoked and false otherwise. If nil, no certificates are
|
|
||||||
// considered to have been revoked.
|
|
||||||
IsRevoked func(cert *Certificate) bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// CheckHostKey checks a host key certificate. This method can be
|
|
||||||
// plugged into ClientConfig.HostKeyCallback.
|
|
||||||
func (c *CertChecker) CheckHostKey(addr string, remote net.Addr, key PublicKey) error {
|
|
||||||
cert, ok := key.(*Certificate)
|
|
||||||
if !ok {
|
|
||||||
if c.HostKeyFallback != nil {
|
|
||||||
return c.HostKeyFallback(addr, remote, key)
|
|
||||||
}
|
|
||||||
return errors.New("ssh: non-certificate host key")
|
|
||||||
}
|
|
||||||
if cert.CertType != HostCert {
|
|
||||||
return fmt.Errorf("ssh: certificate presented as a host key has type %d", cert.CertType)
|
|
||||||
}
|
|
||||||
|
|
||||||
return c.CheckCert(addr, cert)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Authenticate checks a user certificate. Authenticate can be used as
|
|
||||||
// a value for ServerConfig.PublicKeyCallback.
|
|
||||||
func (c *CertChecker) Authenticate(conn ConnMetadata, pubKey PublicKey) (*Permissions, error) {
|
|
||||||
cert, ok := pubKey.(*Certificate)
|
|
||||||
if !ok {
|
|
||||||
if c.UserKeyFallback != nil {
|
|
||||||
return c.UserKeyFallback(conn, pubKey)
|
|
||||||
}
|
|
||||||
return nil, errors.New("ssh: normal key pairs not accepted")
|
|
||||||
}
|
|
||||||
|
|
||||||
if cert.CertType != UserCert {
|
|
||||||
return nil, fmt.Errorf("ssh: cert has type %d", cert.CertType)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := c.CheckCert(conn.User(), cert); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &cert.Permissions, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// CheckCert checks CriticalOptions, ValidPrincipals, revocation, timestamp and
|
|
||||||
// the signature of the certificate.
|
|
||||||
func (c *CertChecker) CheckCert(principal string, cert *Certificate) error {
|
|
||||||
if c.IsRevoked != nil && c.IsRevoked(cert) {
|
|
||||||
return fmt.Errorf("ssh: certicate serial %d revoked", cert.Serial)
|
|
||||||
}
|
|
||||||
|
|
||||||
for opt, _ := range cert.CriticalOptions {
|
|
||||||
// sourceAddressCriticalOption will be enforced by
|
|
||||||
// serverAuthenticate
|
|
||||||
if opt == sourceAddressCriticalOption {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
found := false
|
|
||||||
for _, supp := range c.SupportedCriticalOptions {
|
|
||||||
if supp == opt {
|
|
||||||
found = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !found {
|
|
||||||
return fmt.Errorf("ssh: unsupported critical option %q in certificate", opt)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(cert.ValidPrincipals) > 0 {
|
|
||||||
// By default, certs are valid for all users/hosts.
|
|
||||||
found := false
|
|
||||||
for _, p := range cert.ValidPrincipals {
|
|
||||||
if p == principal {
|
|
||||||
found = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !found {
|
|
||||||
return fmt.Errorf("ssh: principal %q not in the set of valid principals for given certificate: %q", principal, cert.ValidPrincipals)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !c.IsAuthority(cert.SignatureKey) {
|
|
||||||
return fmt.Errorf("ssh: certificate signed by unrecognized authority")
|
|
||||||
}
|
|
||||||
|
|
||||||
clock := c.Clock
|
|
||||||
if clock == nil {
|
|
||||||
clock = time.Now
|
|
||||||
}
|
|
||||||
|
|
||||||
unixNow := clock().Unix()
|
|
||||||
if after := int64(cert.ValidAfter); after < 0 || unixNow < int64(cert.ValidAfter) {
|
|
||||||
return fmt.Errorf("ssh: cert is not yet valid")
|
|
||||||
}
|
|
||||||
if before := int64(cert.ValidBefore); cert.ValidBefore != CertTimeInfinity && (unixNow >= before || before < 0) {
|
|
||||||
return fmt.Errorf("ssh: cert has expired")
|
|
||||||
}
|
|
||||||
if err := cert.SignatureKey.Verify(cert.bytesForSigning(), cert.Signature); err != nil {
|
|
||||||
return fmt.Errorf("ssh: certificate signature does not verify")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SignCert sets c.SignatureKey to the authority's public key and stores a
|
|
||||||
// Signature, by authority, in the certificate.
|
|
||||||
func (c *Certificate) SignCert(rand io.Reader, authority Signer) error {
|
|
||||||
c.Nonce = make([]byte, 32)
|
|
||||||
if _, err := io.ReadFull(rand, c.Nonce); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
c.SignatureKey = authority.PublicKey()
|
|
||||||
|
|
||||||
sig, err := authority.Sign(rand, c.bytesForSigning())
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
c.Signature = sig
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var certAlgoNames = map[string]string{
|
|
||||||
KeyAlgoRSA: CertAlgoRSAv01,
|
|
||||||
KeyAlgoDSA: CertAlgoDSAv01,
|
|
||||||
KeyAlgoECDSA256: CertAlgoECDSA256v01,
|
|
||||||
KeyAlgoECDSA384: CertAlgoECDSA384v01,
|
|
||||||
KeyAlgoECDSA521: CertAlgoECDSA521v01,
|
|
||||||
}
|
|
||||||
|
|
||||||
// certToPrivAlgo returns the underlying algorithm for a certificate algorithm.
|
|
||||||
// Panics if a non-certificate algorithm is passed.
|
|
||||||
func certToPrivAlgo(algo string) string {
|
|
||||||
for privAlgo, pubAlgo := range certAlgoNames {
|
|
||||||
if pubAlgo == algo {
|
|
||||||
return privAlgo
|
|
||||||
}
|
|
||||||
}
|
|
||||||
panic("unknown cert algorithm")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cert *Certificate) bytesForSigning() []byte {
|
|
||||||
c2 := *cert
|
|
||||||
c2.Signature = nil
|
|
||||||
out := c2.Marshal()
|
|
||||||
// Drop trailing signature length.
|
|
||||||
return out[:len(out)-4]
|
|
||||||
}
|
|
||||||
|
|
||||||
// Marshal serializes c into OpenSSH's wire format. It is part of the
|
|
||||||
// PublicKey interface.
|
|
||||||
func (c *Certificate) Marshal() []byte {
|
|
||||||
generic := genericCertData{
|
|
||||||
Serial: c.Serial,
|
|
||||||
CertType: c.CertType,
|
|
||||||
KeyId: c.KeyId,
|
|
||||||
ValidPrincipals: marshalStringList(c.ValidPrincipals),
|
|
||||||
ValidAfter: uint64(c.ValidAfter),
|
|
||||||
ValidBefore: uint64(c.ValidBefore),
|
|
||||||
CriticalOptions: marshalTuples(c.CriticalOptions),
|
|
||||||
Extensions: marshalTuples(c.Extensions),
|
|
||||||
Reserved: c.Reserved,
|
|
||||||
SignatureKey: c.SignatureKey.Marshal(),
|
|
||||||
}
|
|
||||||
if c.Signature != nil {
|
|
||||||
generic.Signature = Marshal(c.Signature)
|
|
||||||
}
|
|
||||||
genericBytes := Marshal(&generic)
|
|
||||||
keyBytes := c.Key.Marshal()
|
|
||||||
_, keyBytes, _ = parseString(keyBytes)
|
|
||||||
prefix := Marshal(&struct {
|
|
||||||
Name string
|
|
||||||
Nonce []byte
|
|
||||||
Key []byte `ssh:"rest"`
|
|
||||||
}{c.Type(), c.Nonce, keyBytes})
|
|
||||||
|
|
||||||
result := make([]byte, 0, len(prefix)+len(genericBytes))
|
|
||||||
result = append(result, prefix...)
|
|
||||||
result = append(result, genericBytes...)
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
// Type returns the key name. It is part of the PublicKey interface.
|
|
||||||
func (c *Certificate) Type() string {
|
|
||||||
algo, ok := certAlgoNames[c.Key.Type()]
|
|
||||||
if !ok {
|
|
||||||
panic("unknown cert key type")
|
|
||||||
}
|
|
||||||
return algo
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify verifies a signature against the certificate's public
|
|
||||||
// key. It is part of the PublicKey interface.
|
|
||||||
func (c *Certificate) Verify(data []byte, sig *Signature) error {
|
|
||||||
return c.Key.Verify(data, sig)
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseSignatureBody(in []byte) (out *Signature, rest []byte, ok bool) {
|
|
||||||
format, in, ok := parseString(in)
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
out = &Signature{
|
|
||||||
Format: string(format),
|
|
||||||
}
|
|
||||||
|
|
||||||
if out.Blob, in, ok = parseString(in); !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
return out, in, ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseSignature(in []byte) (out *Signature, rest []byte, ok bool) {
|
|
||||||
sigBytes, rest, ok := parseString(in)
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
out, trailing, ok := parseSignatureBody(sigBytes)
|
|
||||||
if !ok || len(trailing) > 0 {
|
|
||||||
return nil, nil, false
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
|
@ -1,156 +0,0 @@
|
||||||
// Copyright 2013 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package ssh
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"crypto/rand"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Cert generated by ssh-keygen 6.0p1 Debian-4.
|
|
||||||
// % ssh-keygen -s ca-key -I test user-key
|
|
||||||
var exampleSSHCert = `ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgb1srW/W3ZDjYAO45xLYAwzHBDLsJ4Ux6ICFIkTjb1LEAAAADAQABAAAAYQCkoR51poH0wE8w72cqSB8Sszx+vAhzcMdCO0wqHTj7UNENHWEXGrU0E0UQekD7U+yhkhtoyjbPOVIP7hNa6aRk/ezdh/iUnCIt4Jt1v3Z1h1P+hA4QuYFMHNB+rmjPwAcAAAAAAAAAAAAAAAEAAAAEdGVzdAAAAAAAAAAAAAAAAP//////////AAAAAAAAAIIAAAAVcGVybWl0LVgxMS1mb3J3YXJkaW5nAAAAAAAAABdwZXJtaXQtYWdlbnQtZm9yd2FyZGluZwAAAAAAAAAWcGVybWl0LXBvcnQtZm9yd2FyZGluZwAAAAAAAAAKcGVybWl0LXB0eQAAAAAAAAAOcGVybWl0LXVzZXItcmMAAAAAAAAAAAAAAHcAAAAHc3NoLXJzYQAAAAMBAAEAAABhANFS2kaktpSGc+CcmEKPyw9mJC4nZKxHKTgLVZeaGbFZOvJTNzBspQHdy7Q1uKSfktxpgjZnksiu/tFF9ngyY2KFoc+U88ya95IZUycBGCUbBQ8+bhDtw/icdDGQD5WnUwAAAG8AAAAHc3NoLXJzYQAAAGC8Y9Z2LQKhIhxf52773XaWrXdxP0t3GBVo4A10vUWiYoAGepr6rQIoGGXFxT4B9Gp+nEBJjOwKDXPrAevow0T9ca8gZN+0ykbhSrXLE5Ao48rqr3zP4O1/9P7e6gp0gw8=`
|
|
||||||
|
|
||||||
func TestParseCert(t *testing.T) {
|
|
||||||
authKeyBytes := []byte(exampleSSHCert)
|
|
||||||
|
|
||||||
key, _, _, rest, err := ParseAuthorizedKey(authKeyBytes)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("ParseAuthorizedKey: %v", err)
|
|
||||||
}
|
|
||||||
if len(rest) > 0 {
|
|
||||||
t.Errorf("rest: got %q, want empty", rest)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, ok := key.(*Certificate); !ok {
|
|
||||||
t.Fatalf("got %#v, want *Certificate", key)
|
|
||||||
}
|
|
||||||
|
|
||||||
marshaled := MarshalAuthorizedKey(key)
|
|
||||||
// Before comparison, remove the trailing newline that
|
|
||||||
// MarshalAuthorizedKey adds.
|
|
||||||
marshaled = marshaled[:len(marshaled)-1]
|
|
||||||
if !bytes.Equal(authKeyBytes, marshaled) {
|
|
||||||
t.Errorf("marshaled certificate does not match original: got %q, want %q", marshaled, authKeyBytes)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidateCert(t *testing.T) {
|
|
||||||
key, _, _, _, err := ParseAuthorizedKey([]byte(exampleSSHCert))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("ParseAuthorizedKey: %v", err)
|
|
||||||
}
|
|
||||||
validCert, ok := key.(*Certificate)
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("got %v (%T), want *Certificate", key, key)
|
|
||||||
}
|
|
||||||
checker := CertChecker{}
|
|
||||||
checker.IsAuthority = func(k PublicKey) bool {
|
|
||||||
return bytes.Equal(k.Marshal(), validCert.SignatureKey.Marshal())
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := checker.CheckCert("user", validCert); err != nil {
|
|
||||||
t.Errorf("Unable to validate certificate: %v", err)
|
|
||||||
}
|
|
||||||
invalidCert := &Certificate{
|
|
||||||
Key: testPublicKeys["rsa"],
|
|
||||||
SignatureKey: testPublicKeys["ecdsa"],
|
|
||||||
ValidBefore: CertTimeInfinity,
|
|
||||||
Signature: &Signature{},
|
|
||||||
}
|
|
||||||
if err := checker.CheckCert("user", invalidCert); err == nil {
|
|
||||||
t.Error("Invalid cert signature passed validation")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidateCertTime(t *testing.T) {
|
|
||||||
cert := Certificate{
|
|
||||||
ValidPrincipals: []string{"user"},
|
|
||||||
Key: testPublicKeys["rsa"],
|
|
||||||
ValidAfter: 50,
|
|
||||||
ValidBefore: 100,
|
|
||||||
}
|
|
||||||
|
|
||||||
cert.SignCert(rand.Reader, testSigners["ecdsa"])
|
|
||||||
|
|
||||||
for ts, ok := range map[int64]bool{
|
|
||||||
25: false,
|
|
||||||
50: true,
|
|
||||||
99: true,
|
|
||||||
100: false,
|
|
||||||
125: false,
|
|
||||||
} {
|
|
||||||
checker := CertChecker{
|
|
||||||
Clock: func() time.Time { return time.Unix(ts, 0) },
|
|
||||||
}
|
|
||||||
checker.IsAuthority = func(k PublicKey) bool {
|
|
||||||
return bytes.Equal(k.Marshal(),
|
|
||||||
testPublicKeys["ecdsa"].Marshal())
|
|
||||||
}
|
|
||||||
|
|
||||||
if v := checker.CheckCert("user", &cert); (v == nil) != ok {
|
|
||||||
t.Errorf("Authenticate(%d): %v", ts, v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(hanwen): tests for
|
|
||||||
//
|
|
||||||
// host keys:
|
|
||||||
// * fallbacks
|
|
||||||
|
|
||||||
func TestHostKeyCert(t *testing.T) {
|
|
||||||
cert := &Certificate{
|
|
||||||
ValidPrincipals: []string{"hostname", "hostname.domain"},
|
|
||||||
Key: testPublicKeys["rsa"],
|
|
||||||
ValidBefore: CertTimeInfinity,
|
|
||||||
CertType: HostCert,
|
|
||||||
}
|
|
||||||
cert.SignCert(rand.Reader, testSigners["ecdsa"])
|
|
||||||
|
|
||||||
checker := &CertChecker{
|
|
||||||
IsAuthority: func(p PublicKey) bool {
|
|
||||||
return bytes.Equal(testPublicKeys["ecdsa"].Marshal(), p.Marshal())
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
certSigner, err := NewCertSigner(cert, testSigners["rsa"])
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("NewCertSigner: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, name := range []string{"hostname", "otherhost"} {
|
|
||||||
c1, c2, err := netPipe()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("netPipe: %v", err)
|
|
||||||
}
|
|
||||||
defer c1.Close()
|
|
||||||
defer c2.Close()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
conf := ServerConfig{
|
|
||||||
NoClientAuth: true,
|
|
||||||
}
|
|
||||||
conf.AddHostKey(certSigner)
|
|
||||||
_, _, _, err := NewServerConn(c1, &conf)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("NewServerConn: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
config := &ClientConfig{
|
|
||||||
User: "user",
|
|
||||||
HostKeyCallback: checker.CheckHostKey,
|
|
||||||
}
|
|
||||||
_, _, _, err = NewClientConn(c2, name, config)
|
|
||||||
|
|
||||||
succeed := name == "hostname"
|
|
||||||
if (err == nil) != succeed {
|
|
||||||
t.Fatalf("NewClientConn(%q): %v", name, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,631 +0,0 @@
|
||||||
// Copyright 2011 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package ssh
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/binary"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"log"
|
|
||||||
"sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
minPacketLength = 9
|
|
||||||
// channelMaxPacket contains the maximum number of bytes that will be
|
|
||||||
// sent in a single packet. As per RFC 4253, section 6.1, 32k is also
|
|
||||||
// the minimum.
|
|
||||||
channelMaxPacket = 1 << 15
|
|
||||||
// We follow OpenSSH here.
|
|
||||||
channelWindowSize = 64 * channelMaxPacket
|
|
||||||
)
|
|
||||||
|
|
||||||
// NewChannel represents an incoming request to a channel. It must either be
|
|
||||||
// accepted for use by calling Accept, or rejected by calling Reject.
|
|
||||||
type NewChannel interface {
|
|
||||||
// Accept accepts the channel creation request. It returns the Channel
|
|
||||||
// and a Go channel containing SSH requests. The Go channel must be
|
|
||||||
// serviced otherwise the Channel will hang.
|
|
||||||
Accept() (Channel, <-chan *Request, error)
|
|
||||||
|
|
||||||
// Reject rejects the channel creation request. After calling
|
|
||||||
// this, no other methods on the Channel may be called.
|
|
||||||
Reject(reason RejectionReason, message string) error
|
|
||||||
|
|
||||||
// ChannelType returns the type of the channel, as supplied by the
|
|
||||||
// client.
|
|
||||||
ChannelType() string
|
|
||||||
|
|
||||||
// ExtraData returns the arbitrary payload for this channel, as supplied
|
|
||||||
// by the client. This data is specific to the channel type.
|
|
||||||
ExtraData() []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
// A Channel is an ordered, reliable, flow-controlled, duplex stream
|
|
||||||
// that is multiplexed over an SSH connection.
|
|
||||||
type Channel interface {
|
|
||||||
// Read reads up to len(data) bytes from the channel.
|
|
||||||
Read(data []byte) (int, error)
|
|
||||||
|
|
||||||
// Write writes len(data) bytes to the channel.
|
|
||||||
Write(data []byte) (int, error)
|
|
||||||
|
|
||||||
// Close signals end of channel use. No data may be sent after this
|
|
||||||
// call.
|
|
||||||
Close() error
|
|
||||||
|
|
||||||
// CloseWrite signals the end of sending in-band
|
|
||||||
// data. Requests may still be sent, and the other side may
|
|
||||||
// still send data
|
|
||||||
CloseWrite() error
|
|
||||||
|
|
||||||
// SendRequest sends a channel request. If wantReply is true,
|
|
||||||
// it will wait for a reply and return the result as a
|
|
||||||
// boolean, otherwise the return value will be false. Channel
|
|
||||||
// requests are out-of-band messages so they may be sent even
|
|
||||||
// if the data stream is closed or blocked by flow control.
|
|
||||||
SendRequest(name string, wantReply bool, payload []byte) (bool, error)
|
|
||||||
|
|
||||||
// Stderr returns an io.ReadWriter that writes to this channel
|
|
||||||
// with the extended data type set to stderr. Stderr may
|
|
||||||
// safely be read and written from a different goroutine than
|
|
||||||
// Read and Write respectively.
|
|
||||||
Stderr() io.ReadWriter
|
|
||||||
}
|
|
||||||
|
|
||||||
// Request is a request sent outside of the normal stream of
|
|
||||||
// data. Requests can either be specific to an SSH channel, or they
|
|
||||||
// can be global.
|
|
||||||
type Request struct {
|
|
||||||
Type string
|
|
||||||
WantReply bool
|
|
||||||
Payload []byte
|
|
||||||
|
|
||||||
ch *channel
|
|
||||||
mux *mux
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reply sends a response to a request. It must be called for all requests
|
|
||||||
// where WantReply is true and is a no-op otherwise. The payload argument is
|
|
||||||
// ignored for replies to channel-specific requests.
|
|
||||||
func (r *Request) Reply(ok bool, payload []byte) error {
|
|
||||||
if !r.WantReply {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.ch == nil {
|
|
||||||
return r.mux.ackRequest(ok, payload)
|
|
||||||
}
|
|
||||||
|
|
||||||
return r.ch.ackRequest(ok)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RejectionReason is an enumeration used when rejecting channel creation
|
|
||||||
// requests. See RFC 4254, section 5.1.
|
|
||||||
type RejectionReason uint32
|
|
||||||
|
|
||||||
const (
|
|
||||||
Prohibited RejectionReason = iota + 1
|
|
||||||
ConnectionFailed
|
|
||||||
UnknownChannelType
|
|
||||||
ResourceShortage
|
|
||||||
)
|
|
||||||
|
|
||||||
// String converts the rejection reason to human readable form.
|
|
||||||
func (r RejectionReason) String() string {
|
|
||||||
switch r {
|
|
||||||
case Prohibited:
|
|
||||||
return "administratively prohibited"
|
|
||||||
case ConnectionFailed:
|
|
||||||
return "connect failed"
|
|
||||||
case UnknownChannelType:
|
|
||||||
return "unknown channel type"
|
|
||||||
case ResourceShortage:
|
|
||||||
return "resource shortage"
|
|
||||||
}
|
|
||||||
return fmt.Sprintf("unknown reason %d", int(r))
|
|
||||||
}
|
|
||||||
|
|
||||||
func min(a uint32, b int) uint32 {
|
|
||||||
if a < uint32(b) {
|
|
||||||
return a
|
|
||||||
}
|
|
||||||
return uint32(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
type channelDirection uint8
|
|
||||||
|
|
||||||
const (
|
|
||||||
channelInbound channelDirection = iota
|
|
||||||
channelOutbound
|
|
||||||
)
|
|
||||||
|
|
||||||
// channel is an implementation of the Channel interface that works
|
|
||||||
// with the mux class.
|
|
||||||
type channel struct {
|
|
||||||
// R/O after creation
|
|
||||||
chanType string
|
|
||||||
extraData []byte
|
|
||||||
localId, remoteId uint32
|
|
||||||
|
|
||||||
// maxIncomingPayload and maxRemotePayload are the maximum
|
|
||||||
// payload sizes of normal and extended data packets for
|
|
||||||
// receiving and sending, respectively. The wire packet will
|
|
||||||
// be 9 or 13 bytes larger (excluding encryption overhead).
|
|
||||||
maxIncomingPayload uint32
|
|
||||||
maxRemotePayload uint32
|
|
||||||
|
|
||||||
mux *mux
|
|
||||||
|
|
||||||
// decided is set to true if an accept or reject message has been sent
|
|
||||||
// (for outbound channels) or received (for inbound channels).
|
|
||||||
decided bool
|
|
||||||
|
|
||||||
// direction contains either channelOutbound, for channels created
|
|
||||||
// locally, or channelInbound, for channels created by the peer.
|
|
||||||
direction channelDirection
|
|
||||||
|
|
||||||
// Pending internal channel messages.
|
|
||||||
msg chan interface{}
|
|
||||||
|
|
||||||
// Since requests have no ID, there can be only one request
|
|
||||||
// with WantReply=true outstanding. This lock is held by a
|
|
||||||
// goroutine that has such an outgoing request pending.
|
|
||||||
sentRequestMu sync.Mutex
|
|
||||||
|
|
||||||
incomingRequests chan *Request
|
|
||||||
|
|
||||||
sentEOF bool
|
|
||||||
|
|
||||||
// thread-safe data
|
|
||||||
remoteWin window
|
|
||||||
pending *buffer
|
|
||||||
extPending *buffer
|
|
||||||
|
|
||||||
// windowMu protects myWindow, the flow-control window.
|
|
||||||
windowMu sync.Mutex
|
|
||||||
myWindow uint32
|
|
||||||
|
|
||||||
// writeMu serializes calls to mux.conn.writePacket() and
|
|
||||||
// protects sentClose and packetPool. This mutex must be
|
|
||||||
// different from windowMu, as writePacket can block if there
|
|
||||||
// is a key exchange pending.
|
|
||||||
writeMu sync.Mutex
|
|
||||||
sentClose bool
|
|
||||||
|
|
||||||
// packetPool has a buffer for each extended channel ID to
|
|
||||||
// save allocations during writes.
|
|
||||||
packetPool map[uint32][]byte
|
|
||||||
}
|
|
||||||
|
|
||||||
// writePacket sends a packet. If the packet is a channel close, it updates
|
|
||||||
// sentClose. This method takes the lock c.writeMu.
|
|
||||||
func (c *channel) writePacket(packet []byte) error {
|
|
||||||
c.writeMu.Lock()
|
|
||||||
if c.sentClose {
|
|
||||||
c.writeMu.Unlock()
|
|
||||||
return io.EOF
|
|
||||||
}
|
|
||||||
c.sentClose = (packet[0] == msgChannelClose)
|
|
||||||
err := c.mux.conn.writePacket(packet)
|
|
||||||
c.writeMu.Unlock()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *channel) sendMessage(msg interface{}) error {
|
|
||||||
if debugMux {
|
|
||||||
log.Printf("send %d: %#v", c.mux.chanList.offset, msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
p := Marshal(msg)
|
|
||||||
binary.BigEndian.PutUint32(p[1:], c.remoteId)
|
|
||||||
return c.writePacket(p)
|
|
||||||
}
|
|
||||||
|
|
||||||
// WriteExtended writes data to a specific extended stream. These streams are
|
|
||||||
// used, for example, for stderr.
|
|
||||||
func (c *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err error) {
|
|
||||||
if c.sentEOF {
|
|
||||||
return 0, io.EOF
|
|
||||||
}
|
|
||||||
// 1 byte message type, 4 bytes remoteId, 4 bytes data length
|
|
||||||
opCode := byte(msgChannelData)
|
|
||||||
headerLength := uint32(9)
|
|
||||||
if extendedCode > 0 {
|
|
||||||
headerLength += 4
|
|
||||||
opCode = msgChannelExtendedData
|
|
||||||
}
|
|
||||||
|
|
||||||
c.writeMu.Lock()
|
|
||||||
packet := c.packetPool[extendedCode]
|
|
||||||
// We don't remove the buffer from packetPool, so
|
|
||||||
// WriteExtended calls from different goroutines will be
|
|
||||||
// flagged as errors by the race detector.
|
|
||||||
c.writeMu.Unlock()
|
|
||||||
|
|
||||||
for len(data) > 0 {
|
|
||||||
space := min(c.maxRemotePayload, len(data))
|
|
||||||
if space, err = c.remoteWin.reserve(space); err != nil {
|
|
||||||
return n, err
|
|
||||||
}
|
|
||||||
if want := headerLength + space; uint32(cap(packet)) < want {
|
|
||||||
packet = make([]byte, want)
|
|
||||||
} else {
|
|
||||||
packet = packet[:want]
|
|
||||||
}
|
|
||||||
|
|
||||||
todo := data[:space]
|
|
||||||
|
|
||||||
packet[0] = opCode
|
|
||||||
binary.BigEndian.PutUint32(packet[1:], c.remoteId)
|
|
||||||
if extendedCode > 0 {
|
|
||||||
binary.BigEndian.PutUint32(packet[5:], uint32(extendedCode))
|
|
||||||
}
|
|
||||||
binary.BigEndian.PutUint32(packet[headerLength-4:], uint32(len(todo)))
|
|
||||||
copy(packet[headerLength:], todo)
|
|
||||||
if err = c.writePacket(packet); err != nil {
|
|
||||||
return n, err
|
|
||||||
}
|
|
||||||
|
|
||||||
n += len(todo)
|
|
||||||
data = data[len(todo):]
|
|
||||||
}
|
|
||||||
|
|
||||||
c.writeMu.Lock()
|
|
||||||
c.packetPool[extendedCode] = packet
|
|
||||||
c.writeMu.Unlock()
|
|
||||||
|
|
||||||
return n, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *channel) handleData(packet []byte) error {
|
|
||||||
headerLen := 9
|
|
||||||
isExtendedData := packet[0] == msgChannelExtendedData
|
|
||||||
if isExtendedData {
|
|
||||||
headerLen = 13
|
|
||||||
}
|
|
||||||
if len(packet) < headerLen {
|
|
||||||
// malformed data packet
|
|
||||||
return parseError(packet[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
var extended uint32
|
|
||||||
if isExtendedData {
|
|
||||||
extended = binary.BigEndian.Uint32(packet[5:])
|
|
||||||
}
|
|
||||||
|
|
||||||
length := binary.BigEndian.Uint32(packet[headerLen-4 : headerLen])
|
|
||||||
if length == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if length > c.maxIncomingPayload {
|
|
||||||
// TODO(hanwen): should send Disconnect?
|
|
||||||
return errors.New("ssh: incoming packet exceeds maximum payload size")
|
|
||||||
}
|
|
||||||
|
|
||||||
data := packet[headerLen:]
|
|
||||||
if length != uint32(len(data)) {
|
|
||||||
return errors.New("ssh: wrong packet length")
|
|
||||||
}
|
|
||||||
|
|
||||||
c.windowMu.Lock()
|
|
||||||
if c.myWindow < length {
|
|
||||||
c.windowMu.Unlock()
|
|
||||||
// TODO(hanwen): should send Disconnect with reason?
|
|
||||||
return errors.New("ssh: remote side wrote too much")
|
|
||||||
}
|
|
||||||
c.myWindow -= length
|
|
||||||
c.windowMu.Unlock()
|
|
||||||
|
|
||||||
if extended == 1 {
|
|
||||||
c.extPending.write(data)
|
|
||||||
} else if extended > 0 {
|
|
||||||
// discard other extended data.
|
|
||||||
} else {
|
|
||||||
c.pending.write(data)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *channel) adjustWindow(n uint32) error {
|
|
||||||
c.windowMu.Lock()
|
|
||||||
// Since myWindow is managed on our side, and can never exceed
|
|
||||||
// the initial window setting, we don't worry about overflow.
|
|
||||||
c.myWindow += uint32(n)
|
|
||||||
c.windowMu.Unlock()
|
|
||||||
return c.sendMessage(windowAdjustMsg{
|
|
||||||
AdditionalBytes: uint32(n),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *channel) ReadExtended(data []byte, extended uint32) (n int, err error) {
|
|
||||||
switch extended {
|
|
||||||
case 1:
|
|
||||||
n, err = c.extPending.Read(data)
|
|
||||||
case 0:
|
|
||||||
n, err = c.pending.Read(data)
|
|
||||||
default:
|
|
||||||
return 0, fmt.Errorf("ssh: extended code %d unimplemented", extended)
|
|
||||||
}
|
|
||||||
|
|
||||||
if n > 0 {
|
|
||||||
err = c.adjustWindow(uint32(n))
|
|
||||||
// sendWindowAdjust can return io.EOF if the remote
|
|
||||||
// peer has closed the connection, however we want to
|
|
||||||
// defer forwarding io.EOF to the caller of Read until
|
|
||||||
// the buffer has been drained.
|
|
||||||
if n > 0 && err == io.EOF {
|
|
||||||
err = nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return n, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *channel) close() {
|
|
||||||
c.pending.eof()
|
|
||||||
c.extPending.eof()
|
|
||||||
close(c.msg)
|
|
||||||
close(c.incomingRequests)
|
|
||||||
c.writeMu.Lock()
|
|
||||||
// This is not necesary for a normal channel teardown, but if
|
|
||||||
// there was another error, it is.
|
|
||||||
c.sentClose = true
|
|
||||||
c.writeMu.Unlock()
|
|
||||||
// Unblock writers.
|
|
||||||
c.remoteWin.close()
|
|
||||||
}
|
|
||||||
|
|
||||||
// responseMessageReceived is called when a success or failure message is
|
|
||||||
// received on a channel to check that such a message is reasonable for the
|
|
||||||
// given channel.
|
|
||||||
func (c *channel) responseMessageReceived() error {
|
|
||||||
if c.direction == channelInbound {
|
|
||||||
return errors.New("ssh: channel response message received on inbound channel")
|
|
||||||
}
|
|
||||||
if c.decided {
|
|
||||||
return errors.New("ssh: duplicate response received for channel")
|
|
||||||
}
|
|
||||||
c.decided = true
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *channel) handlePacket(packet []byte) error {
|
|
||||||
switch packet[0] {
|
|
||||||
case msgChannelData, msgChannelExtendedData:
|
|
||||||
return c.handleData(packet)
|
|
||||||
case msgChannelClose:
|
|
||||||
c.sendMessage(channelCloseMsg{PeersId: c.remoteId})
|
|
||||||
c.mux.chanList.remove(c.localId)
|
|
||||||
c.close()
|
|
||||||
return nil
|
|
||||||
case msgChannelEOF:
|
|
||||||
// RFC 4254 is mute on how EOF affects dataExt messages but
|
|
||||||
// it is logical to signal EOF at the same time.
|
|
||||||
c.extPending.eof()
|
|
||||||
c.pending.eof()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
decoded, err := decode(packet)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
switch msg := decoded.(type) {
|
|
||||||
case *channelOpenFailureMsg:
|
|
||||||
if err := c.responseMessageReceived(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
c.mux.chanList.remove(msg.PeersId)
|
|
||||||
c.msg <- msg
|
|
||||||
case *channelOpenConfirmMsg:
|
|
||||||
if err := c.responseMessageReceived(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
|
|
||||||
return fmt.Errorf("ssh: invalid MaxPacketSize %d from peer", msg.MaxPacketSize)
|
|
||||||
}
|
|
||||||
c.remoteId = msg.MyId
|
|
||||||
c.maxRemotePayload = msg.MaxPacketSize
|
|
||||||
c.remoteWin.add(msg.MyWindow)
|
|
||||||
c.msg <- msg
|
|
||||||
case *windowAdjustMsg:
|
|
||||||
if !c.remoteWin.add(msg.AdditionalBytes) {
|
|
||||||
return fmt.Errorf("ssh: invalid window update for %d bytes", msg.AdditionalBytes)
|
|
||||||
}
|
|
||||||
case *channelRequestMsg:
|
|
||||||
req := Request{
|
|
||||||
Type: msg.Request,
|
|
||||||
WantReply: msg.WantReply,
|
|
||||||
Payload: msg.RequestSpecificData,
|
|
||||||
ch: c,
|
|
||||||
}
|
|
||||||
|
|
||||||
c.incomingRequests <- &req
|
|
||||||
default:
|
|
||||||
c.msg <- msg
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mux) newChannel(chanType string, direction channelDirection, extraData []byte) *channel {
|
|
||||||
ch := &channel{
|
|
||||||
remoteWin: window{Cond: newCond()},
|
|
||||||
myWindow: channelWindowSize,
|
|
||||||
pending: newBuffer(),
|
|
||||||
extPending: newBuffer(),
|
|
||||||
direction: direction,
|
|
||||||
incomingRequests: make(chan *Request, 16),
|
|
||||||
msg: make(chan interface{}, 16),
|
|
||||||
chanType: chanType,
|
|
||||||
extraData: extraData,
|
|
||||||
mux: m,
|
|
||||||
packetPool: make(map[uint32][]byte),
|
|
||||||
}
|
|
||||||
ch.localId = m.chanList.add(ch)
|
|
||||||
return ch
|
|
||||||
}
|
|
||||||
|
|
||||||
var errUndecided = errors.New("ssh: must Accept or Reject channel")
|
|
||||||
var errDecidedAlready = errors.New("ssh: can call Accept or Reject only once")
|
|
||||||
|
|
||||||
type extChannel struct {
|
|
||||||
code uint32
|
|
||||||
ch *channel
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *extChannel) Write(data []byte) (n int, err error) {
|
|
||||||
return e.ch.WriteExtended(data, e.code)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *extChannel) Read(data []byte) (n int, err error) {
|
|
||||||
return e.ch.ReadExtended(data, e.code)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *channel) Accept() (Channel, <-chan *Request, error) {
|
|
||||||
if c.decided {
|
|
||||||
return nil, nil, errDecidedAlready
|
|
||||||
}
|
|
||||||
c.maxIncomingPayload = channelMaxPacket
|
|
||||||
confirm := channelOpenConfirmMsg{
|
|
||||||
PeersId: c.remoteId,
|
|
||||||
MyId: c.localId,
|
|
||||||
MyWindow: c.myWindow,
|
|
||||||
MaxPacketSize: c.maxIncomingPayload,
|
|
||||||
}
|
|
||||||
c.decided = true
|
|
||||||
if err := c.sendMessage(confirm); err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return c, c.incomingRequests, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ch *channel) Reject(reason RejectionReason, message string) error {
|
|
||||||
if ch.decided {
|
|
||||||
return errDecidedAlready
|
|
||||||
}
|
|
||||||
reject := channelOpenFailureMsg{
|
|
||||||
PeersId: ch.remoteId,
|
|
||||||
Reason: reason,
|
|
||||||
Message: message,
|
|
||||||
Language: "en",
|
|
||||||
}
|
|
||||||
ch.decided = true
|
|
||||||
return ch.sendMessage(reject)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ch *channel) Read(data []byte) (int, error) {
|
|
||||||
if !ch.decided {
|
|
||||||
return 0, errUndecided
|
|
||||||
}
|
|
||||||
return ch.ReadExtended(data, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ch *channel) Write(data []byte) (int, error) {
|
|
||||||
if !ch.decided {
|
|
||||||
return 0, errUndecided
|
|
||||||
}
|
|
||||||
return ch.WriteExtended(data, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ch *channel) CloseWrite() error {
|
|
||||||
if !ch.decided {
|
|
||||||
return errUndecided
|
|
||||||
}
|
|
||||||
ch.sentEOF = true
|
|
||||||
return ch.sendMessage(channelEOFMsg{
|
|
||||||
PeersId: ch.remoteId})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ch *channel) Close() error {
|
|
||||||
if !ch.decided {
|
|
||||||
return errUndecided
|
|
||||||
}
|
|
||||||
|
|
||||||
return ch.sendMessage(channelCloseMsg{
|
|
||||||
PeersId: ch.remoteId})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extended returns an io.ReadWriter that sends and receives data on the given,
|
|
||||||
// SSH extended stream. Such streams are used, for example, for stderr.
|
|
||||||
func (ch *channel) Extended(code uint32) io.ReadWriter {
|
|
||||||
if !ch.decided {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return &extChannel{code, ch}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ch *channel) Stderr() io.ReadWriter {
|
|
||||||
return ch.Extended(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ch *channel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) {
|
|
||||||
if !ch.decided {
|
|
||||||
return false, errUndecided
|
|
||||||
}
|
|
||||||
|
|
||||||
if wantReply {
|
|
||||||
ch.sentRequestMu.Lock()
|
|
||||||
defer ch.sentRequestMu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := channelRequestMsg{
|
|
||||||
PeersId: ch.remoteId,
|
|
||||||
Request: name,
|
|
||||||
WantReply: wantReply,
|
|
||||||
RequestSpecificData: payload,
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := ch.sendMessage(msg); err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if wantReply {
|
|
||||||
m, ok := (<-ch.msg)
|
|
||||||
if !ok {
|
|
||||||
return false, io.EOF
|
|
||||||
}
|
|
||||||
switch m.(type) {
|
|
||||||
case *channelRequestFailureMsg:
|
|
||||||
return false, nil
|
|
||||||
case *channelRequestSuccessMsg:
|
|
||||||
return true, nil
|
|
||||||
default:
|
|
||||||
return false, fmt.Errorf("ssh: unexpected response to channel request: %#v", m)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ackRequest either sends an ack or nack to the channel request.
|
|
||||||
func (ch *channel) ackRequest(ok bool) error {
|
|
||||||
if !ch.decided {
|
|
||||||
return errUndecided
|
|
||||||
}
|
|
||||||
|
|
||||||
var msg interface{}
|
|
||||||
if !ok {
|
|
||||||
msg = channelRequestFailureMsg{
|
|
||||||
PeersId: ch.remoteId,
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
msg = channelRequestSuccessMsg{
|
|
||||||
PeersId: ch.remoteId,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ch.sendMessage(msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ch *channel) ChannelType() string {
|
|
||||||
return ch.chanType
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ch *channel) ExtraData() []byte {
|
|
||||||
return ch.extraData
|
|
||||||
}
|
|
|
@ -1,344 +0,0 @@
|
||||||
// Copyright 2011 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package ssh
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/aes"
|
|
||||||
"crypto/cipher"
|
|
||||||
"crypto/rc4"
|
|
||||||
"crypto/subtle"
|
|
||||||
"encoding/binary"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"hash"
|
|
||||||
"io"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
packetSizeMultiple = 16 // TODO(huin) this should be determined by the cipher.
|
|
||||||
|
|
||||||
// RFC 4253 section 6.1 defines a minimum packet size of 32768 that implementations
|
|
||||||
// MUST be able to process (plus a few more kilobytes for padding and mac). The RFC
|
|
||||||
// indicates implementations SHOULD be able to handle larger packet sizes, but then
|
|
||||||
// waffles on about reasonable limits.
|
|
||||||
//
|
|
||||||
// OpenSSH caps their maxPacket at 256kB so we choose to do
|
|
||||||
// the same. maxPacket is also used to ensure that uint32
|
|
||||||
// length fields do not overflow, so it should remain well
|
|
||||||
// below 4G.
|
|
||||||
maxPacket = 256 * 1024
|
|
||||||
)
|
|
||||||
|
|
||||||
// noneCipher implements cipher.Stream and provides no encryption. It is used
|
|
||||||
// by the transport before the first key-exchange.
|
|
||||||
type noneCipher struct{}
|
|
||||||
|
|
||||||
func (c noneCipher) XORKeyStream(dst, src []byte) {
|
|
||||||
copy(dst, src)
|
|
||||||
}
|
|
||||||
|
|
||||||
func newAESCTR(key, iv []byte) (cipher.Stream, error) {
|
|
||||||
c, err := aes.NewCipher(key)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return cipher.NewCTR(c, iv), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func newRC4(key, iv []byte) (cipher.Stream, error) {
|
|
||||||
return rc4.NewCipher(key)
|
|
||||||
}
|
|
||||||
|
|
||||||
type streamCipherMode struct {
|
|
||||||
keySize int
|
|
||||||
ivSize int
|
|
||||||
skip int
|
|
||||||
createFunc func(key, iv []byte) (cipher.Stream, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *streamCipherMode) createStream(key, iv []byte) (cipher.Stream, error) {
|
|
||||||
if len(key) < c.keySize {
|
|
||||||
panic("ssh: key length too small for cipher")
|
|
||||||
}
|
|
||||||
if len(iv) < c.ivSize {
|
|
||||||
panic("ssh: iv too small for cipher")
|
|
||||||
}
|
|
||||||
|
|
||||||
stream, err := c.createFunc(key[:c.keySize], iv[:c.ivSize])
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var streamDump []byte
|
|
||||||
if c.skip > 0 {
|
|
||||||
streamDump = make([]byte, 512)
|
|
||||||
}
|
|
||||||
|
|
||||||
for remainingToDump := c.skip; remainingToDump > 0; {
|
|
||||||
dumpThisTime := remainingToDump
|
|
||||||
if dumpThisTime > len(streamDump) {
|
|
||||||
dumpThisTime = len(streamDump)
|
|
||||||
}
|
|
||||||
stream.XORKeyStream(streamDump[:dumpThisTime], streamDump[:dumpThisTime])
|
|
||||||
remainingToDump -= dumpThisTime
|
|
||||||
}
|
|
||||||
|
|
||||||
return stream, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// cipherModes documents properties of supported ciphers. Ciphers not included
|
|
||||||
// are not supported and will not be negotiated, even if explicitly requested in
|
|
||||||
// ClientConfig.Crypto.Ciphers.
|
|
||||||
var cipherModes = map[string]*streamCipherMode{
|
|
||||||
// Ciphers from RFC4344, which introduced many CTR-based ciphers. Algorithms
|
|
||||||
// are defined in the order specified in the RFC.
|
|
||||||
"aes128-ctr": {16, aes.BlockSize, 0, newAESCTR},
|
|
||||||
"aes192-ctr": {24, aes.BlockSize, 0, newAESCTR},
|
|
||||||
"aes256-ctr": {32, aes.BlockSize, 0, newAESCTR},
|
|
||||||
|
|
||||||
// Ciphers from RFC4345, which introduces security-improved arcfour ciphers.
|
|
||||||
// They are defined in the order specified in the RFC.
|
|
||||||
"arcfour128": {16, 0, 1536, newRC4},
|
|
||||||
"arcfour256": {32, 0, 1536, newRC4},
|
|
||||||
|
|
||||||
// Cipher defined in RFC 4253, which describes SSH Transport Layer Protocol.
|
|
||||||
// Note that this cipher is not safe, as stated in RFC 4253: "Arcfour (and
|
|
||||||
// RC4) has problems with weak keys, and should be used with caution."
|
|
||||||
// RFC4345 introduces improved versions of Arcfour.
|
|
||||||
"arcfour": {16, 0, 0, newRC4},
|
|
||||||
|
|
||||||
// AES-GCM is not a stream cipher, so it is constructed with a
|
|
||||||
// special case. If we add any more non-stream ciphers, we
|
|
||||||
// should invest a cleaner way to do this.
|
|
||||||
gcmCipherID: {16, 12, 0, nil},
|
|
||||||
}
|
|
||||||
|
|
||||||
// prefixLen is the length of the packet prefix that contains the packet length
|
|
||||||
// and number of padding bytes.
|
|
||||||
const prefixLen = 5
|
|
||||||
|
|
||||||
// streamPacketCipher is a packetCipher using a stream cipher.
|
|
||||||
type streamPacketCipher struct {
|
|
||||||
mac hash.Hash
|
|
||||||
cipher cipher.Stream
|
|
||||||
|
|
||||||
// The following members are to avoid per-packet allocations.
|
|
||||||
prefix [prefixLen]byte
|
|
||||||
seqNumBytes [4]byte
|
|
||||||
padding [2 * packetSizeMultiple]byte
|
|
||||||
packetData []byte
|
|
||||||
macResult []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
// readPacket reads and decrypt a single packet from the reader argument.
|
|
||||||
func (s *streamPacketCipher) readPacket(seqNum uint32, r io.Reader) ([]byte, error) {
|
|
||||||
if _, err := io.ReadFull(r, s.prefix[:]); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
s.cipher.XORKeyStream(s.prefix[:], s.prefix[:])
|
|
||||||
length := binary.BigEndian.Uint32(s.prefix[0:4])
|
|
||||||
paddingLength := uint32(s.prefix[4])
|
|
||||||
|
|
||||||
var macSize uint32
|
|
||||||
if s.mac != nil {
|
|
||||||
s.mac.Reset()
|
|
||||||
binary.BigEndian.PutUint32(s.seqNumBytes[:], seqNum)
|
|
||||||
s.mac.Write(s.seqNumBytes[:])
|
|
||||||
s.mac.Write(s.prefix[:])
|
|
||||||
macSize = uint32(s.mac.Size())
|
|
||||||
}
|
|
||||||
|
|
||||||
if length <= paddingLength+1 {
|
|
||||||
return nil, errors.New("ssh: invalid packet length, packet too small")
|
|
||||||
}
|
|
||||||
|
|
||||||
if length > maxPacket {
|
|
||||||
return nil, errors.New("ssh: invalid packet length, packet too large")
|
|
||||||
}
|
|
||||||
|
|
||||||
// the maxPacket check above ensures that length-1+macSize
|
|
||||||
// does not overflow.
|
|
||||||
if uint32(cap(s.packetData)) < length-1+macSize {
|
|
||||||
s.packetData = make([]byte, length-1+macSize)
|
|
||||||
} else {
|
|
||||||
s.packetData = s.packetData[:length-1+macSize]
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := io.ReadFull(r, s.packetData); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
mac := s.packetData[length-1:]
|
|
||||||
data := s.packetData[:length-1]
|
|
||||||
s.cipher.XORKeyStream(data, data)
|
|
||||||
|
|
||||||
if s.mac != nil {
|
|
||||||
s.mac.Write(data)
|
|
||||||
s.macResult = s.mac.Sum(s.macResult[:0])
|
|
||||||
if subtle.ConstantTimeCompare(s.macResult, mac) != 1 {
|
|
||||||
return nil, errors.New("ssh: MAC failure")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return s.packetData[:length-paddingLength-1], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// writePacket encrypts and sends a packet of data to the writer argument
|
|
||||||
func (s *streamPacketCipher) writePacket(seqNum uint32, w io.Writer, rand io.Reader, packet []byte) error {
|
|
||||||
if len(packet) > maxPacket {
|
|
||||||
return errors.New("ssh: packet too large")
|
|
||||||
}
|
|
||||||
|
|
||||||
paddingLength := packetSizeMultiple - (prefixLen+len(packet))%packetSizeMultiple
|
|
||||||
if paddingLength < 4 {
|
|
||||||
paddingLength += packetSizeMultiple
|
|
||||||
}
|
|
||||||
|
|
||||||
length := len(packet) + 1 + paddingLength
|
|
||||||
binary.BigEndian.PutUint32(s.prefix[:], uint32(length))
|
|
||||||
s.prefix[4] = byte(paddingLength)
|
|
||||||
padding := s.padding[:paddingLength]
|
|
||||||
if _, err := io.ReadFull(rand, padding); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.mac != nil {
|
|
||||||
s.mac.Reset()
|
|
||||||
binary.BigEndian.PutUint32(s.seqNumBytes[:], seqNum)
|
|
||||||
s.mac.Write(s.seqNumBytes[:])
|
|
||||||
s.mac.Write(s.prefix[:])
|
|
||||||
s.mac.Write(packet)
|
|
||||||
s.mac.Write(padding)
|
|
||||||
}
|
|
||||||
|
|
||||||
s.cipher.XORKeyStream(s.prefix[:], s.prefix[:])
|
|
||||||
s.cipher.XORKeyStream(packet, packet)
|
|
||||||
s.cipher.XORKeyStream(padding, padding)
|
|
||||||
|
|
||||||
if _, err := w.Write(s.prefix[:]); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if _, err := w.Write(packet); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if _, err := w.Write(padding); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.mac != nil {
|
|
||||||
s.macResult = s.mac.Sum(s.macResult[:0])
|
|
||||||
if _, err := w.Write(s.macResult); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type gcmCipher struct {
|
|
||||||
aead cipher.AEAD
|
|
||||||
prefix [4]byte
|
|
||||||
iv []byte
|
|
||||||
buf []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func newGCMCipher(iv, key, macKey []byte) (packetCipher, error) {
|
|
||||||
c, err := aes.NewCipher(key)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
aead, err := cipher.NewGCM(c)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &gcmCipher{
|
|
||||||
aead: aead,
|
|
||||||
iv: iv,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
const gcmTagSize = 16
|
|
||||||
|
|
||||||
func (c *gcmCipher) writePacket(seqNum uint32, w io.Writer, rand io.Reader, packet []byte) error {
|
|
||||||
// Pad out to multiple of 16 bytes. This is different from the
|
|
||||||
// stream cipher because that encrypts the length too.
|
|
||||||
padding := byte(packetSizeMultiple - (1+len(packet))%packetSizeMultiple)
|
|
||||||
if padding < 4 {
|
|
||||||
padding += packetSizeMultiple
|
|
||||||
}
|
|
||||||
|
|
||||||
length := uint32(len(packet) + int(padding) + 1)
|
|
||||||
binary.BigEndian.PutUint32(c.prefix[:], length)
|
|
||||||
if _, err := w.Write(c.prefix[:]); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if cap(c.buf) < int(length) {
|
|
||||||
c.buf = make([]byte, length)
|
|
||||||
} else {
|
|
||||||
c.buf = c.buf[:length]
|
|
||||||
}
|
|
||||||
|
|
||||||
c.buf[0] = padding
|
|
||||||
copy(c.buf[1:], packet)
|
|
||||||
if _, err := io.ReadFull(rand, c.buf[1+len(packet):]); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
c.buf = c.aead.Seal(c.buf[:0], c.iv, c.buf, c.prefix[:])
|
|
||||||
if _, err := w.Write(c.buf); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
c.incIV()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *gcmCipher) incIV() {
|
|
||||||
for i := 4 + 7; i >= 4; i-- {
|
|
||||||
c.iv[i]++
|
|
||||||
if c.iv[i] != 0 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *gcmCipher) readPacket(seqNum uint32, r io.Reader) ([]byte, error) {
|
|
||||||
if _, err := io.ReadFull(r, c.prefix[:]); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
length := binary.BigEndian.Uint32(c.prefix[:])
|
|
||||||
if length > maxPacket {
|
|
||||||
return nil, errors.New("ssh: max packet length exceeded.")
|
|
||||||
}
|
|
||||||
|
|
||||||
if cap(c.buf) < int(length+gcmTagSize) {
|
|
||||||
c.buf = make([]byte, length+gcmTagSize)
|
|
||||||
} else {
|
|
||||||
c.buf = c.buf[:length+gcmTagSize]
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := io.ReadFull(r, c.buf); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
plain, err := c.aead.Open(c.buf[:0], c.iv, c.buf, c.prefix[:])
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
c.incIV()
|
|
||||||
|
|
||||||
padding := plain[0]
|
|
||||||
if padding < 4 || padding >= 20 {
|
|
||||||
return nil, fmt.Errorf("ssh: illegal padding %d", padding)
|
|
||||||
}
|
|
||||||
|
|
||||||
if int(padding+1) >= len(plain) {
|
|
||||||
return nil, fmt.Errorf("ssh: padding %d too large", padding)
|
|
||||||
}
|
|
||||||
plain = plain[1 : length-uint32(padding)]
|
|
||||||
return plain, nil
|
|
||||||
}
|
|
|
@ -1,59 +0,0 @@
|
||||||
// Copyright 2011 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package ssh
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"crypto"
|
|
||||||
"crypto/rand"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestDefaultCiphersExist(t *testing.T) {
|
|
||||||
for _, cipherAlgo := range supportedCiphers {
|
|
||||||
if _, ok := cipherModes[cipherAlgo]; !ok {
|
|
||||||
t.Errorf("default cipher %q is unknown", cipherAlgo)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPacketCiphers(t *testing.T) {
|
|
||||||
for cipher := range cipherModes {
|
|
||||||
kr := &kexResult{Hash: crypto.SHA1}
|
|
||||||
algs := directionAlgorithms{
|
|
||||||
Cipher: cipher,
|
|
||||||
MAC: "hmac-sha1",
|
|
||||||
Compression: "none",
|
|
||||||
}
|
|
||||||
client, err := newPacketCipher(clientKeys, algs, kr)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("newPacketCipher(client, %q): %v", cipher, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
server, err := newPacketCipher(clientKeys, algs, kr)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("newPacketCipher(client, %q): %v", cipher, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
want := "bla bla"
|
|
||||||
input := []byte(want)
|
|
||||||
buf := &bytes.Buffer{}
|
|
||||||
if err := client.writePacket(0, buf, rand.Reader, input); err != nil {
|
|
||||||
t.Errorf("writePacket(%q): %v", cipher, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
packet, err := server.readPacket(0, buf)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("readPacket(%q): %v", cipher, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if string(packet) != want {
|
|
||||||
t.Errorf("roundtrip(%q): got %q, want %q", cipher, packet, want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,202 +0,0 @@
|
||||||
// Copyright 2011 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package ssh
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Client implements a traditional SSH client that supports shells,
|
|
||||||
// subprocesses, port forwarding and tunneled dialing.
|
|
||||||
type Client struct {
|
|
||||||
Conn
|
|
||||||
|
|
||||||
forwards forwardList // forwarded tcpip connections from the remote side
|
|
||||||
mu sync.Mutex
|
|
||||||
channelHandlers map[string]chan NewChannel
|
|
||||||
}
|
|
||||||
|
|
||||||
// HandleChannelOpen returns a channel on which NewChannel requests
|
|
||||||
// for the given type are sent. If the type already is being handled,
|
|
||||||
// nil is returned. The channel is closed when the connection is closed.
|
|
||||||
func (c *Client) HandleChannelOpen(channelType string) <-chan NewChannel {
|
|
||||||
c.mu.Lock()
|
|
||||||
defer c.mu.Unlock()
|
|
||||||
if c.channelHandlers == nil {
|
|
||||||
// The SSH channel has been closed.
|
|
||||||
c := make(chan NewChannel)
|
|
||||||
close(c)
|
|
||||||
return c
|
|
||||||
}
|
|
||||||
|
|
||||||
ch := c.channelHandlers[channelType]
|
|
||||||
if ch != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
ch = make(chan NewChannel, 16)
|
|
||||||
c.channelHandlers[channelType] = ch
|
|
||||||
return ch
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewClient creates a Client on top of the given connection.
|
|
||||||
func NewClient(c Conn, chans <-chan NewChannel, reqs <-chan *Request) *Client {
|
|
||||||
conn := &Client{
|
|
||||||
Conn: c,
|
|
||||||
channelHandlers: make(map[string]chan NewChannel, 1),
|
|
||||||
}
|
|
||||||
|
|
||||||
go conn.handleGlobalRequests(reqs)
|
|
||||||
go conn.handleChannelOpens(chans)
|
|
||||||
go func() {
|
|
||||||
conn.Wait()
|
|
||||||
conn.forwards.closeAll()
|
|
||||||
}()
|
|
||||||
go conn.forwards.handleChannels(conn.HandleChannelOpen("forwarded-tcpip"))
|
|
||||||
return conn
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewClientConn establishes an authenticated SSH connection using c
|
|
||||||
// as the underlying transport. The Request and NewChannel channels
|
|
||||||
// must be serviced or the connection will hang.
|
|
||||||
func NewClientConn(c net.Conn, addr string, config *ClientConfig) (Conn, <-chan NewChannel, <-chan *Request, error) {
|
|
||||||
fullConf := *config
|
|
||||||
fullConf.SetDefaults()
|
|
||||||
conn := &connection{
|
|
||||||
sshConn: sshConn{conn: c},
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := conn.clientHandshake(addr, &fullConf); err != nil {
|
|
||||||
c.Close()
|
|
||||||
return nil, nil, nil, fmt.Errorf("ssh: handshake failed: %v", err)
|
|
||||||
}
|
|
||||||
conn.mux = newMux(conn.transport)
|
|
||||||
return conn, conn.mux.incomingChannels, conn.mux.incomingRequests, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// clientHandshake performs the client side key exchange. See RFC 4253 Section
|
|
||||||
// 7.
|
|
||||||
func (c *connection) clientHandshake(dialAddress string, config *ClientConfig) error {
|
|
||||||
c.clientVersion = []byte(packageVersion)
|
|
||||||
if config.ClientVersion != "" {
|
|
||||||
c.clientVersion = []byte(config.ClientVersion)
|
|
||||||
}
|
|
||||||
|
|
||||||
var err error
|
|
||||||
c.serverVersion, err = exchangeVersions(c.sshConn.conn, c.clientVersion)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
c.transport = newClientTransport(
|
|
||||||
newTransport(c.sshConn.conn, config.Rand, true /* is client */),
|
|
||||||
c.clientVersion, c.serverVersion, config, dialAddress, c.sshConn.RemoteAddr())
|
|
||||||
if err := c.transport.requestKeyChange(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if packet, err := c.transport.readPacket(); err != nil {
|
|
||||||
return err
|
|
||||||
} else if packet[0] != msgNewKeys {
|
|
||||||
return unexpectedMessageError(msgNewKeys, packet[0])
|
|
||||||
}
|
|
||||||
return c.clientAuthenticate(config)
|
|
||||||
}
|
|
||||||
|
|
||||||
// verifyHostKeySignature verifies the host key obtained in the key
|
|
||||||
// exchange.
|
|
||||||
func verifyHostKeySignature(hostKey PublicKey, result *kexResult) error {
|
|
||||||
sig, rest, ok := parseSignatureBody(result.Signature)
|
|
||||||
if len(rest) > 0 || !ok {
|
|
||||||
return errors.New("ssh: signature parse error")
|
|
||||||
}
|
|
||||||
|
|
||||||
return hostKey.Verify(result.H, sig)
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewSession opens a new Session for this client. (A session is a remote
|
|
||||||
// execution of a program.)
|
|
||||||
func (c *Client) NewSession() (*Session, error) {
|
|
||||||
ch, in, err := c.OpenChannel("session", nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return newSession(ch, in)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) handleGlobalRequests(incoming <-chan *Request) {
|
|
||||||
for r := range incoming {
|
|
||||||
// This handles keepalive messages and matches
|
|
||||||
// the behaviour of OpenSSH.
|
|
||||||
r.Reply(false, nil)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleChannelOpens channel open messages from the remote side.
|
|
||||||
func (c *Client) handleChannelOpens(in <-chan NewChannel) {
|
|
||||||
for ch := range in {
|
|
||||||
c.mu.Lock()
|
|
||||||
handler := c.channelHandlers[ch.ChannelType()]
|
|
||||||
c.mu.Unlock()
|
|
||||||
|
|
||||||
if handler != nil {
|
|
||||||
handler <- ch
|
|
||||||
} else {
|
|
||||||
ch.Reject(UnknownChannelType, fmt.Sprintf("unknown channel type: %v", ch.ChannelType()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
c.mu.Lock()
|
|
||||||
for _, ch := range c.channelHandlers {
|
|
||||||
close(ch)
|
|
||||||
}
|
|
||||||
c.channelHandlers = nil
|
|
||||||
c.mu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Dial starts a client connection to the given SSH server. It is a
|
|
||||||
// convenience function that connects to the given network address,
|
|
||||||
// initiates the SSH handshake, and then sets up a Client. For access
|
|
||||||
// to incoming channels and requests, use net.Dial with NewClientConn
|
|
||||||
// instead.
|
|
||||||
func Dial(network, addr string, config *ClientConfig) (*Client, error) {
|
|
||||||
conn, err := net.Dial(network, addr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
c, chans, reqs, err := NewClientConn(conn, addr, config)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return NewClient(c, chans, reqs), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// A ClientConfig structure is used to configure a Client. It must not be
|
|
||||||
// modified after having been passed to an SSH function.
|
|
||||||
type ClientConfig struct {
|
|
||||||
// Config contains configuration that is shared between clients and
|
|
||||||
// servers.
|
|
||||||
Config
|
|
||||||
|
|
||||||
// User contains the username to authenticate as.
|
|
||||||
User string
|
|
||||||
|
|
||||||
// Auth contains possible authentication methods to use with the
|
|
||||||
// server. Only the first instance of a particular RFC 4252 method will
|
|
||||||
// be used during authentication.
|
|
||||||
Auth []AuthMethod
|
|
||||||
|
|
||||||
// HostKeyCallback, if not nil, is called during the cryptographic
|
|
||||||
// handshake to validate the server's host key. A nil HostKeyCallback
|
|
||||||
// implies that all host keys are accepted.
|
|
||||||
HostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error
|
|
||||||
|
|
||||||
// ClientVersion contains the version identification string that will
|
|
||||||
// be used for the connection. If empty, a reasonable default is used.
|
|
||||||
ClientVersion string
|
|
||||||
}
|
|
|
@ -1,441 +0,0 @@
|
||||||
// Copyright 2011 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package ssh
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
)
|
|
||||||
|
|
||||||
// clientAuthenticate authenticates with the remote server. See RFC 4252.
|
|
||||||
func (c *connection) clientAuthenticate(config *ClientConfig) error {
|
|
||||||
// initiate user auth session
|
|
||||||
if err := c.transport.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth})); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
packet, err := c.transport.readPacket()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
var serviceAccept serviceAcceptMsg
|
|
||||||
if err := Unmarshal(packet, &serviceAccept); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// during the authentication phase the client first attempts the "none" method
|
|
||||||
// then any untried methods suggested by the server.
|
|
||||||
tried := make(map[string]bool)
|
|
||||||
var lastMethods []string
|
|
||||||
for auth := AuthMethod(new(noneAuth)); auth != nil; {
|
|
||||||
ok, methods, err := auth.auth(c.transport.getSessionID(), config.User, c.transport, config.Rand)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if ok {
|
|
||||||
// success
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
tried[auth.method()] = true
|
|
||||||
if methods == nil {
|
|
||||||
methods = lastMethods
|
|
||||||
}
|
|
||||||
lastMethods = methods
|
|
||||||
|
|
||||||
auth = nil
|
|
||||||
|
|
||||||
findNext:
|
|
||||||
for _, a := range config.Auth {
|
|
||||||
candidateMethod := a.method()
|
|
||||||
if tried[candidateMethod] {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
for _, meth := range methods {
|
|
||||||
if meth == candidateMethod {
|
|
||||||
auth = a
|
|
||||||
break findNext
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return fmt.Errorf("ssh: unable to authenticate, attempted methods %v, no supported methods remain", keys(tried))
|
|
||||||
}
|
|
||||||
|
|
||||||
func keys(m map[string]bool) []string {
|
|
||||||
s := make([]string, 0, len(m))
|
|
||||||
|
|
||||||
for key := range m {
|
|
||||||
s = append(s, key)
|
|
||||||
}
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// An AuthMethod represents an instance of an RFC 4252 authentication method.
|
|
||||||
type AuthMethod interface {
|
|
||||||
// auth authenticates user over transport t.
|
|
||||||
// Returns true if authentication is successful.
|
|
||||||
// If authentication is not successful, a []string of alternative
|
|
||||||
// method names is returned. If the slice is nil, it will be ignored
|
|
||||||
// and the previous set of possible methods will be reused.
|
|
||||||
auth(session []byte, user string, p packetConn, rand io.Reader) (bool, []string, error)
|
|
||||||
|
|
||||||
// method returns the RFC 4252 method name.
|
|
||||||
method() string
|
|
||||||
}
|
|
||||||
|
|
||||||
// "none" authentication, RFC 4252 section 5.2.
|
|
||||||
type noneAuth int
|
|
||||||
|
|
||||||
func (n *noneAuth) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) {
|
|
||||||
if err := c.writePacket(Marshal(&userAuthRequestMsg{
|
|
||||||
User: user,
|
|
||||||
Service: serviceSSH,
|
|
||||||
Method: "none",
|
|
||||||
})); err != nil {
|
|
||||||
return false, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return handleAuthResponse(c)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *noneAuth) method() string {
|
|
||||||
return "none"
|
|
||||||
}
|
|
||||||
|
|
||||||
// passwordCallback is an AuthMethod that fetches the password through
|
|
||||||
// a function call, e.g. by prompting the user.
|
|
||||||
type passwordCallback func() (password string, err error)
|
|
||||||
|
|
||||||
func (cb passwordCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) {
|
|
||||||
type passwordAuthMsg struct {
|
|
||||||
User string `sshtype:"50"`
|
|
||||||
Service string
|
|
||||||
Method string
|
|
||||||
Reply bool
|
|
||||||
Password string
|
|
||||||
}
|
|
||||||
|
|
||||||
pw, err := cb()
|
|
||||||
// REVIEW NOTE: is there a need to support skipping a password attempt?
|
|
||||||
// The program may only find out that the user doesn't have a password
|
|
||||||
// when prompting.
|
|
||||||
if err != nil {
|
|
||||||
return false, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := c.writePacket(Marshal(&passwordAuthMsg{
|
|
||||||
User: user,
|
|
||||||
Service: serviceSSH,
|
|
||||||
Method: cb.method(),
|
|
||||||
Reply: false,
|
|
||||||
Password: pw,
|
|
||||||
})); err != nil {
|
|
||||||
return false, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return handleAuthResponse(c)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cb passwordCallback) method() string {
|
|
||||||
return "password"
|
|
||||||
}
|
|
||||||
|
|
||||||
// Password returns an AuthMethod using the given password.
|
|
||||||
func Password(secret string) AuthMethod {
|
|
||||||
return passwordCallback(func() (string, error) { return secret, nil })
|
|
||||||
}
|
|
||||||
|
|
||||||
// PasswordCallback returns an AuthMethod that uses a callback for
|
|
||||||
// fetching a password.
|
|
||||||
func PasswordCallback(prompt func() (secret string, err error)) AuthMethod {
|
|
||||||
return passwordCallback(prompt)
|
|
||||||
}
|
|
||||||
|
|
||||||
type publickeyAuthMsg struct {
|
|
||||||
User string `sshtype:"50"`
|
|
||||||
Service string
|
|
||||||
Method string
|
|
||||||
// HasSig indicates to the receiver packet that the auth request is signed and
|
|
||||||
// should be used for authentication of the request.
|
|
||||||
HasSig bool
|
|
||||||
Algoname string
|
|
||||||
PubKey []byte
|
|
||||||
// Sig is tagged with "rest" so Marshal will exclude it during
|
|
||||||
// validateKey
|
|
||||||
Sig []byte `ssh:"rest"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// publicKeyCallback is an AuthMethod that uses a set of key
|
|
||||||
// pairs for authentication.
|
|
||||||
type publicKeyCallback func() ([]Signer, error)
|
|
||||||
|
|
||||||
func (cb publicKeyCallback) method() string {
|
|
||||||
return "publickey"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) {
|
|
||||||
// Authentication is performed in two stages. The first stage sends an
|
|
||||||
// enquiry to test if each key is acceptable to the remote. The second
|
|
||||||
// stage attempts to authenticate with the valid keys obtained in the
|
|
||||||
// first stage.
|
|
||||||
|
|
||||||
signers, err := cb()
|
|
||||||
if err != nil {
|
|
||||||
return false, nil, err
|
|
||||||
}
|
|
||||||
var validKeys []Signer
|
|
||||||
for _, signer := range signers {
|
|
||||||
if ok, err := validateKey(signer.PublicKey(), user, c); ok {
|
|
||||||
validKeys = append(validKeys, signer)
|
|
||||||
} else {
|
|
||||||
if err != nil {
|
|
||||||
return false, nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// methods that may continue if this auth is not successful.
|
|
||||||
var methods []string
|
|
||||||
for _, signer := range validKeys {
|
|
||||||
pub := signer.PublicKey()
|
|
||||||
|
|
||||||
pubKey := pub.Marshal()
|
|
||||||
sign, err := signer.Sign(rand, buildDataSignedForAuth(session, userAuthRequestMsg{
|
|
||||||
User: user,
|
|
||||||
Service: serviceSSH,
|
|
||||||
Method: cb.method(),
|
|
||||||
}, []byte(pub.Type()), pubKey))
|
|
||||||
if err != nil {
|
|
||||||
return false, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// manually wrap the serialized signature in a string
|
|
||||||
s := Marshal(sign)
|
|
||||||
sig := make([]byte, stringLength(len(s)))
|
|
||||||
marshalString(sig, s)
|
|
||||||
msg := publickeyAuthMsg{
|
|
||||||
User: user,
|
|
||||||
Service: serviceSSH,
|
|
||||||
Method: cb.method(),
|
|
||||||
HasSig: true,
|
|
||||||
Algoname: pub.Type(),
|
|
||||||
PubKey: pubKey,
|
|
||||||
Sig: sig,
|
|
||||||
}
|
|
||||||
p := Marshal(&msg)
|
|
||||||
if err := c.writePacket(p); err != nil {
|
|
||||||
return false, nil, err
|
|
||||||
}
|
|
||||||
var success bool
|
|
||||||
success, methods, err = handleAuthResponse(c)
|
|
||||||
if err != nil {
|
|
||||||
return false, nil, err
|
|
||||||
}
|
|
||||||
if success {
|
|
||||||
return success, methods, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false, methods, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// validateKey validates the key provided is acceptable to the server.
|
|
||||||
func validateKey(key PublicKey, user string, c packetConn) (bool, error) {
|
|
||||||
pubKey := key.Marshal()
|
|
||||||
msg := publickeyAuthMsg{
|
|
||||||
User: user,
|
|
||||||
Service: serviceSSH,
|
|
||||||
Method: "publickey",
|
|
||||||
HasSig: false,
|
|
||||||
Algoname: key.Type(),
|
|
||||||
PubKey: pubKey,
|
|
||||||
}
|
|
||||||
if err := c.writePacket(Marshal(&msg)); err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return confirmKeyAck(key, c)
|
|
||||||
}
|
|
||||||
|
|
||||||
func confirmKeyAck(key PublicKey, c packetConn) (bool, error) {
|
|
||||||
pubKey := key.Marshal()
|
|
||||||
algoname := key.Type()
|
|
||||||
|
|
||||||
for {
|
|
||||||
packet, err := c.readPacket()
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
switch packet[0] {
|
|
||||||
case msgUserAuthBanner:
|
|
||||||
// TODO(gpaul): add callback to present the banner to the user
|
|
||||||
case msgUserAuthPubKeyOk:
|
|
||||||
var msg userAuthPubKeyOkMsg
|
|
||||||
if err := Unmarshal(packet, &msg); err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
if msg.Algo != algoname || !bytes.Equal(msg.PubKey, pubKey) {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
return true, nil
|
|
||||||
case msgUserAuthFailure:
|
|
||||||
return false, nil
|
|
||||||
default:
|
|
||||||
return false, unexpectedMessageError(msgUserAuthSuccess, packet[0])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// PublicKeys returns an AuthMethod that uses the given key
|
|
||||||
// pairs.
|
|
||||||
func PublicKeys(signers ...Signer) AuthMethod {
|
|
||||||
return publicKeyCallback(func() ([]Signer, error) { return signers, nil })
|
|
||||||
}
|
|
||||||
|
|
||||||
// PublicKeysCallback returns an AuthMethod that runs the given
|
|
||||||
// function to obtain a list of key pairs.
|
|
||||||
func PublicKeysCallback(getSigners func() (signers []Signer, err error)) AuthMethod {
|
|
||||||
return publicKeyCallback(getSigners)
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleAuthResponse returns whether the preceding authentication request succeeded
|
|
||||||
// along with a list of remaining authentication methods to try next and
|
|
||||||
// an error if an unexpected response was received.
|
|
||||||
func handleAuthResponse(c packetConn) (bool, []string, error) {
|
|
||||||
for {
|
|
||||||
packet, err := c.readPacket()
|
|
||||||
if err != nil {
|
|
||||||
return false, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
switch packet[0] {
|
|
||||||
case msgUserAuthBanner:
|
|
||||||
// TODO: add callback to present the banner to the user
|
|
||||||
case msgUserAuthFailure:
|
|
||||||
var msg userAuthFailureMsg
|
|
||||||
if err := Unmarshal(packet, &msg); err != nil {
|
|
||||||
return false, nil, err
|
|
||||||
}
|
|
||||||
return false, msg.Methods, nil
|
|
||||||
case msgUserAuthSuccess:
|
|
||||||
return true, nil, nil
|
|
||||||
case msgDisconnect:
|
|
||||||
return false, nil, io.EOF
|
|
||||||
default:
|
|
||||||
return false, nil, unexpectedMessageError(msgUserAuthSuccess, packet[0])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// KeyboardInteractiveChallenge should print questions, optionally
|
|
||||||
// disabling echoing (e.g. for passwords), and return all the answers.
|
|
||||||
// Challenge may be called multiple times in a single session. After
|
|
||||||
// successful authentication, the server may send a challenge with no
|
|
||||||
// questions, for which the user and instruction messages should be
|
|
||||||
// printed. RFC 4256 section 3.3 details how the UI should behave for
|
|
||||||
// both CLI and GUI environments.
|
|
||||||
type KeyboardInteractiveChallenge func(user, instruction string, questions []string, echos []bool) (answers []string, err error)
|
|
||||||
|
|
||||||
// KeyboardInteractive returns a AuthMethod using a prompt/response
|
|
||||||
// sequence controlled by the server.
|
|
||||||
func KeyboardInteractive(challenge KeyboardInteractiveChallenge) AuthMethod {
|
|
||||||
return challenge
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cb KeyboardInteractiveChallenge) method() string {
|
|
||||||
return "keyboard-interactive"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) {
|
|
||||||
type initiateMsg struct {
|
|
||||||
User string `sshtype:"50"`
|
|
||||||
Service string
|
|
||||||
Method string
|
|
||||||
Language string
|
|
||||||
Submethods string
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := c.writePacket(Marshal(&initiateMsg{
|
|
||||||
User: user,
|
|
||||||
Service: serviceSSH,
|
|
||||||
Method: "keyboard-interactive",
|
|
||||||
})); err != nil {
|
|
||||||
return false, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
for {
|
|
||||||
packet, err := c.readPacket()
|
|
||||||
if err != nil {
|
|
||||||
return false, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// like handleAuthResponse, but with less options.
|
|
||||||
switch packet[0] {
|
|
||||||
case msgUserAuthBanner:
|
|
||||||
// TODO: Print banners during userauth.
|
|
||||||
continue
|
|
||||||
case msgUserAuthInfoRequest:
|
|
||||||
// OK
|
|
||||||
case msgUserAuthFailure:
|
|
||||||
var msg userAuthFailureMsg
|
|
||||||
if err := Unmarshal(packet, &msg); err != nil {
|
|
||||||
return false, nil, err
|
|
||||||
}
|
|
||||||
return false, msg.Methods, nil
|
|
||||||
case msgUserAuthSuccess:
|
|
||||||
return true, nil, nil
|
|
||||||
default:
|
|
||||||
return false, nil, unexpectedMessageError(msgUserAuthInfoRequest, packet[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
var msg userAuthInfoRequestMsg
|
|
||||||
if err := Unmarshal(packet, &msg); err != nil {
|
|
||||||
return false, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Manually unpack the prompt/echo pairs.
|
|
||||||
rest := msg.Prompts
|
|
||||||
var prompts []string
|
|
||||||
var echos []bool
|
|
||||||
for i := 0; i < int(msg.NumPrompts); i++ {
|
|
||||||
prompt, r, ok := parseString(rest)
|
|
||||||
if !ok || len(r) == 0 {
|
|
||||||
return false, nil, errors.New("ssh: prompt format error")
|
|
||||||
}
|
|
||||||
prompts = append(prompts, string(prompt))
|
|
||||||
echos = append(echos, r[0] != 0)
|
|
||||||
rest = r[1:]
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(rest) != 0 {
|
|
||||||
return false, nil, errors.New("ssh: extra data following keyboard-interactive pairs")
|
|
||||||
}
|
|
||||||
|
|
||||||
answers, err := cb(msg.User, msg.Instruction, prompts, echos)
|
|
||||||
if err != nil {
|
|
||||||
return false, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(answers) != len(prompts) {
|
|
||||||
return false, nil, errors.New("ssh: not enough answers from keyboard-interactive callback")
|
|
||||||
}
|
|
||||||
responseLength := 1 + 4
|
|
||||||
for _, a := range answers {
|
|
||||||
responseLength += stringLength(len(a))
|
|
||||||
}
|
|
||||||
serialized := make([]byte, responseLength)
|
|
||||||
p := serialized
|
|
||||||
p[0] = msgUserAuthInfoResponse
|
|
||||||
p = p[1:]
|
|
||||||
p = marshalUint32(p, uint32(len(answers)))
|
|
||||||
for _, a := range answers {
|
|
||||||
p = marshalString(p, []byte(a))
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := c.writePacket(serialized); err != nil {
|
|
||||||
return false, nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,393 +0,0 @@
|
||||||
// Copyright 2011 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package ssh
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"crypto/rand"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
type keyboardInteractive map[string]string
|
|
||||||
|
|
||||||
func (cr keyboardInteractive) Challenge(user string, instruction string, questions []string, echos []bool) ([]string, error) {
|
|
||||||
var answers []string
|
|
||||||
for _, q := range questions {
|
|
||||||
answers = append(answers, cr[q])
|
|
||||||
}
|
|
||||||
return answers, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// reused internally by tests
|
|
||||||
var clientPassword = "tiger"
|
|
||||||
|
|
||||||
// tryAuth runs a handshake with a given config against an SSH server
|
|
||||||
// with config serverConfig
|
|
||||||
func tryAuth(t *testing.T, config *ClientConfig) error {
|
|
||||||
c1, c2, err := netPipe()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("netPipe: %v", err)
|
|
||||||
}
|
|
||||||
defer c1.Close()
|
|
||||||
defer c2.Close()
|
|
||||||
|
|
||||||
certChecker := CertChecker{
|
|
||||||
IsAuthority: func(k PublicKey) bool {
|
|
||||||
return bytes.Equal(k.Marshal(), testPublicKeys["ecdsa"].Marshal())
|
|
||||||
},
|
|
||||||
UserKeyFallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) {
|
|
||||||
if conn.User() == "testuser" && bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, fmt.Errorf("pubkey for %q not acceptable", conn.User())
|
|
||||||
},
|
|
||||||
IsRevoked: func(c *Certificate) bool {
|
|
||||||
return c.Serial == 666
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
serverConfig := &ServerConfig{
|
|
||||||
PasswordCallback: func(conn ConnMetadata, pass []byte) (*Permissions, error) {
|
|
||||||
if conn.User() == "testuser" && string(pass) == clientPassword {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
return nil, errors.New("password auth failed")
|
|
||||||
},
|
|
||||||
PublicKeyCallback: certChecker.Authenticate,
|
|
||||||
KeyboardInteractiveCallback: func(conn ConnMetadata, challenge KeyboardInteractiveChallenge) (*Permissions, error) {
|
|
||||||
ans, err := challenge("user",
|
|
||||||
"instruction",
|
|
||||||
[]string{"question1", "question2"},
|
|
||||||
[]bool{true, true})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
ok := conn.User() == "testuser" && ans[0] == "answer1" && ans[1] == "answer2"
|
|
||||||
if ok {
|
|
||||||
challenge("user", "motd", nil, nil)
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
return nil, errors.New("keyboard-interactive failed")
|
|
||||||
},
|
|
||||||
AuthLogCallback: func(conn ConnMetadata, method string, err error) {
|
|
||||||
t.Logf("user %q, method %q: %v", conn.User(), method, err)
|
|
||||||
},
|
|
||||||
}
|
|
||||||
serverConfig.AddHostKey(testSigners["rsa"])
|
|
||||||
|
|
||||||
go newServer(c1, serverConfig)
|
|
||||||
_, _, _, err = NewClientConn(c2, "", config)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestClientAuthPublicKey(t *testing.T) {
|
|
||||||
config := &ClientConfig{
|
|
||||||
User: "testuser",
|
|
||||||
Auth: []AuthMethod{
|
|
||||||
PublicKeys(testSigners["rsa"]),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
if err := tryAuth(t, config); err != nil {
|
|
||||||
t.Fatalf("unable to dial remote side: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAuthMethodPassword(t *testing.T) {
|
|
||||||
config := &ClientConfig{
|
|
||||||
User: "testuser",
|
|
||||||
Auth: []AuthMethod{
|
|
||||||
Password(clientPassword),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := tryAuth(t, config); err != nil {
|
|
||||||
t.Fatalf("unable to dial remote side: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAuthMethodFallback(t *testing.T) {
|
|
||||||
var passwordCalled bool
|
|
||||||
config := &ClientConfig{
|
|
||||||
User: "testuser",
|
|
||||||
Auth: []AuthMethod{
|
|
||||||
PublicKeys(testSigners["rsa"]),
|
|
||||||
PasswordCallback(
|
|
||||||
func() (string, error) {
|
|
||||||
passwordCalled = true
|
|
||||||
return "WRONG", nil
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := tryAuth(t, config); err != nil {
|
|
||||||
t.Fatalf("unable to dial remote side: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if passwordCalled {
|
|
||||||
t.Errorf("password auth tried before public-key auth.")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAuthMethodWrongPassword(t *testing.T) {
|
|
||||||
config := &ClientConfig{
|
|
||||||
User: "testuser",
|
|
||||||
Auth: []AuthMethod{
|
|
||||||
Password("wrong"),
|
|
||||||
PublicKeys(testSigners["rsa"]),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := tryAuth(t, config); err != nil {
|
|
||||||
t.Fatalf("unable to dial remote side: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAuthMethodKeyboardInteractive(t *testing.T) {
|
|
||||||
answers := keyboardInteractive(map[string]string{
|
|
||||||
"question1": "answer1",
|
|
||||||
"question2": "answer2",
|
|
||||||
})
|
|
||||||
config := &ClientConfig{
|
|
||||||
User: "testuser",
|
|
||||||
Auth: []AuthMethod{
|
|
||||||
KeyboardInteractive(answers.Challenge),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := tryAuth(t, config); err != nil {
|
|
||||||
t.Fatalf("unable to dial remote side: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAuthMethodWrongKeyboardInteractive(t *testing.T) {
|
|
||||||
answers := keyboardInteractive(map[string]string{
|
|
||||||
"question1": "answer1",
|
|
||||||
"question2": "WRONG",
|
|
||||||
})
|
|
||||||
config := &ClientConfig{
|
|
||||||
User: "testuser",
|
|
||||||
Auth: []AuthMethod{
|
|
||||||
KeyboardInteractive(answers.Challenge),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := tryAuth(t, config); err == nil {
|
|
||||||
t.Fatalf("wrong answers should not have authenticated with KeyboardInteractive")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// the mock server will only authenticate ssh-rsa keys
|
|
||||||
func TestAuthMethodInvalidPublicKey(t *testing.T) {
|
|
||||||
config := &ClientConfig{
|
|
||||||
User: "testuser",
|
|
||||||
Auth: []AuthMethod{
|
|
||||||
PublicKeys(testSigners["dsa"]),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := tryAuth(t, config); err == nil {
|
|
||||||
t.Fatalf("dsa private key should not have authenticated with rsa public key")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// the client should authenticate with the second key
|
|
||||||
func TestAuthMethodRSAandDSA(t *testing.T) {
|
|
||||||
config := &ClientConfig{
|
|
||||||
User: "testuser",
|
|
||||||
Auth: []AuthMethod{
|
|
||||||
PublicKeys(testSigners["dsa"], testSigners["rsa"]),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
if err := tryAuth(t, config); err != nil {
|
|
||||||
t.Fatalf("client could not authenticate with rsa key: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestClientHMAC(t *testing.T) {
|
|
||||||
for _, mac := range supportedMACs {
|
|
||||||
config := &ClientConfig{
|
|
||||||
User: "testuser",
|
|
||||||
Auth: []AuthMethod{
|
|
||||||
PublicKeys(testSigners["rsa"]),
|
|
||||||
},
|
|
||||||
Config: Config{
|
|
||||||
MACs: []string{mac},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
if err := tryAuth(t, config); err != nil {
|
|
||||||
t.Fatalf("client could not authenticate with mac algo %s: %v", mac, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// issue 4285.
|
|
||||||
func TestClientUnsupportedCipher(t *testing.T) {
|
|
||||||
config := &ClientConfig{
|
|
||||||
User: "testuser",
|
|
||||||
Auth: []AuthMethod{
|
|
||||||
PublicKeys(),
|
|
||||||
},
|
|
||||||
Config: Config{
|
|
||||||
Ciphers: []string{"aes128-cbc"}, // not currently supported
|
|
||||||
},
|
|
||||||
}
|
|
||||||
if err := tryAuth(t, config); err == nil {
|
|
||||||
t.Errorf("expected no ciphers in common")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestClientUnsupportedKex(t *testing.T) {
|
|
||||||
config := &ClientConfig{
|
|
||||||
User: "testuser",
|
|
||||||
Auth: []AuthMethod{
|
|
||||||
PublicKeys(),
|
|
||||||
},
|
|
||||||
Config: Config{
|
|
||||||
KeyExchanges: []string{"diffie-hellman-group-exchange-sha256"}, // not currently supported
|
|
||||||
},
|
|
||||||
}
|
|
||||||
if err := tryAuth(t, config); err == nil || !strings.Contains(err.Error(), "no common algorithms") {
|
|
||||||
t.Errorf("got %v, expected 'no common algorithms'", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestClientLoginCert(t *testing.T) {
|
|
||||||
cert := &Certificate{
|
|
||||||
Key: testPublicKeys["rsa"],
|
|
||||||
ValidBefore: CertTimeInfinity,
|
|
||||||
CertType: UserCert,
|
|
||||||
}
|
|
||||||
cert.SignCert(rand.Reader, testSigners["ecdsa"])
|
|
||||||
certSigner, err := NewCertSigner(cert, testSigners["rsa"])
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("NewCertSigner: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
clientConfig := &ClientConfig{
|
|
||||||
User: "user",
|
|
||||||
}
|
|
||||||
clientConfig.Auth = append(clientConfig.Auth, PublicKeys(certSigner))
|
|
||||||
|
|
||||||
t.Log("should succeed")
|
|
||||||
if err := tryAuth(t, clientConfig); err != nil {
|
|
||||||
t.Errorf("cert login failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Log("corrupted signature")
|
|
||||||
cert.Signature.Blob[0]++
|
|
||||||
if err := tryAuth(t, clientConfig); err == nil {
|
|
||||||
t.Errorf("cert login passed with corrupted sig")
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Log("revoked")
|
|
||||||
cert.Serial = 666
|
|
||||||
cert.SignCert(rand.Reader, testSigners["ecdsa"])
|
|
||||||
if err := tryAuth(t, clientConfig); err == nil {
|
|
||||||
t.Errorf("revoked cert login succeeded")
|
|
||||||
}
|
|
||||||
cert.Serial = 1
|
|
||||||
|
|
||||||
t.Log("sign with wrong key")
|
|
||||||
cert.SignCert(rand.Reader, testSigners["dsa"])
|
|
||||||
if err := tryAuth(t, clientConfig); err == nil {
|
|
||||||
t.Errorf("cert login passed with non-authoritive key")
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Log("host cert")
|
|
||||||
cert.CertType = HostCert
|
|
||||||
cert.SignCert(rand.Reader, testSigners["ecdsa"])
|
|
||||||
if err := tryAuth(t, clientConfig); err == nil {
|
|
||||||
t.Errorf("cert login passed with wrong type")
|
|
||||||
}
|
|
||||||
cert.CertType = UserCert
|
|
||||||
|
|
||||||
t.Log("principal specified")
|
|
||||||
cert.ValidPrincipals = []string{"user"}
|
|
||||||
cert.SignCert(rand.Reader, testSigners["ecdsa"])
|
|
||||||
if err := tryAuth(t, clientConfig); err != nil {
|
|
||||||
t.Errorf("cert login failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Log("wrong principal specified")
|
|
||||||
cert.ValidPrincipals = []string{"fred"}
|
|
||||||
cert.SignCert(rand.Reader, testSigners["ecdsa"])
|
|
||||||
if err := tryAuth(t, clientConfig); err == nil {
|
|
||||||
t.Errorf("cert login passed with wrong principal")
|
|
||||||
}
|
|
||||||
cert.ValidPrincipals = nil
|
|
||||||
|
|
||||||
t.Log("added critical option")
|
|
||||||
cert.CriticalOptions = map[string]string{"root-access": "yes"}
|
|
||||||
cert.SignCert(rand.Reader, testSigners["ecdsa"])
|
|
||||||
if err := tryAuth(t, clientConfig); err == nil {
|
|
||||||
t.Errorf("cert login passed with unrecognized critical option")
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Log("allowed source address")
|
|
||||||
cert.CriticalOptions = map[string]string{"source-address": "127.0.0.42/24"}
|
|
||||||
cert.SignCert(rand.Reader, testSigners["ecdsa"])
|
|
||||||
if err := tryAuth(t, clientConfig); err != nil {
|
|
||||||
t.Errorf("cert login with source-address failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Log("disallowed source address")
|
|
||||||
cert.CriticalOptions = map[string]string{"source-address": "127.0.0.42"}
|
|
||||||
cert.SignCert(rand.Reader, testSigners["ecdsa"])
|
|
||||||
if err := tryAuth(t, clientConfig); err == nil {
|
|
||||||
t.Errorf("cert login with source-address succeeded")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func testPermissionsPassing(withPermissions bool, t *testing.T) {
|
|
||||||
serverConfig := &ServerConfig{
|
|
||||||
PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) {
|
|
||||||
if conn.User() == "nopermissions" {
|
|
||||||
return nil, nil
|
|
||||||
} else {
|
|
||||||
return &Permissions{}, nil
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
serverConfig.AddHostKey(testSigners["rsa"])
|
|
||||||
|
|
||||||
clientConfig := &ClientConfig{
|
|
||||||
Auth: []AuthMethod{
|
|
||||||
PublicKeys(testSigners["rsa"]),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
if withPermissions {
|
|
||||||
clientConfig.User = "permissions"
|
|
||||||
} else {
|
|
||||||
clientConfig.User = "nopermissions"
|
|
||||||
}
|
|
||||||
|
|
||||||
c1, c2, err := netPipe()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("netPipe: %v", err)
|
|
||||||
}
|
|
||||||
defer c1.Close()
|
|
||||||
defer c2.Close()
|
|
||||||
|
|
||||||
go NewClientConn(c2, "", clientConfig)
|
|
||||||
serverConn, err := newServer(c1, serverConfig)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if p := serverConn.Permissions; (p != nil) != withPermissions {
|
|
||||||
t.Fatalf("withPermissions is %t, but Permissions object is %#v", withPermissions, p)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPermissionsPassing(t *testing.T) {
|
|
||||||
testPermissionsPassing(true, t)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNoPermissionsPassing(t *testing.T) {
|
|
||||||
testPermissionsPassing(false, t)
|
|
||||||
}
|
|
|
@ -1,39 +0,0 @@
|
||||||
// Copyright 2014 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package ssh
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func testClientVersion(t *testing.T, config *ClientConfig, expected string) {
|
|
||||||
clientConn, serverConn := net.Pipe()
|
|
||||||
defer clientConn.Close()
|
|
||||||
receivedVersion := make(chan string, 1)
|
|
||||||
go func() {
|
|
||||||
version, err := readVersion(serverConn)
|
|
||||||
if err != nil {
|
|
||||||
receivedVersion <- ""
|
|
||||||
} else {
|
|
||||||
receivedVersion <- string(version)
|
|
||||||
}
|
|
||||||
serverConn.Close()
|
|
||||||
}()
|
|
||||||
NewClientConn(clientConn, "", config)
|
|
||||||
actual := <-receivedVersion
|
|
||||||
if actual != expected {
|
|
||||||
t.Fatalf("got %s; want %s", actual, expected)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCustomClientVersion(t *testing.T) {
|
|
||||||
version := "Test-Client-Version-0.0"
|
|
||||||
testClientVersion(t, &ClientConfig{ClientVersion: version}, version)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDefaultClientVersion(t *testing.T) {
|
|
||||||
testClientVersion(t, &ClientConfig{}, packageVersion)
|
|
||||||
}
|
|
|
@ -1,357 +0,0 @@
|
||||||
// Copyright 2011 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package ssh
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto"
|
|
||||||
"crypto/rand"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
_ "crypto/sha1"
|
|
||||||
_ "crypto/sha256"
|
|
||||||
_ "crypto/sha512"
|
|
||||||
)
|
|
||||||
|
|
||||||
// These are string constants in the SSH protocol.
|
|
||||||
const (
|
|
||||||
compressionNone = "none"
|
|
||||||
serviceUserAuth = "ssh-userauth"
|
|
||||||
serviceSSH = "ssh-connection"
|
|
||||||
)
|
|
||||||
|
|
||||||
// supportedCiphers specifies the supported ciphers in preference order.
|
|
||||||
var supportedCiphers = []string{
|
|
||||||
"aes128-ctr", "aes192-ctr", "aes256-ctr",
|
|
||||||
"aes128-gcm@openssh.com",
|
|
||||||
"arcfour256", "arcfour128",
|
|
||||||
}
|
|
||||||
|
|
||||||
// supportedKexAlgos specifies the supported key-exchange algorithms in
|
|
||||||
// preference order.
|
|
||||||
var supportedKexAlgos = []string{
|
|
||||||
// P384 and P521 are not constant-time yet, but since we don't
|
|
||||||
// reuse ephemeral keys, using them for ECDH should be OK.
|
|
||||||
kexAlgoECDH256, kexAlgoECDH384, kexAlgoECDH521,
|
|
||||||
kexAlgoDH14SHA1, kexAlgoDH1SHA1,
|
|
||||||
}
|
|
||||||
|
|
||||||
// supportedKexAlgos specifies the supported host-key algorithms (i.e. methods
|
|
||||||
// of authenticating servers) in preference order.
|
|
||||||
var supportedHostKeyAlgos = []string{
|
|
||||||
CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01,
|
|
||||||
CertAlgoECDSA384v01, CertAlgoECDSA521v01,
|
|
||||||
|
|
||||||
KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521,
|
|
||||||
KeyAlgoRSA, KeyAlgoDSA,
|
|
||||||
}
|
|
||||||
|
|
||||||
// supportedMACs specifies a default set of MAC algorithms in preference order.
|
|
||||||
// This is based on RFC 4253, section 6.4, but with hmac-md5 variants removed
|
|
||||||
// because they have reached the end of their useful life.
|
|
||||||
var supportedMACs = []string{
|
|
||||||
"hmac-sha1", "hmac-sha1-96",
|
|
||||||
}
|
|
||||||
|
|
||||||
var supportedCompressions = []string{compressionNone}
|
|
||||||
|
|
||||||
// hashFuncs keeps the mapping of supported algorithms to their respective
|
|
||||||
// hashes needed for signature verification.
|
|
||||||
var hashFuncs = map[string]crypto.Hash{
|
|
||||||
KeyAlgoRSA: crypto.SHA1,
|
|
||||||
KeyAlgoDSA: crypto.SHA1,
|
|
||||||
KeyAlgoECDSA256: crypto.SHA256,
|
|
||||||
KeyAlgoECDSA384: crypto.SHA384,
|
|
||||||
KeyAlgoECDSA521: crypto.SHA512,
|
|
||||||
CertAlgoRSAv01: crypto.SHA1,
|
|
||||||
CertAlgoDSAv01: crypto.SHA1,
|
|
||||||
CertAlgoECDSA256v01: crypto.SHA256,
|
|
||||||
CertAlgoECDSA384v01: crypto.SHA384,
|
|
||||||
CertAlgoECDSA521v01: crypto.SHA512,
|
|
||||||
}
|
|
||||||
|
|
||||||
// unexpectedMessageError results when the SSH message that we received didn't
|
|
||||||
// match what we wanted.
|
|
||||||
func unexpectedMessageError(expected, got uint8) error {
|
|
||||||
return fmt.Errorf("ssh: unexpected message type %d (expected %d)", got, expected)
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseError results from a malformed SSH message.
|
|
||||||
func parseError(tag uint8) error {
|
|
||||||
return fmt.Errorf("ssh: parse error in message type %d", tag)
|
|
||||||
}
|
|
||||||
|
|
||||||
func findCommonAlgorithm(clientAlgos []string, serverAlgos []string) (commonAlgo string, ok bool) {
|
|
||||||
for _, clientAlgo := range clientAlgos {
|
|
||||||
for _, serverAlgo := range serverAlgos {
|
|
||||||
if clientAlgo == serverAlgo {
|
|
||||||
return clientAlgo, true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func findCommonCipher(clientCiphers []string, serverCiphers []string) (commonCipher string, ok bool) {
|
|
||||||
for _, clientCipher := range clientCiphers {
|
|
||||||
for _, serverCipher := range serverCiphers {
|
|
||||||
// reject the cipher if we have no cipherModes definition
|
|
||||||
if clientCipher == serverCipher && cipherModes[clientCipher] != nil {
|
|
||||||
return clientCipher, true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
type directionAlgorithms struct {
|
|
||||||
Cipher string
|
|
||||||
MAC string
|
|
||||||
Compression string
|
|
||||||
}
|
|
||||||
|
|
||||||
type algorithms struct {
|
|
||||||
kex string
|
|
||||||
hostKey string
|
|
||||||
w directionAlgorithms
|
|
||||||
r directionAlgorithms
|
|
||||||
}
|
|
||||||
|
|
||||||
func findAgreedAlgorithms(clientKexInit, serverKexInit *kexInitMsg) (algs *algorithms) {
|
|
||||||
var ok bool
|
|
||||||
result := &algorithms{}
|
|
||||||
result.kex, ok = findCommonAlgorithm(clientKexInit.KexAlgos, serverKexInit.KexAlgos)
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
result.hostKey, ok = findCommonAlgorithm(clientKexInit.ServerHostKeyAlgos, serverKexInit.ServerHostKeyAlgos)
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
result.w.Cipher, ok = findCommonCipher(clientKexInit.CiphersClientServer, serverKexInit.CiphersClientServer)
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
result.r.Cipher, ok = findCommonCipher(clientKexInit.CiphersServerClient, serverKexInit.CiphersServerClient)
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
result.w.MAC, ok = findCommonAlgorithm(clientKexInit.MACsClientServer, serverKexInit.MACsClientServer)
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
result.r.MAC, ok = findCommonAlgorithm(clientKexInit.MACsServerClient, serverKexInit.MACsServerClient)
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
result.w.Compression, ok = findCommonAlgorithm(clientKexInit.CompressionClientServer, serverKexInit.CompressionClientServer)
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
result.r.Compression, ok = findCommonAlgorithm(clientKexInit.CompressionServerClient, serverKexInit.CompressionServerClient)
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
// If rekeythreshold is too small, we can't make any progress sending
|
|
||||||
// stuff.
|
|
||||||
const minRekeyThreshold uint64 = 256
|
|
||||||
|
|
||||||
// Config contains configuration data common to both ServerConfig and
|
|
||||||
// ClientConfig.
|
|
||||||
type Config struct {
|
|
||||||
// Rand provides the source of entropy for cryptographic
|
|
||||||
// primitives. If Rand is nil, the cryptographic random reader
|
|
||||||
// in package crypto/rand will be used.
|
|
||||||
Rand io.Reader
|
|
||||||
|
|
||||||
// The maximum number of bytes sent or received after which a
|
|
||||||
// new key is negotiated. It must be at least 256. If
|
|
||||||
// unspecified, 1 gigabyte is used.
|
|
||||||
RekeyThreshold uint64
|
|
||||||
|
|
||||||
// The allowed key exchanges algorithms. If unspecified then a
|
|
||||||
// default set of algorithms is used.
|
|
||||||
KeyExchanges []string
|
|
||||||
|
|
||||||
// The allowed cipher algorithms. If unspecified then a sensible
|
|
||||||
// default is used.
|
|
||||||
Ciphers []string
|
|
||||||
|
|
||||||
// The allowed MAC algorithms. If unspecified then a sensible default
|
|
||||||
// is used.
|
|
||||||
MACs []string
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetDefaults sets sensible values for unset fields in config. This is
|
|
||||||
// exported for testing: Configs passed to SSH functions are copied and have
|
|
||||||
// default values set automatically.
|
|
||||||
func (c *Config) SetDefaults() {
|
|
||||||
if c.Rand == nil {
|
|
||||||
c.Rand = rand.Reader
|
|
||||||
}
|
|
||||||
if c.Ciphers == nil {
|
|
||||||
c.Ciphers = supportedCiphers
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.KeyExchanges == nil {
|
|
||||||
c.KeyExchanges = supportedKexAlgos
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.MACs == nil {
|
|
||||||
c.MACs = supportedMACs
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.RekeyThreshold == 0 {
|
|
||||||
// RFC 4253, section 9 suggests rekeying after 1G.
|
|
||||||
c.RekeyThreshold = 1 << 30
|
|
||||||
}
|
|
||||||
if c.RekeyThreshold < minRekeyThreshold {
|
|
||||||
c.RekeyThreshold = minRekeyThreshold
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// buildDataSignedForAuth returns the data that is signed in order to prove
|
|
||||||
// possession of a private key. See RFC 4252, section 7.
|
|
||||||
func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubKey []byte) []byte {
|
|
||||||
data := struct {
|
|
||||||
Session []byte
|
|
||||||
Type byte
|
|
||||||
User string
|
|
||||||
Service string
|
|
||||||
Method string
|
|
||||||
Sign bool
|
|
||||||
Algo []byte
|
|
||||||
PubKey []byte
|
|
||||||
}{
|
|
||||||
sessionId,
|
|
||||||
msgUserAuthRequest,
|
|
||||||
req.User,
|
|
||||||
req.Service,
|
|
||||||
req.Method,
|
|
||||||
true,
|
|
||||||
algo,
|
|
||||||
pubKey,
|
|
||||||
}
|
|
||||||
return Marshal(data)
|
|
||||||
}
|
|
||||||
|
|
||||||
func appendU16(buf []byte, n uint16) []byte {
|
|
||||||
return append(buf, byte(n>>8), byte(n))
|
|
||||||
}
|
|
||||||
|
|
||||||
func appendU32(buf []byte, n uint32) []byte {
|
|
||||||
return append(buf, byte(n>>24), byte(n>>16), byte(n>>8), byte(n))
|
|
||||||
}
|
|
||||||
|
|
||||||
func appendU64(buf []byte, n uint64) []byte {
|
|
||||||
return append(buf,
|
|
||||||
byte(n>>56), byte(n>>48), byte(n>>40), byte(n>>32),
|
|
||||||
byte(n>>24), byte(n>>16), byte(n>>8), byte(n))
|
|
||||||
}
|
|
||||||
|
|
||||||
func appendInt(buf []byte, n int) []byte {
|
|
||||||
return appendU32(buf, uint32(n))
|
|
||||||
}
|
|
||||||
|
|
||||||
func appendString(buf []byte, s string) []byte {
|
|
||||||
buf = appendU32(buf, uint32(len(s)))
|
|
||||||
buf = append(buf, s...)
|
|
||||||
return buf
|
|
||||||
}
|
|
||||||
|
|
||||||
func appendBool(buf []byte, b bool) []byte {
|
|
||||||
if b {
|
|
||||||
return append(buf, 1)
|
|
||||||
}
|
|
||||||
return append(buf, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// newCond is a helper to hide the fact that there is no usable zero
|
|
||||||
// value for sync.Cond.
|
|
||||||
func newCond() *sync.Cond { return sync.NewCond(new(sync.Mutex)) }
|
|
||||||
|
|
||||||
// window represents the buffer available to clients
|
|
||||||
// wishing to write to a channel.
|
|
||||||
type window struct {
|
|
||||||
*sync.Cond
|
|
||||||
win uint32 // RFC 4254 5.2 says the window size can grow to 2^32-1
|
|
||||||
writeWaiters int
|
|
||||||
closed bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// add adds win to the amount of window available
|
|
||||||
// for consumers.
|
|
||||||
func (w *window) add(win uint32) bool {
|
|
||||||
// a zero sized window adjust is a noop.
|
|
||||||
if win == 0 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
w.L.Lock()
|
|
||||||
if w.win+win < win {
|
|
||||||
w.L.Unlock()
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
w.win += win
|
|
||||||
// It is unusual that multiple goroutines would be attempting to reserve
|
|
||||||
// window space, but not guaranteed. Use broadcast to notify all waiters
|
|
||||||
// that additional window is available.
|
|
||||||
w.Broadcast()
|
|
||||||
w.L.Unlock()
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// close sets the window to closed, so all reservations fail
|
|
||||||
// immediately.
|
|
||||||
func (w *window) close() {
|
|
||||||
w.L.Lock()
|
|
||||||
w.closed = true
|
|
||||||
w.Broadcast()
|
|
||||||
w.L.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// reserve reserves win from the available window capacity.
|
|
||||||
// If no capacity remains, reserve will block. reserve may
|
|
||||||
// return less than requested.
|
|
||||||
func (w *window) reserve(win uint32) (uint32, error) {
|
|
||||||
var err error
|
|
||||||
w.L.Lock()
|
|
||||||
w.writeWaiters++
|
|
||||||
w.Broadcast()
|
|
||||||
for w.win == 0 && !w.closed {
|
|
||||||
w.Wait()
|
|
||||||
}
|
|
||||||
w.writeWaiters--
|
|
||||||
if w.win < win {
|
|
||||||
win = w.win
|
|
||||||
}
|
|
||||||
w.win -= win
|
|
||||||
if w.closed {
|
|
||||||
err = io.EOF
|
|
||||||
}
|
|
||||||
w.L.Unlock()
|
|
||||||
return win, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// waitWriterBlocked waits until some goroutine is blocked for further
|
|
||||||
// writes. It is used in tests only.
|
|
||||||
func (w *window) waitWriterBlocked() {
|
|
||||||
w.Cond.L.Lock()
|
|
||||||
for w.writeWaiters == 0 {
|
|
||||||
w.Cond.Wait()
|
|
||||||
}
|
|
||||||
w.Cond.L.Unlock()
|
|
||||||
}
|
|
|
@ -1,144 +0,0 @@
|
||||||
// Copyright 2013 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package ssh
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
)
|
|
||||||
|
|
||||||
// OpenChannelError is returned if the other side rejects an
|
|
||||||
// OpenChannel request.
|
|
||||||
type OpenChannelError struct {
|
|
||||||
Reason RejectionReason
|
|
||||||
Message string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *OpenChannelError) Error() string {
|
|
||||||
return fmt.Sprintf("ssh: rejected: %s (%s)", e.Reason, e.Message)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ConnMetadata holds metadata for the connection.
|
|
||||||
type ConnMetadata interface {
|
|
||||||
// User returns the user ID for this connection.
|
|
||||||
// It is empty if no authentication is used.
|
|
||||||
User() string
|
|
||||||
|
|
||||||
// SessionID returns the sesson hash, also denoted by H.
|
|
||||||
SessionID() []byte
|
|
||||||
|
|
||||||
// ClientVersion returns the client's version string as hashed
|
|
||||||
// into the session ID.
|
|
||||||
ClientVersion() []byte
|
|
||||||
|
|
||||||
// ServerVersion returns the client's version string as hashed
|
|
||||||
// into the session ID.
|
|
||||||
ServerVersion() []byte
|
|
||||||
|
|
||||||
// RemoteAddr returns the remote address for this connection.
|
|
||||||
RemoteAddr() net.Addr
|
|
||||||
|
|
||||||
// LocalAddr returns the local address for this connection.
|
|
||||||
LocalAddr() net.Addr
|
|
||||||
}
|
|
||||||
|
|
||||||
// Conn represents an SSH connection for both server and client roles.
|
|
||||||
// Conn is the basis for implementing an application layer, such
|
|
||||||
// as ClientConn, which implements the traditional shell access for
|
|
||||||
// clients.
|
|
||||||
type Conn interface {
|
|
||||||
ConnMetadata
|
|
||||||
|
|
||||||
// SendRequest sends a global request, and returns the
|
|
||||||
// reply. If wantReply is true, it returns the response status
|
|
||||||
// and payload. See also RFC4254, section 4.
|
|
||||||
SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error)
|
|
||||||
|
|
||||||
// OpenChannel tries to open an channel. If the request is
|
|
||||||
// rejected, it returns *OpenChannelError. On success it returns
|
|
||||||
// the SSH Channel and a Go channel for incoming, out-of-band
|
|
||||||
// requests. The Go channel must be serviced, or the
|
|
||||||
// connection will hang.
|
|
||||||
OpenChannel(name string, data []byte) (Channel, <-chan *Request, error)
|
|
||||||
|
|
||||||
// Close closes the underlying network connection
|
|
||||||
Close() error
|
|
||||||
|
|
||||||
// Wait blocks until the connection has shut down, and returns the
|
|
||||||
// error causing the shutdown.
|
|
||||||
Wait() error
|
|
||||||
|
|
||||||
// TODO(hanwen): consider exposing:
|
|
||||||
// RequestKeyChange
|
|
||||||
// Disconnect
|
|
||||||
}
|
|
||||||
|
|
||||||
// DiscardRequests consumes and rejects all requests from the
|
|
||||||
// passed-in channel.
|
|
||||||
func DiscardRequests(in <-chan *Request) {
|
|
||||||
for req := range in {
|
|
||||||
if req.WantReply {
|
|
||||||
req.Reply(false, nil)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// A connection represents an incoming connection.
|
|
||||||
type connection struct {
|
|
||||||
transport *handshakeTransport
|
|
||||||
sshConn
|
|
||||||
|
|
||||||
// The connection protocol.
|
|
||||||
*mux
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *connection) Close() error {
|
|
||||||
return c.sshConn.conn.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
// sshconn provides net.Conn metadata, but disallows direct reads and
|
|
||||||
// writes.
|
|
||||||
type sshConn struct {
|
|
||||||
conn net.Conn
|
|
||||||
|
|
||||||
user string
|
|
||||||
sessionID []byte
|
|
||||||
clientVersion []byte
|
|
||||||
serverVersion []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func dup(src []byte) []byte {
|
|
||||||
dst := make([]byte, len(src))
|
|
||||||
copy(dst, src)
|
|
||||||
return dst
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *sshConn) User() string {
|
|
||||||
return c.user
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *sshConn) RemoteAddr() net.Addr {
|
|
||||||
return c.conn.RemoteAddr()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *sshConn) Close() error {
|
|
||||||
return c.conn.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *sshConn) LocalAddr() net.Addr {
|
|
||||||
return c.conn.LocalAddr()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *sshConn) SessionID() []byte {
|
|
||||||
return dup(c.sessionID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *sshConn) ClientVersion() []byte {
|
|
||||||
return dup(c.clientVersion)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *sshConn) ServerVersion() []byte {
|
|
||||||
return dup(c.serverVersion)
|
|
||||||
}
|
|
|
@ -1,18 +0,0 @@
|
||||||
// Copyright 2011 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
/*
|
|
||||||
Package ssh implements an SSH client and server.
|
|
||||||
|
|
||||||
SSH is a transport security protocol, an authentication protocol and a
|
|
||||||
family of application protocols. The most typical application level
|
|
||||||
protocol is a remote shell and this is specifically implemented. However,
|
|
||||||
the multiplexed nature of SSH is exposed to users that wish to support
|
|
||||||
others.
|
|
||||||
|
|
||||||
References:
|
|
||||||
[PROTOCOL.certkeys]: http://www.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.certkeys
|
|
||||||
[SSH-PARAMETERS]: http://www.iana.org/assignments/ssh-parameters/ssh-parameters.xml#ssh-parameters-1
|
|
||||||
*/
|
|
||||||
package ssh
|
|
|
@ -1,210 +0,0 @@
|
||||||
// Copyright 2011 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package ssh
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"fmt"
|
|
||||||
"io/ioutil"
|
|
||||||
"log"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh/terminal"
|
|
||||||
)
|
|
||||||
|
|
||||||
func ExampleNewServerConn() {
|
|
||||||
// An SSH server is represented by a ServerConfig, which holds
|
|
||||||
// certificate details and handles authentication of ServerConns.
|
|
||||||
config := &ServerConfig{
|
|
||||||
PasswordCallback: func(c ConnMetadata, pass []byte) (*Permissions, error) {
|
|
||||||
// Should use constant-time compare (or better, salt+hash) in
|
|
||||||
// a production setting.
|
|
||||||
if c.User() == "testuser" && string(pass) == "tiger" {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("password rejected for %q", c.User())
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
privateBytes, err := ioutil.ReadFile("id_rsa")
|
|
||||||
if err != nil {
|
|
||||||
panic("Failed to load private key")
|
|
||||||
}
|
|
||||||
|
|
||||||
private, err := ParsePrivateKey(privateBytes)
|
|
||||||
if err != nil {
|
|
||||||
panic("Failed to parse private key")
|
|
||||||
}
|
|
||||||
|
|
||||||
config.AddHostKey(private)
|
|
||||||
|
|
||||||
// Once a ServerConfig has been configured, connections can be
|
|
||||||
// accepted.
|
|
||||||
listener, err := net.Listen("tcp", "0.0.0.0:2022")
|
|
||||||
if err != nil {
|
|
||||||
panic("failed to listen for connection")
|
|
||||||
}
|
|
||||||
nConn, err := listener.Accept()
|
|
||||||
if err != nil {
|
|
||||||
panic("failed to accept incoming connection")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Before use, a handshake must be performed on the incoming
|
|
||||||
// net.Conn.
|
|
||||||
_, chans, reqs, err := NewServerConn(nConn, config)
|
|
||||||
if err != nil {
|
|
||||||
panic("failed to handshake")
|
|
||||||
}
|
|
||||||
// The incoming Request channel must be serviced.
|
|
||||||
go DiscardRequests(reqs)
|
|
||||||
|
|
||||||
// Service the incoming Channel channel.
|
|
||||||
for newChannel := range chans {
|
|
||||||
// Channels have a type, depending on the application level
|
|
||||||
// protocol intended. In the case of a shell, the type is
|
|
||||||
// "session" and ServerShell may be used to present a simple
|
|
||||||
// terminal interface.
|
|
||||||
if newChannel.ChannelType() != "session" {
|
|
||||||
newChannel.Reject(UnknownChannelType, "unknown channel type")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
channel, requests, err := newChannel.Accept()
|
|
||||||
if err != nil {
|
|
||||||
panic("could not accept channel.")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sessions have out-of-band requests such as "shell",
|
|
||||||
// "pty-req" and "env". Here we handle only the
|
|
||||||
// "shell" request.
|
|
||||||
go func(in <-chan *Request) {
|
|
||||||
for req := range in {
|
|
||||||
ok := false
|
|
||||||
switch req.Type {
|
|
||||||
case "shell":
|
|
||||||
ok = true
|
|
||||||
if len(req.Payload) > 0 {
|
|
||||||
// We don't accept any
|
|
||||||
// commands, only the
|
|
||||||
// default shell.
|
|
||||||
ok = false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
req.Reply(ok, nil)
|
|
||||||
}
|
|
||||||
}(requests)
|
|
||||||
|
|
||||||
term := terminal.NewTerminal(channel, "> ")
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
defer channel.Close()
|
|
||||||
for {
|
|
||||||
line, err := term.ReadLine()
|
|
||||||
if err != nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
fmt.Println(line)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func ExampleDial() {
|
|
||||||
// An SSH client is represented with a ClientConn. Currently only
|
|
||||||
// the "password" authentication method is supported.
|
|
||||||
//
|
|
||||||
// To authenticate with the remote server you must pass at least one
|
|
||||||
// implementation of AuthMethod via the Auth field in ClientConfig.
|
|
||||||
config := &ClientConfig{
|
|
||||||
User: "username",
|
|
||||||
Auth: []AuthMethod{
|
|
||||||
Password("yourpassword"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
client, err := Dial("tcp", "yourserver.com:22", config)
|
|
||||||
if err != nil {
|
|
||||||
panic("Failed to dial: " + err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Each ClientConn can support multiple interactive sessions,
|
|
||||||
// represented by a Session.
|
|
||||||
session, err := client.NewSession()
|
|
||||||
if err != nil {
|
|
||||||
panic("Failed to create session: " + err.Error())
|
|
||||||
}
|
|
||||||
defer session.Close()
|
|
||||||
|
|
||||||
// Once a Session is created, you can execute a single command on
|
|
||||||
// the remote side using the Run method.
|
|
||||||
var b bytes.Buffer
|
|
||||||
session.Stdout = &b
|
|
||||||
if err := session.Run("/usr/bin/whoami"); err != nil {
|
|
||||||
panic("Failed to run: " + err.Error())
|
|
||||||
}
|
|
||||||
fmt.Println(b.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
func ExampleClient_Listen() {
|
|
||||||
config := &ClientConfig{
|
|
||||||
User: "username",
|
|
||||||
Auth: []AuthMethod{
|
|
||||||
Password("password"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
// Dial your ssh server.
|
|
||||||
conn, err := Dial("tcp", "localhost:22", config)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("unable to connect: %s", err)
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
// Request the remote side to open port 8080 on all interfaces.
|
|
||||||
l, err := conn.Listen("tcp", "0.0.0.0:8080")
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("unable to register tcp forward: %v", err)
|
|
||||||
}
|
|
||||||
defer l.Close()
|
|
||||||
|
|
||||||
// Serve HTTP with your SSH server acting as a reverse proxy.
|
|
||||||
http.Serve(l, http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
|
|
||||||
fmt.Fprintf(resp, "Hello world!\n")
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
func ExampleSession_RequestPty() {
|
|
||||||
// Create client config
|
|
||||||
config := &ClientConfig{
|
|
||||||
User: "username",
|
|
||||||
Auth: []AuthMethod{
|
|
||||||
Password("password"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
// Connect to ssh server
|
|
||||||
conn, err := Dial("tcp", "localhost:22", config)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("unable to connect: %s", err)
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
// Create a session
|
|
||||||
session, err := conn.NewSession()
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("unable to create session: %s", err)
|
|
||||||
}
|
|
||||||
defer session.Close()
|
|
||||||
// Set up terminal modes
|
|
||||||
modes := TerminalModes{
|
|
||||||
ECHO: 0, // disable echoing
|
|
||||||
TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud
|
|
||||||
TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud
|
|
||||||
}
|
|
||||||
// Request pseudo terminal
|
|
||||||
if err := session.RequestPty("xterm", 80, 40, modes); err != nil {
|
|
||||||
log.Fatalf("request for pseudo terminal failed: %s", err)
|
|
||||||
}
|
|
||||||
// Start remote shell
|
|
||||||
if err := session.Shell(); err != nil {
|
|
||||||
log.Fatalf("failed to start shell: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,393 +0,0 @@
|
||||||
// Copyright 2013 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package ssh
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/rand"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"log"
|
|
||||||
"net"
|
|
||||||
"sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
// debugHandshake, if set, prints messages sent and received. Key
|
|
||||||
// exchange messages are printed as if DH were used, so the debug
|
|
||||||
// messages are wrong when using ECDH.
|
|
||||||
const debugHandshake = false
|
|
||||||
|
|
||||||
// keyingTransport is a packet based transport that supports key
|
|
||||||
// changes. It need not be thread-safe. It should pass through
|
|
||||||
// msgNewKeys in both directions.
|
|
||||||
type keyingTransport interface {
|
|
||||||
packetConn
|
|
||||||
|
|
||||||
// prepareKeyChange sets up a key change. The key change for a
|
|
||||||
// direction will be effected if a msgNewKeys message is sent
|
|
||||||
// or received.
|
|
||||||
prepareKeyChange(*algorithms, *kexResult) error
|
|
||||||
|
|
||||||
// getSessionID returns the session ID. prepareKeyChange must
|
|
||||||
// have been called once.
|
|
||||||
getSessionID() []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
// rekeyingTransport is the interface of handshakeTransport that we
|
|
||||||
// (internally) expose to ClientConn and ServerConn.
|
|
||||||
type rekeyingTransport interface {
|
|
||||||
packetConn
|
|
||||||
|
|
||||||
// requestKeyChange asks the remote side to change keys. All
|
|
||||||
// writes are blocked until the key change succeeds, which is
|
|
||||||
// signaled by reading a msgNewKeys.
|
|
||||||
requestKeyChange() error
|
|
||||||
|
|
||||||
// getSessionID returns the session ID. This is only valid
|
|
||||||
// after the first key change has completed.
|
|
||||||
getSessionID() []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
// handshakeTransport implements rekeying on top of a keyingTransport
|
|
||||||
// and offers a thread-safe writePacket() interface.
|
|
||||||
type handshakeTransport struct {
|
|
||||||
conn keyingTransport
|
|
||||||
config *Config
|
|
||||||
|
|
||||||
serverVersion []byte
|
|
||||||
clientVersion []byte
|
|
||||||
|
|
||||||
hostKeys []Signer // If hostKeys are given, we are the server.
|
|
||||||
|
|
||||||
// On read error, incoming is closed, and readError is set.
|
|
||||||
incoming chan []byte
|
|
||||||
readError error
|
|
||||||
|
|
||||||
// data for host key checking
|
|
||||||
hostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error
|
|
||||||
dialAddress string
|
|
||||||
remoteAddr net.Addr
|
|
||||||
|
|
||||||
readSinceKex uint64
|
|
||||||
|
|
||||||
// Protects the writing side of the connection
|
|
||||||
mu sync.Mutex
|
|
||||||
cond *sync.Cond
|
|
||||||
sentInitPacket []byte
|
|
||||||
sentInitMsg *kexInitMsg
|
|
||||||
writtenSinceKex uint64
|
|
||||||
writeError error
|
|
||||||
}
|
|
||||||
|
|
||||||
func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion, serverVersion []byte) *handshakeTransport {
|
|
||||||
t := &handshakeTransport{
|
|
||||||
conn: conn,
|
|
||||||
serverVersion: serverVersion,
|
|
||||||
clientVersion: clientVersion,
|
|
||||||
incoming: make(chan []byte, 16),
|
|
||||||
config: config,
|
|
||||||
}
|
|
||||||
t.cond = sync.NewCond(&t.mu)
|
|
||||||
return t
|
|
||||||
}
|
|
||||||
|
|
||||||
func newClientTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ClientConfig, dialAddr string, addr net.Addr) *handshakeTransport {
|
|
||||||
t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion)
|
|
||||||
t.dialAddress = dialAddr
|
|
||||||
t.remoteAddr = addr
|
|
||||||
t.hostKeyCallback = config.HostKeyCallback
|
|
||||||
go t.readLoop()
|
|
||||||
return t
|
|
||||||
}
|
|
||||||
|
|
||||||
func newServerTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ServerConfig) *handshakeTransport {
|
|
||||||
t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion)
|
|
||||||
t.hostKeys = config.hostKeys
|
|
||||||
go t.readLoop()
|
|
||||||
return t
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *handshakeTransport) getSessionID() []byte {
|
|
||||||
return t.conn.getSessionID()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *handshakeTransport) id() string {
|
|
||||||
if len(t.hostKeys) > 0 {
|
|
||||||
return "server"
|
|
||||||
}
|
|
||||||
return "client"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *handshakeTransport) readPacket() ([]byte, error) {
|
|
||||||
p, ok := <-t.incoming
|
|
||||||
if !ok {
|
|
||||||
return nil, t.readError
|
|
||||||
}
|
|
||||||
return p, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *handshakeTransport) readLoop() {
|
|
||||||
for {
|
|
||||||
p, err := t.readOnePacket()
|
|
||||||
if err != nil {
|
|
||||||
t.readError = err
|
|
||||||
close(t.incoming)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if p[0] == msgIgnore || p[0] == msgDebug {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
t.incoming <- p
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *handshakeTransport) readOnePacket() ([]byte, error) {
|
|
||||||
if t.readSinceKex > t.config.RekeyThreshold {
|
|
||||||
if err := t.requestKeyChange(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
p, err := t.conn.readPacket()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
t.readSinceKex += uint64(len(p))
|
|
||||||
if debugHandshake {
|
|
||||||
msg, err := decode(p)
|
|
||||||
log.Printf("%s got %T %v (%v)", t.id(), msg, msg, err)
|
|
||||||
}
|
|
||||||
if p[0] != msgKexInit {
|
|
||||||
return p, nil
|
|
||||||
}
|
|
||||||
err = t.enterKeyExchange(p)
|
|
||||||
|
|
||||||
t.mu.Lock()
|
|
||||||
if err != nil {
|
|
||||||
// drop connection
|
|
||||||
t.conn.Close()
|
|
||||||
t.writeError = err
|
|
||||||
}
|
|
||||||
|
|
||||||
if debugHandshake {
|
|
||||||
log.Printf("%s exited key exchange, err %v", t.id(), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Unblock writers.
|
|
||||||
t.sentInitMsg = nil
|
|
||||||
t.sentInitPacket = nil
|
|
||||||
t.cond.Broadcast()
|
|
||||||
t.writtenSinceKex = 0
|
|
||||||
t.mu.Unlock()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
t.readSinceKex = 0
|
|
||||||
return []byte{msgNewKeys}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// sendKexInit sends a key change message, and returns the message
|
|
||||||
// that was sent. After initiating the key change, all writes will be
|
|
||||||
// blocked until the change is done, and a failed key change will
|
|
||||||
// close the underlying transport. This function is safe for
|
|
||||||
// concurrent use by multiple goroutines.
|
|
||||||
func (t *handshakeTransport) sendKexInit() (*kexInitMsg, []byte, error) {
|
|
||||||
t.mu.Lock()
|
|
||||||
defer t.mu.Unlock()
|
|
||||||
return t.sendKexInitLocked()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *handshakeTransport) requestKeyChange() error {
|
|
||||||
_, _, err := t.sendKexInit()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// sendKexInitLocked sends a key change message. t.mu must be locked
|
|
||||||
// while this happens.
|
|
||||||
func (t *handshakeTransport) sendKexInitLocked() (*kexInitMsg, []byte, error) {
|
|
||||||
// kexInits may be sent either in response to the other side,
|
|
||||||
// or because our side wants to initiate a key change, so we
|
|
||||||
// may have already sent a kexInit. In that case, don't send a
|
|
||||||
// second kexInit.
|
|
||||||
if t.sentInitMsg != nil {
|
|
||||||
return t.sentInitMsg, t.sentInitPacket, nil
|
|
||||||
}
|
|
||||||
msg := &kexInitMsg{
|
|
||||||
KexAlgos: t.config.KeyExchanges,
|
|
||||||
CiphersClientServer: t.config.Ciphers,
|
|
||||||
CiphersServerClient: t.config.Ciphers,
|
|
||||||
MACsClientServer: t.config.MACs,
|
|
||||||
MACsServerClient: t.config.MACs,
|
|
||||||
CompressionClientServer: supportedCompressions,
|
|
||||||
CompressionServerClient: supportedCompressions,
|
|
||||||
}
|
|
||||||
io.ReadFull(rand.Reader, msg.Cookie[:])
|
|
||||||
|
|
||||||
if len(t.hostKeys) > 0 {
|
|
||||||
for _, k := range t.hostKeys {
|
|
||||||
msg.ServerHostKeyAlgos = append(
|
|
||||||
msg.ServerHostKeyAlgos, k.PublicKey().Type())
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
msg.ServerHostKeyAlgos = supportedHostKeyAlgos
|
|
||||||
}
|
|
||||||
packet := Marshal(msg)
|
|
||||||
|
|
||||||
// writePacket destroys the contents, so save a copy.
|
|
||||||
packetCopy := make([]byte, len(packet))
|
|
||||||
copy(packetCopy, packet)
|
|
||||||
|
|
||||||
if err := t.conn.writePacket(packetCopy); err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
t.sentInitMsg = msg
|
|
||||||
t.sentInitPacket = packet
|
|
||||||
return msg, packet, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *handshakeTransport) writePacket(p []byte) error {
|
|
||||||
t.mu.Lock()
|
|
||||||
if t.writtenSinceKex > t.config.RekeyThreshold {
|
|
||||||
t.sendKexInitLocked()
|
|
||||||
}
|
|
||||||
for t.sentInitMsg != nil {
|
|
||||||
t.cond.Wait()
|
|
||||||
}
|
|
||||||
if t.writeError != nil {
|
|
||||||
return t.writeError
|
|
||||||
}
|
|
||||||
t.writtenSinceKex += uint64(len(p))
|
|
||||||
|
|
||||||
var err error
|
|
||||||
switch p[0] {
|
|
||||||
case msgKexInit:
|
|
||||||
err = errors.New("ssh: only handshakeTransport can send kexInit")
|
|
||||||
case msgNewKeys:
|
|
||||||
err = errors.New("ssh: only handshakeTransport can send newKeys")
|
|
||||||
default:
|
|
||||||
err = t.conn.writePacket(p)
|
|
||||||
}
|
|
||||||
t.mu.Unlock()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *handshakeTransport) Close() error {
|
|
||||||
return t.conn.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
// enterKeyExchange runs the key exchange.
|
|
||||||
func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
|
|
||||||
if debugHandshake {
|
|
||||||
log.Printf("%s entered key exchange", t.id())
|
|
||||||
}
|
|
||||||
myInit, myInitPacket, err := t.sendKexInit()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
otherInit := &kexInitMsg{}
|
|
||||||
if err := Unmarshal(otherInitPacket, otherInit); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
magics := handshakeMagics{
|
|
||||||
clientVersion: t.clientVersion,
|
|
||||||
serverVersion: t.serverVersion,
|
|
||||||
clientKexInit: otherInitPacket,
|
|
||||||
serverKexInit: myInitPacket,
|
|
||||||
}
|
|
||||||
|
|
||||||
clientInit := otherInit
|
|
||||||
serverInit := myInit
|
|
||||||
if len(t.hostKeys) == 0 {
|
|
||||||
clientInit = myInit
|
|
||||||
serverInit = otherInit
|
|
||||||
|
|
||||||
magics.clientKexInit = myInitPacket
|
|
||||||
magics.serverKexInit = otherInitPacket
|
|
||||||
}
|
|
||||||
|
|
||||||
algs := findAgreedAlgorithms(clientInit, serverInit)
|
|
||||||
if algs == nil {
|
|
||||||
return errors.New("ssh: no common algorithms")
|
|
||||||
}
|
|
||||||
|
|
||||||
// We don't send FirstKexFollows, but we handle receiving it.
|
|
||||||
if otherInit.FirstKexFollows && algs.kex != otherInit.KexAlgos[0] {
|
|
||||||
// other side sent a kex message for the wrong algorithm,
|
|
||||||
// which we have to ignore.
|
|
||||||
if _, err := t.conn.readPacket(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
kex, ok := kexAlgoMap[algs.kex]
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("ssh: unexpected key exchange algorithm %v", algs.kex)
|
|
||||||
}
|
|
||||||
|
|
||||||
var result *kexResult
|
|
||||||
if len(t.hostKeys) > 0 {
|
|
||||||
result, err = t.server(kex, algs, &magics)
|
|
||||||
} else {
|
|
||||||
result, err = t.client(kex, algs, &magics)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
t.conn.prepareKeyChange(algs, result)
|
|
||||||
if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if packet, err := t.conn.readPacket(); err != nil {
|
|
||||||
return err
|
|
||||||
} else if packet[0] != msgNewKeys {
|
|
||||||
return unexpectedMessageError(msgNewKeys, packet[0])
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *handshakeTransport) server(kex kexAlgorithm, algs *algorithms, magics *handshakeMagics) (*kexResult, error) {
|
|
||||||
var hostKey Signer
|
|
||||||
for _, k := range t.hostKeys {
|
|
||||||
if algs.hostKey == k.PublicKey().Type() {
|
|
||||||
hostKey = k
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
r, err := kex.Server(t.conn, t.config.Rand, magics, hostKey)
|
|
||||||
return r, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *handshakeTransport) client(kex kexAlgorithm, algs *algorithms, magics *handshakeMagics) (*kexResult, error) {
|
|
||||||
result, err := kex.Client(t.conn, t.config.Rand, magics)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
hostKey, err := ParsePublicKey(result.HostKey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := verifyHostKeySignature(hostKey, result); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if t.hostKeyCallback != nil {
|
|
||||||
err = t.hostKeyCallback(t.dialAddress, t.remoteAddr, hostKey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return result, nil
|
|
||||||
}
|
|
|
@ -1,311 +0,0 @@
|
||||||
// Copyright 2013 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package ssh
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"crypto/rand"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
type testChecker struct {
|
|
||||||
calls []string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *testChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
|
|
||||||
if dialAddr == "bad" {
|
|
||||||
return fmt.Errorf("dialAddr is bad")
|
|
||||||
}
|
|
||||||
|
|
||||||
if tcpAddr, ok := addr.(*net.TCPAddr); !ok || tcpAddr == nil {
|
|
||||||
return fmt.Errorf("testChecker: got %T want *net.TCPAddr", addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.calls = append(t.calls, fmt.Sprintf("%s %v %s %x", dialAddr, addr, key.Type(), key.Marshal()))
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// netPipe is analogous to net.Pipe, but it uses a real net.Conn, and
|
|
||||||
// therefore is buffered (net.Pipe deadlocks if both sides start with
|
|
||||||
// a write.)
|
|
||||||
func netPipe() (net.Conn, net.Conn, error) {
|
|
||||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
defer listener.Close()
|
|
||||||
c1, err := net.Dial("tcp", listener.Addr().String())
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
c2, err := listener.Accept()
|
|
||||||
if err != nil {
|
|
||||||
c1.Close()
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return c1, c2, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func handshakePair(clientConf *ClientConfig, addr string) (client *handshakeTransport, server *handshakeTransport, err error) {
|
|
||||||
a, b, err := netPipe()
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
trC := newTransport(a, rand.Reader, true)
|
|
||||||
trS := newTransport(b, rand.Reader, false)
|
|
||||||
clientConf.SetDefaults()
|
|
||||||
|
|
||||||
v := []byte("version")
|
|
||||||
client = newClientTransport(trC, v, v, clientConf, addr, a.RemoteAddr())
|
|
||||||
|
|
||||||
serverConf := &ServerConfig{}
|
|
||||||
serverConf.AddHostKey(testSigners["ecdsa"])
|
|
||||||
serverConf.SetDefaults()
|
|
||||||
server = newServerTransport(trS, v, v, serverConf)
|
|
||||||
|
|
||||||
return client, server, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHandshakeBasic(t *testing.T) {
|
|
||||||
checker := &testChecker{}
|
|
||||||
trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("handshakePair: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
defer trC.Close()
|
|
||||||
defer trS.Close()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
// Client writes a bunch of stuff, and does a key
|
|
||||||
// change in the middle. This should not confuse the
|
|
||||||
// handshake in progress
|
|
||||||
for i := 0; i < 10; i++ {
|
|
||||||
p := []byte{msgRequestSuccess, byte(i)}
|
|
||||||
if err := trC.writePacket(p); err != nil {
|
|
||||||
t.Fatalf("sendPacket: %v", err)
|
|
||||||
}
|
|
||||||
if i == 5 {
|
|
||||||
// halfway through, we request a key change.
|
|
||||||
_, _, err := trC.sendKexInit()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("sendKexInit: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
trC.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Server checks that client messages come in cleanly
|
|
||||||
i := 0
|
|
||||||
for {
|
|
||||||
p, err := trS.readPacket()
|
|
||||||
if err != nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if p[0] == msgNewKeys {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
want := []byte{msgRequestSuccess, byte(i)}
|
|
||||||
if bytes.Compare(p, want) != 0 {
|
|
||||||
t.Errorf("message %d: got %q, want %q", i, p, want)
|
|
||||||
}
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
if i != 10 {
|
|
||||||
t.Errorf("received %d messages, want 10.", i)
|
|
||||||
}
|
|
||||||
|
|
||||||
// If all went well, we registered exactly 1 key change.
|
|
||||||
if len(checker.calls) != 1 {
|
|
||||||
t.Fatalf("got %d host key checks, want 1", len(checker.calls))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub := testSigners["ecdsa"].PublicKey()
|
|
||||||
want := fmt.Sprintf("%s %v %s %x", "addr", trC.remoteAddr, pub.Type(), pub.Marshal())
|
|
||||||
if want != checker.calls[0] {
|
|
||||||
t.Errorf("got %q want %q for host key check", checker.calls[0], want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHandshakeError(t *testing.T) {
|
|
||||||
checker := &testChecker{}
|
|
||||||
trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "bad")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("handshakePair: %v", err)
|
|
||||||
}
|
|
||||||
defer trC.Close()
|
|
||||||
defer trS.Close()
|
|
||||||
|
|
||||||
// send a packet
|
|
||||||
packet := []byte{msgRequestSuccess, 42}
|
|
||||||
if err := trC.writePacket(packet); err != nil {
|
|
||||||
t.Errorf("writePacket: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now request a key change.
|
|
||||||
_, _, err = trC.sendKexInit()
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("sendKexInit: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// the key change will fail, and afterwards we can't write.
|
|
||||||
if err := trC.writePacket([]byte{msgRequestSuccess, 43}); err == nil {
|
|
||||||
t.Errorf("writePacket after botched rekey succeeded.")
|
|
||||||
}
|
|
||||||
|
|
||||||
readback, err := trS.readPacket()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("server closed too soon: %v", err)
|
|
||||||
}
|
|
||||||
if bytes.Compare(readback, packet) != 0 {
|
|
||||||
t.Errorf("got %q want %q", readback, packet)
|
|
||||||
}
|
|
||||||
readback, err = trS.readPacket()
|
|
||||||
if err == nil {
|
|
||||||
t.Errorf("got a message %q after failed key change", readback)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHandshakeTwice(t *testing.T) {
|
|
||||||
checker := &testChecker{}
|
|
||||||
trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("handshakePair: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
defer trC.Close()
|
|
||||||
defer trS.Close()
|
|
||||||
|
|
||||||
// send a packet
|
|
||||||
packet := make([]byte, 5)
|
|
||||||
packet[0] = msgRequestSuccess
|
|
||||||
if err := trC.writePacket(packet); err != nil {
|
|
||||||
t.Errorf("writePacket: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now request a key change.
|
|
||||||
_, _, err = trC.sendKexInit()
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("sendKexInit: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send another packet. Use a fresh one, since writePacket destroys.
|
|
||||||
packet = make([]byte, 5)
|
|
||||||
packet[0] = msgRequestSuccess
|
|
||||||
if err := trC.writePacket(packet); err != nil {
|
|
||||||
t.Errorf("writePacket: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 2nd key change.
|
|
||||||
_, _, err = trC.sendKexInit()
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("sendKexInit: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
packet = make([]byte, 5)
|
|
||||||
packet[0] = msgRequestSuccess
|
|
||||||
if err := trC.writePacket(packet); err != nil {
|
|
||||||
t.Errorf("writePacket: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
packet = make([]byte, 5)
|
|
||||||
packet[0] = msgRequestSuccess
|
|
||||||
for i := 0; i < 5; i++ {
|
|
||||||
msg, err := trS.readPacket()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("server closed too soon: %v", err)
|
|
||||||
}
|
|
||||||
if msg[0] == msgNewKeys {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if bytes.Compare(msg, packet) != 0 {
|
|
||||||
t.Errorf("packet %d: got %q want %q", i, msg, packet)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(checker.calls) != 2 {
|
|
||||||
t.Errorf("got %d key changes, want 2", len(checker.calls))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHandshakeAutoRekeyWrite(t *testing.T) {
|
|
||||||
checker := &testChecker{}
|
|
||||||
clientConf := &ClientConfig{HostKeyCallback: checker.Check}
|
|
||||||
clientConf.RekeyThreshold = 500
|
|
||||||
trC, trS, err := handshakePair(clientConf, "addr")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("handshakePair: %v", err)
|
|
||||||
}
|
|
||||||
defer trC.Close()
|
|
||||||
defer trS.Close()
|
|
||||||
|
|
||||||
for i := 0; i < 5; i++ {
|
|
||||||
packet := make([]byte, 251)
|
|
||||||
packet[0] = msgRequestSuccess
|
|
||||||
if err := trC.writePacket(packet); err != nil {
|
|
||||||
t.Errorf("writePacket: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
j := 0
|
|
||||||
for ; j < 5; j++ {
|
|
||||||
_, err := trS.readPacket()
|
|
||||||
if err != nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if j != 5 {
|
|
||||||
t.Errorf("got %d, want 5 messages", j)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(checker.calls) != 2 {
|
|
||||||
t.Errorf("got %d key changes, wanted 2", len(checker.calls))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type syncChecker struct {
|
|
||||||
called chan int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
|
|
||||||
t.called <- 1
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHandshakeAutoRekeyRead(t *testing.T) {
|
|
||||||
sync := &syncChecker{make(chan int, 2)}
|
|
||||||
clientConf := &ClientConfig{
|
|
||||||
HostKeyCallback: sync.Check,
|
|
||||||
}
|
|
||||||
clientConf.RekeyThreshold = 500
|
|
||||||
|
|
||||||
trC, trS, err := handshakePair(clientConf, "addr")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("handshakePair: %v", err)
|
|
||||||
}
|
|
||||||
defer trC.Close()
|
|
||||||
defer trS.Close()
|
|
||||||
|
|
||||||
packet := make([]byte, 501)
|
|
||||||
packet[0] = msgRequestSuccess
|
|
||||||
if err := trS.writePacket(packet); err != nil {
|
|
||||||
t.Fatalf("writePacket: %v", err)
|
|
||||||
}
|
|
||||||
// While we read out the packet, a key change will be
|
|
||||||
// initiated.
|
|
||||||
if _, err := trC.readPacket(); err != nil {
|
|
||||||
t.Fatalf("readPacket(client): %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
<-sync.called
|
|
||||||
}
|
|
|
@ -1,386 +0,0 @@
|
||||||
// Copyright 2013 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package ssh
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto"
|
|
||||||
"crypto/ecdsa"
|
|
||||||
"crypto/elliptic"
|
|
||||||
"crypto/rand"
|
|
||||||
"errors"
|
|
||||||
"io"
|
|
||||||
"math/big"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
kexAlgoDH1SHA1 = "diffie-hellman-group1-sha1"
|
|
||||||
kexAlgoDH14SHA1 = "diffie-hellman-group14-sha1"
|
|
||||||
kexAlgoECDH256 = "ecdh-sha2-nistp256"
|
|
||||||
kexAlgoECDH384 = "ecdh-sha2-nistp384"
|
|
||||||
kexAlgoECDH521 = "ecdh-sha2-nistp521"
|
|
||||||
)
|
|
||||||
|
|
||||||
// kexResult captures the outcome of a key exchange.
|
|
||||||
type kexResult struct {
|
|
||||||
// Session hash. See also RFC 4253, section 8.
|
|
||||||
H []byte
|
|
||||||
|
|
||||||
// Shared secret. See also RFC 4253, section 8.
|
|
||||||
K []byte
|
|
||||||
|
|
||||||
// Host key as hashed into H.
|
|
||||||
HostKey []byte
|
|
||||||
|
|
||||||
// Signature of H.
|
|
||||||
Signature []byte
|
|
||||||
|
|
||||||
// A cryptographic hash function that matches the security
|
|
||||||
// level of the key exchange algorithm. It is used for
|
|
||||||
// calculating H, and for deriving keys from H and K.
|
|
||||||
Hash crypto.Hash
|
|
||||||
|
|
||||||
// The session ID, which is the first H computed. This is used
|
|
||||||
// to signal data inside transport.
|
|
||||||
SessionID []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
// handshakeMagics contains data that is always included in the
|
|
||||||
// session hash.
|
|
||||||
type handshakeMagics struct {
|
|
||||||
clientVersion, serverVersion []byte
|
|
||||||
clientKexInit, serverKexInit []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *handshakeMagics) write(w io.Writer) {
|
|
||||||
writeString(w, m.clientVersion)
|
|
||||||
writeString(w, m.serverVersion)
|
|
||||||
writeString(w, m.clientKexInit)
|
|
||||||
writeString(w, m.serverKexInit)
|
|
||||||
}
|
|
||||||
|
|
||||||
// kexAlgorithm abstracts different key exchange algorithms.
|
|
||||||
type kexAlgorithm interface {
|
|
||||||
// Server runs server-side key agreement, signing the result
|
|
||||||
// with a hostkey.
|
|
||||||
Server(p packetConn, rand io.Reader, magics *handshakeMagics, s Signer) (*kexResult, error)
|
|
||||||
|
|
||||||
// Client runs the client-side key agreement. Caller is
|
|
||||||
// responsible for verifying the host key signature.
|
|
||||||
Client(p packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// dhGroup is a multiplicative group suitable for implementing Diffie-Hellman key agreement.
|
|
||||||
type dhGroup struct {
|
|
||||||
g, p *big.Int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (group *dhGroup) diffieHellman(theirPublic, myPrivate *big.Int) (*big.Int, error) {
|
|
||||||
if theirPublic.Sign() <= 0 || theirPublic.Cmp(group.p) >= 0 {
|
|
||||||
return nil, errors.New("ssh: DH parameter out of bounds")
|
|
||||||
}
|
|
||||||
return new(big.Int).Exp(theirPublic, myPrivate, group.p), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (group *dhGroup) Client(c packetConn, randSource io.Reader, magics *handshakeMagics) (*kexResult, error) {
|
|
||||||
hashFunc := crypto.SHA1
|
|
||||||
|
|
||||||
x, err := rand.Int(randSource, group.p)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
X := new(big.Int).Exp(group.g, x, group.p)
|
|
||||||
kexDHInit := kexDHInitMsg{
|
|
||||||
X: X,
|
|
||||||
}
|
|
||||||
if err := c.writePacket(Marshal(&kexDHInit)); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
packet, err := c.readPacket()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var kexDHReply kexDHReplyMsg
|
|
||||||
if err = Unmarshal(packet, &kexDHReply); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
kInt, err := group.diffieHellman(kexDHReply.Y, x)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
h := hashFunc.New()
|
|
||||||
magics.write(h)
|
|
||||||
writeString(h, kexDHReply.HostKey)
|
|
||||||
writeInt(h, X)
|
|
||||||
writeInt(h, kexDHReply.Y)
|
|
||||||
K := make([]byte, intLength(kInt))
|
|
||||||
marshalInt(K, kInt)
|
|
||||||
h.Write(K)
|
|
||||||
|
|
||||||
return &kexResult{
|
|
||||||
H: h.Sum(nil),
|
|
||||||
K: K,
|
|
||||||
HostKey: kexDHReply.HostKey,
|
|
||||||
Signature: kexDHReply.Signature,
|
|
||||||
Hash: crypto.SHA1,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (group *dhGroup) Server(c packetConn, randSource io.Reader, magics *handshakeMagics, priv Signer) (result *kexResult, err error) {
|
|
||||||
hashFunc := crypto.SHA1
|
|
||||||
packet, err := c.readPacket()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var kexDHInit kexDHInitMsg
|
|
||||||
if err = Unmarshal(packet, &kexDHInit); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
y, err := rand.Int(randSource, group.p)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
Y := new(big.Int).Exp(group.g, y, group.p)
|
|
||||||
kInt, err := group.diffieHellman(kexDHInit.X, y)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
hostKeyBytes := priv.PublicKey().Marshal()
|
|
||||||
|
|
||||||
h := hashFunc.New()
|
|
||||||
magics.write(h)
|
|
||||||
writeString(h, hostKeyBytes)
|
|
||||||
writeInt(h, kexDHInit.X)
|
|
||||||
writeInt(h, Y)
|
|
||||||
|
|
||||||
K := make([]byte, intLength(kInt))
|
|
||||||
marshalInt(K, kInt)
|
|
||||||
h.Write(K)
|
|
||||||
|
|
||||||
H := h.Sum(nil)
|
|
||||||
|
|
||||||
// H is already a hash, but the hostkey signing will apply its
|
|
||||||
// own key-specific hash algorithm.
|
|
||||||
sig, err := signAndMarshal(priv, randSource, H)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
kexDHReply := kexDHReplyMsg{
|
|
||||||
HostKey: hostKeyBytes,
|
|
||||||
Y: Y,
|
|
||||||
Signature: sig,
|
|
||||||
}
|
|
||||||
packet = Marshal(&kexDHReply)
|
|
||||||
|
|
||||||
err = c.writePacket(packet)
|
|
||||||
return &kexResult{
|
|
||||||
H: H,
|
|
||||||
K: K,
|
|
||||||
HostKey: hostKeyBytes,
|
|
||||||
Signature: sig,
|
|
||||||
Hash: crypto.SHA1,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ecdh performs Elliptic Curve Diffie-Hellman key exchange as
|
|
||||||
// described in RFC 5656, section 4.
|
|
||||||
type ecdh struct {
|
|
||||||
curve elliptic.Curve
|
|
||||||
}
|
|
||||||
|
|
||||||
func (kex *ecdh) Client(c packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error) {
|
|
||||||
ephKey, err := ecdsa.GenerateKey(kex.curve, rand)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
kexInit := kexECDHInitMsg{
|
|
||||||
ClientPubKey: elliptic.Marshal(kex.curve, ephKey.PublicKey.X, ephKey.PublicKey.Y),
|
|
||||||
}
|
|
||||||
|
|
||||||
serialized := Marshal(&kexInit)
|
|
||||||
if err := c.writePacket(serialized); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
packet, err := c.readPacket()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var reply kexECDHReplyMsg
|
|
||||||
if err = Unmarshal(packet, &reply); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
x, y, err := unmarshalECKey(kex.curve, reply.EphemeralPubKey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// generate shared secret
|
|
||||||
secret, _ := kex.curve.ScalarMult(x, y, ephKey.D.Bytes())
|
|
||||||
|
|
||||||
h := ecHash(kex.curve).New()
|
|
||||||
magics.write(h)
|
|
||||||
writeString(h, reply.HostKey)
|
|
||||||
writeString(h, kexInit.ClientPubKey)
|
|
||||||
writeString(h, reply.EphemeralPubKey)
|
|
||||||
K := make([]byte, intLength(secret))
|
|
||||||
marshalInt(K, secret)
|
|
||||||
h.Write(K)
|
|
||||||
|
|
||||||
return &kexResult{
|
|
||||||
H: h.Sum(nil),
|
|
||||||
K: K,
|
|
||||||
HostKey: reply.HostKey,
|
|
||||||
Signature: reply.Signature,
|
|
||||||
Hash: ecHash(kex.curve),
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// unmarshalECKey parses and checks an EC key.
|
|
||||||
func unmarshalECKey(curve elliptic.Curve, pubkey []byte) (x, y *big.Int, err error) {
|
|
||||||
x, y = elliptic.Unmarshal(curve, pubkey)
|
|
||||||
if x == nil {
|
|
||||||
return nil, nil, errors.New("ssh: elliptic.Unmarshal failure")
|
|
||||||
}
|
|
||||||
if !validateECPublicKey(curve, x, y) {
|
|
||||||
return nil, nil, errors.New("ssh: public key not on curve")
|
|
||||||
}
|
|
||||||
return x, y, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// validateECPublicKey checks that the point is a valid public key for
|
|
||||||
// the given curve. See [SEC1], 3.2.2
|
|
||||||
func validateECPublicKey(curve elliptic.Curve, x, y *big.Int) bool {
|
|
||||||
if x.Sign() == 0 && y.Sign() == 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if x.Cmp(curve.Params().P) >= 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if y.Cmp(curve.Params().P) >= 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if !curve.IsOnCurve(x, y) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// We don't check if N * PubKey == 0, since
|
|
||||||
//
|
|
||||||
// - the NIST curves have cofactor = 1, so this is implicit.
|
|
||||||
// (We don't foresee an implementation that supports non NIST
|
|
||||||
// curves)
|
|
||||||
//
|
|
||||||
// - for ephemeral keys, we don't need to worry about small
|
|
||||||
// subgroup attacks.
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (kex *ecdh) Server(c packetConn, rand io.Reader, magics *handshakeMagics, priv Signer) (result *kexResult, err error) {
|
|
||||||
packet, err := c.readPacket()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var kexECDHInit kexECDHInitMsg
|
|
||||||
if err = Unmarshal(packet, &kexECDHInit); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
clientX, clientY, err := unmarshalECKey(kex.curve, kexECDHInit.ClientPubKey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// We could cache this key across multiple users/multiple
|
|
||||||
// connection attempts, but the benefit is small. OpenSSH
|
|
||||||
// generates a new key for each incoming connection.
|
|
||||||
ephKey, err := ecdsa.GenerateKey(kex.curve, rand)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
hostKeyBytes := priv.PublicKey().Marshal()
|
|
||||||
|
|
||||||
serializedEphKey := elliptic.Marshal(kex.curve, ephKey.PublicKey.X, ephKey.PublicKey.Y)
|
|
||||||
|
|
||||||
// generate shared secret
|
|
||||||
secret, _ := kex.curve.ScalarMult(clientX, clientY, ephKey.D.Bytes())
|
|
||||||
|
|
||||||
h := ecHash(kex.curve).New()
|
|
||||||
magics.write(h)
|
|
||||||
writeString(h, hostKeyBytes)
|
|
||||||
writeString(h, kexECDHInit.ClientPubKey)
|
|
||||||
writeString(h, serializedEphKey)
|
|
||||||
|
|
||||||
K := make([]byte, intLength(secret))
|
|
||||||
marshalInt(K, secret)
|
|
||||||
h.Write(K)
|
|
||||||
|
|
||||||
H := h.Sum(nil)
|
|
||||||
|
|
||||||
// H is already a hash, but the hostkey signing will apply its
|
|
||||||
// own key-specific hash algorithm.
|
|
||||||
sig, err := signAndMarshal(priv, rand, H)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
reply := kexECDHReplyMsg{
|
|
||||||
EphemeralPubKey: serializedEphKey,
|
|
||||||
HostKey: hostKeyBytes,
|
|
||||||
Signature: sig,
|
|
||||||
}
|
|
||||||
|
|
||||||
serialized := Marshal(&reply)
|
|
||||||
if err := c.writePacket(serialized); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &kexResult{
|
|
||||||
H: H,
|
|
||||||
K: K,
|
|
||||||
HostKey: reply.HostKey,
|
|
||||||
Signature: sig,
|
|
||||||
Hash: ecHash(kex.curve),
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var kexAlgoMap = map[string]kexAlgorithm{}
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
// This is the group called diffie-hellman-group1-sha1 in RFC
|
|
||||||
// 4253 and Oakley Group 2 in RFC 2409.
|
|
||||||
p, _ := new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF", 16)
|
|
||||||
kexAlgoMap[kexAlgoDH1SHA1] = &dhGroup{
|
|
||||||
g: new(big.Int).SetInt64(2),
|
|
||||||
p: p,
|
|
||||||
}
|
|
||||||
|
|
||||||
// This is the group called diffie-hellman-group14-sha1 in RFC
|
|
||||||
// 4253 and Oakley Group 14 in RFC 3526.
|
|
||||||
p, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF", 16)
|
|
||||||
|
|
||||||
kexAlgoMap[kexAlgoDH14SHA1] = &dhGroup{
|
|
||||||
g: new(big.Int).SetInt64(2),
|
|
||||||
p: p,
|
|
||||||
}
|
|
||||||
|
|
||||||
kexAlgoMap[kexAlgoECDH521] = &ecdh{elliptic.P521()}
|
|
||||||
kexAlgoMap[kexAlgoECDH384] = &ecdh{elliptic.P384()}
|
|
||||||
kexAlgoMap[kexAlgoECDH256] = &ecdh{elliptic.P256()}
|
|
||||||
}
|
|
|
@ -1,48 +0,0 @@
|
||||||
// Copyright 2013 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package ssh
|
|
||||||
|
|
||||||
// Key exchange tests.
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/rand"
|
|
||||||
"reflect"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestKexes(t *testing.T) {
|
|
||||||
type kexResultErr struct {
|
|
||||||
result *kexResult
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
|
|
||||||
for name, kex := range kexAlgoMap {
|
|
||||||
a, b := memPipe()
|
|
||||||
|
|
||||||
s := make(chan kexResultErr, 1)
|
|
||||||
c := make(chan kexResultErr, 1)
|
|
||||||
var magics handshakeMagics
|
|
||||||
go func() {
|
|
||||||
r, e := kex.Client(a, rand.Reader, &magics)
|
|
||||||
c <- kexResultErr{r, e}
|
|
||||||
}()
|
|
||||||
go func() {
|
|
||||||
r, e := kex.Server(b, rand.Reader, &magics, testSigners["ecdsa"])
|
|
||||||
s <- kexResultErr{r, e}
|
|
||||||
}()
|
|
||||||
|
|
||||||
clientRes := <-c
|
|
||||||
serverRes := <-s
|
|
||||||
if clientRes.err != nil {
|
|
||||||
t.Errorf("client: %v", clientRes.err)
|
|
||||||
}
|
|
||||||
if serverRes.err != nil {
|
|
||||||
t.Errorf("server: %v", serverRes.err)
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(clientRes.result, serverRes.result) {
|
|
||||||
t.Errorf("kex %q: mismatch %#v, %#v", name, clientRes.result, serverRes.result)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,628 +0,0 @@
|
||||||
// Copyright 2012 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package ssh
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"crypto"
|
|
||||||
"crypto/dsa"
|
|
||||||
"crypto/ecdsa"
|
|
||||||
"crypto/elliptic"
|
|
||||||
"crypto/rsa"
|
|
||||||
"crypto/x509"
|
|
||||||
"encoding/asn1"
|
|
||||||
"encoding/base64"
|
|
||||||
"encoding/pem"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"math/big"
|
|
||||||
)
|
|
||||||
|
|
||||||
// These constants represent the algorithm names for key types supported by this
|
|
||||||
// package.
|
|
||||||
const (
|
|
||||||
KeyAlgoRSA = "ssh-rsa"
|
|
||||||
KeyAlgoDSA = "ssh-dss"
|
|
||||||
KeyAlgoECDSA256 = "ecdsa-sha2-nistp256"
|
|
||||||
KeyAlgoECDSA384 = "ecdsa-sha2-nistp384"
|
|
||||||
KeyAlgoECDSA521 = "ecdsa-sha2-nistp521"
|
|
||||||
)
|
|
||||||
|
|
||||||
// parsePubKey parses a public key of the given algorithm.
|
|
||||||
// Use ParsePublicKey for keys with prepended algorithm.
|
|
||||||
func parsePubKey(in []byte, algo string) (pubKey PublicKey, rest []byte, err error) {
|
|
||||||
switch algo {
|
|
||||||
case KeyAlgoRSA:
|
|
||||||
return parseRSA(in)
|
|
||||||
case KeyAlgoDSA:
|
|
||||||
return parseDSA(in)
|
|
||||||
case KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521:
|
|
||||||
return parseECDSA(in)
|
|
||||||
case CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01:
|
|
||||||
cert, err := parseCert(in, certToPrivAlgo(algo))
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
return cert, nil, nil
|
|
||||||
}
|
|
||||||
return nil, nil, fmt.Errorf("ssh: unknown key algorithm: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseAuthorizedKey parses a public key in OpenSSH authorized_keys format
|
|
||||||
// (see sshd(8) manual page) once the options and key type fields have been
|
|
||||||
// removed.
|
|
||||||
func parseAuthorizedKey(in []byte) (out PublicKey, comment string, err error) {
|
|
||||||
in = bytes.TrimSpace(in)
|
|
||||||
|
|
||||||
i := bytes.IndexAny(in, " \t")
|
|
||||||
if i == -1 {
|
|
||||||
i = len(in)
|
|
||||||
}
|
|
||||||
base64Key := in[:i]
|
|
||||||
|
|
||||||
key := make([]byte, base64.StdEncoding.DecodedLen(len(base64Key)))
|
|
||||||
n, err := base64.StdEncoding.Decode(key, base64Key)
|
|
||||||
if err != nil {
|
|
||||||
return nil, "", err
|
|
||||||
}
|
|
||||||
key = key[:n]
|
|
||||||
out, err = ParsePublicKey(key)
|
|
||||||
if err != nil {
|
|
||||||
return nil, "", err
|
|
||||||
}
|
|
||||||
comment = string(bytes.TrimSpace(in[i:]))
|
|
||||||
return out, comment, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ParseAuthorizedKeys parses a public key from an authorized_keys
|
|
||||||
// file used in OpenSSH according to the sshd(8) manual page.
|
|
||||||
func ParseAuthorizedKey(in []byte) (out PublicKey, comment string, options []string, rest []byte, err error) {
|
|
||||||
for len(in) > 0 {
|
|
||||||
end := bytes.IndexByte(in, '\n')
|
|
||||||
if end != -1 {
|
|
||||||
rest = in[end+1:]
|
|
||||||
in = in[:end]
|
|
||||||
} else {
|
|
||||||
rest = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
end = bytes.IndexByte(in, '\r')
|
|
||||||
if end != -1 {
|
|
||||||
in = in[:end]
|
|
||||||
}
|
|
||||||
|
|
||||||
in = bytes.TrimSpace(in)
|
|
||||||
if len(in) == 0 || in[0] == '#' {
|
|
||||||
in = rest
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
i := bytes.IndexAny(in, " \t")
|
|
||||||
if i == -1 {
|
|
||||||
in = rest
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if out, comment, err = parseAuthorizedKey(in[i:]); err == nil {
|
|
||||||
return out, comment, options, rest, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// No key type recognised. Maybe there's an options field at
|
|
||||||
// the beginning.
|
|
||||||
var b byte
|
|
||||||
inQuote := false
|
|
||||||
var candidateOptions []string
|
|
||||||
optionStart := 0
|
|
||||||
for i, b = range in {
|
|
||||||
isEnd := !inQuote && (b == ' ' || b == '\t')
|
|
||||||
if (b == ',' && !inQuote) || isEnd {
|
|
||||||
if i-optionStart > 0 {
|
|
||||||
candidateOptions = append(candidateOptions, string(in[optionStart:i]))
|
|
||||||
}
|
|
||||||
optionStart = i + 1
|
|
||||||
}
|
|
||||||
if isEnd {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if b == '"' && (i == 0 || (i > 0 && in[i-1] != '\\')) {
|
|
||||||
inQuote = !inQuote
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for i < len(in) && (in[i] == ' ' || in[i] == '\t') {
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
if i == len(in) {
|
|
||||||
// Invalid line: unmatched quote
|
|
||||||
in = rest
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
in = in[i:]
|
|
||||||
i = bytes.IndexAny(in, " \t")
|
|
||||||
if i == -1 {
|
|
||||||
in = rest
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if out, comment, err = parseAuthorizedKey(in[i:]); err == nil {
|
|
||||||
options = candidateOptions
|
|
||||||
return out, comment, options, rest, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
in = rest
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, "", nil, nil, errors.New("ssh: no key found")
|
|
||||||
}
|
|
||||||
|
|
||||||
// ParsePublicKey parses an SSH public key formatted for use in
|
|
||||||
// the SSH wire protocol according to RFC 4253, section 6.6.
|
|
||||||
func ParsePublicKey(in []byte) (out PublicKey, err error) {
|
|
||||||
algo, in, ok := parseString(in)
|
|
||||||
if !ok {
|
|
||||||
return nil, errShortRead
|
|
||||||
}
|
|
||||||
var rest []byte
|
|
||||||
out, rest, err = parsePubKey(in, string(algo))
|
|
||||||
if len(rest) > 0 {
|
|
||||||
return nil, errors.New("ssh: trailing junk in public key")
|
|
||||||
}
|
|
||||||
|
|
||||||
return out, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// MarshalAuthorizedKey serializes key for inclusion in an OpenSSH
|
|
||||||
// authorized_keys file. The return value ends with newline.
|
|
||||||
func MarshalAuthorizedKey(key PublicKey) []byte {
|
|
||||||
b := &bytes.Buffer{}
|
|
||||||
b.WriteString(key.Type())
|
|
||||||
b.WriteByte(' ')
|
|
||||||
e := base64.NewEncoder(base64.StdEncoding, b)
|
|
||||||
e.Write(key.Marshal())
|
|
||||||
e.Close()
|
|
||||||
b.WriteByte('\n')
|
|
||||||
return b.Bytes()
|
|
||||||
}
|
|
||||||
|
|
||||||
// PublicKey is an abstraction of different types of public keys.
|
|
||||||
type PublicKey interface {
|
|
||||||
// Type returns the key's type, e.g. "ssh-rsa".
|
|
||||||
Type() string
|
|
||||||
|
|
||||||
// Marshal returns the serialized key data in SSH wire format,
|
|
||||||
// with the name prefix.
|
|
||||||
Marshal() []byte
|
|
||||||
|
|
||||||
// Verify that sig is a signature on the given data using this
|
|
||||||
// key. This function will hash the data appropriately first.
|
|
||||||
Verify(data []byte, sig *Signature) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// A Signer can create signatures that verify against a public key.
|
|
||||||
type Signer interface {
|
|
||||||
// PublicKey returns an associated PublicKey instance.
|
|
||||||
PublicKey() PublicKey
|
|
||||||
|
|
||||||
// Sign returns raw signature for the given data. This method
|
|
||||||
// will apply the hash specified for the keytype to the data.
|
|
||||||
Sign(rand io.Reader, data []byte) (*Signature, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type rsaPublicKey rsa.PublicKey
|
|
||||||
|
|
||||||
func (r *rsaPublicKey) Type() string {
|
|
||||||
return "ssh-rsa"
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseRSA parses an RSA key according to RFC 4253, section 6.6.
|
|
||||||
func parseRSA(in []byte) (out PublicKey, rest []byte, err error) {
|
|
||||||
var w struct {
|
|
||||||
E *big.Int
|
|
||||||
N *big.Int
|
|
||||||
Rest []byte `ssh:"rest"`
|
|
||||||
}
|
|
||||||
if err := Unmarshal(in, &w); err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if w.E.BitLen() > 24 {
|
|
||||||
return nil, nil, errors.New("ssh: exponent too large")
|
|
||||||
}
|
|
||||||
e := w.E.Int64()
|
|
||||||
if e < 3 || e&1 == 0 {
|
|
||||||
return nil, nil, errors.New("ssh: incorrect exponent")
|
|
||||||
}
|
|
||||||
|
|
||||||
var key rsa.PublicKey
|
|
||||||
key.E = int(e)
|
|
||||||
key.N = w.N
|
|
||||||
return (*rsaPublicKey)(&key), w.Rest, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *rsaPublicKey) Marshal() []byte {
|
|
||||||
e := new(big.Int).SetInt64(int64(r.E))
|
|
||||||
wirekey := struct {
|
|
||||||
Name string
|
|
||||||
E *big.Int
|
|
||||||
N *big.Int
|
|
||||||
}{
|
|
||||||
KeyAlgoRSA,
|
|
||||||
e,
|
|
||||||
r.N,
|
|
||||||
}
|
|
||||||
return Marshal(&wirekey)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *rsaPublicKey) Verify(data []byte, sig *Signature) error {
|
|
||||||
if sig.Format != r.Type() {
|
|
||||||
return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, r.Type())
|
|
||||||
}
|
|
||||||
h := crypto.SHA1.New()
|
|
||||||
h.Write(data)
|
|
||||||
digest := h.Sum(nil)
|
|
||||||
return rsa.VerifyPKCS1v15((*rsa.PublicKey)(r), crypto.SHA1, digest, sig.Blob)
|
|
||||||
}
|
|
||||||
|
|
||||||
type rsaPrivateKey struct {
|
|
||||||
*rsa.PrivateKey
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *rsaPrivateKey) PublicKey() PublicKey {
|
|
||||||
return (*rsaPublicKey)(&r.PrivateKey.PublicKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *rsaPrivateKey) Sign(rand io.Reader, data []byte) (*Signature, error) {
|
|
||||||
h := crypto.SHA1.New()
|
|
||||||
h.Write(data)
|
|
||||||
digest := h.Sum(nil)
|
|
||||||
blob, err := rsa.SignPKCS1v15(rand, r.PrivateKey, crypto.SHA1, digest)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &Signature{
|
|
||||||
Format: r.PublicKey().Type(),
|
|
||||||
Blob: blob,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type dsaPublicKey dsa.PublicKey
|
|
||||||
|
|
||||||
func (r *dsaPublicKey) Type() string {
|
|
||||||
return "ssh-dss"
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseDSA parses an DSA key according to RFC 4253, section 6.6.
|
|
||||||
func parseDSA(in []byte) (out PublicKey, rest []byte, err error) {
|
|
||||||
var w struct {
|
|
||||||
P, Q, G, Y *big.Int
|
|
||||||
Rest []byte `ssh:"rest"`
|
|
||||||
}
|
|
||||||
if err := Unmarshal(in, &w); err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
key := &dsaPublicKey{
|
|
||||||
Parameters: dsa.Parameters{
|
|
||||||
P: w.P,
|
|
||||||
Q: w.Q,
|
|
||||||
G: w.G,
|
|
||||||
},
|
|
||||||
Y: w.Y,
|
|
||||||
}
|
|
||||||
return key, w.Rest, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (k *dsaPublicKey) Marshal() []byte {
|
|
||||||
w := struct {
|
|
||||||
Name string
|
|
||||||
P, Q, G, Y *big.Int
|
|
||||||
}{
|
|
||||||
k.Type(),
|
|
||||||
k.P,
|
|
||||||
k.Q,
|
|
||||||
k.G,
|
|
||||||
k.Y,
|
|
||||||
}
|
|
||||||
|
|
||||||
return Marshal(&w)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (k *dsaPublicKey) Verify(data []byte, sig *Signature) error {
|
|
||||||
if sig.Format != k.Type() {
|
|
||||||
return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type())
|
|
||||||
}
|
|
||||||
h := crypto.SHA1.New()
|
|
||||||
h.Write(data)
|
|
||||||
digest := h.Sum(nil)
|
|
||||||
|
|
||||||
// Per RFC 4253, section 6.6,
|
|
||||||
// The value for 'dss_signature_blob' is encoded as a string containing
|
|
||||||
// r, followed by s (which are 160-bit integers, without lengths or
|
|
||||||
// padding, unsigned, and in network byte order).
|
|
||||||
// For DSS purposes, sig.Blob should be exactly 40 bytes in length.
|
|
||||||
if len(sig.Blob) != 40 {
|
|
||||||
return errors.New("ssh: DSA signature parse error")
|
|
||||||
}
|
|
||||||
r := new(big.Int).SetBytes(sig.Blob[:20])
|
|
||||||
s := new(big.Int).SetBytes(sig.Blob[20:])
|
|
||||||
if dsa.Verify((*dsa.PublicKey)(k), digest, r, s) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return errors.New("ssh: signature did not verify")
|
|
||||||
}
|
|
||||||
|
|
||||||
type dsaPrivateKey struct {
|
|
||||||
*dsa.PrivateKey
|
|
||||||
}
|
|
||||||
|
|
||||||
func (k *dsaPrivateKey) PublicKey() PublicKey {
|
|
||||||
return (*dsaPublicKey)(&k.PrivateKey.PublicKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (k *dsaPrivateKey) Sign(rand io.Reader, data []byte) (*Signature, error) {
|
|
||||||
h := crypto.SHA1.New()
|
|
||||||
h.Write(data)
|
|
||||||
digest := h.Sum(nil)
|
|
||||||
r, s, err := dsa.Sign(rand, k.PrivateKey, digest)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
sig := make([]byte, 40)
|
|
||||||
rb := r.Bytes()
|
|
||||||
sb := s.Bytes()
|
|
||||||
|
|
||||||
copy(sig[20-len(rb):20], rb)
|
|
||||||
copy(sig[40-len(sb):], sb)
|
|
||||||
|
|
||||||
return &Signature{
|
|
||||||
Format: k.PublicKey().Type(),
|
|
||||||
Blob: sig,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type ecdsaPublicKey ecdsa.PublicKey
|
|
||||||
|
|
||||||
func (key *ecdsaPublicKey) Type() string {
|
|
||||||
return "ecdsa-sha2-" + key.nistID()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (key *ecdsaPublicKey) nistID() string {
|
|
||||||
switch key.Params().BitSize {
|
|
||||||
case 256:
|
|
||||||
return "nistp256"
|
|
||||||
case 384:
|
|
||||||
return "nistp384"
|
|
||||||
case 521:
|
|
||||||
return "nistp521"
|
|
||||||
}
|
|
||||||
panic("ssh: unsupported ecdsa key size")
|
|
||||||
}
|
|
||||||
|
|
||||||
func supportedEllipticCurve(curve elliptic.Curve) bool {
|
|
||||||
return curve == elliptic.P256() || curve == elliptic.P384() || curve == elliptic.P521()
|
|
||||||
}
|
|
||||||
|
|
||||||
// ecHash returns the hash to match the given elliptic curve, see RFC
|
|
||||||
// 5656, section 6.2.1
|
|
||||||
func ecHash(curve elliptic.Curve) crypto.Hash {
|
|
||||||
bitSize := curve.Params().BitSize
|
|
||||||
switch {
|
|
||||||
case bitSize <= 256:
|
|
||||||
return crypto.SHA256
|
|
||||||
case bitSize <= 384:
|
|
||||||
return crypto.SHA384
|
|
||||||
}
|
|
||||||
return crypto.SHA512
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseECDSA parses an ECDSA key according to RFC 5656, section 3.1.
|
|
||||||
func parseECDSA(in []byte) (out PublicKey, rest []byte, err error) {
|
|
||||||
identifier, in, ok := parseString(in)
|
|
||||||
if !ok {
|
|
||||||
return nil, nil, errShortRead
|
|
||||||
}
|
|
||||||
|
|
||||||
key := new(ecdsa.PublicKey)
|
|
||||||
|
|
||||||
switch string(identifier) {
|
|
||||||
case "nistp256":
|
|
||||||
key.Curve = elliptic.P256()
|
|
||||||
case "nistp384":
|
|
||||||
key.Curve = elliptic.P384()
|
|
||||||
case "nistp521":
|
|
||||||
key.Curve = elliptic.P521()
|
|
||||||
default:
|
|
||||||
return nil, nil, errors.New("ssh: unsupported curve")
|
|
||||||
}
|
|
||||||
|
|
||||||
var keyBytes []byte
|
|
||||||
if keyBytes, in, ok = parseString(in); !ok {
|
|
||||||
return nil, nil, errShortRead
|
|
||||||
}
|
|
||||||
|
|
||||||
key.X, key.Y = elliptic.Unmarshal(key.Curve, keyBytes)
|
|
||||||
if key.X == nil || key.Y == nil {
|
|
||||||
return nil, nil, errors.New("ssh: invalid curve point")
|
|
||||||
}
|
|
||||||
return (*ecdsaPublicKey)(key), in, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (key *ecdsaPublicKey) Marshal() []byte {
|
|
||||||
// See RFC 5656, section 3.1.
|
|
||||||
keyBytes := elliptic.Marshal(key.Curve, key.X, key.Y)
|
|
||||||
w := struct {
|
|
||||||
Name string
|
|
||||||
ID string
|
|
||||||
Key []byte
|
|
||||||
}{
|
|
||||||
key.Type(),
|
|
||||||
key.nistID(),
|
|
||||||
keyBytes,
|
|
||||||
}
|
|
||||||
|
|
||||||
return Marshal(&w)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (key *ecdsaPublicKey) Verify(data []byte, sig *Signature) error {
|
|
||||||
if sig.Format != key.Type() {
|
|
||||||
return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, key.Type())
|
|
||||||
}
|
|
||||||
|
|
||||||
h := ecHash(key.Curve).New()
|
|
||||||
h.Write(data)
|
|
||||||
digest := h.Sum(nil)
|
|
||||||
|
|
||||||
// Per RFC 5656, section 3.1.2,
|
|
||||||
// The ecdsa_signature_blob value has the following specific encoding:
|
|
||||||
// mpint r
|
|
||||||
// mpint s
|
|
||||||
var ecSig struct {
|
|
||||||
R *big.Int
|
|
||||||
S *big.Int
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := Unmarshal(sig.Blob, &ecSig); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if ecdsa.Verify((*ecdsa.PublicKey)(key), digest, ecSig.R, ecSig.S) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return errors.New("ssh: signature did not verify")
|
|
||||||
}
|
|
||||||
|
|
||||||
type ecdsaPrivateKey struct {
|
|
||||||
*ecdsa.PrivateKey
|
|
||||||
}
|
|
||||||
|
|
||||||
func (k *ecdsaPrivateKey) PublicKey() PublicKey {
|
|
||||||
return (*ecdsaPublicKey)(&k.PrivateKey.PublicKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (k *ecdsaPrivateKey) Sign(rand io.Reader, data []byte) (*Signature, error) {
|
|
||||||
h := ecHash(k.PrivateKey.PublicKey.Curve).New()
|
|
||||||
h.Write(data)
|
|
||||||
digest := h.Sum(nil)
|
|
||||||
r, s, err := ecdsa.Sign(rand, k.PrivateKey, digest)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
sig := make([]byte, intLength(r)+intLength(s))
|
|
||||||
rest := marshalInt(sig, r)
|
|
||||||
marshalInt(rest, s)
|
|
||||||
return &Signature{
|
|
||||||
Format: k.PublicKey().Type(),
|
|
||||||
Blob: sig,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewSignerFromKey takes a pointer to rsa, dsa or ecdsa PrivateKey
|
|
||||||
// returns a corresponding Signer instance. EC keys should use P256,
|
|
||||||
// P384 or P521.
|
|
||||||
func NewSignerFromKey(k interface{}) (Signer, error) {
|
|
||||||
var sshKey Signer
|
|
||||||
switch t := k.(type) {
|
|
||||||
case *rsa.PrivateKey:
|
|
||||||
sshKey = &rsaPrivateKey{t}
|
|
||||||
case *dsa.PrivateKey:
|
|
||||||
sshKey = &dsaPrivateKey{t}
|
|
||||||
case *ecdsa.PrivateKey:
|
|
||||||
if !supportedEllipticCurve(t.Curve) {
|
|
||||||
return nil, errors.New("ssh: only P256, P384 and P521 EC keys are supported.")
|
|
||||||
}
|
|
||||||
|
|
||||||
sshKey = &ecdsaPrivateKey{t}
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("ssh: unsupported key type %T", k)
|
|
||||||
}
|
|
||||||
return sshKey, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewPublicKey takes a pointer to rsa, dsa or ecdsa PublicKey
|
|
||||||
// and returns a corresponding ssh PublicKey instance. EC keys should use P256, P384 or P521.
|
|
||||||
func NewPublicKey(k interface{}) (PublicKey, error) {
|
|
||||||
var sshKey PublicKey
|
|
||||||
switch t := k.(type) {
|
|
||||||
case *rsa.PublicKey:
|
|
||||||
sshKey = (*rsaPublicKey)(t)
|
|
||||||
case *ecdsa.PublicKey:
|
|
||||||
if !supportedEllipticCurve(t.Curve) {
|
|
||||||
return nil, errors.New("ssh: only P256, P384 and P521 EC keys are supported.")
|
|
||||||
}
|
|
||||||
sshKey = (*ecdsaPublicKey)(t)
|
|
||||||
case *dsa.PublicKey:
|
|
||||||
sshKey = (*dsaPublicKey)(t)
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("ssh: unsupported key type %T", k)
|
|
||||||
}
|
|
||||||
return sshKey, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ParsePrivateKey returns a Signer from a PEM encoded private key. It supports
|
|
||||||
// the same keys as ParseRawPrivateKey.
|
|
||||||
func ParsePrivateKey(pemBytes []byte) (Signer, error) {
|
|
||||||
key, err := ParseRawPrivateKey(pemBytes)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return NewSignerFromKey(key)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ParseRawPrivateKey returns a private key from a PEM encoded private key. It
|
|
||||||
// supports RSA (PKCS#1), DSA (OpenSSL), and ECDSA private keys.
|
|
||||||
func ParseRawPrivateKey(pemBytes []byte) (interface{}, error) {
|
|
||||||
block, _ := pem.Decode(pemBytes)
|
|
||||||
if block == nil {
|
|
||||||
return nil, errors.New("ssh: no key found")
|
|
||||||
}
|
|
||||||
|
|
||||||
switch block.Type {
|
|
||||||
case "RSA PRIVATE KEY":
|
|
||||||
return x509.ParsePKCS1PrivateKey(block.Bytes)
|
|
||||||
case "EC PRIVATE KEY":
|
|
||||||
return x509.ParseECPrivateKey(block.Bytes)
|
|
||||||
case "DSA PRIVATE KEY":
|
|
||||||
return ParseDSAPrivateKey(block.Bytes)
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("ssh: unsupported key type %q", block.Type)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ParseDSAPrivateKey returns a DSA private key from its ASN.1 DER encoding, as
|
|
||||||
// specified by the OpenSSL DSA man page.
|
|
||||||
func ParseDSAPrivateKey(der []byte) (*dsa.PrivateKey, error) {
|
|
||||||
var k struct {
|
|
||||||
Version int
|
|
||||||
P *big.Int
|
|
||||||
Q *big.Int
|
|
||||||
G *big.Int
|
|
||||||
Priv *big.Int
|
|
||||||
Pub *big.Int
|
|
||||||
}
|
|
||||||
rest, err := asn1.Unmarshal(der, &k)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.New("ssh: failed to parse DSA key: " + err.Error())
|
|
||||||
}
|
|
||||||
if len(rest) > 0 {
|
|
||||||
return nil, errors.New("ssh: garbage after DSA key")
|
|
||||||
}
|
|
||||||
|
|
||||||
return &dsa.PrivateKey{
|
|
||||||
PublicKey: dsa.PublicKey{
|
|
||||||
Parameters: dsa.Parameters{
|
|
||||||
P: k.P,
|
|
||||||
Q: k.Q,
|
|
||||||
G: k.G,
|
|
||||||
},
|
|
||||||
Y: k.Priv,
|
|
||||||
},
|
|
||||||
X: k.Pub,
|
|
||||||
}, nil
|
|
||||||
}
|
|
|
@ -1,306 +0,0 @@
|
||||||
// Copyright 2014 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package ssh
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"crypto/dsa"
|
|
||||||
"crypto/ecdsa"
|
|
||||||
"crypto/elliptic"
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/rsa"
|
|
||||||
"encoding/base64"
|
|
||||||
"fmt"
|
|
||||||
"reflect"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh/testdata"
|
|
||||||
)
|
|
||||||
|
|
||||||
func rawKey(pub PublicKey) interface{} {
|
|
||||||
switch k := pub.(type) {
|
|
||||||
case *rsaPublicKey:
|
|
||||||
return (*rsa.PublicKey)(k)
|
|
||||||
case *dsaPublicKey:
|
|
||||||
return (*dsa.PublicKey)(k)
|
|
||||||
case *ecdsaPublicKey:
|
|
||||||
return (*ecdsa.PublicKey)(k)
|
|
||||||
case *Certificate:
|
|
||||||
return k
|
|
||||||
}
|
|
||||||
panic("unknown key type")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestKeyMarshalParse(t *testing.T) {
|
|
||||||
for _, priv := range testSigners {
|
|
||||||
pub := priv.PublicKey()
|
|
||||||
roundtrip, err := ParsePublicKey(pub.Marshal())
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("ParsePublicKey(%T): %v", pub, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
k1 := rawKey(pub)
|
|
||||||
k2 := rawKey(roundtrip)
|
|
||||||
|
|
||||||
if !reflect.DeepEqual(k1, k2) {
|
|
||||||
t.Errorf("got %#v in roundtrip, want %#v", k2, k1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUnsupportedCurves(t *testing.T) {
|
|
||||||
raw, err := ecdsa.GenerateKey(elliptic.P224(), rand.Reader)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("GenerateKey: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err = NewSignerFromKey(raw); err == nil || !strings.Contains(err.Error(), "only P256") {
|
|
||||||
t.Fatalf("NewPrivateKey should not succeed with P224, got: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err = NewPublicKey(&raw.PublicKey); err == nil || !strings.Contains(err.Error(), "only P256") {
|
|
||||||
t.Fatalf("NewPublicKey should not succeed with P224, got: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewPublicKey(t *testing.T) {
|
|
||||||
for _, k := range testSigners {
|
|
||||||
raw := rawKey(k.PublicKey())
|
|
||||||
// Skip certificates, as NewPublicKey does not support them.
|
|
||||||
if _, ok := raw.(*Certificate); ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
pub, err := NewPublicKey(raw)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("NewPublicKey(%#v): %v", raw, err)
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(k.PublicKey(), pub) {
|
|
||||||
t.Errorf("NewPublicKey(%#v) = %#v, want %#v", raw, pub, k.PublicKey())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestKeySignVerify(t *testing.T) {
|
|
||||||
for _, priv := range testSigners {
|
|
||||||
pub := priv.PublicKey()
|
|
||||||
|
|
||||||
data := []byte("sign me")
|
|
||||||
sig, err := priv.Sign(rand.Reader, data)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Sign(%T): %v", priv, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := pub.Verify(data, sig); err != nil {
|
|
||||||
t.Errorf("publicKey.Verify(%T): %v", priv, err)
|
|
||||||
}
|
|
||||||
sig.Blob[5]++
|
|
||||||
if err := pub.Verify(data, sig); err == nil {
|
|
||||||
t.Errorf("publicKey.Verify on broken sig did not fail")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseRSAPrivateKey(t *testing.T) {
|
|
||||||
key := testPrivateKeys["rsa"]
|
|
||||||
|
|
||||||
rsa, ok := key.(*rsa.PrivateKey)
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("got %T, want *rsa.PrivateKey", rsa)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := rsa.Validate(); err != nil {
|
|
||||||
t.Errorf("Validate: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseECPrivateKey(t *testing.T) {
|
|
||||||
key := testPrivateKeys["ecdsa"]
|
|
||||||
|
|
||||||
ecKey, ok := key.(*ecdsa.PrivateKey)
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("got %T, want *ecdsa.PrivateKey", ecKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !validateECPublicKey(ecKey.Curve, ecKey.X, ecKey.Y) {
|
|
||||||
t.Fatalf("public key does not validate.")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseDSA(t *testing.T) {
|
|
||||||
// We actually exercise the ParsePrivateKey codepath here, as opposed to
|
|
||||||
// using the ParseRawPrivateKey+NewSignerFromKey path that testdata_test.go
|
|
||||||
// uses.
|
|
||||||
s, err := ParsePrivateKey(testdata.PEMBytes["dsa"])
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("ParsePrivateKey returned error: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
data := []byte("sign me")
|
|
||||||
sig, err := s.Sign(rand.Reader, data)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("dsa.Sign: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := s.PublicKey().Verify(data, sig); err != nil {
|
|
||||||
t.Errorf("Verify failed: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Tests for authorized_keys parsing.
|
|
||||||
|
|
||||||
// getTestKey returns a public key, and its base64 encoding.
|
|
||||||
func getTestKey() (PublicKey, string) {
|
|
||||||
k := testPublicKeys["rsa"]
|
|
||||||
|
|
||||||
b := &bytes.Buffer{}
|
|
||||||
e := base64.NewEncoder(base64.StdEncoding, b)
|
|
||||||
e.Write(k.Marshal())
|
|
||||||
e.Close()
|
|
||||||
|
|
||||||
return k, b.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMarshalParsePublicKey(t *testing.T) {
|
|
||||||
pub, pubSerialized := getTestKey()
|
|
||||||
line := fmt.Sprintf("%s %s user@host", pub.Type(), pubSerialized)
|
|
||||||
|
|
||||||
authKeys := MarshalAuthorizedKey(pub)
|
|
||||||
actualFields := strings.Fields(string(authKeys))
|
|
||||||
if len(actualFields) == 0 {
|
|
||||||
t.Fatalf("failed authKeys: %v", authKeys)
|
|
||||||
}
|
|
||||||
|
|
||||||
// drop the comment
|
|
||||||
expectedFields := strings.Fields(line)[0:2]
|
|
||||||
|
|
||||||
if !reflect.DeepEqual(actualFields, expectedFields) {
|
|
||||||
t.Errorf("got %v, expected %v", actualFields, expectedFields)
|
|
||||||
}
|
|
||||||
|
|
||||||
actPub, _, _, _, err := ParseAuthorizedKey([]byte(line))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("cannot parse %v: %v", line, err)
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(actPub, pub) {
|
|
||||||
t.Errorf("got %v, expected %v", actPub, pub)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type authResult struct {
|
|
||||||
pubKey PublicKey
|
|
||||||
options []string
|
|
||||||
comments string
|
|
||||||
rest string
|
|
||||||
ok bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func testAuthorizedKeys(t *testing.T, authKeys []byte, expected []authResult) {
|
|
||||||
rest := authKeys
|
|
||||||
var values []authResult
|
|
||||||
for len(rest) > 0 {
|
|
||||||
var r authResult
|
|
||||||
var err error
|
|
||||||
r.pubKey, r.comments, r.options, rest, err = ParseAuthorizedKey(rest)
|
|
||||||
r.ok = (err == nil)
|
|
||||||
t.Log(err)
|
|
||||||
r.rest = string(rest)
|
|
||||||
values = append(values, r)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !reflect.DeepEqual(values, expected) {
|
|
||||||
t.Errorf("got %#v, expected %#v", values, expected)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAuthorizedKeyBasic(t *testing.T) {
|
|
||||||
pub, pubSerialized := getTestKey()
|
|
||||||
line := "ssh-rsa " + pubSerialized + " user@host"
|
|
||||||
testAuthorizedKeys(t, []byte(line),
|
|
||||||
[]authResult{
|
|
||||||
{pub, nil, "user@host", "", true},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAuth(t *testing.T) {
|
|
||||||
pub, pubSerialized := getTestKey()
|
|
||||||
authWithOptions := []string{
|
|
||||||
`# comments to ignore before any keys...`,
|
|
||||||
``,
|
|
||||||
`env="HOME=/home/root",no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host`,
|
|
||||||
`# comments to ignore, along with a blank line`,
|
|
||||||
``,
|
|
||||||
`env="HOME=/home/root2" ssh-rsa ` + pubSerialized + ` user2@host2`,
|
|
||||||
``,
|
|
||||||
`# more comments, plus a invalid entry`,
|
|
||||||
`ssh-rsa data-that-will-not-parse user@host3`,
|
|
||||||
}
|
|
||||||
for _, eol := range []string{"\n", "\r\n"} {
|
|
||||||
authOptions := strings.Join(authWithOptions, eol)
|
|
||||||
rest2 := strings.Join(authWithOptions[3:], eol)
|
|
||||||
rest3 := strings.Join(authWithOptions[6:], eol)
|
|
||||||
testAuthorizedKeys(t, []byte(authOptions), []authResult{
|
|
||||||
{pub, []string{`env="HOME=/home/root"`, "no-port-forwarding"}, "user@host", rest2, true},
|
|
||||||
{pub, []string{`env="HOME=/home/root2"`}, "user2@host2", rest3, true},
|
|
||||||
{nil, nil, "", "", false},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAuthWithQuotedSpaceInEnv(t *testing.T) {
|
|
||||||
pub, pubSerialized := getTestKey()
|
|
||||||
authWithQuotedSpaceInEnv := []byte(`env="HOME=/home/root dir",no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host`)
|
|
||||||
testAuthorizedKeys(t, []byte(authWithQuotedSpaceInEnv), []authResult{
|
|
||||||
{pub, []string{`env="HOME=/home/root dir"`, "no-port-forwarding"}, "user@host", "", true},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAuthWithQuotedCommaInEnv(t *testing.T) {
|
|
||||||
pub, pubSerialized := getTestKey()
|
|
||||||
authWithQuotedCommaInEnv := []byte(`env="HOME=/home/root,dir",no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host`)
|
|
||||||
testAuthorizedKeys(t, []byte(authWithQuotedCommaInEnv), []authResult{
|
|
||||||
{pub, []string{`env="HOME=/home/root,dir"`, "no-port-forwarding"}, "user@host", "", true},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAuthWithQuotedQuoteInEnv(t *testing.T) {
|
|
||||||
pub, pubSerialized := getTestKey()
|
|
||||||
authWithQuotedQuoteInEnv := []byte(`env="HOME=/home/\"root dir",no-port-forwarding` + "\t" + `ssh-rsa` + "\t" + pubSerialized + ` user@host`)
|
|
||||||
authWithDoubleQuotedQuote := []byte(`no-port-forwarding,env="HOME=/home/ \"root dir\"" ssh-rsa ` + pubSerialized + "\t" + `user@host`)
|
|
||||||
testAuthorizedKeys(t, []byte(authWithQuotedQuoteInEnv), []authResult{
|
|
||||||
{pub, []string{`env="HOME=/home/\"root dir"`, "no-port-forwarding"}, "user@host", "", true},
|
|
||||||
})
|
|
||||||
|
|
||||||
testAuthorizedKeys(t, []byte(authWithDoubleQuotedQuote), []authResult{
|
|
||||||
{pub, []string{"no-port-forwarding", `env="HOME=/home/ \"root dir\""`}, "user@host", "", true},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAuthWithInvalidSpace(t *testing.T) {
|
|
||||||
_, pubSerialized := getTestKey()
|
|
||||||
authWithInvalidSpace := []byte(`env="HOME=/home/root dir", no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host
|
|
||||||
#more to follow but still no valid keys`)
|
|
||||||
testAuthorizedKeys(t, []byte(authWithInvalidSpace), []authResult{
|
|
||||||
{nil, nil, "", "", false},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAuthWithMissingQuote(t *testing.T) {
|
|
||||||
pub, pubSerialized := getTestKey()
|
|
||||||
authWithMissingQuote := []byte(`env="HOME=/home/root,no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host
|
|
||||||
env="HOME=/home/root",shared-control ssh-rsa ` + pubSerialized + ` user@host`)
|
|
||||||
|
|
||||||
testAuthorizedKeys(t, []byte(authWithMissingQuote), []authResult{
|
|
||||||
{pub, []string{`env="HOME=/home/root"`, `shared-control`}, "user@host", "", true},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestInvalidEntry(t *testing.T) {
|
|
||||||
authInvalid := []byte(`ssh-rsa`)
|
|
||||||
_, _, _, _, err := ParseAuthorizedKey(authInvalid)
|
|
||||||
if err == nil {
|
|
||||||
t.Errorf("got valid entry for %q", authInvalid)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,53 +0,0 @@
|
||||||
// Copyright 2012 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package ssh
|
|
||||||
|
|
||||||
// Message authentication support
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/hmac"
|
|
||||||
"crypto/sha1"
|
|
||||||
"hash"
|
|
||||||
)
|
|
||||||
|
|
||||||
type macMode struct {
|
|
||||||
keySize int
|
|
||||||
new func(key []byte) hash.Hash
|
|
||||||
}
|
|
||||||
|
|
||||||
// truncatingMAC wraps around a hash.Hash and truncates the output digest to
|
|
||||||
// a given size.
|
|
||||||
type truncatingMAC struct {
|
|
||||||
length int
|
|
||||||
hmac hash.Hash
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t truncatingMAC) Write(data []byte) (int, error) {
|
|
||||||
return t.hmac.Write(data)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t truncatingMAC) Sum(in []byte) []byte {
|
|
||||||
out := t.hmac.Sum(in)
|
|
||||||
return out[:len(in)+t.length]
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t truncatingMAC) Reset() {
|
|
||||||
t.hmac.Reset()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t truncatingMAC) Size() int {
|
|
||||||
return t.length
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t truncatingMAC) BlockSize() int { return t.hmac.BlockSize() }
|
|
||||||
|
|
||||||
var macModes = map[string]*macMode{
|
|
||||||
"hmac-sha1": {20, func(key []byte) hash.Hash {
|
|
||||||
return hmac.New(sha1.New, key)
|
|
||||||
}},
|
|
||||||
"hmac-sha1-96": {20, func(key []byte) hash.Hash {
|
|
||||||
return truncatingMAC{12, hmac.New(sha1.New, key)}
|
|
||||||
}},
|
|
||||||
}
|
|
|
@ -1,110 +0,0 @@
|
||||||
// Copyright 2013 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package ssh
|
|
||||||
|
|
||||||
import (
|
|
||||||
"io"
|
|
||||||
"sync"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
// An in-memory packetConn. It is safe to call Close and writePacket
|
|
||||||
// from different goroutines.
|
|
||||||
type memTransport struct {
|
|
||||||
eof bool
|
|
||||||
pending [][]byte
|
|
||||||
write *memTransport
|
|
||||||
sync.Mutex
|
|
||||||
*sync.Cond
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *memTransport) readPacket() ([]byte, error) {
|
|
||||||
t.Lock()
|
|
||||||
defer t.Unlock()
|
|
||||||
for {
|
|
||||||
if len(t.pending) > 0 {
|
|
||||||
r := t.pending[0]
|
|
||||||
t.pending = t.pending[1:]
|
|
||||||
return r, nil
|
|
||||||
}
|
|
||||||
if t.eof {
|
|
||||||
return nil, io.EOF
|
|
||||||
}
|
|
||||||
t.Cond.Wait()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *memTransport) closeSelf() error {
|
|
||||||
t.Lock()
|
|
||||||
defer t.Unlock()
|
|
||||||
if t.eof {
|
|
||||||
return io.EOF
|
|
||||||
}
|
|
||||||
t.eof = true
|
|
||||||
t.Cond.Broadcast()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *memTransport) Close() error {
|
|
||||||
err := t.write.closeSelf()
|
|
||||||
t.closeSelf()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *memTransport) writePacket(p []byte) error {
|
|
||||||
t.write.Lock()
|
|
||||||
defer t.write.Unlock()
|
|
||||||
if t.write.eof {
|
|
||||||
return io.EOF
|
|
||||||
}
|
|
||||||
c := make([]byte, len(p))
|
|
||||||
copy(c, p)
|
|
||||||
t.write.pending = append(t.write.pending, c)
|
|
||||||
t.write.Cond.Signal()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func memPipe() (a, b packetConn) {
|
|
||||||
t1 := memTransport{}
|
|
||||||
t2 := memTransport{}
|
|
||||||
t1.write = &t2
|
|
||||||
t2.write = &t1
|
|
||||||
t1.Cond = sync.NewCond(&t1.Mutex)
|
|
||||||
t2.Cond = sync.NewCond(&t2.Mutex)
|
|
||||||
return &t1, &t2
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestmemPipe(t *testing.T) {
|
|
||||||
a, b := memPipe()
|
|
||||||
if err := a.writePacket([]byte{42}); err != nil {
|
|
||||||
t.Fatalf("writePacket: %v", err)
|
|
||||||
}
|
|
||||||
if err := a.Close(); err != nil {
|
|
||||||
t.Fatal("Close: ", err)
|
|
||||||
}
|
|
||||||
p, err := b.readPacket()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal("readPacket: ", err)
|
|
||||||
}
|
|
||||||
if len(p) != 1 || p[0] != 42 {
|
|
||||||
t.Fatalf("got %v, want {42}", p)
|
|
||||||
}
|
|
||||||
p, err = b.readPacket()
|
|
||||||
if err != io.EOF {
|
|
||||||
t.Fatalf("got %v, %v, want EOF", p, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDoubleClose(t *testing.T) {
|
|
||||||
a, _ := memPipe()
|
|
||||||
err := a.Close()
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Close: %v", err)
|
|
||||||
}
|
|
||||||
err = a.Close()
|
|
||||||
if err != io.EOF {
|
|
||||||
t.Errorf("expect EOF on double close.")
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,724 +0,0 @@
|
||||||
// Copyright 2011 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package ssh
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/binary"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"math/big"
|
|
||||||
"reflect"
|
|
||||||
"strconv"
|
|
||||||
)
|
|
||||||
|
|
||||||
// These are SSH message type numbers. They are scattered around several
|
|
||||||
// documents but many were taken from [SSH-PARAMETERS].
|
|
||||||
const (
|
|
||||||
msgIgnore = 2
|
|
||||||
msgUnimplemented = 3
|
|
||||||
msgDebug = 4
|
|
||||||
msgNewKeys = 21
|
|
||||||
|
|
||||||
// Standard authentication messages
|
|
||||||
msgUserAuthSuccess = 52
|
|
||||||
msgUserAuthBanner = 53
|
|
||||||
)
|
|
||||||
|
|
||||||
// SSH messages:
|
|
||||||
//
|
|
||||||
// These structures mirror the wire format of the corresponding SSH messages.
|
|
||||||
// They are marshaled using reflection with the marshal and unmarshal functions
|
|
||||||
// in this file. The only wrinkle is that a final member of type []byte with a
|
|
||||||
// ssh tag of "rest" receives the remainder of a packet when unmarshaling.
|
|
||||||
|
|
||||||
// See RFC 4253, section 11.1.
|
|
||||||
const msgDisconnect = 1
|
|
||||||
|
|
||||||
// disconnectMsg is the message that signals a disconnect. It is also
|
|
||||||
// the error type returned from mux.Wait()
|
|
||||||
type disconnectMsg struct {
|
|
||||||
Reason uint32 `sshtype:"1"`
|
|
||||||
Message string
|
|
||||||
Language string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *disconnectMsg) Error() string {
|
|
||||||
return fmt.Sprintf("ssh: disconnect reason %d: %s", d.Reason, d.Message)
|
|
||||||
}
|
|
||||||
|
|
||||||
// See RFC 4253, section 7.1.
|
|
||||||
const msgKexInit = 20
|
|
||||||
|
|
||||||
type kexInitMsg struct {
|
|
||||||
Cookie [16]byte `sshtype:"20"`
|
|
||||||
KexAlgos []string
|
|
||||||
ServerHostKeyAlgos []string
|
|
||||||
CiphersClientServer []string
|
|
||||||
CiphersServerClient []string
|
|
||||||
MACsClientServer []string
|
|
||||||
MACsServerClient []string
|
|
||||||
CompressionClientServer []string
|
|
||||||
CompressionServerClient []string
|
|
||||||
LanguagesClientServer []string
|
|
||||||
LanguagesServerClient []string
|
|
||||||
FirstKexFollows bool
|
|
||||||
Reserved uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
// See RFC 4253, section 8.
|
|
||||||
|
|
||||||
// Diffie-Helman
|
|
||||||
const msgKexDHInit = 30
|
|
||||||
|
|
||||||
type kexDHInitMsg struct {
|
|
||||||
X *big.Int `sshtype:"30"`
|
|
||||||
}
|
|
||||||
|
|
||||||
const msgKexECDHInit = 30
|
|
||||||
|
|
||||||
type kexECDHInitMsg struct {
|
|
||||||
ClientPubKey []byte `sshtype:"30"`
|
|
||||||
}
|
|
||||||
|
|
||||||
const msgKexECDHReply = 31
|
|
||||||
|
|
||||||
type kexECDHReplyMsg struct {
|
|
||||||
HostKey []byte `sshtype:"31"`
|
|
||||||
EphemeralPubKey []byte
|
|
||||||
Signature []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
const msgKexDHReply = 31
|
|
||||||
|
|
||||||
type kexDHReplyMsg struct {
|
|
||||||
HostKey []byte `sshtype:"31"`
|
|
||||||
Y *big.Int
|
|
||||||
Signature []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
// See RFC 4253, section 10.
|
|
||||||
const msgServiceRequest = 5
|
|
||||||
|
|
||||||
type serviceRequestMsg struct {
|
|
||||||
Service string `sshtype:"5"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// See RFC 4253, section 10.
|
|
||||||
const msgServiceAccept = 6
|
|
||||||
|
|
||||||
type serviceAcceptMsg struct {
|
|
||||||
Service string `sshtype:"6"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// See RFC 4252, section 5.
|
|
||||||
const msgUserAuthRequest = 50
|
|
||||||
|
|
||||||
type userAuthRequestMsg struct {
|
|
||||||
User string `sshtype:"50"`
|
|
||||||
Service string
|
|
||||||
Method string
|
|
||||||
Payload []byte `ssh:"rest"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// See RFC 4252, section 5.1
|
|
||||||
const msgUserAuthFailure = 51
|
|
||||||
|
|
||||||
type userAuthFailureMsg struct {
|
|
||||||
Methods []string `sshtype:"51"`
|
|
||||||
PartialSuccess bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// See RFC 4256, section 3.2
|
|
||||||
const msgUserAuthInfoRequest = 60
|
|
||||||
const msgUserAuthInfoResponse = 61
|
|
||||||
|
|
||||||
type userAuthInfoRequestMsg struct {
|
|
||||||
User string `sshtype:"60"`
|
|
||||||
Instruction string
|
|
||||||
DeprecatedLanguage string
|
|
||||||
NumPrompts uint32
|
|
||||||
Prompts []byte `ssh:"rest"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// See RFC 4254, section 5.1.
|
|
||||||
const msgChannelOpen = 90
|
|
||||||
|
|
||||||
type channelOpenMsg struct {
|
|
||||||
ChanType string `sshtype:"90"`
|
|
||||||
PeersId uint32
|
|
||||||
PeersWindow uint32
|
|
||||||
MaxPacketSize uint32
|
|
||||||
TypeSpecificData []byte `ssh:"rest"`
|
|
||||||
}
|
|
||||||
|
|
||||||
const msgChannelExtendedData = 95
|
|
||||||
const msgChannelData = 94
|
|
||||||
|
|
||||||
// See RFC 4254, section 5.1.
|
|
||||||
const msgChannelOpenConfirm = 91
|
|
||||||
|
|
||||||
type channelOpenConfirmMsg struct {
|
|
||||||
PeersId uint32 `sshtype:"91"`
|
|
||||||
MyId uint32
|
|
||||||
MyWindow uint32
|
|
||||||
MaxPacketSize uint32
|
|
||||||
TypeSpecificData []byte `ssh:"rest"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// See RFC 4254, section 5.1.
|
|
||||||
const msgChannelOpenFailure = 92
|
|
||||||
|
|
||||||
type channelOpenFailureMsg struct {
|
|
||||||
PeersId uint32 `sshtype:"92"`
|
|
||||||
Reason RejectionReason
|
|
||||||
Message string
|
|
||||||
Language string
|
|
||||||
}
|
|
||||||
|
|
||||||
const msgChannelRequest = 98
|
|
||||||
|
|
||||||
type channelRequestMsg struct {
|
|
||||||
PeersId uint32 `sshtype:"98"`
|
|
||||||
Request string
|
|
||||||
WantReply bool
|
|
||||||
RequestSpecificData []byte `ssh:"rest"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// See RFC 4254, section 5.4.
|
|
||||||
const msgChannelSuccess = 99
|
|
||||||
|
|
||||||
type channelRequestSuccessMsg struct {
|
|
||||||
PeersId uint32 `sshtype:"99"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// See RFC 4254, section 5.4.
|
|
||||||
const msgChannelFailure = 100
|
|
||||||
|
|
||||||
type channelRequestFailureMsg struct {
|
|
||||||
PeersId uint32 `sshtype:"100"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// See RFC 4254, section 5.3
|
|
||||||
const msgChannelClose = 97
|
|
||||||
|
|
||||||
type channelCloseMsg struct {
|
|
||||||
PeersId uint32 `sshtype:"97"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// See RFC 4254, section 5.3
|
|
||||||
const msgChannelEOF = 96
|
|
||||||
|
|
||||||
type channelEOFMsg struct {
|
|
||||||
PeersId uint32 `sshtype:"96"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// See RFC 4254, section 4
|
|
||||||
const msgGlobalRequest = 80
|
|
||||||
|
|
||||||
type globalRequestMsg struct {
|
|
||||||
Type string `sshtype:"80"`
|
|
||||||
WantReply bool
|
|
||||||
Data []byte `ssh:"rest"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// See RFC 4254, section 4
|
|
||||||
const msgRequestSuccess = 81
|
|
||||||
|
|
||||||
type globalRequestSuccessMsg struct {
|
|
||||||
Data []byte `ssh:"rest" sshtype:"81"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// See RFC 4254, section 4
|
|
||||||
const msgRequestFailure = 82
|
|
||||||
|
|
||||||
type globalRequestFailureMsg struct {
|
|
||||||
Data []byte `ssh:"rest" sshtype:"82"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// See RFC 4254, section 5.2
|
|
||||||
const msgChannelWindowAdjust = 93
|
|
||||||
|
|
||||||
type windowAdjustMsg struct {
|
|
||||||
PeersId uint32 `sshtype:"93"`
|
|
||||||
AdditionalBytes uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
// See RFC 4252, section 7
|
|
||||||
const msgUserAuthPubKeyOk = 60
|
|
||||||
|
|
||||||
type userAuthPubKeyOkMsg struct {
|
|
||||||
Algo string `sshtype:"60"`
|
|
||||||
PubKey []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
// typeTag returns the type byte for the given type. The type should
|
|
||||||
// be struct.
|
|
||||||
func typeTag(structType reflect.Type) byte {
|
|
||||||
var tag byte
|
|
||||||
var tagStr string
|
|
||||||
tagStr = structType.Field(0).Tag.Get("sshtype")
|
|
||||||
i, err := strconv.Atoi(tagStr)
|
|
||||||
if err == nil {
|
|
||||||
tag = byte(i)
|
|
||||||
}
|
|
||||||
return tag
|
|
||||||
}
|
|
||||||
|
|
||||||
func fieldError(t reflect.Type, field int, problem string) error {
|
|
||||||
if problem != "" {
|
|
||||||
problem = ": " + problem
|
|
||||||
}
|
|
||||||
return fmt.Errorf("ssh: unmarshal error for field %s of type %s%s", t.Field(field).Name, t.Name(), problem)
|
|
||||||
}
|
|
||||||
|
|
||||||
var errShortRead = errors.New("ssh: short read")
|
|
||||||
|
|
||||||
// Unmarshal parses data in SSH wire format into a structure. The out
|
|
||||||
// argument should be a pointer to struct. If the first member of the
|
|
||||||
// struct has the "sshtype" tag set to a number in decimal, the packet
|
|
||||||
// must start that number. In case of error, Unmarshal returns a
|
|
||||||
// ParseError or UnexpectedMessageError.
|
|
||||||
func Unmarshal(data []byte, out interface{}) error {
|
|
||||||
v := reflect.ValueOf(out).Elem()
|
|
||||||
structType := v.Type()
|
|
||||||
expectedType := typeTag(structType)
|
|
||||||
if len(data) == 0 {
|
|
||||||
return parseError(expectedType)
|
|
||||||
}
|
|
||||||
if expectedType > 0 {
|
|
||||||
if data[0] != expectedType {
|
|
||||||
return unexpectedMessageError(expectedType, data[0])
|
|
||||||
}
|
|
||||||
data = data[1:]
|
|
||||||
}
|
|
||||||
|
|
||||||
var ok bool
|
|
||||||
for i := 0; i < v.NumField(); i++ {
|
|
||||||
field := v.Field(i)
|
|
||||||
t := field.Type()
|
|
||||||
switch t.Kind() {
|
|
||||||
case reflect.Bool:
|
|
||||||
if len(data) < 1 {
|
|
||||||
return errShortRead
|
|
||||||
}
|
|
||||||
field.SetBool(data[0] != 0)
|
|
||||||
data = data[1:]
|
|
||||||
case reflect.Array:
|
|
||||||
if t.Elem().Kind() != reflect.Uint8 {
|
|
||||||
return fieldError(structType, i, "array of unsupported type")
|
|
||||||
}
|
|
||||||
if len(data) < t.Len() {
|
|
||||||
return errShortRead
|
|
||||||
}
|
|
||||||
for j, n := 0, t.Len(); j < n; j++ {
|
|
||||||
field.Index(j).Set(reflect.ValueOf(data[j]))
|
|
||||||
}
|
|
||||||
data = data[t.Len():]
|
|
||||||
case reflect.Uint64:
|
|
||||||
var u64 uint64
|
|
||||||
if u64, data, ok = parseUint64(data); !ok {
|
|
||||||
return errShortRead
|
|
||||||
}
|
|
||||||
field.SetUint(u64)
|
|
||||||
case reflect.Uint32:
|
|
||||||
var u32 uint32
|
|
||||||
if u32, data, ok = parseUint32(data); !ok {
|
|
||||||
return errShortRead
|
|
||||||
}
|
|
||||||
field.SetUint(uint64(u32))
|
|
||||||
case reflect.Uint8:
|
|
||||||
if len(data) < 1 {
|
|
||||||
return errShortRead
|
|
||||||
}
|
|
||||||
field.SetUint(uint64(data[0]))
|
|
||||||
data = data[1:]
|
|
||||||
case reflect.String:
|
|
||||||
var s []byte
|
|
||||||
if s, data, ok = parseString(data); !ok {
|
|
||||||
return fieldError(structType, i, "")
|
|
||||||
}
|
|
||||||
field.SetString(string(s))
|
|
||||||
case reflect.Slice:
|
|
||||||
switch t.Elem().Kind() {
|
|
||||||
case reflect.Uint8:
|
|
||||||
if structType.Field(i).Tag.Get("ssh") == "rest" {
|
|
||||||
field.Set(reflect.ValueOf(data))
|
|
||||||
data = nil
|
|
||||||
} else {
|
|
||||||
var s []byte
|
|
||||||
if s, data, ok = parseString(data); !ok {
|
|
||||||
return errShortRead
|
|
||||||
}
|
|
||||||
field.Set(reflect.ValueOf(s))
|
|
||||||
}
|
|
||||||
case reflect.String:
|
|
||||||
var nl []string
|
|
||||||
if nl, data, ok = parseNameList(data); !ok {
|
|
||||||
return errShortRead
|
|
||||||
}
|
|
||||||
field.Set(reflect.ValueOf(nl))
|
|
||||||
default:
|
|
||||||
return fieldError(structType, i, "slice of unsupported type")
|
|
||||||
}
|
|
||||||
case reflect.Ptr:
|
|
||||||
if t == bigIntType {
|
|
||||||
var n *big.Int
|
|
||||||
if n, data, ok = parseInt(data); !ok {
|
|
||||||
return errShortRead
|
|
||||||
}
|
|
||||||
field.Set(reflect.ValueOf(n))
|
|
||||||
} else {
|
|
||||||
return fieldError(structType, i, "pointer to unsupported type")
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return fieldError(structType, i, "unsupported type")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(data) != 0 {
|
|
||||||
return parseError(expectedType)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Marshal serializes the message in msg to SSH wire format. The msg
|
|
||||||
// argument should be a struct or pointer to struct. If the first
|
|
||||||
// member has the "sshtype" tag set to a number in decimal, that
|
|
||||||
// number is prepended to the result. If the last of member has the
|
|
||||||
// "ssh" tag set to "rest", its contents are appended to the output.
|
|
||||||
func Marshal(msg interface{}) []byte {
|
|
||||||
out := make([]byte, 0, 64)
|
|
||||||
return marshalStruct(out, msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
func marshalStruct(out []byte, msg interface{}) []byte {
|
|
||||||
v := reflect.Indirect(reflect.ValueOf(msg))
|
|
||||||
msgType := typeTag(v.Type())
|
|
||||||
if msgType > 0 {
|
|
||||||
out = append(out, msgType)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, n := 0, v.NumField(); i < n; i++ {
|
|
||||||
field := v.Field(i)
|
|
||||||
switch t := field.Type(); t.Kind() {
|
|
||||||
case reflect.Bool:
|
|
||||||
var v uint8
|
|
||||||
if field.Bool() {
|
|
||||||
v = 1
|
|
||||||
}
|
|
||||||
out = append(out, v)
|
|
||||||
case reflect.Array:
|
|
||||||
if t.Elem().Kind() != reflect.Uint8 {
|
|
||||||
panic(fmt.Sprintf("array of non-uint8 in field %d: %T", i, field.Interface()))
|
|
||||||
}
|
|
||||||
for j, l := 0, t.Len(); j < l; j++ {
|
|
||||||
out = append(out, uint8(field.Index(j).Uint()))
|
|
||||||
}
|
|
||||||
case reflect.Uint32:
|
|
||||||
out = appendU32(out, uint32(field.Uint()))
|
|
||||||
case reflect.Uint64:
|
|
||||||
out = appendU64(out, uint64(field.Uint()))
|
|
||||||
case reflect.Uint8:
|
|
||||||
out = append(out, uint8(field.Uint()))
|
|
||||||
case reflect.String:
|
|
||||||
s := field.String()
|
|
||||||
out = appendInt(out, len(s))
|
|
||||||
out = append(out, s...)
|
|
||||||
case reflect.Slice:
|
|
||||||
switch t.Elem().Kind() {
|
|
||||||
case reflect.Uint8:
|
|
||||||
if v.Type().Field(i).Tag.Get("ssh") != "rest" {
|
|
||||||
out = appendInt(out, field.Len())
|
|
||||||
}
|
|
||||||
out = append(out, field.Bytes()...)
|
|
||||||
case reflect.String:
|
|
||||||
offset := len(out)
|
|
||||||
out = appendU32(out, 0)
|
|
||||||
if n := field.Len(); n > 0 {
|
|
||||||
for j := 0; j < n; j++ {
|
|
||||||
f := field.Index(j)
|
|
||||||
if j != 0 {
|
|
||||||
out = append(out, ',')
|
|
||||||
}
|
|
||||||
out = append(out, f.String()...)
|
|
||||||
}
|
|
||||||
// overwrite length value
|
|
||||||
binary.BigEndian.PutUint32(out[offset:], uint32(len(out)-offset-4))
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
panic(fmt.Sprintf("slice of unknown type in field %d: %T", i, field.Interface()))
|
|
||||||
}
|
|
||||||
case reflect.Ptr:
|
|
||||||
if t == bigIntType {
|
|
||||||
var n *big.Int
|
|
||||||
nValue := reflect.ValueOf(&n)
|
|
||||||
nValue.Elem().Set(field)
|
|
||||||
needed := intLength(n)
|
|
||||||
oldLength := len(out)
|
|
||||||
|
|
||||||
if cap(out)-len(out) < needed {
|
|
||||||
newOut := make([]byte, len(out), 2*(len(out)+needed))
|
|
||||||
copy(newOut, out)
|
|
||||||
out = newOut
|
|
||||||
}
|
|
||||||
out = out[:oldLength+needed]
|
|
||||||
marshalInt(out[oldLength:], n)
|
|
||||||
} else {
|
|
||||||
panic(fmt.Sprintf("pointer to unknown type in field %d: %T", i, field.Interface()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
var bigOne = big.NewInt(1)
|
|
||||||
|
|
||||||
func parseString(in []byte) (out, rest []byte, ok bool) {
|
|
||||||
if len(in) < 4 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
length := binary.BigEndian.Uint32(in)
|
|
||||||
if uint32(len(in)) < 4+length {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
out = in[4 : 4+length]
|
|
||||||
rest = in[4+length:]
|
|
||||||
ok = true
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
comma = []byte{','}
|
|
||||||
emptyNameList = []string{}
|
|
||||||
)
|
|
||||||
|
|
||||||
func parseNameList(in []byte) (out []string, rest []byte, ok bool) {
|
|
||||||
contents, rest, ok := parseString(in)
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if len(contents) == 0 {
|
|
||||||
out = emptyNameList
|
|
||||||
return
|
|
||||||
}
|
|
||||||
parts := bytes.Split(contents, comma)
|
|
||||||
out = make([]string, len(parts))
|
|
||||||
for i, part := range parts {
|
|
||||||
out[i] = string(part)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseInt(in []byte) (out *big.Int, rest []byte, ok bool) {
|
|
||||||
contents, rest, ok := parseString(in)
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
out = new(big.Int)
|
|
||||||
|
|
||||||
if len(contents) > 0 && contents[0]&0x80 == 0x80 {
|
|
||||||
// This is a negative number
|
|
||||||
notBytes := make([]byte, len(contents))
|
|
||||||
for i := range notBytes {
|
|
||||||
notBytes[i] = ^contents[i]
|
|
||||||
}
|
|
||||||
out.SetBytes(notBytes)
|
|
||||||
out.Add(out, bigOne)
|
|
||||||
out.Neg(out)
|
|
||||||
} else {
|
|
||||||
// Positive number
|
|
||||||
out.SetBytes(contents)
|
|
||||||
}
|
|
||||||
ok = true
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseUint32(in []byte) (uint32, []byte, bool) {
|
|
||||||
if len(in) < 4 {
|
|
||||||
return 0, nil, false
|
|
||||||
}
|
|
||||||
return binary.BigEndian.Uint32(in), in[4:], true
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseUint64(in []byte) (uint64, []byte, bool) {
|
|
||||||
if len(in) < 8 {
|
|
||||||
return 0, nil, false
|
|
||||||
}
|
|
||||||
return binary.BigEndian.Uint64(in), in[8:], true
|
|
||||||
}
|
|
||||||
|
|
||||||
func intLength(n *big.Int) int {
|
|
||||||
length := 4 /* length bytes */
|
|
||||||
if n.Sign() < 0 {
|
|
||||||
nMinus1 := new(big.Int).Neg(n)
|
|
||||||
nMinus1.Sub(nMinus1, bigOne)
|
|
||||||
bitLen := nMinus1.BitLen()
|
|
||||||
if bitLen%8 == 0 {
|
|
||||||
// The number will need 0xff padding
|
|
||||||
length++
|
|
||||||
}
|
|
||||||
length += (bitLen + 7) / 8
|
|
||||||
} else if n.Sign() == 0 {
|
|
||||||
// A zero is the zero length string
|
|
||||||
} else {
|
|
||||||
bitLen := n.BitLen()
|
|
||||||
if bitLen%8 == 0 {
|
|
||||||
// The number will need 0x00 padding
|
|
||||||
length++
|
|
||||||
}
|
|
||||||
length += (bitLen + 7) / 8
|
|
||||||
}
|
|
||||||
|
|
||||||
return length
|
|
||||||
}
|
|
||||||
|
|
||||||
func marshalUint32(to []byte, n uint32) []byte {
|
|
||||||
binary.BigEndian.PutUint32(to, n)
|
|
||||||
return to[4:]
|
|
||||||
}
|
|
||||||
|
|
||||||
func marshalUint64(to []byte, n uint64) []byte {
|
|
||||||
binary.BigEndian.PutUint64(to, n)
|
|
||||||
return to[8:]
|
|
||||||
}
|
|
||||||
|
|
||||||
func marshalInt(to []byte, n *big.Int) []byte {
|
|
||||||
lengthBytes := to
|
|
||||||
to = to[4:]
|
|
||||||
length := 0
|
|
||||||
|
|
||||||
if n.Sign() < 0 {
|
|
||||||
// A negative number has to be converted to two's-complement
|
|
||||||
// form. So we'll subtract 1 and invert. If the
|
|
||||||
// most-significant-bit isn't set then we'll need to pad the
|
|
||||||
// beginning with 0xff in order to keep the number negative.
|
|
||||||
nMinus1 := new(big.Int).Neg(n)
|
|
||||||
nMinus1.Sub(nMinus1, bigOne)
|
|
||||||
bytes := nMinus1.Bytes()
|
|
||||||
for i := range bytes {
|
|
||||||
bytes[i] ^= 0xff
|
|
||||||
}
|
|
||||||
if len(bytes) == 0 || bytes[0]&0x80 == 0 {
|
|
||||||
to[0] = 0xff
|
|
||||||
to = to[1:]
|
|
||||||
length++
|
|
||||||
}
|
|
||||||
nBytes := copy(to, bytes)
|
|
||||||
to = to[nBytes:]
|
|
||||||
length += nBytes
|
|
||||||
} else if n.Sign() == 0 {
|
|
||||||
// A zero is the zero length string
|
|
||||||
} else {
|
|
||||||
bytes := n.Bytes()
|
|
||||||
if len(bytes) > 0 && bytes[0]&0x80 != 0 {
|
|
||||||
// We'll have to pad this with a 0x00 in order to
|
|
||||||
// stop it looking like a negative number.
|
|
||||||
to[0] = 0
|
|
||||||
to = to[1:]
|
|
||||||
length++
|
|
||||||
}
|
|
||||||
nBytes := copy(to, bytes)
|
|
||||||
to = to[nBytes:]
|
|
||||||
length += nBytes
|
|
||||||
}
|
|
||||||
|
|
||||||
lengthBytes[0] = byte(length >> 24)
|
|
||||||
lengthBytes[1] = byte(length >> 16)
|
|
||||||
lengthBytes[2] = byte(length >> 8)
|
|
||||||
lengthBytes[3] = byte(length)
|
|
||||||
return to
|
|
||||||
}
|
|
||||||
|
|
||||||
func writeInt(w io.Writer, n *big.Int) {
|
|
||||||
length := intLength(n)
|
|
||||||
buf := make([]byte, length)
|
|
||||||
marshalInt(buf, n)
|
|
||||||
w.Write(buf)
|
|
||||||
}
|
|
||||||
|
|
||||||
func writeString(w io.Writer, s []byte) {
|
|
||||||
var lengthBytes [4]byte
|
|
||||||
lengthBytes[0] = byte(len(s) >> 24)
|
|
||||||
lengthBytes[1] = byte(len(s) >> 16)
|
|
||||||
lengthBytes[2] = byte(len(s) >> 8)
|
|
||||||
lengthBytes[3] = byte(len(s))
|
|
||||||
w.Write(lengthBytes[:])
|
|
||||||
w.Write(s)
|
|
||||||
}
|
|
||||||
|
|
||||||
func stringLength(n int) int {
|
|
||||||
return 4 + n
|
|
||||||
}
|
|
||||||
|
|
||||||
func marshalString(to []byte, s []byte) []byte {
|
|
||||||
to[0] = byte(len(s) >> 24)
|
|
||||||
to[1] = byte(len(s) >> 16)
|
|
||||||
to[2] = byte(len(s) >> 8)
|
|
||||||
to[3] = byte(len(s))
|
|
||||||
to = to[4:]
|
|
||||||
copy(to, s)
|
|
||||||
return to[len(s):]
|
|
||||||
}
|
|
||||||
|
|
||||||
var bigIntType = reflect.TypeOf((*big.Int)(nil))
|
|
||||||
|
|
||||||
// Decode a packet into its corresponding message.
|
|
||||||
func decode(packet []byte) (interface{}, error) {
|
|
||||||
var msg interface{}
|
|
||||||
switch packet[0] {
|
|
||||||
case msgDisconnect:
|
|
||||||
msg = new(disconnectMsg)
|
|
||||||
case msgServiceRequest:
|
|
||||||
msg = new(serviceRequestMsg)
|
|
||||||
case msgServiceAccept:
|
|
||||||
msg = new(serviceAcceptMsg)
|
|
||||||
case msgKexInit:
|
|
||||||
msg = new(kexInitMsg)
|
|
||||||
case msgKexDHInit:
|
|
||||||
msg = new(kexDHInitMsg)
|
|
||||||
case msgKexDHReply:
|
|
||||||
msg = new(kexDHReplyMsg)
|
|
||||||
case msgUserAuthRequest:
|
|
||||||
msg = new(userAuthRequestMsg)
|
|
||||||
case msgUserAuthFailure:
|
|
||||||
msg = new(userAuthFailureMsg)
|
|
||||||
case msgUserAuthPubKeyOk:
|
|
||||||
msg = new(userAuthPubKeyOkMsg)
|
|
||||||
case msgGlobalRequest:
|
|
||||||
msg = new(globalRequestMsg)
|
|
||||||
case msgRequestSuccess:
|
|
||||||
msg = new(globalRequestSuccessMsg)
|
|
||||||
case msgRequestFailure:
|
|
||||||
msg = new(globalRequestFailureMsg)
|
|
||||||
case msgChannelOpen:
|
|
||||||
msg = new(channelOpenMsg)
|
|
||||||
case msgChannelOpenConfirm:
|
|
||||||
msg = new(channelOpenConfirmMsg)
|
|
||||||
case msgChannelOpenFailure:
|
|
||||||
msg = new(channelOpenFailureMsg)
|
|
||||||
case msgChannelWindowAdjust:
|
|
||||||
msg = new(windowAdjustMsg)
|
|
||||||
case msgChannelEOF:
|
|
||||||
msg = new(channelEOFMsg)
|
|
||||||
case msgChannelClose:
|
|
||||||
msg = new(channelCloseMsg)
|
|
||||||
case msgChannelRequest:
|
|
||||||
msg = new(channelRequestMsg)
|
|
||||||
case msgChannelSuccess:
|
|
||||||
msg = new(channelRequestSuccessMsg)
|
|
||||||
case msgChannelFailure:
|
|
||||||
msg = new(channelRequestFailureMsg)
|
|
||||||
default:
|
|
||||||
return nil, unexpectedMessageError(0, packet[0])
|
|
||||||
}
|
|
||||||
if err := Unmarshal(packet, msg); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return msg, nil
|
|
||||||
}
|
|
|
@ -1,244 +0,0 @@
|
||||||
// Copyright 2011 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package ssh
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"math/big"
|
|
||||||
"math/rand"
|
|
||||||
"reflect"
|
|
||||||
"testing"
|
|
||||||
"testing/quick"
|
|
||||||
)
|
|
||||||
|
|
||||||
var intLengthTests = []struct {
|
|
||||||
val, length int
|
|
||||||
}{
|
|
||||||
{0, 4 + 0},
|
|
||||||
{1, 4 + 1},
|
|
||||||
{127, 4 + 1},
|
|
||||||
{128, 4 + 2},
|
|
||||||
{-1, 4 + 1},
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIntLength(t *testing.T) {
|
|
||||||
for _, test := range intLengthTests {
|
|
||||||
v := new(big.Int).SetInt64(int64(test.val))
|
|
||||||
length := intLength(v)
|
|
||||||
if length != test.length {
|
|
||||||
t.Errorf("For %d, got length %d but expected %d", test.val, length, test.length)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type msgAllTypes struct {
|
|
||||||
Bool bool `sshtype:"21"`
|
|
||||||
Array [16]byte
|
|
||||||
Uint64 uint64
|
|
||||||
Uint32 uint32
|
|
||||||
Uint8 uint8
|
|
||||||
String string
|
|
||||||
Strings []string
|
|
||||||
Bytes []byte
|
|
||||||
Int *big.Int
|
|
||||||
Rest []byte `ssh:"rest"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *msgAllTypes) Generate(rand *rand.Rand, size int) reflect.Value {
|
|
||||||
m := &msgAllTypes{}
|
|
||||||
m.Bool = rand.Intn(2) == 1
|
|
||||||
randomBytes(m.Array[:], rand)
|
|
||||||
m.Uint64 = uint64(rand.Int63n(1<<63 - 1))
|
|
||||||
m.Uint32 = uint32(rand.Intn((1 << 31) - 1))
|
|
||||||
m.Uint8 = uint8(rand.Intn(1 << 8))
|
|
||||||
m.String = string(m.Array[:])
|
|
||||||
m.Strings = randomNameList(rand)
|
|
||||||
m.Bytes = m.Array[:]
|
|
||||||
m.Int = randomInt(rand)
|
|
||||||
m.Rest = m.Array[:]
|
|
||||||
return reflect.ValueOf(m)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMarshalUnmarshal(t *testing.T) {
|
|
||||||
rand := rand.New(rand.NewSource(0))
|
|
||||||
iface := &msgAllTypes{}
|
|
||||||
ty := reflect.ValueOf(iface).Type()
|
|
||||||
|
|
||||||
n := 100
|
|
||||||
if testing.Short() {
|
|
||||||
n = 5
|
|
||||||
}
|
|
||||||
for j := 0; j < n; j++ {
|
|
||||||
v, ok := quick.Value(ty, rand)
|
|
||||||
if !ok {
|
|
||||||
t.Errorf("failed to create value")
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
m1 := v.Elem().Interface()
|
|
||||||
m2 := iface
|
|
||||||
|
|
||||||
marshaled := Marshal(m1)
|
|
||||||
if err := Unmarshal(marshaled, m2); err != nil {
|
|
||||||
t.Errorf("Unmarshal %#v: %s", m1, err)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
if !reflect.DeepEqual(v.Interface(), m2) {
|
|
||||||
t.Errorf("got: %#v\nwant:%#v\n%x", m2, m1, marshaled)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUnmarshalEmptyPacket(t *testing.T) {
|
|
||||||
var b []byte
|
|
||||||
var m channelRequestSuccessMsg
|
|
||||||
if err := Unmarshal(b, &m); err == nil {
|
|
||||||
t.Fatalf("unmarshal of empty slice succeeded")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUnmarshalUnexpectedPacket(t *testing.T) {
|
|
||||||
type S struct {
|
|
||||||
I uint32 `sshtype:"43"`
|
|
||||||
S string
|
|
||||||
B bool
|
|
||||||
}
|
|
||||||
|
|
||||||
s := S{11, "hello", true}
|
|
||||||
packet := Marshal(s)
|
|
||||||
packet[0] = 42
|
|
||||||
roundtrip := S{}
|
|
||||||
err := Unmarshal(packet, &roundtrip)
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("expected error, not nil")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMarshalPtr(t *testing.T) {
|
|
||||||
s := struct {
|
|
||||||
S string
|
|
||||||
}{"hello"}
|
|
||||||
|
|
||||||
m1 := Marshal(s)
|
|
||||||
m2 := Marshal(&s)
|
|
||||||
if !bytes.Equal(m1, m2) {
|
|
||||||
t.Errorf("got %q, want %q for marshaled pointer", m2, m1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBareMarshalUnmarshal(t *testing.T) {
|
|
||||||
type S struct {
|
|
||||||
I uint32
|
|
||||||
S string
|
|
||||||
B bool
|
|
||||||
}
|
|
||||||
|
|
||||||
s := S{42, "hello", true}
|
|
||||||
packet := Marshal(s)
|
|
||||||
roundtrip := S{}
|
|
||||||
Unmarshal(packet, &roundtrip)
|
|
||||||
|
|
||||||
if !reflect.DeepEqual(s, roundtrip) {
|
|
||||||
t.Errorf("got %#v, want %#v", roundtrip, s)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBareMarshal(t *testing.T) {
|
|
||||||
type S2 struct {
|
|
||||||
I uint32
|
|
||||||
}
|
|
||||||
s := S2{42}
|
|
||||||
packet := Marshal(s)
|
|
||||||
i, rest, ok := parseUint32(packet)
|
|
||||||
if len(rest) > 0 || !ok {
|
|
||||||
t.Errorf("parseInt(%q): parse error", packet)
|
|
||||||
}
|
|
||||||
if i != s.I {
|
|
||||||
t.Errorf("got %d, want %d", i, s.I)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func randomBytes(out []byte, rand *rand.Rand) {
|
|
||||||
for i := 0; i < len(out); i++ {
|
|
||||||
out[i] = byte(rand.Int31())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func randomNameList(rand *rand.Rand) []string {
|
|
||||||
ret := make([]string, rand.Int31()&15)
|
|
||||||
for i := range ret {
|
|
||||||
s := make([]byte, 1+(rand.Int31()&15))
|
|
||||||
for j := range s {
|
|
||||||
s[j] = 'a' + uint8(rand.Int31()&15)
|
|
||||||
}
|
|
||||||
ret[i] = string(s)
|
|
||||||
}
|
|
||||||
return ret
|
|
||||||
}
|
|
||||||
|
|
||||||
func randomInt(rand *rand.Rand) *big.Int {
|
|
||||||
return new(big.Int).SetInt64(int64(int32(rand.Uint32())))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (*kexInitMsg) Generate(rand *rand.Rand, size int) reflect.Value {
|
|
||||||
ki := &kexInitMsg{}
|
|
||||||
randomBytes(ki.Cookie[:], rand)
|
|
||||||
ki.KexAlgos = randomNameList(rand)
|
|
||||||
ki.ServerHostKeyAlgos = randomNameList(rand)
|
|
||||||
ki.CiphersClientServer = randomNameList(rand)
|
|
||||||
ki.CiphersServerClient = randomNameList(rand)
|
|
||||||
ki.MACsClientServer = randomNameList(rand)
|
|
||||||
ki.MACsServerClient = randomNameList(rand)
|
|
||||||
ki.CompressionClientServer = randomNameList(rand)
|
|
||||||
ki.CompressionServerClient = randomNameList(rand)
|
|
||||||
ki.LanguagesClientServer = randomNameList(rand)
|
|
||||||
ki.LanguagesServerClient = randomNameList(rand)
|
|
||||||
if rand.Int31()&1 == 1 {
|
|
||||||
ki.FirstKexFollows = true
|
|
||||||
}
|
|
||||||
return reflect.ValueOf(ki)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (*kexDHInitMsg) Generate(rand *rand.Rand, size int) reflect.Value {
|
|
||||||
dhi := &kexDHInitMsg{}
|
|
||||||
dhi.X = randomInt(rand)
|
|
||||||
return reflect.ValueOf(dhi)
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
_kexInitMsg = new(kexInitMsg).Generate(rand.New(rand.NewSource(0)), 10).Elem().Interface()
|
|
||||||
_kexDHInitMsg = new(kexDHInitMsg).Generate(rand.New(rand.NewSource(0)), 10).Elem().Interface()
|
|
||||||
|
|
||||||
_kexInit = Marshal(_kexInitMsg)
|
|
||||||
_kexDHInit = Marshal(_kexDHInitMsg)
|
|
||||||
)
|
|
||||||
|
|
||||||
func BenchmarkMarshalKexInitMsg(b *testing.B) {
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
Marshal(_kexInitMsg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkUnmarshalKexInitMsg(b *testing.B) {
|
|
||||||
m := new(kexInitMsg)
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
Unmarshal(_kexInit, m)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkMarshalKexDHInitMsg(b *testing.B) {
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
Marshal(_kexDHInitMsg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkUnmarshalKexDHInitMsg(b *testing.B) {
|
|
||||||
m := new(kexDHInitMsg)
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
Unmarshal(_kexDHInit, m)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,356 +0,0 @@
|
||||||
// Copyright 2013 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package ssh
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/binary"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"log"
|
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
)
|
|
||||||
|
|
||||||
// debugMux, if set, causes messages in the connection protocol to be
|
|
||||||
// logged.
|
|
||||||
const debugMux = false
|
|
||||||
|
|
||||||
// chanList is a thread safe channel list.
|
|
||||||
type chanList struct {
|
|
||||||
// protects concurrent access to chans
|
|
||||||
sync.Mutex
|
|
||||||
|
|
||||||
// chans are indexed by the local id of the channel, which the
|
|
||||||
// other side should send in the PeersId field.
|
|
||||||
chans []*channel
|
|
||||||
|
|
||||||
// This is a debugging aid: it offsets all IDs by this
|
|
||||||
// amount. This helps distinguish otherwise identical
|
|
||||||
// server/client muxes
|
|
||||||
offset uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
// Assigns a channel ID to the given channel.
|
|
||||||
func (c *chanList) add(ch *channel) uint32 {
|
|
||||||
c.Lock()
|
|
||||||
defer c.Unlock()
|
|
||||||
for i := range c.chans {
|
|
||||||
if c.chans[i] == nil {
|
|
||||||
c.chans[i] = ch
|
|
||||||
return uint32(i) + c.offset
|
|
||||||
}
|
|
||||||
}
|
|
||||||
c.chans = append(c.chans, ch)
|
|
||||||
return uint32(len(c.chans)-1) + c.offset
|
|
||||||
}
|
|
||||||
|
|
||||||
// getChan returns the channel for the given ID.
|
|
||||||
func (c *chanList) getChan(id uint32) *channel {
|
|
||||||
id -= c.offset
|
|
||||||
|
|
||||||
c.Lock()
|
|
||||||
defer c.Unlock()
|
|
||||||
if id < uint32(len(c.chans)) {
|
|
||||||
return c.chans[id]
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *chanList) remove(id uint32) {
|
|
||||||
id -= c.offset
|
|
||||||
c.Lock()
|
|
||||||
if id < uint32(len(c.chans)) {
|
|
||||||
c.chans[id] = nil
|
|
||||||
}
|
|
||||||
c.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// dropAll forgets all channels it knows, returning them in a slice.
|
|
||||||
func (c *chanList) dropAll() []*channel {
|
|
||||||
c.Lock()
|
|
||||||
defer c.Unlock()
|
|
||||||
var r []*channel
|
|
||||||
|
|
||||||
for _, ch := range c.chans {
|
|
||||||
if ch == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
r = append(r, ch)
|
|
||||||
}
|
|
||||||
c.chans = nil
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
// mux represents the state for the SSH connection protocol, which
|
|
||||||
// multiplexes many channels onto a single packet transport.
|
|
||||||
type mux struct {
|
|
||||||
conn packetConn
|
|
||||||
chanList chanList
|
|
||||||
|
|
||||||
incomingChannels chan NewChannel
|
|
||||||
|
|
||||||
globalSentMu sync.Mutex
|
|
||||||
globalResponses chan interface{}
|
|
||||||
incomingRequests chan *Request
|
|
||||||
|
|
||||||
errCond *sync.Cond
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
|
|
||||||
// When debugging, each new chanList instantiation has a different
|
|
||||||
// offset.
|
|
||||||
var globalOff uint32
|
|
||||||
|
|
||||||
func (m *mux) Wait() error {
|
|
||||||
m.errCond.L.Lock()
|
|
||||||
defer m.errCond.L.Unlock()
|
|
||||||
for m.err == nil {
|
|
||||||
m.errCond.Wait()
|
|
||||||
}
|
|
||||||
return m.err
|
|
||||||
}
|
|
||||||
|
|
||||||
// newMux returns a mux that runs over the given connection.
|
|
||||||
func newMux(p packetConn) *mux {
|
|
||||||
m := &mux{
|
|
||||||
conn: p,
|
|
||||||
incomingChannels: make(chan NewChannel, 16),
|
|
||||||
globalResponses: make(chan interface{}, 1),
|
|
||||||
incomingRequests: make(chan *Request, 16),
|
|
||||||
errCond: newCond(),
|
|
||||||
}
|
|
||||||
if debugMux {
|
|
||||||
m.chanList.offset = atomic.AddUint32(&globalOff, 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
go m.loop()
|
|
||||||
return m
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mux) sendMessage(msg interface{}) error {
|
|
||||||
p := Marshal(msg)
|
|
||||||
return m.conn.writePacket(p)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mux) SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) {
|
|
||||||
if wantReply {
|
|
||||||
m.globalSentMu.Lock()
|
|
||||||
defer m.globalSentMu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.sendMessage(globalRequestMsg{
|
|
||||||
Type: name,
|
|
||||||
WantReply: wantReply,
|
|
||||||
Data: payload,
|
|
||||||
}); err != nil {
|
|
||||||
return false, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !wantReply {
|
|
||||||
return false, nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
msg, ok := <-m.globalResponses
|
|
||||||
if !ok {
|
|
||||||
return false, nil, io.EOF
|
|
||||||
}
|
|
||||||
switch msg := msg.(type) {
|
|
||||||
case *globalRequestFailureMsg:
|
|
||||||
return false, msg.Data, nil
|
|
||||||
case *globalRequestSuccessMsg:
|
|
||||||
return true, msg.Data, nil
|
|
||||||
default:
|
|
||||||
return false, nil, fmt.Errorf("ssh: unexpected response to request: %#v", msg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ackRequest must be called after processing a global request that
|
|
||||||
// has WantReply set.
|
|
||||||
func (m *mux) ackRequest(ok bool, data []byte) error {
|
|
||||||
if ok {
|
|
||||||
return m.sendMessage(globalRequestSuccessMsg{Data: data})
|
|
||||||
}
|
|
||||||
return m.sendMessage(globalRequestFailureMsg{Data: data})
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(hanwen): Disconnect is a transport layer message. We should
|
|
||||||
// probably send and receive Disconnect somewhere in the transport
|
|
||||||
// code.
|
|
||||||
|
|
||||||
// Disconnect sends a disconnect message.
|
|
||||||
func (m *mux) Disconnect(reason uint32, message string) error {
|
|
||||||
return m.sendMessage(disconnectMsg{
|
|
||||||
Reason: reason,
|
|
||||||
Message: message,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mux) Close() error {
|
|
||||||
return m.conn.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
// loop runs the connection machine. It will process packets until an
|
|
||||||
// error is encountered. To synchronize on loop exit, use mux.Wait.
|
|
||||||
func (m *mux) loop() {
|
|
||||||
var err error
|
|
||||||
for err == nil {
|
|
||||||
err = m.onePacket()
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, ch := range m.chanList.dropAll() {
|
|
||||||
ch.close()
|
|
||||||
}
|
|
||||||
|
|
||||||
close(m.incomingChannels)
|
|
||||||
close(m.incomingRequests)
|
|
||||||
close(m.globalResponses)
|
|
||||||
|
|
||||||
m.conn.Close()
|
|
||||||
|
|
||||||
m.errCond.L.Lock()
|
|
||||||
m.err = err
|
|
||||||
m.errCond.Broadcast()
|
|
||||||
m.errCond.L.Unlock()
|
|
||||||
|
|
||||||
if debugMux {
|
|
||||||
log.Println("loop exit", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// onePacket reads and processes one packet.
|
|
||||||
func (m *mux) onePacket() error {
|
|
||||||
packet, err := m.conn.readPacket()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if debugMux {
|
|
||||||
if packet[0] == msgChannelData || packet[0] == msgChannelExtendedData {
|
|
||||||
log.Printf("decoding(%d): data packet - %d bytes", m.chanList.offset, len(packet))
|
|
||||||
} else {
|
|
||||||
p, _ := decode(packet)
|
|
||||||
log.Printf("decoding(%d): %d %#v - %d bytes", m.chanList.offset, packet[0], p, len(packet))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
switch packet[0] {
|
|
||||||
case msgNewKeys:
|
|
||||||
// Ignore notification of key change.
|
|
||||||
return nil
|
|
||||||
case msgDisconnect:
|
|
||||||
return m.handleDisconnect(packet)
|
|
||||||
case msgChannelOpen:
|
|
||||||
return m.handleChannelOpen(packet)
|
|
||||||
case msgGlobalRequest, msgRequestSuccess, msgRequestFailure:
|
|
||||||
return m.handleGlobalPacket(packet)
|
|
||||||
}
|
|
||||||
|
|
||||||
// assume a channel packet.
|
|
||||||
if len(packet) < 5 {
|
|
||||||
return parseError(packet[0])
|
|
||||||
}
|
|
||||||
id := binary.BigEndian.Uint32(packet[1:])
|
|
||||||
ch := m.chanList.getChan(id)
|
|
||||||
if ch == nil {
|
|
||||||
return fmt.Errorf("ssh: invalid channel %d", id)
|
|
||||||
}
|
|
||||||
|
|
||||||
return ch.handlePacket(packet)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mux) handleDisconnect(packet []byte) error {
|
|
||||||
var d disconnectMsg
|
|
||||||
if err := Unmarshal(packet, &d); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if debugMux {
|
|
||||||
log.Printf("caught disconnect: %v", d)
|
|
||||||
}
|
|
||||||
return &d
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mux) handleGlobalPacket(packet []byte) error {
|
|
||||||
msg, err := decode(packet)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
switch msg := msg.(type) {
|
|
||||||
case *globalRequestMsg:
|
|
||||||
m.incomingRequests <- &Request{
|
|
||||||
Type: msg.Type,
|
|
||||||
WantReply: msg.WantReply,
|
|
||||||
Payload: msg.Data,
|
|
||||||
mux: m,
|
|
||||||
}
|
|
||||||
case *globalRequestSuccessMsg, *globalRequestFailureMsg:
|
|
||||||
m.globalResponses <- msg
|
|
||||||
default:
|
|
||||||
panic(fmt.Sprintf("not a global message %#v", msg))
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleChannelOpen schedules a channel to be Accept()ed.
|
|
||||||
func (m *mux) handleChannelOpen(packet []byte) error {
|
|
||||||
var msg channelOpenMsg
|
|
||||||
if err := Unmarshal(packet, &msg); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
|
|
||||||
failMsg := channelOpenFailureMsg{
|
|
||||||
PeersId: msg.PeersId,
|
|
||||||
Reason: ConnectionFailed,
|
|
||||||
Message: "invalid request",
|
|
||||||
Language: "en_US.UTF-8",
|
|
||||||
}
|
|
||||||
return m.sendMessage(failMsg)
|
|
||||||
}
|
|
||||||
|
|
||||||
c := m.newChannel(msg.ChanType, channelInbound, msg.TypeSpecificData)
|
|
||||||
c.remoteId = msg.PeersId
|
|
||||||
c.maxRemotePayload = msg.MaxPacketSize
|
|
||||||
c.remoteWin.add(msg.PeersWindow)
|
|
||||||
m.incomingChannels <- c
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mux) OpenChannel(chanType string, extra []byte) (Channel, <-chan *Request, error) {
|
|
||||||
ch, err := m.openChannel(chanType, extra)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return ch, ch.incomingRequests, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mux) openChannel(chanType string, extra []byte) (*channel, error) {
|
|
||||||
ch := m.newChannel(chanType, channelOutbound, extra)
|
|
||||||
|
|
||||||
ch.maxIncomingPayload = channelMaxPacket
|
|
||||||
|
|
||||||
open := channelOpenMsg{
|
|
||||||
ChanType: chanType,
|
|
||||||
PeersWindow: ch.myWindow,
|
|
||||||
MaxPacketSize: ch.maxIncomingPayload,
|
|
||||||
TypeSpecificData: extra,
|
|
||||||
PeersId: ch.localId,
|
|
||||||
}
|
|
||||||
if err := m.sendMessage(open); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
switch msg := (<-ch.msg).(type) {
|
|
||||||
case *channelOpenConfirmMsg:
|
|
||||||
return ch, nil
|
|
||||||
case *channelOpenFailureMsg:
|
|
||||||
return nil, &OpenChannelError{msg.Reason, msg.Message}
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("ssh: unexpected packet in response to channel open: %T", msg)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,525 +0,0 @@
|
||||||
// Copyright 2013 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package ssh
|
|
||||||
|
|
||||||
import (
|
|
||||||
"io"
|
|
||||||
"io/ioutil"
|
|
||||||
"sync"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func muxPair() (*mux, *mux) {
|
|
||||||
a, b := memPipe()
|
|
||||||
|
|
||||||
s := newMux(a)
|
|
||||||
c := newMux(b)
|
|
||||||
|
|
||||||
return s, c
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns both ends of a channel, and the mux for the the 2nd
|
|
||||||
// channel.
|
|
||||||
func channelPair(t *testing.T) (*channel, *channel, *mux) {
|
|
||||||
c, s := muxPair()
|
|
||||||
|
|
||||||
res := make(chan *channel, 1)
|
|
||||||
go func() {
|
|
||||||
newCh, ok := <-s.incomingChannels
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("No incoming channel")
|
|
||||||
}
|
|
||||||
if newCh.ChannelType() != "chan" {
|
|
||||||
t.Fatalf("got type %q want chan", newCh.ChannelType())
|
|
||||||
}
|
|
||||||
ch, _, err := newCh.Accept()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Accept %v", err)
|
|
||||||
}
|
|
||||||
res <- ch.(*channel)
|
|
||||||
}()
|
|
||||||
|
|
||||||
ch, err := c.openChannel("chan", nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("OpenChannel: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return <-res, ch, c
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test that stderr and stdout can be addressed from different
|
|
||||||
// goroutines. This is intended for use with the race detector.
|
|
||||||
func TestMuxChannelExtendedThreadSafety(t *testing.T) {
|
|
||||||
writer, reader, mux := channelPair(t)
|
|
||||||
defer writer.Close()
|
|
||||||
defer reader.Close()
|
|
||||||
defer mux.Close()
|
|
||||||
|
|
||||||
var wr, rd sync.WaitGroup
|
|
||||||
magic := "hello world"
|
|
||||||
|
|
||||||
wr.Add(2)
|
|
||||||
go func() {
|
|
||||||
io.WriteString(writer, magic)
|
|
||||||
wr.Done()
|
|
||||||
}()
|
|
||||||
go func() {
|
|
||||||
io.WriteString(writer.Stderr(), magic)
|
|
||||||
wr.Done()
|
|
||||||
}()
|
|
||||||
|
|
||||||
rd.Add(2)
|
|
||||||
go func() {
|
|
||||||
c, err := ioutil.ReadAll(reader)
|
|
||||||
if string(c) != magic {
|
|
||||||
t.Fatalf("stdout read got %q, want %q (error %s)", c, magic, err)
|
|
||||||
}
|
|
||||||
rd.Done()
|
|
||||||
}()
|
|
||||||
go func() {
|
|
||||||
c, err := ioutil.ReadAll(reader.Stderr())
|
|
||||||
if string(c) != magic {
|
|
||||||
t.Fatalf("stderr read got %q, want %q (error %s)", c, magic, err)
|
|
||||||
}
|
|
||||||
rd.Done()
|
|
||||||
}()
|
|
||||||
|
|
||||||
wr.Wait()
|
|
||||||
writer.CloseWrite()
|
|
||||||
rd.Wait()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMuxReadWrite(t *testing.T) {
|
|
||||||
s, c, mux := channelPair(t)
|
|
||||||
defer s.Close()
|
|
||||||
defer c.Close()
|
|
||||||
defer mux.Close()
|
|
||||||
|
|
||||||
magic := "hello world"
|
|
||||||
magicExt := "hello stderr"
|
|
||||||
go func() {
|
|
||||||
_, err := s.Write([]byte(magic))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Write: %v", err)
|
|
||||||
}
|
|
||||||
_, err = s.Extended(1).Write([]byte(magicExt))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Write: %v", err)
|
|
||||||
}
|
|
||||||
err = s.Close()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Close: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
var buf [1024]byte
|
|
||||||
n, err := c.Read(buf[:])
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("server Read: %v", err)
|
|
||||||
}
|
|
||||||
got := string(buf[:n])
|
|
||||||
if got != magic {
|
|
||||||
t.Fatalf("server: got %q want %q", got, magic)
|
|
||||||
}
|
|
||||||
|
|
||||||
n, err = c.Extended(1).Read(buf[:])
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("server Read: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
got = string(buf[:n])
|
|
||||||
if got != magicExt {
|
|
||||||
t.Fatalf("server: got %q want %q", got, magic)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMuxChannelOverflow(t *testing.T) {
|
|
||||||
reader, writer, mux := channelPair(t)
|
|
||||||
defer reader.Close()
|
|
||||||
defer writer.Close()
|
|
||||||
defer mux.Close()
|
|
||||||
|
|
||||||
wDone := make(chan int, 1)
|
|
||||||
go func() {
|
|
||||||
if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
|
|
||||||
t.Errorf("could not fill window: %v", err)
|
|
||||||
}
|
|
||||||
writer.Write(make([]byte, 1))
|
|
||||||
wDone <- 1
|
|
||||||
}()
|
|
||||||
writer.remoteWin.waitWriterBlocked()
|
|
||||||
|
|
||||||
// Send 1 byte.
|
|
||||||
packet := make([]byte, 1+4+4+1)
|
|
||||||
packet[0] = msgChannelData
|
|
||||||
marshalUint32(packet[1:], writer.remoteId)
|
|
||||||
marshalUint32(packet[5:], uint32(1))
|
|
||||||
packet[9] = 42
|
|
||||||
|
|
||||||
if err := writer.mux.conn.writePacket(packet); err != nil {
|
|
||||||
t.Errorf("could not send packet")
|
|
||||||
}
|
|
||||||
if _, err := reader.SendRequest("hello", true, nil); err == nil {
|
|
||||||
t.Errorf("SendRequest succeeded.")
|
|
||||||
}
|
|
||||||
<-wDone
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMuxChannelCloseWriteUnblock(t *testing.T) {
|
|
||||||
reader, writer, mux := channelPair(t)
|
|
||||||
defer reader.Close()
|
|
||||||
defer writer.Close()
|
|
||||||
defer mux.Close()
|
|
||||||
|
|
||||||
wDone := make(chan int, 1)
|
|
||||||
go func() {
|
|
||||||
if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
|
|
||||||
t.Errorf("could not fill window: %v", err)
|
|
||||||
}
|
|
||||||
if _, err := writer.Write(make([]byte, 1)); err != io.EOF {
|
|
||||||
t.Errorf("got %v, want EOF for unblock write", err)
|
|
||||||
}
|
|
||||||
wDone <- 1
|
|
||||||
}()
|
|
||||||
|
|
||||||
writer.remoteWin.waitWriterBlocked()
|
|
||||||
reader.Close()
|
|
||||||
<-wDone
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMuxConnectionCloseWriteUnblock(t *testing.T) {
|
|
||||||
reader, writer, mux := channelPair(t)
|
|
||||||
defer reader.Close()
|
|
||||||
defer writer.Close()
|
|
||||||
defer mux.Close()
|
|
||||||
|
|
||||||
wDone := make(chan int, 1)
|
|
||||||
go func() {
|
|
||||||
if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
|
|
||||||
t.Errorf("could not fill window: %v", err)
|
|
||||||
}
|
|
||||||
if _, err := writer.Write(make([]byte, 1)); err != io.EOF {
|
|
||||||
t.Errorf("got %v, want EOF for unblock write", err)
|
|
||||||
}
|
|
||||||
wDone <- 1
|
|
||||||
}()
|
|
||||||
|
|
||||||
writer.remoteWin.waitWriterBlocked()
|
|
||||||
mux.Close()
|
|
||||||
<-wDone
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMuxReject(t *testing.T) {
|
|
||||||
client, server := muxPair()
|
|
||||||
defer server.Close()
|
|
||||||
defer client.Close()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
ch, ok := <-server.incomingChannels
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("Accept")
|
|
||||||
}
|
|
||||||
if ch.ChannelType() != "ch" || string(ch.ExtraData()) != "extra" {
|
|
||||||
t.Fatalf("unexpected channel: %q, %q", ch.ChannelType(), ch.ExtraData())
|
|
||||||
}
|
|
||||||
ch.Reject(RejectionReason(42), "message")
|
|
||||||
}()
|
|
||||||
|
|
||||||
ch, err := client.openChannel("ch", []byte("extra"))
|
|
||||||
if ch != nil {
|
|
||||||
t.Fatal("openChannel not rejected")
|
|
||||||
}
|
|
||||||
|
|
||||||
ocf, ok := err.(*OpenChannelError)
|
|
||||||
if !ok {
|
|
||||||
t.Errorf("got %#v want *OpenChannelError", err)
|
|
||||||
} else if ocf.Reason != 42 || ocf.Message != "message" {
|
|
||||||
t.Errorf("got %#v, want {Reason: 42, Message: %q}", ocf, "message")
|
|
||||||
}
|
|
||||||
|
|
||||||
want := "ssh: rejected: unknown reason 42 (message)"
|
|
||||||
if err.Error() != want {
|
|
||||||
t.Errorf("got %q, want %q", err.Error(), want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMuxChannelRequest(t *testing.T) {
|
|
||||||
client, server, mux := channelPair(t)
|
|
||||||
defer server.Close()
|
|
||||||
defer client.Close()
|
|
||||||
defer mux.Close()
|
|
||||||
|
|
||||||
var received int
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
for r := range server.incomingRequests {
|
|
||||||
received++
|
|
||||||
r.Reply(r.Type == "yes", nil)
|
|
||||||
}
|
|
||||||
wg.Done()
|
|
||||||
}()
|
|
||||||
_, err := client.SendRequest("yes", false, nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("SendRequest: %v", err)
|
|
||||||
}
|
|
||||||
ok, err := client.SendRequest("yes", true, nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("SendRequest: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !ok {
|
|
||||||
t.Errorf("SendRequest(yes): %v", ok)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
ok, err = client.SendRequest("no", true, nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("SendRequest: %v", err)
|
|
||||||
}
|
|
||||||
if ok {
|
|
||||||
t.Errorf("SendRequest(no): %v", ok)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
client.Close()
|
|
||||||
wg.Wait()
|
|
||||||
|
|
||||||
if received != 3 {
|
|
||||||
t.Errorf("got %d requests, want %d", received, 3)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMuxGlobalRequest(t *testing.T) {
|
|
||||||
clientMux, serverMux := muxPair()
|
|
||||||
defer serverMux.Close()
|
|
||||||
defer clientMux.Close()
|
|
||||||
|
|
||||||
var seen bool
|
|
||||||
go func() {
|
|
||||||
for r := range serverMux.incomingRequests {
|
|
||||||
seen = seen || r.Type == "peek"
|
|
||||||
if r.WantReply {
|
|
||||||
err := r.Reply(r.Type == "yes",
|
|
||||||
append([]byte(r.Type), r.Payload...))
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("AckRequest: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
_, _, err := clientMux.SendRequest("peek", false, nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("SendRequest: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ok, data, err := clientMux.SendRequest("yes", true, []byte("a"))
|
|
||||||
if !ok || string(data) != "yesa" || err != nil {
|
|
||||||
t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v",
|
|
||||||
ok, data, err)
|
|
||||||
}
|
|
||||||
if ok, data, err := clientMux.SendRequest("yes", true, []byte("a")); !ok || string(data) != "yesa" || err != nil {
|
|
||||||
t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v",
|
|
||||||
ok, data, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if ok, data, err := clientMux.SendRequest("no", true, []byte("a")); ok || string(data) != "noa" || err != nil {
|
|
||||||
t.Errorf("SendRequest(\"no\", true, \"a\"): %v %v %v",
|
|
||||||
ok, data, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
clientMux.Disconnect(0, "")
|
|
||||||
if !seen {
|
|
||||||
t.Errorf("never saw 'peek' request")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMuxGlobalRequestUnblock(t *testing.T) {
|
|
||||||
clientMux, serverMux := muxPair()
|
|
||||||
defer serverMux.Close()
|
|
||||||
defer clientMux.Close()
|
|
||||||
|
|
||||||
result := make(chan error, 1)
|
|
||||||
go func() {
|
|
||||||
_, _, err := clientMux.SendRequest("hello", true, nil)
|
|
||||||
result <- err
|
|
||||||
}()
|
|
||||||
|
|
||||||
<-serverMux.incomingRequests
|
|
||||||
serverMux.conn.Close()
|
|
||||||
err := <-result
|
|
||||||
|
|
||||||
if err != io.EOF {
|
|
||||||
t.Errorf("want EOF, got %v", io.EOF)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMuxChannelRequestUnblock(t *testing.T) {
|
|
||||||
a, b, connB := channelPair(t)
|
|
||||||
defer a.Close()
|
|
||||||
defer b.Close()
|
|
||||||
defer connB.Close()
|
|
||||||
|
|
||||||
result := make(chan error, 1)
|
|
||||||
go func() {
|
|
||||||
_, err := a.SendRequest("hello", true, nil)
|
|
||||||
result <- err
|
|
||||||
}()
|
|
||||||
|
|
||||||
<-b.incomingRequests
|
|
||||||
connB.conn.Close()
|
|
||||||
err := <-result
|
|
||||||
|
|
||||||
if err != io.EOF {
|
|
||||||
t.Errorf("want EOF, got %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMuxDisconnect(t *testing.T) {
|
|
||||||
a, b := muxPair()
|
|
||||||
defer a.Close()
|
|
||||||
defer b.Close()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
for r := range b.incomingRequests {
|
|
||||||
r.Reply(true, nil)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
a.Disconnect(42, "whatever")
|
|
||||||
ok, _, err := a.SendRequest("hello", true, nil)
|
|
||||||
if ok || err == nil {
|
|
||||||
t.Errorf("got reply after disconnecting")
|
|
||||||
}
|
|
||||||
err = b.Wait()
|
|
||||||
if d, ok := err.(*disconnectMsg); !ok || d.Reason != 42 {
|
|
||||||
t.Errorf("got %#v, want disconnectMsg{Reason:42}", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMuxCloseChannel(t *testing.T) {
|
|
||||||
r, w, mux := channelPair(t)
|
|
||||||
defer mux.Close()
|
|
||||||
defer r.Close()
|
|
||||||
defer w.Close()
|
|
||||||
|
|
||||||
result := make(chan error, 1)
|
|
||||||
go func() {
|
|
||||||
var b [1024]byte
|
|
||||||
_, err := r.Read(b[:])
|
|
||||||
result <- err
|
|
||||||
}()
|
|
||||||
if err := w.Close(); err != nil {
|
|
||||||
t.Errorf("w.Close: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := w.Write([]byte("hello")); err != io.EOF {
|
|
||||||
t.Errorf("got err %v, want io.EOF after Close", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := <-result; err != io.EOF {
|
|
||||||
t.Errorf("got %v (%T), want io.EOF", err, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMuxCloseWriteChannel(t *testing.T) {
|
|
||||||
r, w, mux := channelPair(t)
|
|
||||||
defer mux.Close()
|
|
||||||
|
|
||||||
result := make(chan error, 1)
|
|
||||||
go func() {
|
|
||||||
var b [1024]byte
|
|
||||||
_, err := r.Read(b[:])
|
|
||||||
result <- err
|
|
||||||
}()
|
|
||||||
if err := w.CloseWrite(); err != nil {
|
|
||||||
t.Errorf("w.CloseWrite: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := w.Write([]byte("hello")); err != io.EOF {
|
|
||||||
t.Errorf("got err %v, want io.EOF after CloseWrite", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := <-result; err != io.EOF {
|
|
||||||
t.Errorf("got %v (%T), want io.EOF", err, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMuxInvalidRecord(t *testing.T) {
|
|
||||||
a, b := muxPair()
|
|
||||||
defer a.Close()
|
|
||||||
defer b.Close()
|
|
||||||
|
|
||||||
packet := make([]byte, 1+4+4+1)
|
|
||||||
packet[0] = msgChannelData
|
|
||||||
marshalUint32(packet[1:], 29348723 /* invalid channel id */)
|
|
||||||
marshalUint32(packet[5:], 1)
|
|
||||||
packet[9] = 42
|
|
||||||
|
|
||||||
a.conn.writePacket(packet)
|
|
||||||
go a.SendRequest("hello", false, nil)
|
|
||||||
// 'a' wrote an invalid packet, so 'b' has exited.
|
|
||||||
req, ok := <-b.incomingRequests
|
|
||||||
if ok {
|
|
||||||
t.Errorf("got request %#v after receiving invalid packet", req)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestZeroWindowAdjust(t *testing.T) {
|
|
||||||
a, b, mux := channelPair(t)
|
|
||||||
defer a.Close()
|
|
||||||
defer b.Close()
|
|
||||||
defer mux.Close()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
io.WriteString(a, "hello")
|
|
||||||
// bogus adjust.
|
|
||||||
a.sendMessage(windowAdjustMsg{})
|
|
||||||
io.WriteString(a, "world")
|
|
||||||
a.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
want := "helloworld"
|
|
||||||
c, _ := ioutil.ReadAll(b)
|
|
||||||
if string(c) != want {
|
|
||||||
t.Errorf("got %q want %q", c, want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMuxMaxPacketSize(t *testing.T) {
|
|
||||||
a, b, mux := channelPair(t)
|
|
||||||
defer a.Close()
|
|
||||||
defer b.Close()
|
|
||||||
defer mux.Close()
|
|
||||||
|
|
||||||
large := make([]byte, a.maxRemotePayload+1)
|
|
||||||
packet := make([]byte, 1+4+4+1+len(large))
|
|
||||||
packet[0] = msgChannelData
|
|
||||||
marshalUint32(packet[1:], a.remoteId)
|
|
||||||
marshalUint32(packet[5:], uint32(len(large)))
|
|
||||||
packet[9] = 42
|
|
||||||
|
|
||||||
if err := a.mux.conn.writePacket(packet); err != nil {
|
|
||||||
t.Errorf("could not send packet")
|
|
||||||
}
|
|
||||||
|
|
||||||
go a.SendRequest("hello", false, nil)
|
|
||||||
|
|
||||||
_, ok := <-b.incomingRequests
|
|
||||||
if ok {
|
|
||||||
t.Errorf("connection still alive after receiving large packet.")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Don't ship code with debug=true.
|
|
||||||
func TestDebug(t *testing.T) {
|
|
||||||
if debugMux {
|
|
||||||
t.Error("mux debug switched on")
|
|
||||||
}
|
|
||||||
if debugHandshake {
|
|
||||||
t.Error("handshake debug switched on")
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,477 +0,0 @@
|
||||||
// Copyright 2011 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package ssh
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
)
|
|
||||||
|
|
||||||
// The Permissions type holds fine-grained permissions that are
|
|
||||||
// specific to a user or a specific authentication method for a
|
|
||||||
// user. Permissions, except for "source-address", must be enforced in
|
|
||||||
// the server application layer, after successful authentication. The
|
|
||||||
// Permissions are passed on in ServerConn so a server implementation
|
|
||||||
// can honor them.
|
|
||||||
type Permissions struct {
|
|
||||||
// Critical options restrict default permissions. Common
|
|
||||||
// restrictions are "source-address" and "force-command". If
|
|
||||||
// the server cannot enforce the restriction, or does not
|
|
||||||
// recognize it, the user should not authenticate.
|
|
||||||
CriticalOptions map[string]string
|
|
||||||
|
|
||||||
// Extensions are extra functionality that the server may
|
|
||||||
// offer on authenticated connections. Common extensions are
|
|
||||||
// "permit-agent-forwarding", "permit-X11-forwarding". Lack of
|
|
||||||
// support for an extension does not preclude authenticating a
|
|
||||||
// user.
|
|
||||||
Extensions map[string]string
|
|
||||||
}
|
|
||||||
|
|
||||||
// ServerConfig holds server specific configuration data.
|
|
||||||
type ServerConfig struct {
|
|
||||||
// Config contains configuration shared between client and server.
|
|
||||||
Config
|
|
||||||
|
|
||||||
hostKeys []Signer
|
|
||||||
|
|
||||||
// NoClientAuth is true if clients are allowed to connect without
|
|
||||||
// authenticating.
|
|
||||||
NoClientAuth bool
|
|
||||||
|
|
||||||
// PasswordCallback, if non-nil, is called when a user
|
|
||||||
// attempts to authenticate using a password.
|
|
||||||
PasswordCallback func(conn ConnMetadata, password []byte) (*Permissions, error)
|
|
||||||
|
|
||||||
// PublicKeyCallback, if non-nil, is called when a client attempts public
|
|
||||||
// key authentication. It must return true if the given public key is
|
|
||||||
// valid for the given user. For example, see CertChecker.Authenticate.
|
|
||||||
PublicKeyCallback func(conn ConnMetadata, key PublicKey) (*Permissions, error)
|
|
||||||
|
|
||||||
// KeyboardInteractiveCallback, if non-nil, is called when
|
|
||||||
// keyboard-interactive authentication is selected (RFC
|
|
||||||
// 4256). The client object's Challenge function should be
|
|
||||||
// used to query the user. The callback may offer multiple
|
|
||||||
// Challenge rounds. To avoid information leaks, the client
|
|
||||||
// should be presented a challenge even if the user is
|
|
||||||
// unknown.
|
|
||||||
KeyboardInteractiveCallback func(conn ConnMetadata, client KeyboardInteractiveChallenge) (*Permissions, error)
|
|
||||||
|
|
||||||
// AuthLogCallback, if non-nil, is called to log all authentication
|
|
||||||
// attempts.
|
|
||||||
AuthLogCallback func(conn ConnMetadata, method string, err error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddHostKey adds a private key as a host key. If an existing host
|
|
||||||
// key exists with the same algorithm, it is overwritten. Each server
|
|
||||||
// config must have at least one host key.
|
|
||||||
func (s *ServerConfig) AddHostKey(key Signer) {
|
|
||||||
for i, k := range s.hostKeys {
|
|
||||||
if k.PublicKey().Type() == key.PublicKey().Type() {
|
|
||||||
s.hostKeys[i] = key
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
s.hostKeys = append(s.hostKeys, key)
|
|
||||||
}
|
|
||||||
|
|
||||||
// cachedPubKey contains the results of querying whether a public key is
|
|
||||||
// acceptable for a user.
|
|
||||||
type cachedPubKey struct {
|
|
||||||
user string
|
|
||||||
pubKeyData []byte
|
|
||||||
result error
|
|
||||||
perms *Permissions
|
|
||||||
}
|
|
||||||
|
|
||||||
const maxCachedPubKeys = 16
|
|
||||||
|
|
||||||
// pubKeyCache caches tests for public keys. Since SSH clients
|
|
||||||
// will query whether a public key is acceptable before attempting to
|
|
||||||
// authenticate with it, we end up with duplicate queries for public
|
|
||||||
// key validity. The cache only applies to a single ServerConn.
|
|
||||||
type pubKeyCache struct {
|
|
||||||
keys []cachedPubKey
|
|
||||||
}
|
|
||||||
|
|
||||||
// get returns the result for a given user/algo/key tuple.
|
|
||||||
func (c *pubKeyCache) get(user string, pubKeyData []byte) (cachedPubKey, bool) {
|
|
||||||
for _, k := range c.keys {
|
|
||||||
if k.user == user && bytes.Equal(k.pubKeyData, pubKeyData) {
|
|
||||||
return k, true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return cachedPubKey{}, false
|
|
||||||
}
|
|
||||||
|
|
||||||
// add adds the given tuple to the cache.
|
|
||||||
func (c *pubKeyCache) add(candidate cachedPubKey) {
|
|
||||||
if len(c.keys) < maxCachedPubKeys {
|
|
||||||
c.keys = append(c.keys, candidate)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ServerConn is an authenticated SSH connection, as seen from the
|
|
||||||
// server
|
|
||||||
type ServerConn struct {
|
|
||||||
Conn
|
|
||||||
|
|
||||||
// If the succeeding authentication callback returned a
|
|
||||||
// non-nil Permissions pointer, it is stored here.
|
|
||||||
Permissions *Permissions
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewServerConn starts a new SSH server with c as the underlying
|
|
||||||
// transport. It starts with a handshake and, if the handshake is
|
|
||||||
// unsuccessful, it closes the connection and returns an error. The
|
|
||||||
// Request and NewChannel channels must be serviced, or the connection
|
|
||||||
// will hang.
|
|
||||||
func NewServerConn(c net.Conn, config *ServerConfig) (*ServerConn, <-chan NewChannel, <-chan *Request, error) {
|
|
||||||
fullConf := *config
|
|
||||||
fullConf.SetDefaults()
|
|
||||||
s := &connection{
|
|
||||||
sshConn: sshConn{conn: c},
|
|
||||||
}
|
|
||||||
perms, err := s.serverHandshake(&fullConf)
|
|
||||||
if err != nil {
|
|
||||||
c.Close()
|
|
||||||
return nil, nil, nil, err
|
|
||||||
}
|
|
||||||
return &ServerConn{s, perms}, s.mux.incomingChannels, s.mux.incomingRequests, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// signAndMarshal signs the data with the appropriate algorithm,
|
|
||||||
// and serializes the result in SSH wire format.
|
|
||||||
func signAndMarshal(k Signer, rand io.Reader, data []byte) ([]byte, error) {
|
|
||||||
sig, err := k.Sign(rand, data)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return Marshal(sig), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// handshake performs key exchange and user authentication.
|
|
||||||
func (s *connection) serverHandshake(config *ServerConfig) (*Permissions, error) {
|
|
||||||
if len(config.hostKeys) == 0 {
|
|
||||||
return nil, errors.New("ssh: server has no host keys")
|
|
||||||
}
|
|
||||||
|
|
||||||
var err error
|
|
||||||
s.serverVersion = []byte(packageVersion)
|
|
||||||
s.clientVersion, err = exchangeVersions(s.sshConn.conn, s.serverVersion)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
tr := newTransport(s.sshConn.conn, config.Rand, false /* not client */)
|
|
||||||
s.transport = newServerTransport(tr, s.clientVersion, s.serverVersion, config)
|
|
||||||
|
|
||||||
if err := s.transport.requestKeyChange(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if packet, err := s.transport.readPacket(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
} else if packet[0] != msgNewKeys {
|
|
||||||
return nil, unexpectedMessageError(msgNewKeys, packet[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
var packet []byte
|
|
||||||
if packet, err = s.transport.readPacket(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var serviceRequest serviceRequestMsg
|
|
||||||
if err = Unmarshal(packet, &serviceRequest); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if serviceRequest.Service != serviceUserAuth {
|
|
||||||
return nil, errors.New("ssh: requested service '" + serviceRequest.Service + "' before authenticating")
|
|
||||||
}
|
|
||||||
serviceAccept := serviceAcceptMsg{
|
|
||||||
Service: serviceUserAuth,
|
|
||||||
}
|
|
||||||
if err := s.transport.writePacket(Marshal(&serviceAccept)); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
perms, err := s.serverAuthenticate(config)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
s.mux = newMux(s.transport)
|
|
||||||
return perms, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func isAcceptableAlgo(algo string) bool {
|
|
||||||
switch algo {
|
|
||||||
case KeyAlgoRSA, KeyAlgoDSA, KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521,
|
|
||||||
CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01:
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func checkSourceAddress(addr net.Addr, sourceAddr string) error {
|
|
||||||
if addr == nil {
|
|
||||||
return errors.New("ssh: no address known for client, but source-address match required")
|
|
||||||
}
|
|
||||||
|
|
||||||
tcpAddr, ok := addr.(*net.TCPAddr)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("ssh: remote address %v is not an TCP address when checking source-address match", addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
if allowedIP := net.ParseIP(sourceAddr); allowedIP != nil {
|
|
||||||
if bytes.Equal(allowedIP, tcpAddr.IP) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
_, ipNet, err := net.ParseCIDR(sourceAddr)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("ssh: error parsing source-address restriction %q: %v", sourceAddr, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if ipNet.Contains(tcpAddr.IP) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Errorf("ssh: remote address %v is not allowed because of source-address restriction", addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, error) {
|
|
||||||
var err error
|
|
||||||
var cache pubKeyCache
|
|
||||||
var perms *Permissions
|
|
||||||
|
|
||||||
userAuthLoop:
|
|
||||||
for {
|
|
||||||
var userAuthReq userAuthRequestMsg
|
|
||||||
if packet, err := s.transport.readPacket(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
} else if err = Unmarshal(packet, &userAuthReq); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if userAuthReq.Service != serviceSSH {
|
|
||||||
return nil, errors.New("ssh: client attempted to negotiate for unknown service: " + userAuthReq.Service)
|
|
||||||
}
|
|
||||||
|
|
||||||
s.user = userAuthReq.User
|
|
||||||
perms = nil
|
|
||||||
authErr := errors.New("no auth passed yet")
|
|
||||||
|
|
||||||
switch userAuthReq.Method {
|
|
||||||
case "none":
|
|
||||||
if config.NoClientAuth {
|
|
||||||
s.user = ""
|
|
||||||
authErr = nil
|
|
||||||
}
|
|
||||||
case "password":
|
|
||||||
if config.PasswordCallback == nil {
|
|
||||||
authErr = errors.New("ssh: password auth not configured")
|
|
||||||
break
|
|
||||||
}
|
|
||||||
payload := userAuthReq.Payload
|
|
||||||
if len(payload) < 1 || payload[0] != 0 {
|
|
||||||
return nil, parseError(msgUserAuthRequest)
|
|
||||||
}
|
|
||||||
payload = payload[1:]
|
|
||||||
password, payload, ok := parseString(payload)
|
|
||||||
if !ok || len(payload) > 0 {
|
|
||||||
return nil, parseError(msgUserAuthRequest)
|
|
||||||
}
|
|
||||||
|
|
||||||
perms, authErr = config.PasswordCallback(s, password)
|
|
||||||
case "keyboard-interactive":
|
|
||||||
if config.KeyboardInteractiveCallback == nil {
|
|
||||||
authErr = errors.New("ssh: keyboard-interactive auth not configubred")
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
prompter := &sshClientKeyboardInteractive{s}
|
|
||||||
perms, authErr = config.KeyboardInteractiveCallback(s, prompter.Challenge)
|
|
||||||
case "publickey":
|
|
||||||
if config.PublicKeyCallback == nil {
|
|
||||||
authErr = errors.New("ssh: publickey auth not configured")
|
|
||||||
break
|
|
||||||
}
|
|
||||||
payload := userAuthReq.Payload
|
|
||||||
if len(payload) < 1 {
|
|
||||||
return nil, parseError(msgUserAuthRequest)
|
|
||||||
}
|
|
||||||
isQuery := payload[0] == 0
|
|
||||||
payload = payload[1:]
|
|
||||||
algoBytes, payload, ok := parseString(payload)
|
|
||||||
if !ok {
|
|
||||||
return nil, parseError(msgUserAuthRequest)
|
|
||||||
}
|
|
||||||
algo := string(algoBytes)
|
|
||||||
if !isAcceptableAlgo(algo) {
|
|
||||||
authErr = fmt.Errorf("ssh: algorithm %q not accepted", algo)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
pubKeyData, payload, ok := parseString(payload)
|
|
||||||
if !ok {
|
|
||||||
return nil, parseError(msgUserAuthRequest)
|
|
||||||
}
|
|
||||||
|
|
||||||
pubKey, err := ParsePublicKey(pubKeyData)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
candidate, ok := cache.get(s.user, pubKeyData)
|
|
||||||
if !ok {
|
|
||||||
candidate.user = s.user
|
|
||||||
candidate.pubKeyData = pubKeyData
|
|
||||||
candidate.perms, candidate.result = config.PublicKeyCallback(s, pubKey)
|
|
||||||
if candidate.result == nil && candidate.perms != nil && candidate.perms.CriticalOptions != nil && candidate.perms.CriticalOptions[sourceAddressCriticalOption] != "" {
|
|
||||||
candidate.result = checkSourceAddress(
|
|
||||||
s.RemoteAddr(),
|
|
||||||
candidate.perms.CriticalOptions[sourceAddressCriticalOption])
|
|
||||||
}
|
|
||||||
cache.add(candidate)
|
|
||||||
}
|
|
||||||
|
|
||||||
if isQuery {
|
|
||||||
// The client can query if the given public key
|
|
||||||
// would be okay.
|
|
||||||
if len(payload) > 0 {
|
|
||||||
return nil, parseError(msgUserAuthRequest)
|
|
||||||
}
|
|
||||||
|
|
||||||
if candidate.result == nil {
|
|
||||||
okMsg := userAuthPubKeyOkMsg{
|
|
||||||
Algo: algo,
|
|
||||||
PubKey: pubKeyData,
|
|
||||||
}
|
|
||||||
if err = s.transport.writePacket(Marshal(&okMsg)); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
continue userAuthLoop
|
|
||||||
}
|
|
||||||
authErr = candidate.result
|
|
||||||
} else {
|
|
||||||
sig, payload, ok := parseSignature(payload)
|
|
||||||
if !ok || len(payload) > 0 {
|
|
||||||
return nil, parseError(msgUserAuthRequest)
|
|
||||||
}
|
|
||||||
// Ensure the public key algo and signature algo
|
|
||||||
// are supported. Compare the private key
|
|
||||||
// algorithm name that corresponds to algo with
|
|
||||||
// sig.Format. This is usually the same, but
|
|
||||||
// for certs, the names differ.
|
|
||||||
if !isAcceptableAlgo(sig.Format) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
signedData := buildDataSignedForAuth(s.transport.getSessionID(), userAuthReq, algoBytes, pubKeyData)
|
|
||||||
|
|
||||||
if err := pubKey.Verify(signedData, sig); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
authErr = candidate.result
|
|
||||||
perms = candidate.perms
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
authErr = fmt.Errorf("ssh: unknown method %q", userAuthReq.Method)
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.AuthLogCallback != nil {
|
|
||||||
config.AuthLogCallback(s, userAuthReq.Method, authErr)
|
|
||||||
}
|
|
||||||
|
|
||||||
if authErr == nil {
|
|
||||||
break userAuthLoop
|
|
||||||
}
|
|
||||||
|
|
||||||
var failureMsg userAuthFailureMsg
|
|
||||||
if config.PasswordCallback != nil {
|
|
||||||
failureMsg.Methods = append(failureMsg.Methods, "password")
|
|
||||||
}
|
|
||||||
if config.PublicKeyCallback != nil {
|
|
||||||
failureMsg.Methods = append(failureMsg.Methods, "publickey")
|
|
||||||
}
|
|
||||||
if config.KeyboardInteractiveCallback != nil {
|
|
||||||
failureMsg.Methods = append(failureMsg.Methods, "keyboard-interactive")
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(failureMsg.Methods) == 0 {
|
|
||||||
return nil, errors.New("ssh: no authentication methods configured but NoClientAuth is also false")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = s.transport.writePacket(Marshal(&failureMsg)); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = s.transport.writePacket([]byte{msgUserAuthSuccess}); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return perms, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// sshClientKeyboardInteractive implements a ClientKeyboardInteractive by
|
|
||||||
// asking the client on the other side of a ServerConn.
|
|
||||||
type sshClientKeyboardInteractive struct {
|
|
||||||
*connection
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *sshClientKeyboardInteractive) Challenge(user, instruction string, questions []string, echos []bool) (answers []string, err error) {
|
|
||||||
if len(questions) != len(echos) {
|
|
||||||
return nil, errors.New("ssh: echos and questions must have equal length")
|
|
||||||
}
|
|
||||||
|
|
||||||
var prompts []byte
|
|
||||||
for i := range questions {
|
|
||||||
prompts = appendString(prompts, questions[i])
|
|
||||||
prompts = appendBool(prompts, echos[i])
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := c.transport.writePacket(Marshal(&userAuthInfoRequestMsg{
|
|
||||||
Instruction: instruction,
|
|
||||||
NumPrompts: uint32(len(questions)),
|
|
||||||
Prompts: prompts,
|
|
||||||
})); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
packet, err := c.transport.readPacket()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if packet[0] != msgUserAuthInfoResponse {
|
|
||||||
return nil, unexpectedMessageError(msgUserAuthInfoResponse, packet[0])
|
|
||||||
}
|
|
||||||
packet = packet[1:]
|
|
||||||
|
|
||||||
n, packet, ok := parseUint32(packet)
|
|
||||||
if !ok || int(n) != len(questions) {
|
|
||||||
return nil, parseError(msgUserAuthInfoResponse)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := uint32(0); i < n; i++ {
|
|
||||||
ans, rest, ok := parseString(packet)
|
|
||||||
if !ok {
|
|
||||||
return nil, parseError(msgUserAuthInfoResponse)
|
|
||||||
}
|
|
||||||
|
|
||||||
answers = append(answers, string(ans))
|
|
||||||
packet = rest
|
|
||||||
}
|
|
||||||
if len(packet) != 0 {
|
|
||||||
return nil, errors.New("ssh: junk at end of message")
|
|
||||||
}
|
|
||||||
|
|
||||||
return answers, nil
|
|
||||||
}
|
|
|
@ -1,605 +0,0 @@
|
||||||
// Copyright 2011 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package ssh
|
|
||||||
|
|
||||||
// Session implements an interactive session described in
|
|
||||||
// "RFC 4254, section 6".
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"io/ioutil"
|
|
||||||
"sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Signal string
|
|
||||||
|
|
||||||
// POSIX signals as listed in RFC 4254 Section 6.10.
|
|
||||||
const (
|
|
||||||
SIGABRT Signal = "ABRT"
|
|
||||||
SIGALRM Signal = "ALRM"
|
|
||||||
SIGFPE Signal = "FPE"
|
|
||||||
SIGHUP Signal = "HUP"
|
|
||||||
SIGILL Signal = "ILL"
|
|
||||||
SIGINT Signal = "INT"
|
|
||||||
SIGKILL Signal = "KILL"
|
|
||||||
SIGPIPE Signal = "PIPE"
|
|
||||||
SIGQUIT Signal = "QUIT"
|
|
||||||
SIGSEGV Signal = "SEGV"
|
|
||||||
SIGTERM Signal = "TERM"
|
|
||||||
SIGUSR1 Signal = "USR1"
|
|
||||||
SIGUSR2 Signal = "USR2"
|
|
||||||
)
|
|
||||||
|
|
||||||
var signals = map[Signal]int{
|
|
||||||
SIGABRT: 6,
|
|
||||||
SIGALRM: 14,
|
|
||||||
SIGFPE: 8,
|
|
||||||
SIGHUP: 1,
|
|
||||||
SIGILL: 4,
|
|
||||||
SIGINT: 2,
|
|
||||||
SIGKILL: 9,
|
|
||||||
SIGPIPE: 13,
|
|
||||||
SIGQUIT: 3,
|
|
||||||
SIGSEGV: 11,
|
|
||||||
SIGTERM: 15,
|
|
||||||
}
|
|
||||||
|
|
||||||
type TerminalModes map[uint8]uint32
|
|
||||||
|
|
||||||
// POSIX terminal mode flags as listed in RFC 4254 Section 8.
|
|
||||||
const (
|
|
||||||
tty_OP_END = 0
|
|
||||||
VINTR = 1
|
|
||||||
VQUIT = 2
|
|
||||||
VERASE = 3
|
|
||||||
VKILL = 4
|
|
||||||
VEOF = 5
|
|
||||||
VEOL = 6
|
|
||||||
VEOL2 = 7
|
|
||||||
VSTART = 8
|
|
||||||
VSTOP = 9
|
|
||||||
VSUSP = 10
|
|
||||||
VDSUSP = 11
|
|
||||||
VREPRINT = 12
|
|
||||||
VWERASE = 13
|
|
||||||
VLNEXT = 14
|
|
||||||
VFLUSH = 15
|
|
||||||
VSWTCH = 16
|
|
||||||
VSTATUS = 17
|
|
||||||
VDISCARD = 18
|
|
||||||
IGNPAR = 30
|
|
||||||
PARMRK = 31
|
|
||||||
INPCK = 32
|
|
||||||
ISTRIP = 33
|
|
||||||
INLCR = 34
|
|
||||||
IGNCR = 35
|
|
||||||
ICRNL = 36
|
|
||||||
IUCLC = 37
|
|
||||||
IXON = 38
|
|
||||||
IXANY = 39
|
|
||||||
IXOFF = 40
|
|
||||||
IMAXBEL = 41
|
|
||||||
ISIG = 50
|
|
||||||
ICANON = 51
|
|
||||||
XCASE = 52
|
|
||||||
ECHO = 53
|
|
||||||
ECHOE = 54
|
|
||||||
ECHOK = 55
|
|
||||||
ECHONL = 56
|
|
||||||
NOFLSH = 57
|
|
||||||
TOSTOP = 58
|
|
||||||
IEXTEN = 59
|
|
||||||
ECHOCTL = 60
|
|
||||||
ECHOKE = 61
|
|
||||||
PENDIN = 62
|
|
||||||
OPOST = 70
|
|
||||||
OLCUC = 71
|
|
||||||
ONLCR = 72
|
|
||||||
OCRNL = 73
|
|
||||||
ONOCR = 74
|
|
||||||
ONLRET = 75
|
|
||||||
CS7 = 90
|
|
||||||
CS8 = 91
|
|
||||||
PARENB = 92
|
|
||||||
PARODD = 93
|
|
||||||
TTY_OP_ISPEED = 128
|
|
||||||
TTY_OP_OSPEED = 129
|
|
||||||
)
|
|
||||||
|
|
||||||
// A Session represents a connection to a remote command or shell.
|
|
||||||
type Session struct {
|
|
||||||
// Stdin specifies the remote process's standard input.
|
|
||||||
// If Stdin is nil, the remote process reads from an empty
|
|
||||||
// bytes.Buffer.
|
|
||||||
Stdin io.Reader
|
|
||||||
|
|
||||||
// Stdout and Stderr specify the remote process's standard
|
|
||||||
// output and error.
|
|
||||||
//
|
|
||||||
// If either is nil, Run connects the corresponding file
|
|
||||||
// descriptor to an instance of ioutil.Discard. There is a
|
|
||||||
// fixed amount of buffering that is shared for the two streams.
|
|
||||||
// If either blocks it may eventually cause the remote
|
|
||||||
// command to block.
|
|
||||||
Stdout io.Writer
|
|
||||||
Stderr io.Writer
|
|
||||||
|
|
||||||
ch Channel // the channel backing this session
|
|
||||||
started bool // true once Start, Run or Shell is invoked.
|
|
||||||
copyFuncs []func() error
|
|
||||||
errors chan error // one send per copyFunc
|
|
||||||
|
|
||||||
// true if pipe method is active
|
|
||||||
stdinpipe, stdoutpipe, stderrpipe bool
|
|
||||||
|
|
||||||
// stdinPipeWriter is non-nil if StdinPipe has not been called
|
|
||||||
// and Stdin was specified by the user; it is the write end of
|
|
||||||
// a pipe connecting Session.Stdin to the stdin channel.
|
|
||||||
stdinPipeWriter io.WriteCloser
|
|
||||||
|
|
||||||
exitStatus chan error
|
|
||||||
}
|
|
||||||
|
|
||||||
// SendRequest sends an out-of-band channel request on the SSH channel
|
|
||||||
// underlying the session.
|
|
||||||
func (s *Session) SendRequest(name string, wantReply bool, payload []byte) (bool, error) {
|
|
||||||
return s.ch.SendRequest(name, wantReply, payload)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Session) Close() error {
|
|
||||||
return s.ch.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
// RFC 4254 Section 6.4.
|
|
||||||
type setenvRequest struct {
|
|
||||||
Name string
|
|
||||||
Value string
|
|
||||||
}
|
|
||||||
|
|
||||||
// Setenv sets an environment variable that will be applied to any
|
|
||||||
// command executed by Shell or Run.
|
|
||||||
func (s *Session) Setenv(name, value string) error {
|
|
||||||
msg := setenvRequest{
|
|
||||||
Name: name,
|
|
||||||
Value: value,
|
|
||||||
}
|
|
||||||
ok, err := s.ch.SendRequest("env", true, Marshal(&msg))
|
|
||||||
if err == nil && !ok {
|
|
||||||
err = errors.New("ssh: setenv failed")
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// RFC 4254 Section 6.2.
|
|
||||||
type ptyRequestMsg struct {
|
|
||||||
Term string
|
|
||||||
Columns uint32
|
|
||||||
Rows uint32
|
|
||||||
Width uint32
|
|
||||||
Height uint32
|
|
||||||
Modelist string
|
|
||||||
}
|
|
||||||
|
|
||||||
// RequestPty requests the association of a pty with the session on the remote host.
|
|
||||||
func (s *Session) RequestPty(term string, h, w int, termmodes TerminalModes) error {
|
|
||||||
var tm []byte
|
|
||||||
for k, v := range termmodes {
|
|
||||||
kv := struct {
|
|
||||||
Key byte
|
|
||||||
Val uint32
|
|
||||||
}{k, v}
|
|
||||||
|
|
||||||
tm = append(tm, Marshal(&kv)...)
|
|
||||||
}
|
|
||||||
tm = append(tm, tty_OP_END)
|
|
||||||
req := ptyRequestMsg{
|
|
||||||
Term: term,
|
|
||||||
Columns: uint32(w),
|
|
||||||
Rows: uint32(h),
|
|
||||||
Width: uint32(w * 8),
|
|
||||||
Height: uint32(h * 8),
|
|
||||||
Modelist: string(tm),
|
|
||||||
}
|
|
||||||
ok, err := s.ch.SendRequest("pty-req", true, Marshal(&req))
|
|
||||||
if err == nil && !ok {
|
|
||||||
err = errors.New("ssh: pty-req failed")
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// RFC 4254 Section 6.5.
|
|
||||||
type subsystemRequestMsg struct {
|
|
||||||
Subsystem string
|
|
||||||
}
|
|
||||||
|
|
||||||
// RequestSubsystem requests the association of a subsystem with the session on the remote host.
|
|
||||||
// A subsystem is a predefined command that runs in the background when the ssh session is initiated
|
|
||||||
func (s *Session) RequestSubsystem(subsystem string) error {
|
|
||||||
msg := subsystemRequestMsg{
|
|
||||||
Subsystem: subsystem,
|
|
||||||
}
|
|
||||||
ok, err := s.ch.SendRequest("subsystem", true, Marshal(&msg))
|
|
||||||
if err == nil && !ok {
|
|
||||||
err = errors.New("ssh: subsystem request failed")
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// RFC 4254 Section 6.9.
|
|
||||||
type signalMsg struct {
|
|
||||||
Signal string
|
|
||||||
}
|
|
||||||
|
|
||||||
// Signal sends the given signal to the remote process.
|
|
||||||
// sig is one of the SIG* constants.
|
|
||||||
func (s *Session) Signal(sig Signal) error {
|
|
||||||
msg := signalMsg{
|
|
||||||
Signal: string(sig),
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := s.ch.SendRequest("signal", false, Marshal(&msg))
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// RFC 4254 Section 6.5.
|
|
||||||
type execMsg struct {
|
|
||||||
Command string
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start runs cmd on the remote host. Typically, the remote
|
|
||||||
// server passes cmd to the shell for interpretation.
|
|
||||||
// A Session only accepts one call to Run, Start or Shell.
|
|
||||||
func (s *Session) Start(cmd string) error {
|
|
||||||
if s.started {
|
|
||||||
return errors.New("ssh: session already started")
|
|
||||||
}
|
|
||||||
req := execMsg{
|
|
||||||
Command: cmd,
|
|
||||||
}
|
|
||||||
|
|
||||||
ok, err := s.ch.SendRequest("exec", true, Marshal(&req))
|
|
||||||
if err == nil && !ok {
|
|
||||||
err = fmt.Errorf("ssh: command %v failed", cmd)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return s.start()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Run runs cmd on the remote host. Typically, the remote
|
|
||||||
// server passes cmd to the shell for interpretation.
|
|
||||||
// A Session only accepts one call to Run, Start, Shell, Output,
|
|
||||||
// or CombinedOutput.
|
|
||||||
//
|
|
||||||
// The returned error is nil if the command runs, has no problems
|
|
||||||
// copying stdin, stdout, and stderr, and exits with a zero exit
|
|
||||||
// status.
|
|
||||||
//
|
|
||||||
// If the command fails to run or doesn't complete successfully, the
|
|
||||||
// error is of type *ExitError. Other error types may be
|
|
||||||
// returned for I/O problems.
|
|
||||||
func (s *Session) Run(cmd string) error {
|
|
||||||
err := s.Start(cmd)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return s.Wait()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Output runs cmd on the remote host and returns its standard output.
|
|
||||||
func (s *Session) Output(cmd string) ([]byte, error) {
|
|
||||||
if s.Stdout != nil {
|
|
||||||
return nil, errors.New("ssh: Stdout already set")
|
|
||||||
}
|
|
||||||
var b bytes.Buffer
|
|
||||||
s.Stdout = &b
|
|
||||||
err := s.Run(cmd)
|
|
||||||
return b.Bytes(), err
|
|
||||||
}
|
|
||||||
|
|
||||||
type singleWriter struct {
|
|
||||||
b bytes.Buffer
|
|
||||||
mu sync.Mutex
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *singleWriter) Write(p []byte) (int, error) {
|
|
||||||
w.mu.Lock()
|
|
||||||
defer w.mu.Unlock()
|
|
||||||
return w.b.Write(p)
|
|
||||||
}
|
|
||||||
|
|
||||||
// CombinedOutput runs cmd on the remote host and returns its combined
|
|
||||||
// standard output and standard error.
|
|
||||||
func (s *Session) CombinedOutput(cmd string) ([]byte, error) {
|
|
||||||
if s.Stdout != nil {
|
|
||||||
return nil, errors.New("ssh: Stdout already set")
|
|
||||||
}
|
|
||||||
if s.Stderr != nil {
|
|
||||||
return nil, errors.New("ssh: Stderr already set")
|
|
||||||
}
|
|
||||||
var b singleWriter
|
|
||||||
s.Stdout = &b
|
|
||||||
s.Stderr = &b
|
|
||||||
err := s.Run(cmd)
|
|
||||||
return b.b.Bytes(), err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Shell starts a login shell on the remote host. A Session only
|
|
||||||
// accepts one call to Run, Start, Shell, Output, or CombinedOutput.
|
|
||||||
func (s *Session) Shell() error {
|
|
||||||
if s.started {
|
|
||||||
return errors.New("ssh: session already started")
|
|
||||||
}
|
|
||||||
|
|
||||||
ok, err := s.ch.SendRequest("shell", true, nil)
|
|
||||||
if err == nil && !ok {
|
|
||||||
return fmt.Errorf("ssh: cound not start shell")
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return s.start()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Session) start() error {
|
|
||||||
s.started = true
|
|
||||||
|
|
||||||
type F func(*Session)
|
|
||||||
for _, setupFd := range []F{(*Session).stdin, (*Session).stdout, (*Session).stderr} {
|
|
||||||
setupFd(s)
|
|
||||||
}
|
|
||||||
|
|
||||||
s.errors = make(chan error, len(s.copyFuncs))
|
|
||||||
for _, fn := range s.copyFuncs {
|
|
||||||
go func(fn func() error) {
|
|
||||||
s.errors <- fn()
|
|
||||||
}(fn)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait waits for the remote command to exit.
|
|
||||||
//
|
|
||||||
// The returned error is nil if the command runs, has no problems
|
|
||||||
// copying stdin, stdout, and stderr, and exits with a zero exit
|
|
||||||
// status.
|
|
||||||
//
|
|
||||||
// If the command fails to run or doesn't complete successfully, the
|
|
||||||
// error is of type *ExitError. Other error types may be
|
|
||||||
// returned for I/O problems.
|
|
||||||
func (s *Session) Wait() error {
|
|
||||||
if !s.started {
|
|
||||||
return errors.New("ssh: session not started")
|
|
||||||
}
|
|
||||||
waitErr := <-s.exitStatus
|
|
||||||
|
|
||||||
if s.stdinPipeWriter != nil {
|
|
||||||
s.stdinPipeWriter.Close()
|
|
||||||
}
|
|
||||||
var copyError error
|
|
||||||
for _ = range s.copyFuncs {
|
|
||||||
if err := <-s.errors; err != nil && copyError == nil {
|
|
||||||
copyError = err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if waitErr != nil {
|
|
||||||
return waitErr
|
|
||||||
}
|
|
||||||
return copyError
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Session) wait(reqs <-chan *Request) error {
|
|
||||||
wm := Waitmsg{status: -1}
|
|
||||||
// Wait for msg channel to be closed before returning.
|
|
||||||
for msg := range reqs {
|
|
||||||
switch msg.Type {
|
|
||||||
case "exit-status":
|
|
||||||
d := msg.Payload
|
|
||||||
wm.status = int(d[0])<<24 | int(d[1])<<16 | int(d[2])<<8 | int(d[3])
|
|
||||||
case "exit-signal":
|
|
||||||
var sigval struct {
|
|
||||||
Signal string
|
|
||||||
CoreDumped bool
|
|
||||||
Error string
|
|
||||||
Lang string
|
|
||||||
}
|
|
||||||
if err := Unmarshal(msg.Payload, &sigval); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Must sanitize strings?
|
|
||||||
wm.signal = sigval.Signal
|
|
||||||
wm.msg = sigval.Error
|
|
||||||
wm.lang = sigval.Lang
|
|
||||||
default:
|
|
||||||
// This handles keepalives and matches
|
|
||||||
// OpenSSH's behaviour.
|
|
||||||
if msg.WantReply {
|
|
||||||
msg.Reply(false, nil)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if wm.status == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if wm.status == -1 {
|
|
||||||
// exit-status was never sent from server
|
|
||||||
if wm.signal == "" {
|
|
||||||
return errors.New("wait: remote command exited without exit status or exit signal")
|
|
||||||
}
|
|
||||||
wm.status = 128
|
|
||||||
if _, ok := signals[Signal(wm.signal)]; ok {
|
|
||||||
wm.status += signals[Signal(wm.signal)]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return &ExitError{wm}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Session) stdin() {
|
|
||||||
if s.stdinpipe {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var stdin io.Reader
|
|
||||||
if s.Stdin == nil {
|
|
||||||
stdin = new(bytes.Buffer)
|
|
||||||
} else {
|
|
||||||
r, w := io.Pipe()
|
|
||||||
go func() {
|
|
||||||
_, err := io.Copy(w, s.Stdin)
|
|
||||||
w.CloseWithError(err)
|
|
||||||
}()
|
|
||||||
stdin, s.stdinPipeWriter = r, w
|
|
||||||
}
|
|
||||||
s.copyFuncs = append(s.copyFuncs, func() error {
|
|
||||||
_, err := io.Copy(s.ch, stdin)
|
|
||||||
if err1 := s.ch.CloseWrite(); err == nil && err1 != io.EOF {
|
|
||||||
err = err1
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Session) stdout() {
|
|
||||||
if s.stdoutpipe {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.Stdout == nil {
|
|
||||||
s.Stdout = ioutil.Discard
|
|
||||||
}
|
|
||||||
s.copyFuncs = append(s.copyFuncs, func() error {
|
|
||||||
_, err := io.Copy(s.Stdout, s.ch)
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Session) stderr() {
|
|
||||||
if s.stderrpipe {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.Stderr == nil {
|
|
||||||
s.Stderr = ioutil.Discard
|
|
||||||
}
|
|
||||||
s.copyFuncs = append(s.copyFuncs, func() error {
|
|
||||||
_, err := io.Copy(s.Stderr, s.ch.Stderr())
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// sessionStdin reroutes Close to CloseWrite.
|
|
||||||
type sessionStdin struct {
|
|
||||||
io.Writer
|
|
||||||
ch Channel
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *sessionStdin) Close() error {
|
|
||||||
return s.ch.CloseWrite()
|
|
||||||
}
|
|
||||||
|
|
||||||
// StdinPipe returns a pipe that will be connected to the
|
|
||||||
// remote command's standard input when the command starts.
|
|
||||||
func (s *Session) StdinPipe() (io.WriteCloser, error) {
|
|
||||||
if s.Stdin != nil {
|
|
||||||
return nil, errors.New("ssh: Stdin already set")
|
|
||||||
}
|
|
||||||
if s.started {
|
|
||||||
return nil, errors.New("ssh: StdinPipe after process started")
|
|
||||||
}
|
|
||||||
s.stdinpipe = true
|
|
||||||
return &sessionStdin{s.ch, s.ch}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// StdoutPipe returns a pipe that will be connected to the
|
|
||||||
// remote command's standard output when the command starts.
|
|
||||||
// There is a fixed amount of buffering that is shared between
|
|
||||||
// stdout and stderr streams. If the StdoutPipe reader is
|
|
||||||
// not serviced fast enough it may eventually cause the
|
|
||||||
// remote command to block.
|
|
||||||
func (s *Session) StdoutPipe() (io.Reader, error) {
|
|
||||||
if s.Stdout != nil {
|
|
||||||
return nil, errors.New("ssh: Stdout already set")
|
|
||||||
}
|
|
||||||
if s.started {
|
|
||||||
return nil, errors.New("ssh: StdoutPipe after process started")
|
|
||||||
}
|
|
||||||
s.stdoutpipe = true
|
|
||||||
return s.ch, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// StderrPipe returns a pipe that will be connected to the
|
|
||||||
// remote command's standard error when the command starts.
|
|
||||||
// There is a fixed amount of buffering that is shared between
|
|
||||||
// stdout and stderr streams. If the StderrPipe reader is
|
|
||||||
// not serviced fast enough it may eventually cause the
|
|
||||||
// remote command to block.
|
|
||||||
func (s *Session) StderrPipe() (io.Reader, error) {
|
|
||||||
if s.Stderr != nil {
|
|
||||||
return nil, errors.New("ssh: Stderr already set")
|
|
||||||
}
|
|
||||||
if s.started {
|
|
||||||
return nil, errors.New("ssh: StderrPipe after process started")
|
|
||||||
}
|
|
||||||
s.stderrpipe = true
|
|
||||||
return s.ch.Stderr(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// newSession returns a new interactive session on the remote host.
|
|
||||||
func newSession(ch Channel, reqs <-chan *Request) (*Session, error) {
|
|
||||||
s := &Session{
|
|
||||||
ch: ch,
|
|
||||||
}
|
|
||||||
s.exitStatus = make(chan error, 1)
|
|
||||||
go func() {
|
|
||||||
s.exitStatus <- s.wait(reqs)
|
|
||||||
}()
|
|
||||||
|
|
||||||
return s, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// An ExitError reports unsuccessful completion of a remote command.
|
|
||||||
type ExitError struct {
|
|
||||||
Waitmsg
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *ExitError) Error() string {
|
|
||||||
return e.Waitmsg.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Waitmsg stores the information about an exited remote command
|
|
||||||
// as reported by Wait.
|
|
||||||
type Waitmsg struct {
|
|
||||||
status int
|
|
||||||
signal string
|
|
||||||
msg string
|
|
||||||
lang string
|
|
||||||
}
|
|
||||||
|
|
||||||
// ExitStatus returns the exit status of the remote command.
|
|
||||||
func (w Waitmsg) ExitStatus() int {
|
|
||||||
return w.status
|
|
||||||
}
|
|
||||||
|
|
||||||
// Signal returns the exit signal of the remote command if
|
|
||||||
// it was terminated violently.
|
|
||||||
func (w Waitmsg) Signal() string {
|
|
||||||
return w.signal
|
|
||||||
}
|
|
||||||
|
|
||||||
// Msg returns the exit message given by the remote command
|
|
||||||
func (w Waitmsg) Msg() string {
|
|
||||||
return w.msg
|
|
||||||
}
|
|
||||||
|
|
||||||
// Lang returns the language tag. See RFC 3066
|
|
||||||
func (w Waitmsg) Lang() string {
|
|
||||||
return w.lang
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w Waitmsg) String() string {
|
|
||||||
return fmt.Sprintf("Process exited with: %v. Reason was: %v (%v)", w.status, w.msg, w.signal)
|
|
||||||
}
|
|
|
@ -1,628 +0,0 @@
|
||||||
// Copyright 2011 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package ssh
|
|
||||||
|
|
||||||
// Session tests.
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
crypto_rand "crypto/rand"
|
|
||||||
"io"
|
|
||||||
"io/ioutil"
|
|
||||||
"math/rand"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh/terminal"
|
|
||||||
)
|
|
||||||
|
|
||||||
type serverType func(Channel, <-chan *Request, *testing.T)
|
|
||||||
|
|
||||||
// dial constructs a new test server and returns a *ClientConn.
|
|
||||||
func dial(handler serverType, t *testing.T) *Client {
|
|
||||||
c1, c2, err := netPipe()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("netPipe: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
defer c1.Close()
|
|
||||||
conf := ServerConfig{
|
|
||||||
NoClientAuth: true,
|
|
||||||
}
|
|
||||||
conf.AddHostKey(testSigners["rsa"])
|
|
||||||
|
|
||||||
_, chans, reqs, err := NewServerConn(c1, &conf)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Unable to handshake: %v", err)
|
|
||||||
}
|
|
||||||
go DiscardRequests(reqs)
|
|
||||||
|
|
||||||
for newCh := range chans {
|
|
||||||
if newCh.ChannelType() != "session" {
|
|
||||||
newCh.Reject(UnknownChannelType, "unknown channel type")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
ch, inReqs, err := newCh.Accept()
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Accept: %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
go func() {
|
|
||||||
handler(ch, inReqs, t)
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
config := &ClientConfig{
|
|
||||||
User: "testuser",
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, chans, reqs, err := NewClientConn(c2, "", config)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to dial remote side: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return NewClient(conn, chans, reqs)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test a simple string is returned to session.Stdout.
|
|
||||||
func TestSessionShell(t *testing.T) {
|
|
||||||
conn := dial(shellHandler, t)
|
|
||||||
defer conn.Close()
|
|
||||||
session, err := conn.NewSession()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Unable to request new session: %v", err)
|
|
||||||
}
|
|
||||||
defer session.Close()
|
|
||||||
stdout := new(bytes.Buffer)
|
|
||||||
session.Stdout = stdout
|
|
||||||
if err := session.Shell(); err != nil {
|
|
||||||
t.Fatalf("Unable to execute command: %s", err)
|
|
||||||
}
|
|
||||||
if err := session.Wait(); err != nil {
|
|
||||||
t.Fatalf("Remote command did not exit cleanly: %v", err)
|
|
||||||
}
|
|
||||||
actual := stdout.String()
|
|
||||||
if actual != "golang" {
|
|
||||||
t.Fatalf("Remote shell did not return expected string: expected=golang, actual=%s", actual)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(dfc) add support for Std{in,err}Pipe when the Server supports it.
|
|
||||||
|
|
||||||
// Test a simple string is returned via StdoutPipe.
|
|
||||||
func TestSessionStdoutPipe(t *testing.T) {
|
|
||||||
conn := dial(shellHandler, t)
|
|
||||||
defer conn.Close()
|
|
||||||
session, err := conn.NewSession()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Unable to request new session: %v", err)
|
|
||||||
}
|
|
||||||
defer session.Close()
|
|
||||||
stdout, err := session.StdoutPipe()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Unable to request StdoutPipe(): %v", err)
|
|
||||||
}
|
|
||||||
var buf bytes.Buffer
|
|
||||||
if err := session.Shell(); err != nil {
|
|
||||||
t.Fatalf("Unable to execute command: %v", err)
|
|
||||||
}
|
|
||||||
done := make(chan bool, 1)
|
|
||||||
go func() {
|
|
||||||
if _, err := io.Copy(&buf, stdout); err != nil {
|
|
||||||
t.Errorf("Copy of stdout failed: %v", err)
|
|
||||||
}
|
|
||||||
done <- true
|
|
||||||
}()
|
|
||||||
if err := session.Wait(); err != nil {
|
|
||||||
t.Fatalf("Remote command did not exit cleanly: %v", err)
|
|
||||||
}
|
|
||||||
<-done
|
|
||||||
actual := buf.String()
|
|
||||||
if actual != "golang" {
|
|
||||||
t.Fatalf("Remote shell did not return expected string: expected=golang, actual=%s", actual)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test that a simple string is returned via the Output helper,
|
|
||||||
// and that stderr is discarded.
|
|
||||||
func TestSessionOutput(t *testing.T) {
|
|
||||||
conn := dial(fixedOutputHandler, t)
|
|
||||||
defer conn.Close()
|
|
||||||
session, err := conn.NewSession()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Unable to request new session: %v", err)
|
|
||||||
}
|
|
||||||
defer session.Close()
|
|
||||||
|
|
||||||
buf, err := session.Output("") // cmd is ignored by fixedOutputHandler
|
|
||||||
if err != nil {
|
|
||||||
t.Error("Remote command did not exit cleanly:", err)
|
|
||||||
}
|
|
||||||
w := "this-is-stdout."
|
|
||||||
g := string(buf)
|
|
||||||
if g != w {
|
|
||||||
t.Error("Remote command did not return expected string:")
|
|
||||||
t.Logf("want %q", w)
|
|
||||||
t.Logf("got %q", g)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test that both stdout and stderr are returned
|
|
||||||
// via the CombinedOutput helper.
|
|
||||||
func TestSessionCombinedOutput(t *testing.T) {
|
|
||||||
conn := dial(fixedOutputHandler, t)
|
|
||||||
defer conn.Close()
|
|
||||||
session, err := conn.NewSession()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Unable to request new session: %v", err)
|
|
||||||
}
|
|
||||||
defer session.Close()
|
|
||||||
|
|
||||||
buf, err := session.CombinedOutput("") // cmd is ignored by fixedOutputHandler
|
|
||||||
if err != nil {
|
|
||||||
t.Error("Remote command did not exit cleanly:", err)
|
|
||||||
}
|
|
||||||
const stdout = "this-is-stdout."
|
|
||||||
const stderr = "this-is-stderr."
|
|
||||||
g := string(buf)
|
|
||||||
if g != stdout+stderr && g != stderr+stdout {
|
|
||||||
t.Error("Remote command did not return expected string:")
|
|
||||||
t.Logf("want %q, or %q", stdout+stderr, stderr+stdout)
|
|
||||||
t.Logf("got %q", g)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test non-0 exit status is returned correctly.
|
|
||||||
func TestExitStatusNonZero(t *testing.T) {
|
|
||||||
conn := dial(exitStatusNonZeroHandler, t)
|
|
||||||
defer conn.Close()
|
|
||||||
session, err := conn.NewSession()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Unable to request new session: %v", err)
|
|
||||||
}
|
|
||||||
defer session.Close()
|
|
||||||
if err := session.Shell(); err != nil {
|
|
||||||
t.Fatalf("Unable to execute command: %v", err)
|
|
||||||
}
|
|
||||||
err = session.Wait()
|
|
||||||
if err == nil {
|
|
||||||
t.Fatalf("expected command to fail but it didn't")
|
|
||||||
}
|
|
||||||
e, ok := err.(*ExitError)
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("expected *ExitError but got %T", err)
|
|
||||||
}
|
|
||||||
if e.ExitStatus() != 15 {
|
|
||||||
t.Fatalf("expected command to exit with 15 but got %v", e.ExitStatus())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test 0 exit status is returned correctly.
|
|
||||||
func TestExitStatusZero(t *testing.T) {
|
|
||||||
conn := dial(exitStatusZeroHandler, t)
|
|
||||||
defer conn.Close()
|
|
||||||
session, err := conn.NewSession()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Unable to request new session: %v", err)
|
|
||||||
}
|
|
||||||
defer session.Close()
|
|
||||||
|
|
||||||
if err := session.Shell(); err != nil {
|
|
||||||
t.Fatalf("Unable to execute command: %v", err)
|
|
||||||
}
|
|
||||||
err = session.Wait()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("expected nil but got %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test exit signal and status are both returned correctly.
|
|
||||||
func TestExitSignalAndStatus(t *testing.T) {
|
|
||||||
conn := dial(exitSignalAndStatusHandler, t)
|
|
||||||
defer conn.Close()
|
|
||||||
session, err := conn.NewSession()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Unable to request new session: %v", err)
|
|
||||||
}
|
|
||||||
defer session.Close()
|
|
||||||
if err := session.Shell(); err != nil {
|
|
||||||
t.Fatalf("Unable to execute command: %v", err)
|
|
||||||
}
|
|
||||||
err = session.Wait()
|
|
||||||
if err == nil {
|
|
||||||
t.Fatalf("expected command to fail but it didn't")
|
|
||||||
}
|
|
||||||
e, ok := err.(*ExitError)
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("expected *ExitError but got %T", err)
|
|
||||||
}
|
|
||||||
if e.Signal() != "TERM" || e.ExitStatus() != 15 {
|
|
||||||
t.Fatalf("expected command to exit with signal TERM and status 15 but got signal %s and status %v", e.Signal(), e.ExitStatus())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test exit signal and status are both returned correctly.
|
|
||||||
func TestKnownExitSignalOnly(t *testing.T) {
|
|
||||||
conn := dial(exitSignalHandler, t)
|
|
||||||
defer conn.Close()
|
|
||||||
session, err := conn.NewSession()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Unable to request new session: %v", err)
|
|
||||||
}
|
|
||||||
defer session.Close()
|
|
||||||
if err := session.Shell(); err != nil {
|
|
||||||
t.Fatalf("Unable to execute command: %v", err)
|
|
||||||
}
|
|
||||||
err = session.Wait()
|
|
||||||
if err == nil {
|
|
||||||
t.Fatalf("expected command to fail but it didn't")
|
|
||||||
}
|
|
||||||
e, ok := err.(*ExitError)
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("expected *ExitError but got %T", err)
|
|
||||||
}
|
|
||||||
if e.Signal() != "TERM" || e.ExitStatus() != 143 {
|
|
||||||
t.Fatalf("expected command to exit with signal TERM and status 143 but got signal %s and status %v", e.Signal(), e.ExitStatus())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test exit signal and status are both returned correctly.
|
|
||||||
func TestUnknownExitSignal(t *testing.T) {
|
|
||||||
conn := dial(exitSignalUnknownHandler, t)
|
|
||||||
defer conn.Close()
|
|
||||||
session, err := conn.NewSession()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Unable to request new session: %v", err)
|
|
||||||
}
|
|
||||||
defer session.Close()
|
|
||||||
if err := session.Shell(); err != nil {
|
|
||||||
t.Fatalf("Unable to execute command: %v", err)
|
|
||||||
}
|
|
||||||
err = session.Wait()
|
|
||||||
if err == nil {
|
|
||||||
t.Fatalf("expected command to fail but it didn't")
|
|
||||||
}
|
|
||||||
e, ok := err.(*ExitError)
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("expected *ExitError but got %T", err)
|
|
||||||
}
|
|
||||||
if e.Signal() != "SYS" || e.ExitStatus() != 128 {
|
|
||||||
t.Fatalf("expected command to exit with signal SYS and status 128 but got signal %s and status %v", e.Signal(), e.ExitStatus())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test WaitMsg is not returned if the channel closes abruptly.
|
|
||||||
func TestExitWithoutStatusOrSignal(t *testing.T) {
|
|
||||||
conn := dial(exitWithoutSignalOrStatus, t)
|
|
||||||
defer conn.Close()
|
|
||||||
session, err := conn.NewSession()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Unable to request new session: %v", err)
|
|
||||||
}
|
|
||||||
defer session.Close()
|
|
||||||
if err := session.Shell(); err != nil {
|
|
||||||
t.Fatalf("Unable to execute command: %v", err)
|
|
||||||
}
|
|
||||||
err = session.Wait()
|
|
||||||
if err == nil {
|
|
||||||
t.Fatalf("expected command to fail but it didn't")
|
|
||||||
}
|
|
||||||
_, ok := err.(*ExitError)
|
|
||||||
if ok {
|
|
||||||
// you can't actually test for errors.errorString
|
|
||||||
// because it's not exported.
|
|
||||||
t.Fatalf("expected *errorString but got %T", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// windowTestBytes is the number of bytes that we'll send to the SSH server.
|
|
||||||
const windowTestBytes = 16000 * 200
|
|
||||||
|
|
||||||
// TestServerWindow writes random data to the server. The server is expected to echo
|
|
||||||
// the same data back, which is compared against the original.
|
|
||||||
func TestServerWindow(t *testing.T) {
|
|
||||||
origBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes))
|
|
||||||
io.CopyN(origBuf, crypto_rand.Reader, windowTestBytes)
|
|
||||||
origBytes := origBuf.Bytes()
|
|
||||||
|
|
||||||
conn := dial(echoHandler, t)
|
|
||||||
defer conn.Close()
|
|
||||||
session, err := conn.NewSession()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer session.Close()
|
|
||||||
result := make(chan []byte)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
defer close(result)
|
|
||||||
echoedBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes))
|
|
||||||
serverStdout, err := session.StdoutPipe()
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("StdoutPipe failed: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
n, err := copyNRandomly("stdout", echoedBuf, serverStdout, windowTestBytes)
|
|
||||||
if err != nil && err != io.EOF {
|
|
||||||
t.Errorf("Read only %d bytes from server, expected %d: %v", n, windowTestBytes, err)
|
|
||||||
}
|
|
||||||
result <- echoedBuf.Bytes()
|
|
||||||
}()
|
|
||||||
|
|
||||||
serverStdin, err := session.StdinPipe()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("StdinPipe failed: %v", err)
|
|
||||||
}
|
|
||||||
written, err := copyNRandomly("stdin", serverStdin, origBuf, windowTestBytes)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to copy origBuf to serverStdin: %v", err)
|
|
||||||
}
|
|
||||||
if written != windowTestBytes {
|
|
||||||
t.Fatalf("Wrote only %d of %d bytes to server", written, windowTestBytes)
|
|
||||||
}
|
|
||||||
|
|
||||||
echoedBytes := <-result
|
|
||||||
|
|
||||||
if !bytes.Equal(origBytes, echoedBytes) {
|
|
||||||
t.Fatalf("Echoed buffer differed from original, orig %d, echoed %d", len(origBytes), len(echoedBytes))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify the client can handle a keepalive packet from the server.
|
|
||||||
func TestClientHandlesKeepalives(t *testing.T) {
|
|
||||||
conn := dial(channelKeepaliveSender, t)
|
|
||||||
defer conn.Close()
|
|
||||||
session, err := conn.NewSession()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer session.Close()
|
|
||||||
if err := session.Shell(); err != nil {
|
|
||||||
t.Fatalf("Unable to execute command: %v", err)
|
|
||||||
}
|
|
||||||
err = session.Wait()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("expected nil but got: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type exitStatusMsg struct {
|
|
||||||
Status uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
type exitSignalMsg struct {
|
|
||||||
Signal string
|
|
||||||
CoreDumped bool
|
|
||||||
Errmsg string
|
|
||||||
Lang string
|
|
||||||
}
|
|
||||||
|
|
||||||
func handleTerminalRequests(in <-chan *Request) {
|
|
||||||
for req := range in {
|
|
||||||
ok := false
|
|
||||||
switch req.Type {
|
|
||||||
case "shell":
|
|
||||||
ok = true
|
|
||||||
if len(req.Payload) > 0 {
|
|
||||||
// We don't accept any commands, only the default shell.
|
|
||||||
ok = false
|
|
||||||
}
|
|
||||||
case "env":
|
|
||||||
ok = true
|
|
||||||
}
|
|
||||||
req.Reply(ok, nil)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func newServerShell(ch Channel, in <-chan *Request, prompt string) *terminal.Terminal {
|
|
||||||
term := terminal.NewTerminal(ch, prompt)
|
|
||||||
go handleTerminalRequests(in)
|
|
||||||
return term
|
|
||||||
}
|
|
||||||
|
|
||||||
func exitStatusZeroHandler(ch Channel, in <-chan *Request, t *testing.T) {
|
|
||||||
defer ch.Close()
|
|
||||||
// this string is returned to stdout
|
|
||||||
shell := newServerShell(ch, in, "> ")
|
|
||||||
readLine(shell, t)
|
|
||||||
sendStatus(0, ch, t)
|
|
||||||
}
|
|
||||||
|
|
||||||
func exitStatusNonZeroHandler(ch Channel, in <-chan *Request, t *testing.T) {
|
|
||||||
defer ch.Close()
|
|
||||||
shell := newServerShell(ch, in, "> ")
|
|
||||||
readLine(shell, t)
|
|
||||||
sendStatus(15, ch, t)
|
|
||||||
}
|
|
||||||
|
|
||||||
func exitSignalAndStatusHandler(ch Channel, in <-chan *Request, t *testing.T) {
|
|
||||||
defer ch.Close()
|
|
||||||
shell := newServerShell(ch, in, "> ")
|
|
||||||
readLine(shell, t)
|
|
||||||
sendStatus(15, ch, t)
|
|
||||||
sendSignal("TERM", ch, t)
|
|
||||||
}
|
|
||||||
|
|
||||||
func exitSignalHandler(ch Channel, in <-chan *Request, t *testing.T) {
|
|
||||||
defer ch.Close()
|
|
||||||
shell := newServerShell(ch, in, "> ")
|
|
||||||
readLine(shell, t)
|
|
||||||
sendSignal("TERM", ch, t)
|
|
||||||
}
|
|
||||||
|
|
||||||
func exitSignalUnknownHandler(ch Channel, in <-chan *Request, t *testing.T) {
|
|
||||||
defer ch.Close()
|
|
||||||
shell := newServerShell(ch, in, "> ")
|
|
||||||
readLine(shell, t)
|
|
||||||
sendSignal("SYS", ch, t)
|
|
||||||
}
|
|
||||||
|
|
||||||
func exitWithoutSignalOrStatus(ch Channel, in <-chan *Request, t *testing.T) {
|
|
||||||
defer ch.Close()
|
|
||||||
shell := newServerShell(ch, in, "> ")
|
|
||||||
readLine(shell, t)
|
|
||||||
}
|
|
||||||
|
|
||||||
func shellHandler(ch Channel, in <-chan *Request, t *testing.T) {
|
|
||||||
defer ch.Close()
|
|
||||||
// this string is returned to stdout
|
|
||||||
shell := newServerShell(ch, in, "golang")
|
|
||||||
readLine(shell, t)
|
|
||||||
sendStatus(0, ch, t)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ignores the command, writes fixed strings to stderr and stdout.
|
|
||||||
// Strings are "this-is-stdout." and "this-is-stderr.".
|
|
||||||
func fixedOutputHandler(ch Channel, in <-chan *Request, t *testing.T) {
|
|
||||||
defer ch.Close()
|
|
||||||
_, err := ch.Read(nil)
|
|
||||||
|
|
||||||
req, ok := <-in
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("error: expected channel request, got: %#v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// ignore request, always send some text
|
|
||||||
req.Reply(true, nil)
|
|
||||||
|
|
||||||
_, err = io.WriteString(ch, "this-is-stdout.")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("error writing on server: %v", err)
|
|
||||||
}
|
|
||||||
_, err = io.WriteString(ch.Stderr(), "this-is-stderr.")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("error writing on server: %v", err)
|
|
||||||
}
|
|
||||||
sendStatus(0, ch, t)
|
|
||||||
}
|
|
||||||
|
|
||||||
func readLine(shell *terminal.Terminal, t *testing.T) {
|
|
||||||
if _, err := shell.ReadLine(); err != nil && err != io.EOF {
|
|
||||||
t.Errorf("unable to read line: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func sendStatus(status uint32, ch Channel, t *testing.T) {
|
|
||||||
msg := exitStatusMsg{
|
|
||||||
Status: status,
|
|
||||||
}
|
|
||||||
if _, err := ch.SendRequest("exit-status", false, Marshal(&msg)); err != nil {
|
|
||||||
t.Errorf("unable to send status: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func sendSignal(signal string, ch Channel, t *testing.T) {
|
|
||||||
sig := exitSignalMsg{
|
|
||||||
Signal: signal,
|
|
||||||
CoreDumped: false,
|
|
||||||
Errmsg: "Process terminated",
|
|
||||||
Lang: "en-GB-oed",
|
|
||||||
}
|
|
||||||
if _, err := ch.SendRequest("exit-signal", false, Marshal(&sig)); err != nil {
|
|
||||||
t.Errorf("unable to send signal: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func discardHandler(ch Channel, t *testing.T) {
|
|
||||||
defer ch.Close()
|
|
||||||
io.Copy(ioutil.Discard, ch)
|
|
||||||
}
|
|
||||||
|
|
||||||
func echoHandler(ch Channel, in <-chan *Request, t *testing.T) {
|
|
||||||
defer ch.Close()
|
|
||||||
if n, err := copyNRandomly("echohandler", ch, ch, windowTestBytes); err != nil {
|
|
||||||
t.Errorf("short write, wrote %d, expected %d: %v ", n, windowTestBytes, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// copyNRandomly copies n bytes from src to dst. It uses a variable, and random,
|
|
||||||
// buffer size to exercise more code paths.
|
|
||||||
func copyNRandomly(title string, dst io.Writer, src io.Reader, n int) (int, error) {
|
|
||||||
var (
|
|
||||||
buf = make([]byte, 32*1024)
|
|
||||||
written int
|
|
||||||
remaining = n
|
|
||||||
)
|
|
||||||
for remaining > 0 {
|
|
||||||
l := rand.Intn(1 << 15)
|
|
||||||
if remaining < l {
|
|
||||||
l = remaining
|
|
||||||
}
|
|
||||||
nr, er := src.Read(buf[:l])
|
|
||||||
nw, ew := dst.Write(buf[:nr])
|
|
||||||
remaining -= nw
|
|
||||||
written += nw
|
|
||||||
if ew != nil {
|
|
||||||
return written, ew
|
|
||||||
}
|
|
||||||
if nr != nw {
|
|
||||||
return written, io.ErrShortWrite
|
|
||||||
}
|
|
||||||
if er != nil && er != io.EOF {
|
|
||||||
return written, er
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return written, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func channelKeepaliveSender(ch Channel, in <-chan *Request, t *testing.T) {
|
|
||||||
defer ch.Close()
|
|
||||||
shell := newServerShell(ch, in, "> ")
|
|
||||||
readLine(shell, t)
|
|
||||||
if _, err := ch.SendRequest("keepalive@openssh.com", true, nil); err != nil {
|
|
||||||
t.Errorf("unable to send channel keepalive request: %v", err)
|
|
||||||
}
|
|
||||||
sendStatus(0, ch, t)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestClientWriteEOF(t *testing.T) {
|
|
||||||
conn := dial(simpleEchoHandler, t)
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
session, err := conn.NewSession()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer session.Close()
|
|
||||||
stdin, err := session.StdinPipe()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("StdinPipe failed: %v", err)
|
|
||||||
}
|
|
||||||
stdout, err := session.StdoutPipe()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("StdoutPipe failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
data := []byte(`0000`)
|
|
||||||
_, err = stdin.Write(data)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Write failed: %v", err)
|
|
||||||
}
|
|
||||||
stdin.Close()
|
|
||||||
|
|
||||||
res, err := ioutil.ReadAll(stdout)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Read failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !bytes.Equal(data, res) {
|
|
||||||
t.Fatalf("Read differed from write, wrote: %v, read: %v", data, res)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func simpleEchoHandler(ch Channel, in <-chan *Request, t *testing.T) {
|
|
||||||
defer ch.Close()
|
|
||||||
data, err := ioutil.ReadAll(ch)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("handler read error: %v", err)
|
|
||||||
}
|
|
||||||
_, err = ch.Write(data)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("handler write error: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,404 +0,0 @@
|
||||||
// Copyright 2011 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package ssh
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"math/rand"
|
|
||||||
"net"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Listen requests the remote peer open a listening socket on
|
|
||||||
// addr. Incoming connections will be available by calling Accept on
|
|
||||||
// the returned net.Listener. The listener must be serviced, or the
|
|
||||||
// SSH connection may hang.
|
|
||||||
func (c *Client) Listen(n, addr string) (net.Listener, error) {
|
|
||||||
laddr, err := net.ResolveTCPAddr(n, addr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return c.ListenTCP(laddr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Automatic port allocation is broken with OpenSSH before 6.0. See
|
|
||||||
// also https://bugzilla.mindrot.org/show_bug.cgi?id=2017. In
|
|
||||||
// particular, OpenSSH 5.9 sends a channelOpenMsg with port number 0,
|
|
||||||
// rather than the actual port number. This means you can never open
|
|
||||||
// two different listeners with auto allocated ports. We work around
|
|
||||||
// this by trying explicit ports until we succeed.
|
|
||||||
|
|
||||||
const openSSHPrefix = "OpenSSH_"
|
|
||||||
|
|
||||||
var portRandomizer = rand.New(rand.NewSource(time.Now().UnixNano()))
|
|
||||||
|
|
||||||
// isBrokenOpenSSHVersion returns true if the given version string
|
|
||||||
// specifies a version of OpenSSH that is known to have a bug in port
|
|
||||||
// forwarding.
|
|
||||||
func isBrokenOpenSSHVersion(versionStr string) bool {
|
|
||||||
i := strings.Index(versionStr, openSSHPrefix)
|
|
||||||
if i < 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
i += len(openSSHPrefix)
|
|
||||||
j := i
|
|
||||||
for ; j < len(versionStr); j++ {
|
|
||||||
if versionStr[j] < '0' || versionStr[j] > '9' {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
version, _ := strconv.Atoi(versionStr[i:j])
|
|
||||||
return version < 6
|
|
||||||
}
|
|
||||||
|
|
||||||
// autoPortListenWorkaround simulates automatic port allocation by
|
|
||||||
// trying random ports repeatedly.
|
|
||||||
func (c *Client) autoPortListenWorkaround(laddr *net.TCPAddr) (net.Listener, error) {
|
|
||||||
var sshListener net.Listener
|
|
||||||
var err error
|
|
||||||
const tries = 10
|
|
||||||
for i := 0; i < tries; i++ {
|
|
||||||
addr := *laddr
|
|
||||||
addr.Port = 1024 + portRandomizer.Intn(60000)
|
|
||||||
sshListener, err = c.ListenTCP(&addr)
|
|
||||||
if err == nil {
|
|
||||||
laddr.Port = addr.Port
|
|
||||||
return sshListener, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("ssh: listen on random port failed after %d tries: %v", tries, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RFC 4254 7.1
|
|
||||||
type channelForwardMsg struct {
|
|
||||||
addr string
|
|
||||||
rport uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListenTCP requests the remote peer open a listening socket
|
|
||||||
// on laddr. Incoming connections will be available by calling
|
|
||||||
// Accept on the returned net.Listener.
|
|
||||||
func (c *Client) ListenTCP(laddr *net.TCPAddr) (net.Listener, error) {
|
|
||||||
if laddr.Port == 0 && isBrokenOpenSSHVersion(string(c.ServerVersion())) {
|
|
||||||
return c.autoPortListenWorkaround(laddr)
|
|
||||||
}
|
|
||||||
|
|
||||||
m := channelForwardMsg{
|
|
||||||
laddr.IP.String(),
|
|
||||||
uint32(laddr.Port),
|
|
||||||
}
|
|
||||||
// send message
|
|
||||||
ok, resp, err := c.SendRequest("tcpip-forward", true, Marshal(&m))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if !ok {
|
|
||||||
return nil, errors.New("ssh: tcpip-forward request denied by peer")
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the original port was 0, then the remote side will
|
|
||||||
// supply a real port number in the response.
|
|
||||||
if laddr.Port == 0 {
|
|
||||||
var p struct {
|
|
||||||
Port uint32
|
|
||||||
}
|
|
||||||
if err := Unmarshal(resp, &p); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
laddr.Port = int(p.Port)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Register this forward, using the port number we obtained.
|
|
||||||
ch := c.forwards.add(*laddr)
|
|
||||||
|
|
||||||
return &tcpListener{laddr, c, ch}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// forwardList stores a mapping between remote
|
|
||||||
// forward requests and the tcpListeners.
|
|
||||||
type forwardList struct {
|
|
||||||
sync.Mutex
|
|
||||||
entries []forwardEntry
|
|
||||||
}
|
|
||||||
|
|
||||||
// forwardEntry represents an established mapping of a laddr on a
|
|
||||||
// remote ssh server to a channel connected to a tcpListener.
|
|
||||||
type forwardEntry struct {
|
|
||||||
laddr net.TCPAddr
|
|
||||||
c chan forward
|
|
||||||
}
|
|
||||||
|
|
||||||
// forward represents an incoming forwarded tcpip connection. The
|
|
||||||
// arguments to add/remove/lookup should be address as specified in
|
|
||||||
// the original forward-request.
|
|
||||||
type forward struct {
|
|
||||||
newCh NewChannel // the ssh client channel underlying this forward
|
|
||||||
raddr *net.TCPAddr // the raddr of the incoming connection
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *forwardList) add(addr net.TCPAddr) chan forward {
|
|
||||||
l.Lock()
|
|
||||||
defer l.Unlock()
|
|
||||||
f := forwardEntry{
|
|
||||||
addr,
|
|
||||||
make(chan forward, 1),
|
|
||||||
}
|
|
||||||
l.entries = append(l.entries, f)
|
|
||||||
return f.c
|
|
||||||
}
|
|
||||||
|
|
||||||
// See RFC 4254, section 7.2
|
|
||||||
type forwardedTCPPayload struct {
|
|
||||||
Addr string
|
|
||||||
Port uint32
|
|
||||||
OriginAddr string
|
|
||||||
OriginPort uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseTCPAddr parses the originating address from the remote into a *net.TCPAddr.
|
|
||||||
func parseTCPAddr(addr string, port uint32) (*net.TCPAddr, error) {
|
|
||||||
if port == 0 || port > 65535 {
|
|
||||||
return nil, fmt.Errorf("ssh: port number out of range: %d", port)
|
|
||||||
}
|
|
||||||
ip := net.ParseIP(string(addr))
|
|
||||||
if ip == nil {
|
|
||||||
return nil, fmt.Errorf("ssh: cannot parse IP address %q", addr)
|
|
||||||
}
|
|
||||||
return &net.TCPAddr{IP: ip, Port: int(port)}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *forwardList) handleChannels(in <-chan NewChannel) {
|
|
||||||
for ch := range in {
|
|
||||||
var payload forwardedTCPPayload
|
|
||||||
if err := Unmarshal(ch.ExtraData(), &payload); err != nil {
|
|
||||||
ch.Reject(ConnectionFailed, "could not parse forwarded-tcpip payload: "+err.Error())
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// RFC 4254 section 7.2 specifies that incoming
|
|
||||||
// addresses should list the address, in string
|
|
||||||
// format. It is implied that this should be an IP
|
|
||||||
// address, as it would be impossible to connect to it
|
|
||||||
// otherwise.
|
|
||||||
laddr, err := parseTCPAddr(payload.Addr, payload.Port)
|
|
||||||
if err != nil {
|
|
||||||
ch.Reject(ConnectionFailed, err.Error())
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
raddr, err := parseTCPAddr(payload.OriginAddr, payload.OriginPort)
|
|
||||||
if err != nil {
|
|
||||||
ch.Reject(ConnectionFailed, err.Error())
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if ok := l.forward(*laddr, *raddr, ch); !ok {
|
|
||||||
// Section 7.2, implementations MUST reject spurious incoming
|
|
||||||
// connections.
|
|
||||||
ch.Reject(Prohibited, "no forward for address")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// remove removes the forward entry, and the channel feeding its
|
|
||||||
// listener.
|
|
||||||
func (l *forwardList) remove(addr net.TCPAddr) {
|
|
||||||
l.Lock()
|
|
||||||
defer l.Unlock()
|
|
||||||
for i, f := range l.entries {
|
|
||||||
if addr.IP.Equal(f.laddr.IP) && addr.Port == f.laddr.Port {
|
|
||||||
l.entries = append(l.entries[:i], l.entries[i+1:]...)
|
|
||||||
close(f.c)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// closeAll closes and clears all forwards.
|
|
||||||
func (l *forwardList) closeAll() {
|
|
||||||
l.Lock()
|
|
||||||
defer l.Unlock()
|
|
||||||
for _, f := range l.entries {
|
|
||||||
close(f.c)
|
|
||||||
}
|
|
||||||
l.entries = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *forwardList) forward(laddr, raddr net.TCPAddr, ch NewChannel) bool {
|
|
||||||
l.Lock()
|
|
||||||
defer l.Unlock()
|
|
||||||
for _, f := range l.entries {
|
|
||||||
if laddr.IP.Equal(f.laddr.IP) && laddr.Port == f.laddr.Port {
|
|
||||||
f.c <- forward{ch, &raddr}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
type tcpListener struct {
|
|
||||||
laddr *net.TCPAddr
|
|
||||||
|
|
||||||
conn *Client
|
|
||||||
in <-chan forward
|
|
||||||
}
|
|
||||||
|
|
||||||
// Accept waits for and returns the next connection to the listener.
|
|
||||||
func (l *tcpListener) Accept() (net.Conn, error) {
|
|
||||||
s, ok := <-l.in
|
|
||||||
if !ok {
|
|
||||||
return nil, io.EOF
|
|
||||||
}
|
|
||||||
ch, incoming, err := s.newCh.Accept()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
go DiscardRequests(incoming)
|
|
||||||
|
|
||||||
return &tcpChanConn{
|
|
||||||
Channel: ch,
|
|
||||||
laddr: l.laddr,
|
|
||||||
raddr: s.raddr,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close closes the listener.
|
|
||||||
func (l *tcpListener) Close() error {
|
|
||||||
m := channelForwardMsg{
|
|
||||||
l.laddr.IP.String(),
|
|
||||||
uint32(l.laddr.Port),
|
|
||||||
}
|
|
||||||
|
|
||||||
// this also closes the listener.
|
|
||||||
l.conn.forwards.remove(*l.laddr)
|
|
||||||
ok, _, err := l.conn.SendRequest("cancel-tcpip-forward", true, Marshal(&m))
|
|
||||||
if err == nil && !ok {
|
|
||||||
err = errors.New("ssh: cancel-tcpip-forward failed")
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Addr returns the listener's network address.
|
|
||||||
func (l *tcpListener) Addr() net.Addr {
|
|
||||||
return l.laddr
|
|
||||||
}
|
|
||||||
|
|
||||||
// Dial initiates a connection to the addr from the remote host.
|
|
||||||
// The resulting connection has a zero LocalAddr() and RemoteAddr().
|
|
||||||
func (c *Client) Dial(n, addr string) (net.Conn, error) {
|
|
||||||
// Parse the address into host and numeric port.
|
|
||||||
host, portString, err := net.SplitHostPort(addr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
port, err := strconv.ParseUint(portString, 10, 16)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
// Use a zero address for local and remote address.
|
|
||||||
zeroAddr := &net.TCPAddr{
|
|
||||||
IP: net.IPv4zero,
|
|
||||||
Port: 0,
|
|
||||||
}
|
|
||||||
ch, err := c.dial(net.IPv4zero.String(), 0, host, int(port))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &tcpChanConn{
|
|
||||||
Channel: ch,
|
|
||||||
laddr: zeroAddr,
|
|
||||||
raddr: zeroAddr,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// DialTCP connects to the remote address raddr on the network net,
|
|
||||||
// which must be "tcp", "tcp4", or "tcp6". If laddr is not nil, it is used
|
|
||||||
// as the local address for the connection.
|
|
||||||
func (c *Client) DialTCP(n string, laddr, raddr *net.TCPAddr) (net.Conn, error) {
|
|
||||||
if laddr == nil {
|
|
||||||
laddr = &net.TCPAddr{
|
|
||||||
IP: net.IPv4zero,
|
|
||||||
Port: 0,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ch, err := c.dial(laddr.IP.String(), laddr.Port, raddr.IP.String(), raddr.Port)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &tcpChanConn{
|
|
||||||
Channel: ch,
|
|
||||||
laddr: laddr,
|
|
||||||
raddr: raddr,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// RFC 4254 7.2
|
|
||||||
type channelOpenDirectMsg struct {
|
|
||||||
raddr string
|
|
||||||
rport uint32
|
|
||||||
laddr string
|
|
||||||
lport uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) dial(laddr string, lport int, raddr string, rport int) (Channel, error) {
|
|
||||||
msg := channelOpenDirectMsg{
|
|
||||||
raddr: raddr,
|
|
||||||
rport: uint32(rport),
|
|
||||||
laddr: laddr,
|
|
||||||
lport: uint32(lport),
|
|
||||||
}
|
|
||||||
ch, in, err := c.OpenChannel("direct-tcpip", Marshal(&msg))
|
|
||||||
go DiscardRequests(in)
|
|
||||||
return ch, err
|
|
||||||
}
|
|
||||||
|
|
||||||
type tcpChan struct {
|
|
||||||
Channel // the backing channel
|
|
||||||
}
|
|
||||||
|
|
||||||
// tcpChanConn fulfills the net.Conn interface without
|
|
||||||
// the tcpChan having to hold laddr or raddr directly.
|
|
||||||
type tcpChanConn struct {
|
|
||||||
Channel
|
|
||||||
laddr, raddr net.Addr
|
|
||||||
}
|
|
||||||
|
|
||||||
// LocalAddr returns the local network address.
|
|
||||||
func (t *tcpChanConn) LocalAddr() net.Addr {
|
|
||||||
return t.laddr
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoteAddr returns the remote network address.
|
|
||||||
func (t *tcpChanConn) RemoteAddr() net.Addr {
|
|
||||||
return t.raddr
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetDeadline sets the read and write deadlines associated
|
|
||||||
// with the connection.
|
|
||||||
func (t *tcpChanConn) SetDeadline(deadline time.Time) error {
|
|
||||||
if err := t.SetReadDeadline(deadline); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return t.SetWriteDeadline(deadline)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetReadDeadline sets the read deadline.
|
|
||||||
// A zero value for t means Read will not time out.
|
|
||||||
// After the deadline, the error from Read will implement net.Error
|
|
||||||
// with Timeout() == true.
|
|
||||||
func (t *tcpChanConn) SetReadDeadline(deadline time.Time) error {
|
|
||||||
return errors.New("ssh: tcpChan: deadline not supported")
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetWriteDeadline exists to satisfy the net.Conn interface
|
|
||||||
// but is not implemented by this type. It always returns an error.
|
|
||||||
func (t *tcpChanConn) SetWriteDeadline(deadline time.Time) error {
|
|
||||||
return errors.New("ssh: tcpChan: deadline not supported")
|
|
||||||
}
|
|
|
@ -1,20 +0,0 @@
|
||||||
// Copyright 2014 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package ssh
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestAutoPortListenBroken(t *testing.T) {
|
|
||||||
broken := "SSH-2.0-OpenSSH_5.9hh11"
|
|
||||||
works := "SSH-2.0-OpenSSH_6.1"
|
|
||||||
if !isBrokenOpenSSHVersion(broken) {
|
|
||||||
t.Errorf("version %q not marked as broken", broken)
|
|
||||||
}
|
|
||||||
if isBrokenOpenSSHVersion(works) {
|
|
||||||
t.Errorf("version %q marked as broken", works)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,888 +0,0 @@
|
||||||
// Copyright 2011 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package terminal
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"io"
|
|
||||||
"sync"
|
|
||||||
"unicode/utf8"
|
|
||||||
)
|
|
||||||
|
|
||||||
// EscapeCodes contains escape sequences that can be written to the terminal in
|
|
||||||
// order to achieve different styles of text.
|
|
||||||
type EscapeCodes struct {
|
|
||||||
// Foreground colors
|
|
||||||
Black, Red, Green, Yellow, Blue, Magenta, Cyan, White []byte
|
|
||||||
|
|
||||||
// Reset all attributes
|
|
||||||
Reset []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
var vt100EscapeCodes = EscapeCodes{
|
|
||||||
Black: []byte{keyEscape, '[', '3', '0', 'm'},
|
|
||||||
Red: []byte{keyEscape, '[', '3', '1', 'm'},
|
|
||||||
Green: []byte{keyEscape, '[', '3', '2', 'm'},
|
|
||||||
Yellow: []byte{keyEscape, '[', '3', '3', 'm'},
|
|
||||||
Blue: []byte{keyEscape, '[', '3', '4', 'm'},
|
|
||||||
Magenta: []byte{keyEscape, '[', '3', '5', 'm'},
|
|
||||||
Cyan: []byte{keyEscape, '[', '3', '6', 'm'},
|
|
||||||
White: []byte{keyEscape, '[', '3', '7', 'm'},
|
|
||||||
|
|
||||||
Reset: []byte{keyEscape, '[', '0', 'm'},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Terminal contains the state for running a VT100 terminal that is capable of
|
|
||||||
// reading lines of input.
|
|
||||||
type Terminal struct {
|
|
||||||
// AutoCompleteCallback, if non-null, is called for each keypress with
|
|
||||||
// the full input line and the current position of the cursor (in
|
|
||||||
// bytes, as an index into |line|). If it returns ok=false, the key
|
|
||||||
// press is processed normally. Otherwise it returns a replacement line
|
|
||||||
// and the new cursor position.
|
|
||||||
AutoCompleteCallback func(line string, pos int, key rune) (newLine string, newPos int, ok bool)
|
|
||||||
|
|
||||||
// Escape contains a pointer to the escape codes for this terminal.
|
|
||||||
// It's always a valid pointer, although the escape codes themselves
|
|
||||||
// may be empty if the terminal doesn't support them.
|
|
||||||
Escape *EscapeCodes
|
|
||||||
|
|
||||||
// lock protects the terminal and the state in this object from
|
|
||||||
// concurrent processing of a key press and a Write() call.
|
|
||||||
lock sync.Mutex
|
|
||||||
|
|
||||||
c io.ReadWriter
|
|
||||||
prompt []rune
|
|
||||||
|
|
||||||
// line is the current line being entered.
|
|
||||||
line []rune
|
|
||||||
// pos is the logical position of the cursor in line
|
|
||||||
pos int
|
|
||||||
// echo is true if local echo is enabled
|
|
||||||
echo bool
|
|
||||||
// pasteActive is true iff there is a bracketed paste operation in
|
|
||||||
// progress.
|
|
||||||
pasteActive bool
|
|
||||||
|
|
||||||
// cursorX contains the current X value of the cursor where the left
|
|
||||||
// edge is 0. cursorY contains the row number where the first row of
|
|
||||||
// the current line is 0.
|
|
||||||
cursorX, cursorY int
|
|
||||||
// maxLine is the greatest value of cursorY so far.
|
|
||||||
maxLine int
|
|
||||||
|
|
||||||
termWidth, termHeight int
|
|
||||||
|
|
||||||
// outBuf contains the terminal data to be sent.
|
|
||||||
outBuf []byte
|
|
||||||
// remainder contains the remainder of any partial key sequences after
|
|
||||||
// a read. It aliases into inBuf.
|
|
||||||
remainder []byte
|
|
||||||
inBuf [256]byte
|
|
||||||
|
|
||||||
// history contains previously entered commands so that they can be
|
|
||||||
// accessed with the up and down keys.
|
|
||||||
history stRingBuffer
|
|
||||||
// historyIndex stores the currently accessed history entry, where zero
|
|
||||||
// means the immediately previous entry.
|
|
||||||
historyIndex int
|
|
||||||
// When navigating up and down the history it's possible to return to
|
|
||||||
// the incomplete, initial line. That value is stored in
|
|
||||||
// historyPending.
|
|
||||||
historyPending string
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewTerminal runs a VT100 terminal on the given ReadWriter. If the ReadWriter is
|
|
||||||
// a local terminal, that terminal must first have been put into raw mode.
|
|
||||||
// prompt is a string that is written at the start of each input line (i.e.
|
|
||||||
// "> ").
|
|
||||||
func NewTerminal(c io.ReadWriter, prompt string) *Terminal {
|
|
||||||
return &Terminal{
|
|
||||||
Escape: &vt100EscapeCodes,
|
|
||||||
c: c,
|
|
||||||
prompt: []rune(prompt),
|
|
||||||
termWidth: 80,
|
|
||||||
termHeight: 24,
|
|
||||||
echo: true,
|
|
||||||
historyIndex: -1,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
keyCtrlD = 4
|
|
||||||
keyCtrlU = 21
|
|
||||||
keyEnter = '\r'
|
|
||||||
keyEscape = 27
|
|
||||||
keyBackspace = 127
|
|
||||||
keyUnknown = 0xd800 /* UTF-16 surrogate area */ + iota
|
|
||||||
keyUp
|
|
||||||
keyDown
|
|
||||||
keyLeft
|
|
||||||
keyRight
|
|
||||||
keyAltLeft
|
|
||||||
keyAltRight
|
|
||||||
keyHome
|
|
||||||
keyEnd
|
|
||||||
keyDeleteWord
|
|
||||||
keyDeleteLine
|
|
||||||
keyClearScreen
|
|
||||||
keyPasteStart
|
|
||||||
keyPasteEnd
|
|
||||||
)
|
|
||||||
|
|
||||||
var pasteStart = []byte{keyEscape, '[', '2', '0', '0', '~'}
|
|
||||||
var pasteEnd = []byte{keyEscape, '[', '2', '0', '1', '~'}
|
|
||||||
|
|
||||||
// bytesToKey tries to parse a key sequence from b. If successful, it returns
|
|
||||||
// the key and the remainder of the input. Otherwise it returns utf8.RuneError.
|
|
||||||
func bytesToKey(b []byte, pasteActive bool) (rune, []byte) {
|
|
||||||
if len(b) == 0 {
|
|
||||||
return utf8.RuneError, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if !pasteActive {
|
|
||||||
switch b[0] {
|
|
||||||
case 1: // ^A
|
|
||||||
return keyHome, b[1:]
|
|
||||||
case 5: // ^E
|
|
||||||
return keyEnd, b[1:]
|
|
||||||
case 8: // ^H
|
|
||||||
return keyBackspace, b[1:]
|
|
||||||
case 11: // ^K
|
|
||||||
return keyDeleteLine, b[1:]
|
|
||||||
case 12: // ^L
|
|
||||||
return keyClearScreen, b[1:]
|
|
||||||
case 23: // ^W
|
|
||||||
return keyDeleteWord, b[1:]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if b[0] != keyEscape {
|
|
||||||
if !utf8.FullRune(b) {
|
|
||||||
return utf8.RuneError, b
|
|
||||||
}
|
|
||||||
r, l := utf8.DecodeRune(b)
|
|
||||||
return r, b[l:]
|
|
||||||
}
|
|
||||||
|
|
||||||
if !pasteActive && len(b) >= 3 && b[0] == keyEscape && b[1] == '[' {
|
|
||||||
switch b[2] {
|
|
||||||
case 'A':
|
|
||||||
return keyUp, b[3:]
|
|
||||||
case 'B':
|
|
||||||
return keyDown, b[3:]
|
|
||||||
case 'C':
|
|
||||||
return keyRight, b[3:]
|
|
||||||
case 'D':
|
|
||||||
return keyLeft, b[3:]
|
|
||||||
case 'H':
|
|
||||||
return keyHome, b[3:]
|
|
||||||
case 'F':
|
|
||||||
return keyEnd, b[3:]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !pasteActive && len(b) >= 6 && b[0] == keyEscape && b[1] == '[' && b[2] == '1' && b[3] == ';' && b[4] == '3' {
|
|
||||||
switch b[5] {
|
|
||||||
case 'C':
|
|
||||||
return keyAltRight, b[6:]
|
|
||||||
case 'D':
|
|
||||||
return keyAltLeft, b[6:]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !pasteActive && len(b) >= 6 && bytes.Equal(b[:6], pasteStart) {
|
|
||||||
return keyPasteStart, b[6:]
|
|
||||||
}
|
|
||||||
|
|
||||||
if pasteActive && len(b) >= 6 && bytes.Equal(b[:6], pasteEnd) {
|
|
||||||
return keyPasteEnd, b[6:]
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we get here then we have a key that we don't recognise, or a
|
|
||||||
// partial sequence. It's not clear how one should find the end of a
|
|
||||||
// sequence without knowing them all, but it seems that [a-zA-Z~] only
|
|
||||||
// appears at the end of a sequence.
|
|
||||||
for i, c := range b[0:] {
|
|
||||||
if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' || c == '~' {
|
|
||||||
return keyUnknown, b[i+1:]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return utf8.RuneError, b
|
|
||||||
}
|
|
||||||
|
|
||||||
// queue appends data to the end of t.outBuf
|
|
||||||
func (t *Terminal) queue(data []rune) {
|
|
||||||
t.outBuf = append(t.outBuf, []byte(string(data))...)
|
|
||||||
}
|
|
||||||
|
|
||||||
var eraseUnderCursor = []rune{' ', keyEscape, '[', 'D'}
|
|
||||||
var space = []rune{' '}
|
|
||||||
|
|
||||||
func isPrintable(key rune) bool {
|
|
||||||
isInSurrogateArea := key >= 0xd800 && key <= 0xdbff
|
|
||||||
return key >= 32 && !isInSurrogateArea
|
|
||||||
}
|
|
||||||
|
|
||||||
// moveCursorToPos appends data to t.outBuf which will move the cursor to the
|
|
||||||
// given, logical position in the text.
|
|
||||||
func (t *Terminal) moveCursorToPos(pos int) {
|
|
||||||
if !t.echo {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
x := visualLength(t.prompt) + pos
|
|
||||||
y := x / t.termWidth
|
|
||||||
x = x % t.termWidth
|
|
||||||
|
|
||||||
up := 0
|
|
||||||
if y < t.cursorY {
|
|
||||||
up = t.cursorY - y
|
|
||||||
}
|
|
||||||
|
|
||||||
down := 0
|
|
||||||
if y > t.cursorY {
|
|
||||||
down = y - t.cursorY
|
|
||||||
}
|
|
||||||
|
|
||||||
left := 0
|
|
||||||
if x < t.cursorX {
|
|
||||||
left = t.cursorX - x
|
|
||||||
}
|
|
||||||
|
|
||||||
right := 0
|
|
||||||
if x > t.cursorX {
|
|
||||||
right = x - t.cursorX
|
|
||||||
}
|
|
||||||
|
|
||||||
t.cursorX = x
|
|
||||||
t.cursorY = y
|
|
||||||
t.move(up, down, left, right)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Terminal) move(up, down, left, right int) {
|
|
||||||
movement := make([]rune, 3*(up+down+left+right))
|
|
||||||
m := movement
|
|
||||||
for i := 0; i < up; i++ {
|
|
||||||
m[0] = keyEscape
|
|
||||||
m[1] = '['
|
|
||||||
m[2] = 'A'
|
|
||||||
m = m[3:]
|
|
||||||
}
|
|
||||||
for i := 0; i < down; i++ {
|
|
||||||
m[0] = keyEscape
|
|
||||||
m[1] = '['
|
|
||||||
m[2] = 'B'
|
|
||||||
m = m[3:]
|
|
||||||
}
|
|
||||||
for i := 0; i < left; i++ {
|
|
||||||
m[0] = keyEscape
|
|
||||||
m[1] = '['
|
|
||||||
m[2] = 'D'
|
|
||||||
m = m[3:]
|
|
||||||
}
|
|
||||||
for i := 0; i < right; i++ {
|
|
||||||
m[0] = keyEscape
|
|
||||||
m[1] = '['
|
|
||||||
m[2] = 'C'
|
|
||||||
m = m[3:]
|
|
||||||
}
|
|
||||||
|
|
||||||
t.queue(movement)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Terminal) clearLineToRight() {
|
|
||||||
op := []rune{keyEscape, '[', 'K'}
|
|
||||||
t.queue(op)
|
|
||||||
}
|
|
||||||
|
|
||||||
const maxLineLength = 4096
|
|
||||||
|
|
||||||
func (t *Terminal) setLine(newLine []rune, newPos int) {
|
|
||||||
if t.echo {
|
|
||||||
t.moveCursorToPos(0)
|
|
||||||
t.writeLine(newLine)
|
|
||||||
for i := len(newLine); i < len(t.line); i++ {
|
|
||||||
t.writeLine(space)
|
|
||||||
}
|
|
||||||
t.moveCursorToPos(newPos)
|
|
||||||
}
|
|
||||||
t.line = newLine
|
|
||||||
t.pos = newPos
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Terminal) advanceCursor(places int) {
|
|
||||||
t.cursorX += places
|
|
||||||
t.cursorY += t.cursorX / t.termWidth
|
|
||||||
if t.cursorY > t.maxLine {
|
|
||||||
t.maxLine = t.cursorY
|
|
||||||
}
|
|
||||||
t.cursorX = t.cursorX % t.termWidth
|
|
||||||
|
|
||||||
if places > 0 && t.cursorX == 0 {
|
|
||||||
// Normally terminals will advance the current position
|
|
||||||
// when writing a character. But that doesn't happen
|
|
||||||
// for the last character in a line. However, when
|
|
||||||
// writing a character (except a new line) that causes
|
|
||||||
// a line wrap, the position will be advanced two
|
|
||||||
// places.
|
|
||||||
//
|
|
||||||
// So, if we are stopping at the end of a line, we
|
|
||||||
// need to write a newline so that our cursor can be
|
|
||||||
// advanced to the next line.
|
|
||||||
t.outBuf = append(t.outBuf, '\n')
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Terminal) eraseNPreviousChars(n int) {
|
|
||||||
if n == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if t.pos < n {
|
|
||||||
n = t.pos
|
|
||||||
}
|
|
||||||
t.pos -= n
|
|
||||||
t.moveCursorToPos(t.pos)
|
|
||||||
|
|
||||||
copy(t.line[t.pos:], t.line[n+t.pos:])
|
|
||||||
t.line = t.line[:len(t.line)-n]
|
|
||||||
if t.echo {
|
|
||||||
t.writeLine(t.line[t.pos:])
|
|
||||||
for i := 0; i < n; i++ {
|
|
||||||
t.queue(space)
|
|
||||||
}
|
|
||||||
t.advanceCursor(n)
|
|
||||||
t.moveCursorToPos(t.pos)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// countToLeftWord returns then number of characters from the cursor to the
|
|
||||||
// start of the previous word.
|
|
||||||
func (t *Terminal) countToLeftWord() int {
|
|
||||||
if t.pos == 0 {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
pos := t.pos - 1
|
|
||||||
for pos > 0 {
|
|
||||||
if t.line[pos] != ' ' {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
pos--
|
|
||||||
}
|
|
||||||
for pos > 0 {
|
|
||||||
if t.line[pos] == ' ' {
|
|
||||||
pos++
|
|
||||||
break
|
|
||||||
}
|
|
||||||
pos--
|
|
||||||
}
|
|
||||||
|
|
||||||
return t.pos - pos
|
|
||||||
}
|
|
||||||
|
|
||||||
// countToRightWord returns then number of characters from the cursor to the
|
|
||||||
// start of the next word.
|
|
||||||
func (t *Terminal) countToRightWord() int {
|
|
||||||
pos := t.pos
|
|
||||||
for pos < len(t.line) {
|
|
||||||
if t.line[pos] == ' ' {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
pos++
|
|
||||||
}
|
|
||||||
for pos < len(t.line) {
|
|
||||||
if t.line[pos] != ' ' {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
pos++
|
|
||||||
}
|
|
||||||
return pos - t.pos
|
|
||||||
}
|
|
||||||
|
|
||||||
// visualLength returns the number of visible glyphs in s.
|
|
||||||
func visualLength(runes []rune) int {
|
|
||||||
inEscapeSeq := false
|
|
||||||
length := 0
|
|
||||||
|
|
||||||
for _, r := range runes {
|
|
||||||
switch {
|
|
||||||
case inEscapeSeq:
|
|
||||||
if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') {
|
|
||||||
inEscapeSeq = false
|
|
||||||
}
|
|
||||||
case r == '\x1b':
|
|
||||||
inEscapeSeq = true
|
|
||||||
default:
|
|
||||||
length++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return length
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleKey processes the given key and, optionally, returns a line of text
|
|
||||||
// that the user has entered.
|
|
||||||
func (t *Terminal) handleKey(key rune) (line string, ok bool) {
|
|
||||||
if t.pasteActive && key != keyEnter {
|
|
||||||
t.addKeyToLine(key)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
switch key {
|
|
||||||
case keyBackspace:
|
|
||||||
if t.pos == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
t.eraseNPreviousChars(1)
|
|
||||||
case keyAltLeft:
|
|
||||||
// move left by a word.
|
|
||||||
t.pos -= t.countToLeftWord()
|
|
||||||
t.moveCursorToPos(t.pos)
|
|
||||||
case keyAltRight:
|
|
||||||
// move right by a word.
|
|
||||||
t.pos += t.countToRightWord()
|
|
||||||
t.moveCursorToPos(t.pos)
|
|
||||||
case keyLeft:
|
|
||||||
if t.pos == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
t.pos--
|
|
||||||
t.moveCursorToPos(t.pos)
|
|
||||||
case keyRight:
|
|
||||||
if t.pos == len(t.line) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
t.pos++
|
|
||||||
t.moveCursorToPos(t.pos)
|
|
||||||
case keyHome:
|
|
||||||
if t.pos == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
t.pos = 0
|
|
||||||
t.moveCursorToPos(t.pos)
|
|
||||||
case keyEnd:
|
|
||||||
if t.pos == len(t.line) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
t.pos = len(t.line)
|
|
||||||
t.moveCursorToPos(t.pos)
|
|
||||||
case keyUp:
|
|
||||||
entry, ok := t.history.NthPreviousEntry(t.historyIndex + 1)
|
|
||||||
if !ok {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
if t.historyIndex == -1 {
|
|
||||||
t.historyPending = string(t.line)
|
|
||||||
}
|
|
||||||
t.historyIndex++
|
|
||||||
runes := []rune(entry)
|
|
||||||
t.setLine(runes, len(runes))
|
|
||||||
case keyDown:
|
|
||||||
switch t.historyIndex {
|
|
||||||
case -1:
|
|
||||||
return
|
|
||||||
case 0:
|
|
||||||
runes := []rune(t.historyPending)
|
|
||||||
t.setLine(runes, len(runes))
|
|
||||||
t.historyIndex--
|
|
||||||
default:
|
|
||||||
entry, ok := t.history.NthPreviousEntry(t.historyIndex - 1)
|
|
||||||
if ok {
|
|
||||||
t.historyIndex--
|
|
||||||
runes := []rune(entry)
|
|
||||||
t.setLine(runes, len(runes))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case keyEnter:
|
|
||||||
t.moveCursorToPos(len(t.line))
|
|
||||||
t.queue([]rune("\r\n"))
|
|
||||||
line = string(t.line)
|
|
||||||
ok = true
|
|
||||||
t.line = t.line[:0]
|
|
||||||
t.pos = 0
|
|
||||||
t.cursorX = 0
|
|
||||||
t.cursorY = 0
|
|
||||||
t.maxLine = 0
|
|
||||||
case keyDeleteWord:
|
|
||||||
// Delete zero or more spaces and then one or more characters.
|
|
||||||
t.eraseNPreviousChars(t.countToLeftWord())
|
|
||||||
case keyDeleteLine:
|
|
||||||
// Delete everything from the current cursor position to the
|
|
||||||
// end of line.
|
|
||||||
for i := t.pos; i < len(t.line); i++ {
|
|
||||||
t.queue(space)
|
|
||||||
t.advanceCursor(1)
|
|
||||||
}
|
|
||||||
t.line = t.line[:t.pos]
|
|
||||||
t.moveCursorToPos(t.pos)
|
|
||||||
case keyCtrlD:
|
|
||||||
// Erase the character under the current position.
|
|
||||||
// The EOF case when the line is empty is handled in
|
|
||||||
// readLine().
|
|
||||||
if t.pos < len(t.line) {
|
|
||||||
t.pos++
|
|
||||||
t.eraseNPreviousChars(1)
|
|
||||||
}
|
|
||||||
case keyCtrlU:
|
|
||||||
t.eraseNPreviousChars(t.pos)
|
|
||||||
case keyClearScreen:
|
|
||||||
// Erases the screen and moves the cursor to the home position.
|
|
||||||
t.queue([]rune("\x1b[2J\x1b[H"))
|
|
||||||
t.queue(t.prompt)
|
|
||||||
t.cursorX, t.cursorY = 0, 0
|
|
||||||
t.advanceCursor(visualLength(t.prompt))
|
|
||||||
t.setLine(t.line, t.pos)
|
|
||||||
default:
|
|
||||||
if t.AutoCompleteCallback != nil {
|
|
||||||
prefix := string(t.line[:t.pos])
|
|
||||||
suffix := string(t.line[t.pos:])
|
|
||||||
|
|
||||||
t.lock.Unlock()
|
|
||||||
newLine, newPos, completeOk := t.AutoCompleteCallback(prefix+suffix, len(prefix), key)
|
|
||||||
t.lock.Lock()
|
|
||||||
|
|
||||||
if completeOk {
|
|
||||||
t.setLine([]rune(newLine), utf8.RuneCount([]byte(newLine)[:newPos]))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !isPrintable(key) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if len(t.line) == maxLineLength {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
t.addKeyToLine(key)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// addKeyToLine inserts the given key at the current position in the current
|
|
||||||
// line.
|
|
||||||
func (t *Terminal) addKeyToLine(key rune) {
|
|
||||||
if len(t.line) == cap(t.line) {
|
|
||||||
newLine := make([]rune, len(t.line), 2*(1+len(t.line)))
|
|
||||||
copy(newLine, t.line)
|
|
||||||
t.line = newLine
|
|
||||||
}
|
|
||||||
t.line = t.line[:len(t.line)+1]
|
|
||||||
copy(t.line[t.pos+1:], t.line[t.pos:])
|
|
||||||
t.line[t.pos] = key
|
|
||||||
if t.echo {
|
|
||||||
t.writeLine(t.line[t.pos:])
|
|
||||||
}
|
|
||||||
t.pos++
|
|
||||||
t.moveCursorToPos(t.pos)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Terminal) writeLine(line []rune) {
|
|
||||||
for len(line) != 0 {
|
|
||||||
remainingOnLine := t.termWidth - t.cursorX
|
|
||||||
todo := len(line)
|
|
||||||
if todo > remainingOnLine {
|
|
||||||
todo = remainingOnLine
|
|
||||||
}
|
|
||||||
t.queue(line[:todo])
|
|
||||||
t.advanceCursor(visualLength(line[:todo]))
|
|
||||||
line = line[todo:]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Terminal) Write(buf []byte) (n int, err error) {
|
|
||||||
t.lock.Lock()
|
|
||||||
defer t.lock.Unlock()
|
|
||||||
|
|
||||||
if t.cursorX == 0 && t.cursorY == 0 {
|
|
||||||
// This is the easy case: there's nothing on the screen that we
|
|
||||||
// have to move out of the way.
|
|
||||||
return t.c.Write(buf)
|
|
||||||
}
|
|
||||||
|
|
||||||
// We have a prompt and possibly user input on the screen. We
|
|
||||||
// have to clear it first.
|
|
||||||
t.move(0 /* up */, 0 /* down */, t.cursorX /* left */, 0 /* right */)
|
|
||||||
t.cursorX = 0
|
|
||||||
t.clearLineToRight()
|
|
||||||
|
|
||||||
for t.cursorY > 0 {
|
|
||||||
t.move(1 /* up */, 0, 0, 0)
|
|
||||||
t.cursorY--
|
|
||||||
t.clearLineToRight()
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err = t.c.Write(t.outBuf); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
t.outBuf = t.outBuf[:0]
|
|
||||||
|
|
||||||
if n, err = t.c.Write(buf); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
t.writeLine(t.prompt)
|
|
||||||
if t.echo {
|
|
||||||
t.writeLine(t.line)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.moveCursorToPos(t.pos)
|
|
||||||
|
|
||||||
if _, err = t.c.Write(t.outBuf); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
t.outBuf = t.outBuf[:0]
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReadPassword temporarily changes the prompt and reads a password, without
|
|
||||||
// echo, from the terminal.
|
|
||||||
func (t *Terminal) ReadPassword(prompt string) (line string, err error) {
|
|
||||||
t.lock.Lock()
|
|
||||||
defer t.lock.Unlock()
|
|
||||||
|
|
||||||
oldPrompt := t.prompt
|
|
||||||
t.prompt = []rune(prompt)
|
|
||||||
t.echo = false
|
|
||||||
|
|
||||||
line, err = t.readLine()
|
|
||||||
|
|
||||||
t.prompt = oldPrompt
|
|
||||||
t.echo = true
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReadLine returns a line of input from the terminal.
|
|
||||||
func (t *Terminal) ReadLine() (line string, err error) {
|
|
||||||
t.lock.Lock()
|
|
||||||
defer t.lock.Unlock()
|
|
||||||
|
|
||||||
return t.readLine()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Terminal) readLine() (line string, err error) {
|
|
||||||
// t.lock must be held at this point
|
|
||||||
|
|
||||||
if t.cursorX == 0 && t.cursorY == 0 {
|
|
||||||
t.writeLine(t.prompt)
|
|
||||||
t.c.Write(t.outBuf)
|
|
||||||
t.outBuf = t.outBuf[:0]
|
|
||||||
}
|
|
||||||
|
|
||||||
lineIsPasted := t.pasteActive
|
|
||||||
|
|
||||||
for {
|
|
||||||
rest := t.remainder
|
|
||||||
lineOk := false
|
|
||||||
for !lineOk {
|
|
||||||
var key rune
|
|
||||||
key, rest = bytesToKey(rest, t.pasteActive)
|
|
||||||
if key == utf8.RuneError {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if !t.pasteActive {
|
|
||||||
if key == keyCtrlD {
|
|
||||||
if len(t.line) == 0 {
|
|
||||||
return "", io.EOF
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if key == keyPasteStart {
|
|
||||||
t.pasteActive = true
|
|
||||||
if len(t.line) == 0 {
|
|
||||||
lineIsPasted = true
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
} else if key == keyPasteEnd {
|
|
||||||
t.pasteActive = false
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if !t.pasteActive {
|
|
||||||
lineIsPasted = false
|
|
||||||
}
|
|
||||||
line, lineOk = t.handleKey(key)
|
|
||||||
}
|
|
||||||
if len(rest) > 0 {
|
|
||||||
n := copy(t.inBuf[:], rest)
|
|
||||||
t.remainder = t.inBuf[:n]
|
|
||||||
} else {
|
|
||||||
t.remainder = nil
|
|
||||||
}
|
|
||||||
t.c.Write(t.outBuf)
|
|
||||||
t.outBuf = t.outBuf[:0]
|
|
||||||
if lineOk {
|
|
||||||
if t.echo {
|
|
||||||
t.historyIndex = -1
|
|
||||||
t.history.Add(line)
|
|
||||||
}
|
|
||||||
if lineIsPasted {
|
|
||||||
err = ErrPasteIndicator
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// t.remainder is a slice at the beginning of t.inBuf
|
|
||||||
// containing a partial key sequence
|
|
||||||
readBuf := t.inBuf[len(t.remainder):]
|
|
||||||
var n int
|
|
||||||
|
|
||||||
t.lock.Unlock()
|
|
||||||
n, err = t.c.Read(readBuf)
|
|
||||||
t.lock.Lock()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
t.remainder = t.inBuf[:n+len(t.remainder)]
|
|
||||||
}
|
|
||||||
|
|
||||||
panic("unreachable") // for Go 1.0.
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetPrompt sets the prompt to be used when reading subsequent lines.
|
|
||||||
func (t *Terminal) SetPrompt(prompt string) {
|
|
||||||
t.lock.Lock()
|
|
||||||
defer t.lock.Unlock()
|
|
||||||
|
|
||||||
t.prompt = []rune(prompt)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Terminal) clearAndRepaintLinePlusNPrevious(numPrevLines int) {
|
|
||||||
// Move cursor to column zero at the start of the line.
|
|
||||||
t.move(t.cursorY, 0, t.cursorX, 0)
|
|
||||||
t.cursorX, t.cursorY = 0, 0
|
|
||||||
t.clearLineToRight()
|
|
||||||
for t.cursorY < numPrevLines {
|
|
||||||
// Move down a line
|
|
||||||
t.move(0, 1, 0, 0)
|
|
||||||
t.cursorY++
|
|
||||||
t.clearLineToRight()
|
|
||||||
}
|
|
||||||
// Move back to beginning.
|
|
||||||
t.move(t.cursorY, 0, 0, 0)
|
|
||||||
t.cursorX, t.cursorY = 0, 0
|
|
||||||
|
|
||||||
t.queue(t.prompt)
|
|
||||||
t.advanceCursor(visualLength(t.prompt))
|
|
||||||
t.writeLine(t.line)
|
|
||||||
t.moveCursorToPos(t.pos)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Terminal) SetSize(width, height int) error {
|
|
||||||
t.lock.Lock()
|
|
||||||
defer t.lock.Unlock()
|
|
||||||
|
|
||||||
if width == 0 {
|
|
||||||
width = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
oldWidth := t.termWidth
|
|
||||||
t.termWidth, t.termHeight = width, height
|
|
||||||
|
|
||||||
switch {
|
|
||||||
case width == oldWidth:
|
|
||||||
// If the width didn't change then nothing else needs to be
|
|
||||||
// done.
|
|
||||||
return nil
|
|
||||||
case width < oldWidth:
|
|
||||||
// Some terminals (e.g. xterm) will truncate lines that were
|
|
||||||
// too long when shinking. Others, (e.g. gnome-terminal) will
|
|
||||||
// attempt to wrap them. For the former, repainting t.maxLine
|
|
||||||
// works great, but that behaviour goes badly wrong in the case
|
|
||||||
// of the latter because they have doubled every full line.
|
|
||||||
|
|
||||||
// We assume that we are working on a terminal that wraps lines
|
|
||||||
// and adjust the cursor position based on every previous line
|
|
||||||
// wrapping and turning into two. This causes the prompt on
|
|
||||||
// xterms to move upwards, which isn't great, but it avoids a
|
|
||||||
// huge mess with gnome-terminal.
|
|
||||||
if t.cursorX >= t.termWidth {
|
|
||||||
t.cursorX = t.termWidth - 1
|
|
||||||
}
|
|
||||||
t.cursorY *= 2
|
|
||||||
t.clearAndRepaintLinePlusNPrevious(t.maxLine * 2)
|
|
||||||
case width > oldWidth:
|
|
||||||
// If the terminal expands then our position calculations will
|
|
||||||
// be wrong in the future because we think the cursor is
|
|
||||||
// |t.pos| chars into the string, but there will be a gap at
|
|
||||||
// the end of any wrapped line.
|
|
||||||
//
|
|
||||||
// But the position will actually be correct until we move, so
|
|
||||||
// we can move back to the beginning and repaint everything.
|
|
||||||
t.clearAndRepaintLinePlusNPrevious(t.maxLine)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := t.c.Write(t.outBuf)
|
|
||||||
t.outBuf = t.outBuf[:0]
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
type pasteIndicatorError struct{}
|
|
||||||
|
|
||||||
func (pasteIndicatorError) Error() string {
|
|
||||||
return "terminal: ErrPasteIndicator not correctly handled"
|
|
||||||
}
|
|
||||||
|
|
||||||
// ErrPasteIndicator may be returned from ReadLine as the error, in addition
|
|
||||||
// to valid line data. It indicates that bracketed paste mode is enabled and
|
|
||||||
// that the returned line consists only of pasted data. Programs may wish to
|
|
||||||
// interpret pasted data more literally than typed data.
|
|
||||||
var ErrPasteIndicator = pasteIndicatorError{}
|
|
||||||
|
|
||||||
// SetBracketedPasteMode requests that the terminal bracket paste operations
|
|
||||||
// with markers. Not all terminals support this but, if it is supported, then
|
|
||||||
// enabling this mode will stop any autocomplete callback from running due to
|
|
||||||
// pastes. Additionally, any lines that are completely pasted will be returned
|
|
||||||
// from ReadLine with the error set to ErrPasteIndicator.
|
|
||||||
func (t *Terminal) SetBracketedPasteMode(on bool) {
|
|
||||||
if on {
|
|
||||||
io.WriteString(t.c, "\x1b[?2004h")
|
|
||||||
} else {
|
|
||||||
io.WriteString(t.c, "\x1b[?2004l")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// stRingBuffer is a ring buffer of strings.
|
|
||||||
type stRingBuffer struct {
|
|
||||||
// entries contains max elements.
|
|
||||||
entries []string
|
|
||||||
max int
|
|
||||||
// head contains the index of the element most recently added to the ring.
|
|
||||||
head int
|
|
||||||
// size contains the number of elements in the ring.
|
|
||||||
size int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *stRingBuffer) Add(a string) {
|
|
||||||
if s.entries == nil {
|
|
||||||
const defaultNumEntries = 100
|
|
||||||
s.entries = make([]string, defaultNumEntries)
|
|
||||||
s.max = defaultNumEntries
|
|
||||||
}
|
|
||||||
|
|
||||||
s.head = (s.head + 1) % s.max
|
|
||||||
s.entries[s.head] = a
|
|
||||||
if s.size < s.max {
|
|
||||||
s.size++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// NthPreviousEntry returns the value passed to the nth previous call to Add.
|
|
||||||
// If n is zero then the immediately prior value is returned, if one, then the
|
|
||||||
// next most recent, and so on. If such an element doesn't exist then ok is
|
|
||||||
// false.
|
|
||||||
func (s *stRingBuffer) NthPreviousEntry(n int) (value string, ok bool) {
|
|
||||||
if n >= s.size {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
index := s.head - n
|
|
||||||
if index < 0 {
|
|
||||||
index += s.max
|
|
||||||
}
|
|
||||||
return s.entries[index], true
|
|
||||||
}
|
|
|
@ -1,243 +0,0 @@
|
||||||
// Copyright 2011 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package terminal
|
|
||||||
|
|
||||||
import (
|
|
||||||
"io"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
type MockTerminal struct {
|
|
||||||
toSend []byte
|
|
||||||
bytesPerRead int
|
|
||||||
received []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *MockTerminal) Read(data []byte) (n int, err error) {
|
|
||||||
n = len(data)
|
|
||||||
if n == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if n > len(c.toSend) {
|
|
||||||
n = len(c.toSend)
|
|
||||||
}
|
|
||||||
if n == 0 {
|
|
||||||
return 0, io.EOF
|
|
||||||
}
|
|
||||||
if c.bytesPerRead > 0 && n > c.bytesPerRead {
|
|
||||||
n = c.bytesPerRead
|
|
||||||
}
|
|
||||||
copy(data, c.toSend[:n])
|
|
||||||
c.toSend = c.toSend[n:]
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *MockTerminal) Write(data []byte) (n int, err error) {
|
|
||||||
c.received = append(c.received, data...)
|
|
||||||
return len(data), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestClose(t *testing.T) {
|
|
||||||
c := &MockTerminal{}
|
|
||||||
ss := NewTerminal(c, "> ")
|
|
||||||
line, err := ss.ReadLine()
|
|
||||||
if line != "" {
|
|
||||||
t.Errorf("Expected empty line but got: %s", line)
|
|
||||||
}
|
|
||||||
if err != io.EOF {
|
|
||||||
t.Errorf("Error should have been EOF but got: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var keyPressTests = []struct {
|
|
||||||
in string
|
|
||||||
line string
|
|
||||||
err error
|
|
||||||
throwAwayLines int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
err: io.EOF,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
in: "\r",
|
|
||||||
line: "",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
in: "foo\r",
|
|
||||||
line: "foo",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
in: "a\x1b[Cb\r", // right
|
|
||||||
line: "ab",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
in: "a\x1b[Db\r", // left
|
|
||||||
line: "ba",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
in: "a\177b\r", // backspace
|
|
||||||
line: "b",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
in: "\x1b[A\r", // up
|
|
||||||
},
|
|
||||||
{
|
|
||||||
in: "\x1b[B\r", // down
|
|
||||||
},
|
|
||||||
{
|
|
||||||
in: "line\x1b[A\x1b[B\r", // up then down
|
|
||||||
line: "line",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
in: "line1\rline2\x1b[A\r", // recall previous line.
|
|
||||||
line: "line1",
|
|
||||||
throwAwayLines: 1,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// recall two previous lines and append.
|
|
||||||
in: "line1\rline2\rline3\x1b[A\x1b[Axxx\r",
|
|
||||||
line: "line1xxx",
|
|
||||||
throwAwayLines: 2,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// Ctrl-A to move to beginning of line followed by ^K to kill
|
|
||||||
// line.
|
|
||||||
in: "a b \001\013\r",
|
|
||||||
line: "",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// Ctrl-A to move to beginning of line, Ctrl-E to move to end,
|
|
||||||
// finally ^K to kill nothing.
|
|
||||||
in: "a b \001\005\013\r",
|
|
||||||
line: "a b ",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
in: "\027\r",
|
|
||||||
line: "",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
in: "a\027\r",
|
|
||||||
line: "",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
in: "a \027\r",
|
|
||||||
line: "",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
in: "a b\027\r",
|
|
||||||
line: "a ",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
in: "a b \027\r",
|
|
||||||
line: "a ",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
in: "one two thr\x1b[D\027\r",
|
|
||||||
line: "one two r",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
in: "\013\r",
|
|
||||||
line: "",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
in: "a\013\r",
|
|
||||||
line: "a",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
in: "ab\x1b[D\013\r",
|
|
||||||
line: "a",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
in: "Ξεσκεπάζω\r",
|
|
||||||
line: "Ξεσκεπάζω",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
in: "£\r\x1b[A\177\r", // non-ASCII char, enter, up, backspace.
|
|
||||||
line: "",
|
|
||||||
throwAwayLines: 1,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
in: "£\r££\x1b[A\x1b[B\177\r", // non-ASCII char, enter, 2x non-ASCII, up, down, backspace, enter.
|
|
||||||
line: "£",
|
|
||||||
throwAwayLines: 1,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// Ctrl-D at the end of the line should be ignored.
|
|
||||||
in: "a\004\r",
|
|
||||||
line: "a",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// a, b, left, Ctrl-D should erase the b.
|
|
||||||
in: "ab\x1b[D\004\r",
|
|
||||||
line: "a",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// a, b, c, d, left, left, ^U should erase to the beginning of
|
|
||||||
// the line.
|
|
||||||
in: "abcd\x1b[D\x1b[D\025\r",
|
|
||||||
line: "cd",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// Bracketed paste mode: control sequences should be returned
|
|
||||||
// verbatim in paste mode.
|
|
||||||
in: "abc\x1b[200~de\177f\x1b[201~\177\r",
|
|
||||||
line: "abcde\177",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// Enter in bracketed paste mode should still work.
|
|
||||||
in: "abc\x1b[200~d\refg\x1b[201~h\r",
|
|
||||||
line: "efgh",
|
|
||||||
throwAwayLines: 1,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// Lines consisting entirely of pasted data should be indicated as such.
|
|
||||||
in: "\x1b[200~a\r",
|
|
||||||
line: "a",
|
|
||||||
err: ErrPasteIndicator,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestKeyPresses(t *testing.T) {
|
|
||||||
for i, test := range keyPressTests {
|
|
||||||
for j := 1; j < len(test.in); j++ {
|
|
||||||
c := &MockTerminal{
|
|
||||||
toSend: []byte(test.in),
|
|
||||||
bytesPerRead: j,
|
|
||||||
}
|
|
||||||
ss := NewTerminal(c, "> ")
|
|
||||||
for k := 0; k < test.throwAwayLines; k++ {
|
|
||||||
_, err := ss.ReadLine()
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Throwaway line %d from test %d resulted in error: %s", k, i, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
line, err := ss.ReadLine()
|
|
||||||
if line != test.line {
|
|
||||||
t.Errorf("Line resulting from test %d (%d bytes per read) was '%s', expected '%s'", i, j, line, test.line)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if err != test.err {
|
|
||||||
t.Errorf("Error resulting from test %d (%d bytes per read) was '%v', expected '%v'", i, j, err, test.err)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPasswordNotSaved(t *testing.T) {
|
|
||||||
c := &MockTerminal{
|
|
||||||
toSend: []byte("password\r\x1b[A\r"),
|
|
||||||
bytesPerRead: 1,
|
|
||||||
}
|
|
||||||
ss := NewTerminal(c, "> ")
|
|
||||||
pw, _ := ss.ReadPassword("> ")
|
|
||||||
if pw != "password" {
|
|
||||||
t.Fatalf("failed to read password, got %s", pw)
|
|
||||||
}
|
|
||||||
line, _ := ss.ReadLine()
|
|
||||||
if len(line) > 0 {
|
|
||||||
t.Fatalf("password was saved in history")
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,128 +0,0 @@
|
||||||
// Copyright 2011 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
// +build darwin dragonfly freebsd linux,!appengine netbsd openbsd
|
|
||||||
|
|
||||||
// Package terminal provides support functions for dealing with terminals, as
|
|
||||||
// commonly found on UNIX systems.
|
|
||||||
//
|
|
||||||
// Putting a terminal into raw mode is the most common requirement:
|
|
||||||
//
|
|
||||||
// oldState, err := terminal.MakeRaw(0)
|
|
||||||
// if err != nil {
|
|
||||||
// panic(err)
|
|
||||||
// }
|
|
||||||
// defer terminal.Restore(0, oldState)
|
|
||||||
package terminal
|
|
||||||
|
|
||||||
import (
|
|
||||||
"io"
|
|
||||||
"syscall"
|
|
||||||
"unsafe"
|
|
||||||
)
|
|
||||||
|
|
||||||
// State contains the state of a terminal.
|
|
||||||
type State struct {
|
|
||||||
termios syscall.Termios
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsTerminal returns true if the given file descriptor is a terminal.
|
|
||||||
func IsTerminal(fd int) bool {
|
|
||||||
var termios syscall.Termios
|
|
||||||
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlReadTermios, uintptr(unsafe.Pointer(&termios)), 0, 0, 0)
|
|
||||||
return err == 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// MakeRaw put the terminal connected to the given file descriptor into raw
|
|
||||||
// mode and returns the previous state of the terminal so that it can be
|
|
||||||
// restored.
|
|
||||||
func MakeRaw(fd int) (*State, error) {
|
|
||||||
var oldState State
|
|
||||||
if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlReadTermios, uintptr(unsafe.Pointer(&oldState.termios)), 0, 0, 0); err != 0 {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
newState := oldState.termios
|
|
||||||
newState.Iflag &^= syscall.ISTRIP | syscall.INLCR | syscall.ICRNL | syscall.IGNCR | syscall.IXON | syscall.IXOFF
|
|
||||||
newState.Lflag &^= syscall.ECHO | syscall.ICANON | syscall.ISIG
|
|
||||||
if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&newState)), 0, 0, 0); err != 0 {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &oldState, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetState returns the current state of a terminal which may be useful to
|
|
||||||
// restore the terminal after a signal.
|
|
||||||
func GetState(fd int) (*State, error) {
|
|
||||||
var oldState State
|
|
||||||
if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlReadTermios, uintptr(unsafe.Pointer(&oldState.termios)), 0, 0, 0); err != 0 {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &oldState, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Restore restores the terminal connected to the given file descriptor to a
|
|
||||||
// previous state.
|
|
||||||
func Restore(fd int, state *State) error {
|
|
||||||
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&state.termios)), 0, 0, 0)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetSize returns the dimensions of the given terminal.
|
|
||||||
func GetSize(fd int) (width, height int, err error) {
|
|
||||||
var dimensions [4]uint16
|
|
||||||
|
|
||||||
if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), uintptr(syscall.TIOCGWINSZ), uintptr(unsafe.Pointer(&dimensions)), 0, 0, 0); err != 0 {
|
|
||||||
return -1, -1, err
|
|
||||||
}
|
|
||||||
return int(dimensions[1]), int(dimensions[0]), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReadPassword reads a line of input from a terminal without local echo. This
|
|
||||||
// is commonly used for inputting passwords and other sensitive data. The slice
|
|
||||||
// returned does not include the \n.
|
|
||||||
func ReadPassword(fd int) ([]byte, error) {
|
|
||||||
var oldState syscall.Termios
|
|
||||||
if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlReadTermios, uintptr(unsafe.Pointer(&oldState)), 0, 0, 0); err != 0 {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
newState := oldState
|
|
||||||
newState.Lflag &^= syscall.ECHO
|
|
||||||
newState.Lflag |= syscall.ICANON | syscall.ISIG
|
|
||||||
newState.Iflag |= syscall.ICRNL
|
|
||||||
if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&newState)), 0, 0, 0); err != 0 {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&oldState)), 0, 0, 0)
|
|
||||||
}()
|
|
||||||
|
|
||||||
var buf [16]byte
|
|
||||||
var ret []byte
|
|
||||||
for {
|
|
||||||
n, err := syscall.Read(fd, buf[:])
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if n == 0 {
|
|
||||||
if len(ret) == 0 {
|
|
||||||
return nil, io.EOF
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if buf[n-1] == '\n' {
|
|
||||||
n--
|
|
||||||
}
|
|
||||||
ret = append(ret, buf[:n]...)
|
|
||||||
if n < len(buf) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return ret, nil
|
|
||||||
}
|
|
|
@ -1,12 +0,0 @@
|
||||||
// Copyright 2013 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
// +build darwin dragonfly freebsd netbsd openbsd
|
|
||||||
|
|
||||||
package terminal
|
|
||||||
|
|
||||||
import "syscall"
|
|
||||||
|
|
||||||
const ioctlReadTermios = syscall.TIOCGETA
|
|
||||||
const ioctlWriteTermios = syscall.TIOCSETA
|
|
|
@ -1,11 +0,0 @@
|
||||||
// Copyright 2013 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package terminal
|
|
||||||
|
|
||||||
// These constants are declared here, rather than importing
|
|
||||||
// them from the syscall package as some syscall packages, even
|
|
||||||
// on linux, for example gccgo, do not declare them.
|
|
||||||
const ioctlReadTermios = 0x5401 // syscall.TCGETS
|
|
||||||
const ioctlWriteTermios = 0x5402 // syscall.TCSETS
|
|
|
@ -1,174 +0,0 @@
|
||||||
// Copyright 2011 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
// +build windows
|
|
||||||
|
|
||||||
// Package terminal provides support functions for dealing with terminals, as
|
|
||||||
// commonly found on UNIX systems.
|
|
||||||
//
|
|
||||||
// Putting a terminal into raw mode is the most common requirement:
|
|
||||||
//
|
|
||||||
// oldState, err := terminal.MakeRaw(0)
|
|
||||||
// if err != nil {
|
|
||||||
// panic(err)
|
|
||||||
// }
|
|
||||||
// defer terminal.Restore(0, oldState)
|
|
||||||
package terminal
|
|
||||||
|
|
||||||
import (
|
|
||||||
"io"
|
|
||||||
"syscall"
|
|
||||||
"unsafe"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
enableLineInput = 2
|
|
||||||
enableEchoInput = 4
|
|
||||||
enableProcessedInput = 1
|
|
||||||
enableWindowInput = 8
|
|
||||||
enableMouseInput = 16
|
|
||||||
enableInsertMode = 32
|
|
||||||
enableQuickEditMode = 64
|
|
||||||
enableExtendedFlags = 128
|
|
||||||
enableAutoPosition = 256
|
|
||||||
enableProcessedOutput = 1
|
|
||||||
enableWrapAtEolOutput = 2
|
|
||||||
)
|
|
||||||
|
|
||||||
var kernel32 = syscall.NewLazyDLL("kernel32.dll")
|
|
||||||
|
|
||||||
var (
|
|
||||||
procGetConsoleMode = kernel32.NewProc("GetConsoleMode")
|
|
||||||
procSetConsoleMode = kernel32.NewProc("SetConsoleMode")
|
|
||||||
procGetConsoleScreenBufferInfo = kernel32.NewProc("GetConsoleScreenBufferInfo")
|
|
||||||
)
|
|
||||||
|
|
||||||
type (
|
|
||||||
short int16
|
|
||||||
word uint16
|
|
||||||
|
|
||||||
coord struct {
|
|
||||||
x short
|
|
||||||
y short
|
|
||||||
}
|
|
||||||
smallRect struct {
|
|
||||||
left short
|
|
||||||
top short
|
|
||||||
right short
|
|
||||||
bottom short
|
|
||||||
}
|
|
||||||
consoleScreenBufferInfo struct {
|
|
||||||
size coord
|
|
||||||
cursorPosition coord
|
|
||||||
attributes word
|
|
||||||
window smallRect
|
|
||||||
maximumWindowSize coord
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
type State struct {
|
|
||||||
mode uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsTerminal returns true if the given file descriptor is a terminal.
|
|
||||||
func IsTerminal(fd int) bool {
|
|
||||||
var st uint32
|
|
||||||
r, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&st)), 0)
|
|
||||||
return r != 0 && e == 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// MakeRaw put the terminal connected to the given file descriptor into raw
|
|
||||||
// mode and returns the previous state of the terminal so that it can be
|
|
||||||
// restored.
|
|
||||||
func MakeRaw(fd int) (*State, error) {
|
|
||||||
var st uint32
|
|
||||||
_, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&st)), 0)
|
|
||||||
if e != 0 {
|
|
||||||
return nil, error(e)
|
|
||||||
}
|
|
||||||
st &^= (enableEchoInput | enableProcessedInput | enableLineInput | enableProcessedOutput)
|
|
||||||
_, _, e = syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(st), 0)
|
|
||||||
if e != 0 {
|
|
||||||
return nil, error(e)
|
|
||||||
}
|
|
||||||
return &State{st}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetState returns the current state of a terminal which may be useful to
|
|
||||||
// restore the terminal after a signal.
|
|
||||||
func GetState(fd int) (*State, error) {
|
|
||||||
var st uint32
|
|
||||||
_, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&st)), 0)
|
|
||||||
if e != 0 {
|
|
||||||
return nil, error(e)
|
|
||||||
}
|
|
||||||
return &State{st}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Restore restores the terminal connected to the given file descriptor to a
|
|
||||||
// previous state.
|
|
||||||
func Restore(fd int, state *State) error {
|
|
||||||
_, _, err := syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(state.mode), 0)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetSize returns the dimensions of the given terminal.
|
|
||||||
func GetSize(fd int) (width, height int, err error) {
|
|
||||||
var info consoleScreenBufferInfo
|
|
||||||
_, _, e := syscall.Syscall(procGetConsoleScreenBufferInfo.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&info)), 0)
|
|
||||||
if e != 0 {
|
|
||||||
return 0, 0, error(e)
|
|
||||||
}
|
|
||||||
return int(info.size.x), int(info.size.y), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReadPassword reads a line of input from a terminal without local echo. This
|
|
||||||
// is commonly used for inputting passwords and other sensitive data. The slice
|
|
||||||
// returned does not include the \n.
|
|
||||||
func ReadPassword(fd int) ([]byte, error) {
|
|
||||||
var st uint32
|
|
||||||
_, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&st)), 0)
|
|
||||||
if e != 0 {
|
|
||||||
return nil, error(e)
|
|
||||||
}
|
|
||||||
old := st
|
|
||||||
|
|
||||||
st &^= (enableEchoInput)
|
|
||||||
st |= (enableProcessedInput | enableLineInput | enableProcessedOutput)
|
|
||||||
_, _, e = syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(st), 0)
|
|
||||||
if e != 0 {
|
|
||||||
return nil, error(e)
|
|
||||||
}
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(old), 0)
|
|
||||||
}()
|
|
||||||
|
|
||||||
var buf [16]byte
|
|
||||||
var ret []byte
|
|
||||||
for {
|
|
||||||
n, err := syscall.Read(syscall.Handle(fd), buf[:])
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if n == 0 {
|
|
||||||
if len(ret) == 0 {
|
|
||||||
return nil, io.EOF
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if buf[n-1] == '\n' {
|
|
||||||
n--
|
|
||||||
}
|
|
||||||
if n > 0 && buf[n-1] == '\r' {
|
|
||||||
n--
|
|
||||||
}
|
|
||||||
ret = append(ret, buf[:n]...)
|
|
||||||
if n < len(buf) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return ret, nil
|
|
||||||
}
|
|
|
@ -1,50 +0,0 @@
|
||||||
// Copyright 2014 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
// +build darwin dragonfly freebsd linux netbsd openbsd
|
|
||||||
|
|
||||||
package test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
|
||||||
"golang.org/x/crypto/ssh/agent"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestAgentForward(t *testing.T) {
|
|
||||||
server := newServer(t)
|
|
||||||
defer server.Shutdown()
|
|
||||||
conn := server.Dial(clientConfig())
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
keyring := agent.NewKeyring()
|
|
||||||
keyring.Add(testPrivateKeys["dsa"], nil, "")
|
|
||||||
pub := testPublicKeys["dsa"]
|
|
||||||
|
|
||||||
sess, err := conn.NewSession()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("NewSession: %v", err)
|
|
||||||
}
|
|
||||||
if err := agent.RequestAgentForwarding(sess); err != nil {
|
|
||||||
t.Fatalf("RequestAgentForwarding: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := agent.ForwardToAgent(conn, keyring); err != nil {
|
|
||||||
t.Fatalf("SetupForwardKeyring: %v", err)
|
|
||||||
}
|
|
||||||
out, err := sess.CombinedOutput("ssh-add -L")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("running ssh-add: %v, out %s", err, out)
|
|
||||||
}
|
|
||||||
key, _, _, _, err := ssh.ParseAuthorizedKey(out)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("ParseAuthorizedKey(%q): %v", out, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !bytes.Equal(key.Marshal(), pub.Marshal()) {
|
|
||||||
t.Fatalf("got key %s, want %s", ssh.MarshalAuthorizedKey(key), ssh.MarshalAuthorizedKey(pub))
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,47 +0,0 @@
|
||||||
// Copyright 2014 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
// +build darwin dragonfly freebsd linux netbsd openbsd
|
|
||||||
|
|
||||||
package test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/rand"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestCertLogin(t *testing.T) {
|
|
||||||
s := newServer(t)
|
|
||||||
defer s.Shutdown()
|
|
||||||
|
|
||||||
// Use a key different from the default.
|
|
||||||
clientKey := testSigners["dsa"]
|
|
||||||
caAuthKey := testSigners["ecdsa"]
|
|
||||||
cert := &ssh.Certificate{
|
|
||||||
Key: clientKey.PublicKey(),
|
|
||||||
ValidPrincipals: []string{username()},
|
|
||||||
CertType: ssh.UserCert,
|
|
||||||
ValidBefore: ssh.CertTimeInfinity,
|
|
||||||
}
|
|
||||||
if err := cert.SignCert(rand.Reader, caAuthKey); err != nil {
|
|
||||||
t.Fatalf("SetSignature: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
certSigner, err := ssh.NewCertSigner(cert, clientKey)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("NewCertSigner: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
conf := &ssh.ClientConfig{
|
|
||||||
User: username(),
|
|
||||||
}
|
|
||||||
conf.Auth = append(conf.Auth, ssh.PublicKeys(certSigner))
|
|
||||||
client, err := s.TryDial(conf)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("TryDial: %v", err)
|
|
||||||
}
|
|
||||||
client.Close()
|
|
||||||
}
|
|
|
@ -1,7 +0,0 @@
|
||||||
// Copyright 2012 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
// This package contains integration tests for the
|
|
||||||
// code.google.com/p/go.crypto/ssh package.
|
|
||||||
package test
|
|
|
@ -1,160 +0,0 @@
|
||||||
// Copyright 2012 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
// +build darwin dragonfly freebsd linux netbsd openbsd
|
|
||||||
|
|
||||||
package test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"io"
|
|
||||||
"io/ioutil"
|
|
||||||
"math/rand"
|
|
||||||
"net"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestPortForward(t *testing.T) {
|
|
||||||
server := newServer(t)
|
|
||||||
defer server.Shutdown()
|
|
||||||
conn := server.Dial(clientConfig())
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
sshListener, err := conn.Listen("tcp", "localhost:0")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
sshConn, err := sshListener.Accept()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("listen.Accept failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = io.Copy(sshConn, sshConn)
|
|
||||||
if err != nil && err != io.EOF {
|
|
||||||
t.Fatalf("ssh client copy: %v", err)
|
|
||||||
}
|
|
||||||
sshConn.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
forwardedAddr := sshListener.Addr().String()
|
|
||||||
tcpConn, err := net.Dial("tcp", forwardedAddr)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("TCP dial failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
readChan := make(chan []byte)
|
|
||||||
go func() {
|
|
||||||
data, _ := ioutil.ReadAll(tcpConn)
|
|
||||||
readChan <- data
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Invent some data.
|
|
||||||
data := make([]byte, 100*1000)
|
|
||||||
for i := range data {
|
|
||||||
data[i] = byte(i % 255)
|
|
||||||
}
|
|
||||||
|
|
||||||
var sent []byte
|
|
||||||
for len(sent) < 1000*1000 {
|
|
||||||
// Send random sized chunks
|
|
||||||
m := rand.Intn(len(data))
|
|
||||||
n, err := tcpConn.Write(data[:m])
|
|
||||||
if err != nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
sent = append(sent, data[:n]...)
|
|
||||||
}
|
|
||||||
if err := tcpConn.(*net.TCPConn).CloseWrite(); err != nil {
|
|
||||||
t.Errorf("tcpConn.CloseWrite: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
read := <-readChan
|
|
||||||
|
|
||||||
if len(sent) != len(read) {
|
|
||||||
t.Fatalf("got %d bytes, want %d", len(read), len(sent))
|
|
||||||
}
|
|
||||||
if bytes.Compare(sent, read) != 0 {
|
|
||||||
t.Fatalf("read back data does not match")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := sshListener.Close(); err != nil {
|
|
||||||
t.Fatalf("sshListener.Close: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check that the forward disappeared.
|
|
||||||
tcpConn, err = net.Dial("tcp", forwardedAddr)
|
|
||||||
if err == nil {
|
|
||||||
tcpConn.Close()
|
|
||||||
t.Errorf("still listening to %s after closing", forwardedAddr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAcceptClose(t *testing.T) {
|
|
||||||
server := newServer(t)
|
|
||||||
defer server.Shutdown()
|
|
||||||
conn := server.Dial(clientConfig())
|
|
||||||
|
|
||||||
sshListener, err := conn.Listen("tcp", "localhost:0")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
quit := make(chan error, 1)
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
c, err := sshListener.Accept()
|
|
||||||
if err != nil {
|
|
||||||
quit <- err
|
|
||||||
break
|
|
||||||
}
|
|
||||||
c.Close()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
sshListener.Close()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-time.After(1 * time.Second):
|
|
||||||
t.Errorf("timeout: listener did not close.")
|
|
||||||
case err := <-quit:
|
|
||||||
t.Logf("quit as expected (error %v)", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check that listeners exit if the underlying client transport dies.
|
|
||||||
func TestPortForwardConnectionClose(t *testing.T) {
|
|
||||||
server := newServer(t)
|
|
||||||
defer server.Shutdown()
|
|
||||||
conn := server.Dial(clientConfig())
|
|
||||||
|
|
||||||
sshListener, err := conn.Listen("tcp", "localhost:0")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
quit := make(chan error, 1)
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
c, err := sshListener.Accept()
|
|
||||||
if err != nil {
|
|
||||||
quit <- err
|
|
||||||
break
|
|
||||||
}
|
|
||||||
c.Close()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// It would be even nicer if we closed the server side, but it
|
|
||||||
// is more involved as the fd for that side is dup()ed.
|
|
||||||
server.clientConn.Close()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-time.After(1 * time.Second):
|
|
||||||
t.Errorf("timeout: listener did not close.")
|
|
||||||
case err := <-quit:
|
|
||||||
t.Logf("quit as expected (error %v)", err)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,317 +0,0 @@
|
||||||
// Copyright 2012 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
// +build !windows
|
|
||||||
|
|
||||||
package test
|
|
||||||
|
|
||||||
// Session functional tests.
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"errors"
|
|
||||||
"golang.org/x/crypto/ssh"
|
|
||||||
"io"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestRunCommandSuccess(t *testing.T) {
|
|
||||||
server := newServer(t)
|
|
||||||
defer server.Shutdown()
|
|
||||||
conn := server.Dial(clientConfig())
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
session, err := conn.NewSession()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("session failed: %v", err)
|
|
||||||
}
|
|
||||||
defer session.Close()
|
|
||||||
err = session.Run("true")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("session failed: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHostKeyCheck(t *testing.T) {
|
|
||||||
server := newServer(t)
|
|
||||||
defer server.Shutdown()
|
|
||||||
|
|
||||||
conf := clientConfig()
|
|
||||||
hostDB := hostKeyDB()
|
|
||||||
conf.HostKeyCallback = hostDB.Check
|
|
||||||
|
|
||||||
// change the keys.
|
|
||||||
hostDB.keys[ssh.KeyAlgoRSA][25]++
|
|
||||||
hostDB.keys[ssh.KeyAlgoDSA][25]++
|
|
||||||
hostDB.keys[ssh.KeyAlgoECDSA256][25]++
|
|
||||||
|
|
||||||
conn, err := server.TryDial(conf)
|
|
||||||
if err == nil {
|
|
||||||
conn.Close()
|
|
||||||
t.Fatalf("dial should have failed.")
|
|
||||||
} else if !strings.Contains(err.Error(), "host key mismatch") {
|
|
||||||
t.Fatalf("'host key mismatch' not found in %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRunCommandStdin(t *testing.T) {
|
|
||||||
server := newServer(t)
|
|
||||||
defer server.Shutdown()
|
|
||||||
conn := server.Dial(clientConfig())
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
session, err := conn.NewSession()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("session failed: %v", err)
|
|
||||||
}
|
|
||||||
defer session.Close()
|
|
||||||
|
|
||||||
r, w := io.Pipe()
|
|
||||||
defer r.Close()
|
|
||||||
defer w.Close()
|
|
||||||
session.Stdin = r
|
|
||||||
|
|
||||||
err = session.Run("true")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("session failed: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRunCommandStdinError(t *testing.T) {
|
|
||||||
server := newServer(t)
|
|
||||||
defer server.Shutdown()
|
|
||||||
conn := server.Dial(clientConfig())
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
session, err := conn.NewSession()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("session failed: %v", err)
|
|
||||||
}
|
|
||||||
defer session.Close()
|
|
||||||
|
|
||||||
r, w := io.Pipe()
|
|
||||||
defer r.Close()
|
|
||||||
session.Stdin = r
|
|
||||||
pipeErr := errors.New("closing write end of pipe")
|
|
||||||
w.CloseWithError(pipeErr)
|
|
||||||
|
|
||||||
err = session.Run("true")
|
|
||||||
if err != pipeErr {
|
|
||||||
t.Fatalf("expected %v, found %v", pipeErr, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRunCommandFailed(t *testing.T) {
|
|
||||||
server := newServer(t)
|
|
||||||
defer server.Shutdown()
|
|
||||||
conn := server.Dial(clientConfig())
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
session, err := conn.NewSession()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("session failed: %v", err)
|
|
||||||
}
|
|
||||||
defer session.Close()
|
|
||||||
err = session.Run(`bash -c "kill -9 $$"`)
|
|
||||||
if err == nil {
|
|
||||||
t.Fatalf("session succeeded: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRunCommandWeClosed(t *testing.T) {
|
|
||||||
server := newServer(t)
|
|
||||||
defer server.Shutdown()
|
|
||||||
conn := server.Dial(clientConfig())
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
session, err := conn.NewSession()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("session failed: %v", err)
|
|
||||||
}
|
|
||||||
err = session.Shell()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("shell failed: %v", err)
|
|
||||||
}
|
|
||||||
err = session.Close()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("shell failed: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFuncLargeRead(t *testing.T) {
|
|
||||||
server := newServer(t)
|
|
||||||
defer server.Shutdown()
|
|
||||||
conn := server.Dial(clientConfig())
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
session, err := conn.NewSession()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create new session: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
stdout, err := session.StdoutPipe()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to acquire stdout pipe: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = session.Start("dd if=/dev/urandom bs=2048 count=1024")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to execute remote command: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
buf := new(bytes.Buffer)
|
|
||||||
n, err := io.Copy(buf, stdout)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("error reading from remote stdout: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if n != 2048*1024 {
|
|
||||||
t.Fatalf("Expected %d bytes but read only %d from remote command", 2048, n)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestKeyChange(t *testing.T) {
|
|
||||||
server := newServer(t)
|
|
||||||
defer server.Shutdown()
|
|
||||||
conf := clientConfig()
|
|
||||||
hostDB := hostKeyDB()
|
|
||||||
conf.HostKeyCallback = hostDB.Check
|
|
||||||
conf.RekeyThreshold = 1024
|
|
||||||
conn := server.Dial(conf)
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
for i := 0; i < 4; i++ {
|
|
||||||
session, err := conn.NewSession()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create new session: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
stdout, err := session.StdoutPipe()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to acquire stdout pipe: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = session.Start("dd if=/dev/urandom bs=1024 count=1")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to execute remote command: %s", err)
|
|
||||||
}
|
|
||||||
buf := new(bytes.Buffer)
|
|
||||||
n, err := io.Copy(buf, stdout)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("error reading from remote stdout: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
want := int64(1024)
|
|
||||||
if n != want {
|
|
||||||
t.Fatalf("Expected %d bytes but read only %d from remote command", want, n)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if changes := hostDB.checkCount; changes < 4 {
|
|
||||||
t.Errorf("got %d key changes, want 4", changes)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestInvalidTerminalMode(t *testing.T) {
|
|
||||||
server := newServer(t)
|
|
||||||
defer server.Shutdown()
|
|
||||||
conn := server.Dial(clientConfig())
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
session, err := conn.NewSession()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("session failed: %v", err)
|
|
||||||
}
|
|
||||||
defer session.Close()
|
|
||||||
|
|
||||||
if err = session.RequestPty("vt100", 80, 40, ssh.TerminalModes{255: 1984}); err == nil {
|
|
||||||
t.Fatalf("req-pty failed: successful request with invalid mode")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidTerminalMode(t *testing.T) {
|
|
||||||
server := newServer(t)
|
|
||||||
defer server.Shutdown()
|
|
||||||
conn := server.Dial(clientConfig())
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
session, err := conn.NewSession()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("session failed: %v", err)
|
|
||||||
}
|
|
||||||
defer session.Close()
|
|
||||||
|
|
||||||
stdout, err := session.StdoutPipe()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to acquire stdout pipe: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
stdin, err := session.StdinPipe()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to acquire stdin pipe: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
tm := ssh.TerminalModes{ssh.ECHO: 0}
|
|
||||||
if err = session.RequestPty("xterm", 80, 40, tm); err != nil {
|
|
||||||
t.Fatalf("req-pty failed: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = session.Shell()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("session failed: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
stdin.Write([]byte("stty -a && exit\n"))
|
|
||||||
|
|
||||||
var buf bytes.Buffer
|
|
||||||
if _, err := io.Copy(&buf, stdout); err != nil {
|
|
||||||
t.Fatalf("reading failed: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if sttyOutput := buf.String(); !strings.Contains(sttyOutput, "-echo ") {
|
|
||||||
t.Fatalf("terminal mode failure: expected -echo in stty output, got %s", sttyOutput)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCiphers(t *testing.T) {
|
|
||||||
var config ssh.Config
|
|
||||||
config.SetDefaults()
|
|
||||||
cipherOrder := config.Ciphers
|
|
||||||
|
|
||||||
for _, ciph := range cipherOrder {
|
|
||||||
server := newServer(t)
|
|
||||||
defer server.Shutdown()
|
|
||||||
conf := clientConfig()
|
|
||||||
conf.Ciphers = []string{ciph}
|
|
||||||
// Don't fail if sshd doesnt have the cipher.
|
|
||||||
conf.Ciphers = append(conf.Ciphers, cipherOrder...)
|
|
||||||
conn, err := server.TryDial(conf)
|
|
||||||
if err == nil {
|
|
||||||
conn.Close()
|
|
||||||
} else {
|
|
||||||
t.Fatalf("failed for cipher %q", ciph)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMACs(t *testing.T) {
|
|
||||||
var config ssh.Config
|
|
||||||
config.SetDefaults()
|
|
||||||
macOrder := config.MACs
|
|
||||||
|
|
||||||
for _, mac := range macOrder {
|
|
||||||
server := newServer(t)
|
|
||||||
defer server.Shutdown()
|
|
||||||
conf := clientConfig()
|
|
||||||
conf.MACs = []string{mac}
|
|
||||||
// Don't fail if sshd doesnt have the MAC.
|
|
||||||
conf.MACs = append(conf.MACs, macOrder...)
|
|
||||||
if conn, err := server.TryDial(conf); err == nil {
|
|
||||||
conn.Close()
|
|
||||||
} else {
|
|
||||||
t.Fatalf("failed for MAC %q", mac)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,46 +0,0 @@
|
||||||
// Copyright 2012 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
// +build !windows
|
|
||||||
|
|
||||||
package test
|
|
||||||
|
|
||||||
// direct-tcpip functional tests
|
|
||||||
|
|
||||||
import (
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestDial(t *testing.T) {
|
|
||||||
server := newServer(t)
|
|
||||||
defer server.Shutdown()
|
|
||||||
sshConn := server.Dial(clientConfig())
|
|
||||||
defer sshConn.Close()
|
|
||||||
|
|
||||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Listen: %v", err)
|
|
||||||
}
|
|
||||||
defer l.Close()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
c, err := l.Accept()
|
|
||||||
if err != nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
io.WriteString(c, c.RemoteAddr().String())
|
|
||||||
c.Close()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
conn, err := sshConn.Dial("tcp", l.Addr().String())
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Dial: %v", err)
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
}
|
|
|
@ -1,261 +0,0 @@
|
||||||
// Copyright 2012 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
// +build darwin dragonfly freebsd linux netbsd openbsd plan9
|
|
||||||
|
|
||||||
package test
|
|
||||||
|
|
||||||
// functional test harness for unix.
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"fmt"
|
|
||||||
"io/ioutil"
|
|
||||||
"log"
|
|
||||||
"net"
|
|
||||||
"os"
|
|
||||||
"os/exec"
|
|
||||||
"os/user"
|
|
||||||
"path/filepath"
|
|
||||||
"testing"
|
|
||||||
"text/template"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
|
||||||
"golang.org/x/crypto/ssh/testdata"
|
|
||||||
)
|
|
||||||
|
|
||||||
const sshd_config = `
|
|
||||||
Protocol 2
|
|
||||||
HostKey {{.Dir}}/id_rsa
|
|
||||||
HostKey {{.Dir}}/id_dsa
|
|
||||||
HostKey {{.Dir}}/id_ecdsa
|
|
||||||
Pidfile {{.Dir}}/sshd.pid
|
|
||||||
#UsePrivilegeSeparation no
|
|
||||||
KeyRegenerationInterval 3600
|
|
||||||
ServerKeyBits 768
|
|
||||||
SyslogFacility AUTH
|
|
||||||
LogLevel DEBUG2
|
|
||||||
LoginGraceTime 120
|
|
||||||
PermitRootLogin no
|
|
||||||
StrictModes no
|
|
||||||
RSAAuthentication yes
|
|
||||||
PubkeyAuthentication yes
|
|
||||||
AuthorizedKeysFile {{.Dir}}/id_user.pub
|
|
||||||
TrustedUserCAKeys {{.Dir}}/id_ecdsa.pub
|
|
||||||
IgnoreRhosts yes
|
|
||||||
RhostsRSAAuthentication no
|
|
||||||
HostbasedAuthentication no
|
|
||||||
`
|
|
||||||
|
|
||||||
var configTmpl = template.Must(template.New("").Parse(sshd_config))
|
|
||||||
|
|
||||||
type server struct {
|
|
||||||
t *testing.T
|
|
||||||
cleanup func() // executed during Shutdown
|
|
||||||
configfile string
|
|
||||||
cmd *exec.Cmd
|
|
||||||
output bytes.Buffer // holds stderr from sshd process
|
|
||||||
|
|
||||||
// Client half of the network connection.
|
|
||||||
clientConn net.Conn
|
|
||||||
}
|
|
||||||
|
|
||||||
func username() string {
|
|
||||||
var username string
|
|
||||||
if user, err := user.Current(); err == nil {
|
|
||||||
username = user.Username
|
|
||||||
} else {
|
|
||||||
// user.Current() currently requires cgo. If an error is
|
|
||||||
// returned attempt to get the username from the environment.
|
|
||||||
log.Printf("user.Current: %v; falling back on $USER", err)
|
|
||||||
username = os.Getenv("USER")
|
|
||||||
}
|
|
||||||
if username == "" {
|
|
||||||
panic("Unable to get username")
|
|
||||||
}
|
|
||||||
return username
|
|
||||||
}
|
|
||||||
|
|
||||||
type storedHostKey struct {
|
|
||||||
// keys map from an algorithm string to binary key data.
|
|
||||||
keys map[string][]byte
|
|
||||||
|
|
||||||
// checkCount counts the Check calls. Used for testing
|
|
||||||
// rekeying.
|
|
||||||
checkCount int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (k *storedHostKey) Add(key ssh.PublicKey) {
|
|
||||||
if k.keys == nil {
|
|
||||||
k.keys = map[string][]byte{}
|
|
||||||
}
|
|
||||||
k.keys[key.Type()] = key.Marshal()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (k *storedHostKey) Check(addr string, remote net.Addr, key ssh.PublicKey) error {
|
|
||||||
k.checkCount++
|
|
||||||
algo := key.Type()
|
|
||||||
|
|
||||||
if k.keys == nil || bytes.Compare(key.Marshal(), k.keys[algo]) != 0 {
|
|
||||||
return fmt.Errorf("host key mismatch. Got %q, want %q", key, k.keys[algo])
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func hostKeyDB() *storedHostKey {
|
|
||||||
keyChecker := &storedHostKey{}
|
|
||||||
keyChecker.Add(testPublicKeys["ecdsa"])
|
|
||||||
keyChecker.Add(testPublicKeys["rsa"])
|
|
||||||
keyChecker.Add(testPublicKeys["dsa"])
|
|
||||||
return keyChecker
|
|
||||||
}
|
|
||||||
|
|
||||||
func clientConfig() *ssh.ClientConfig {
|
|
||||||
config := &ssh.ClientConfig{
|
|
||||||
User: username(),
|
|
||||||
Auth: []ssh.AuthMethod{
|
|
||||||
ssh.PublicKeys(testSigners["user"]),
|
|
||||||
},
|
|
||||||
HostKeyCallback: hostKeyDB().Check,
|
|
||||||
}
|
|
||||||
return config
|
|
||||||
}
|
|
||||||
|
|
||||||
// unixConnection creates two halves of a connected net.UnixConn. It
|
|
||||||
// is used for connecting the Go SSH client with sshd without opening
|
|
||||||
// ports.
|
|
||||||
func unixConnection() (*net.UnixConn, *net.UnixConn, error) {
|
|
||||||
dir, err := ioutil.TempDir("", "unixConnection")
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
defer os.Remove(dir)
|
|
||||||
|
|
||||||
addr := filepath.Join(dir, "ssh")
|
|
||||||
listener, err := net.Listen("unix", addr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
defer listener.Close()
|
|
||||||
c1, err := net.Dial("unix", addr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
c2, err := listener.Accept()
|
|
||||||
if err != nil {
|
|
||||||
c1.Close()
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return c1.(*net.UnixConn), c2.(*net.UnixConn), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *server) TryDial(config *ssh.ClientConfig) (*ssh.Client, error) {
|
|
||||||
sshd, err := exec.LookPath("sshd")
|
|
||||||
if err != nil {
|
|
||||||
s.t.Skipf("skipping test: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
c1, c2, err := unixConnection()
|
|
||||||
if err != nil {
|
|
||||||
s.t.Fatalf("unixConnection: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
s.cmd = exec.Command(sshd, "-f", s.configfile, "-i", "-e")
|
|
||||||
f, err := c2.File()
|
|
||||||
if err != nil {
|
|
||||||
s.t.Fatalf("UnixConn.File: %v", err)
|
|
||||||
}
|
|
||||||
defer f.Close()
|
|
||||||
s.cmd.Stdin = f
|
|
||||||
s.cmd.Stdout = f
|
|
||||||
s.cmd.Stderr = &s.output
|
|
||||||
if err := s.cmd.Start(); err != nil {
|
|
||||||
s.t.Fail()
|
|
||||||
s.Shutdown()
|
|
||||||
s.t.Fatalf("s.cmd.Start: %v", err)
|
|
||||||
}
|
|
||||||
s.clientConn = c1
|
|
||||||
conn, chans, reqs, err := ssh.NewClientConn(c1, "", config)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return ssh.NewClient(conn, chans, reqs), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *server) Dial(config *ssh.ClientConfig) *ssh.Client {
|
|
||||||
conn, err := s.TryDial(config)
|
|
||||||
if err != nil {
|
|
||||||
s.t.Fail()
|
|
||||||
s.Shutdown()
|
|
||||||
s.t.Fatalf("ssh.Client: %v", err)
|
|
||||||
}
|
|
||||||
return conn
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *server) Shutdown() {
|
|
||||||
if s.cmd != nil && s.cmd.Process != nil {
|
|
||||||
// Don't check for errors; if it fails it's most
|
|
||||||
// likely "os: process already finished", and we don't
|
|
||||||
// care about that. Use os.Interrupt, so child
|
|
||||||
// processes are killed too.
|
|
||||||
s.cmd.Process.Signal(os.Interrupt)
|
|
||||||
s.cmd.Wait()
|
|
||||||
}
|
|
||||||
if s.t.Failed() {
|
|
||||||
// log any output from sshd process
|
|
||||||
s.t.Logf("sshd: %s", s.output.String())
|
|
||||||
}
|
|
||||||
s.cleanup()
|
|
||||||
}
|
|
||||||
|
|
||||||
func writeFile(path string, contents []byte) {
|
|
||||||
f, err := os.OpenFile(path, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0600)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
defer f.Close()
|
|
||||||
if _, err := f.Write(contents); err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// newServer returns a new mock ssh server.
|
|
||||||
func newServer(t *testing.T) *server {
|
|
||||||
if testing.Short() {
|
|
||||||
t.Skip("skipping test due to -short")
|
|
||||||
}
|
|
||||||
dir, err := ioutil.TempDir("", "sshtest")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
f, err := os.Create(filepath.Join(dir, "sshd_config"))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
err = configTmpl.Execute(f, map[string]string{
|
|
||||||
"Dir": dir,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
f.Close()
|
|
||||||
|
|
||||||
for k, v := range testdata.PEMBytes {
|
|
||||||
filename := "id_" + k
|
|
||||||
writeFile(filepath.Join(dir, filename), v)
|
|
||||||
writeFile(filepath.Join(dir, filename+".pub"), ssh.MarshalAuthorizedKey(testPublicKeys[k]))
|
|
||||||
}
|
|
||||||
|
|
||||||
return &server{
|
|
||||||
t: t,
|
|
||||||
configfile: f.Name(),
|
|
||||||
cleanup: func() {
|
|
||||||
if err := os.RemoveAll(dir); err != nil {
|
|
||||||
t.Error(err)
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,64 +0,0 @@
|
||||||
// Copyright 2014 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
// IMPLEMENTOR NOTE: To avoid a package loop, this file is in three places:
|
|
||||||
// ssh/, ssh/agent, and ssh/test/. It should be kept in sync across all three
|
|
||||||
// instances.
|
|
||||||
|
|
||||||
package test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/rand"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
|
||||||
"golang.org/x/crypto/ssh/testdata"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
testPrivateKeys map[string]interface{}
|
|
||||||
testSigners map[string]ssh.Signer
|
|
||||||
testPublicKeys map[string]ssh.PublicKey
|
|
||||||
)
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
var err error
|
|
||||||
|
|
||||||
n := len(testdata.PEMBytes)
|
|
||||||
testPrivateKeys = make(map[string]interface{}, n)
|
|
||||||
testSigners = make(map[string]ssh.Signer, n)
|
|
||||||
testPublicKeys = make(map[string]ssh.PublicKey, n)
|
|
||||||
for t, k := range testdata.PEMBytes {
|
|
||||||
testPrivateKeys[t], err = ssh.ParseRawPrivateKey(k)
|
|
||||||
if err != nil {
|
|
||||||
panic(fmt.Sprintf("Unable to parse test key %s: %v", t, err))
|
|
||||||
}
|
|
||||||
testSigners[t], err = ssh.NewSignerFromKey(testPrivateKeys[t])
|
|
||||||
if err != nil {
|
|
||||||
panic(fmt.Sprintf("Unable to create signer for test key %s: %v", t, err))
|
|
||||||
}
|
|
||||||
testPublicKeys[t] = testSigners[t].PublicKey()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a cert and sign it for use in tests.
|
|
||||||
testCert := &ssh.Certificate{
|
|
||||||
Nonce: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil
|
|
||||||
ValidPrincipals: []string{"gopher1", "gopher2"}, // increases test coverage
|
|
||||||
ValidAfter: 0, // unix epoch
|
|
||||||
ValidBefore: ssh.CertTimeInfinity, // The end of currently representable time.
|
|
||||||
Reserved: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil
|
|
||||||
Key: testPublicKeys["ecdsa"],
|
|
||||||
SignatureKey: testPublicKeys["rsa"],
|
|
||||||
Permissions: ssh.Permissions{
|
|
||||||
CriticalOptions: map[string]string{},
|
|
||||||
Extensions: map[string]string{},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
testCert.SignCert(rand.Reader, testSigners["rsa"])
|
|
||||||
testPrivateKeys["cert"] = testPrivateKeys["ecdsa"]
|
|
||||||
testSigners["cert"], err = ssh.NewCertSigner(testCert, testSigners["ecdsa"])
|
|
||||||
if err != nil {
|
|
||||||
panic(fmt.Sprintf("Unable to create certificate signer: %v", err))
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,8 +0,0 @@
|
||||||
// Copyright 2014 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
// This package contains test data shared between the various subpackages of
|
|
||||||
// the code.google.com/p/go.crypto/ssh package. Under no circumstance should
|
|
||||||
// this data be used for production code.
|
|
||||||
package testdata
|
|
|
@ -1,43 +0,0 @@
|
||||||
// Copyright 2014 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package testdata
|
|
||||||
|
|
||||||
var PEMBytes = map[string][]byte{
|
|
||||||
"dsa": []byte(`-----BEGIN DSA PRIVATE KEY-----
|
|
||||||
MIIBuwIBAAKBgQD6PDSEyXiI9jfNs97WuM46MSDCYlOqWw80ajN16AohtBncs1YB
|
|
||||||
lHk//dQOvCYOsYaE+gNix2jtoRjwXhDsc25/IqQbU1ahb7mB8/rsaILRGIbA5WH3
|
|
||||||
EgFtJmXFovDz3if6F6TzvhFpHgJRmLYVR8cqsezL3hEZOvvs2iH7MorkxwIVAJHD
|
|
||||||
nD82+lxh2fb4PMsIiaXudAsBAoGAQRf7Q/iaPRn43ZquUhd6WwvirqUj+tkIu6eV
|
|
||||||
2nZWYmXLlqFQKEy4Tejl7Wkyzr2OSYvbXLzo7TNxLKoWor6ips0phYPPMyXld14r
|
|
||||||
juhT24CrhOzuLMhDduMDi032wDIZG4Y+K7ElU8Oufn8Sj5Wge8r6ANmmVgmFfynr
|
|
||||||
FhdYCngCgYEA3ucGJ93/Mx4q4eKRDxcWD3QzWyqpbRVRRV1Vmih9Ha/qC994nJFz
|
|
||||||
DQIdjxDIT2Rk2AGzMqFEB68Zc3O+Wcsmz5eWWzEwFxaTwOGWTyDqsDRLm3fD+QYj
|
|
||||||
nOwuxb0Kce+gWI8voWcqC9cyRm09jGzu2Ab3Bhtpg8JJ8L7gS3MRZK4CFEx4UAfY
|
|
||||||
Fmsr0W6fHB9nhS4/UXM8
|
|
||||||
-----END DSA PRIVATE KEY-----
|
|
||||||
`),
|
|
||||||
"ecdsa": []byte(`-----BEGIN EC PRIVATE KEY-----
|
|
||||||
MHcCAQEEINGWx0zo6fhJ/0EAfrPzVFyFC9s18lBt3cRoEDhS3ARooAoGCCqGSM49
|
|
||||||
AwEHoUQDQgAEi9Hdw6KvZcWxfg2IDhA7UkpDtzzt6ZqJXSsFdLd+Kx4S3Sx4cVO+
|
|
||||||
6/ZOXRnPmNAlLUqjShUsUBBngG0u2fqEqA==
|
|
||||||
-----END EC PRIVATE KEY-----
|
|
||||||
`),
|
|
||||||
"rsa": []byte(`-----BEGIN RSA PRIVATE KEY-----
|
|
||||||
MIIBOwIBAAJBALdGZxkXDAjsYk10ihwU6Id2KeILz1TAJuoq4tOgDWxEEGeTrcld
|
|
||||||
r/ZwVaFzjWzxaf6zQIJbfaSEAhqD5yo72+sCAwEAAQJBAK8PEVU23Wj8mV0QjwcJ
|
|
||||||
tZ4GcTUYQL7cF4+ezTCE9a1NrGnCP2RuQkHEKxuTVrxXt+6OF15/1/fuXnxKjmJC
|
|
||||||
nxkCIQDaXvPPBi0c7vAxGwNY9726x01/dNbHCE0CBtcotobxpwIhANbbQbh3JHVW
|
|
||||||
2haQh4fAG5mhesZKAGcxTyv4mQ7uMSQdAiAj+4dzMpJWdSzQ+qGHlHMIBvVHLkqB
|
|
||||||
y2VdEyF7DPCZewIhAI7GOI/6LDIFOvtPo6Bj2nNmyQ1HU6k/LRtNIXi4c9NJAiAr
|
|
||||||
rrxx26itVhJmcvoUhOjwuzSlP2bE5VHAvkGB352YBg==
|
|
||||||
-----END RSA PRIVATE KEY-----
|
|
||||||
`),
|
|
||||||
"user": []byte(`-----BEGIN EC PRIVATE KEY-----
|
|
||||||
MHcCAQEEILYCAeq8f7V4vSSypRw7pxy8yz3V5W4qg8kSC3zJhqpQoAoGCCqGSM49
|
|
||||||
AwEHoUQDQgAEYcO2xNKiRUYOLEHM7VYAp57HNyKbOdYtHD83Z4hzNPVC4tM5mdGD
|
|
||||||
PLL8IEwvYu2wq+lpXfGQnNMbzYf9gspG0w==
|
|
||||||
-----END EC PRIVATE KEY-----
|
|
||||||
`),
|
|
||||||
}
|
|
|
@ -1,63 +0,0 @@
|
||||||
// Copyright 2014 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
// IMPLEMENTOR NOTE: To avoid a package loop, this file is in three places:
|
|
||||||
// ssh/, ssh/agent, and ssh/test/. It should be kept in sync across all three
|
|
||||||
// instances.
|
|
||||||
|
|
||||||
package ssh
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/rand"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh/testdata"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
testPrivateKeys map[string]interface{}
|
|
||||||
testSigners map[string]Signer
|
|
||||||
testPublicKeys map[string]PublicKey
|
|
||||||
)
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
var err error
|
|
||||||
|
|
||||||
n := len(testdata.PEMBytes)
|
|
||||||
testPrivateKeys = make(map[string]interface{}, n)
|
|
||||||
testSigners = make(map[string]Signer, n)
|
|
||||||
testPublicKeys = make(map[string]PublicKey, n)
|
|
||||||
for t, k := range testdata.PEMBytes {
|
|
||||||
testPrivateKeys[t], err = ParseRawPrivateKey(k)
|
|
||||||
if err != nil {
|
|
||||||
panic(fmt.Sprintf("Unable to parse test key %s: %v", t, err))
|
|
||||||
}
|
|
||||||
testSigners[t], err = NewSignerFromKey(testPrivateKeys[t])
|
|
||||||
if err != nil {
|
|
||||||
panic(fmt.Sprintf("Unable to create signer for test key %s: %v", t, err))
|
|
||||||
}
|
|
||||||
testPublicKeys[t] = testSigners[t].PublicKey()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a cert and sign it for use in tests.
|
|
||||||
testCert := &Certificate{
|
|
||||||
Nonce: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil
|
|
||||||
ValidPrincipals: []string{"gopher1", "gopher2"}, // increases test coverage
|
|
||||||
ValidAfter: 0, // unix epoch
|
|
||||||
ValidBefore: CertTimeInfinity, // The end of currently representable time.
|
|
||||||
Reserved: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil
|
|
||||||
Key: testPublicKeys["ecdsa"],
|
|
||||||
SignatureKey: testPublicKeys["rsa"],
|
|
||||||
Permissions: Permissions{
|
|
||||||
CriticalOptions: map[string]string{},
|
|
||||||
Extensions: map[string]string{},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
testCert.SignCert(rand.Reader, testSigners["rsa"])
|
|
||||||
testPrivateKeys["cert"] = testPrivateKeys["ecdsa"]
|
|
||||||
testSigners["cert"], err = NewCertSigner(testCert, testSigners["ecdsa"])
|
|
||||||
if err != nil {
|
|
||||||
panic(fmt.Sprintf("Unable to create certificate signer: %v", err))
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,327 +0,0 @@
|
||||||
// Copyright 2011 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package ssh
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"errors"
|
|
||||||
"io"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
gcmCipherID = "aes128-gcm@openssh.com"
|
|
||||||
)
|
|
||||||
|
|
||||||
// packetConn represents a transport that implements packet based
|
|
||||||
// operations.
|
|
||||||
type packetConn interface {
|
|
||||||
// Encrypt and send a packet of data to the remote peer.
|
|
||||||
writePacket(packet []byte) error
|
|
||||||
|
|
||||||
// Read a packet from the connection
|
|
||||||
readPacket() ([]byte, error)
|
|
||||||
|
|
||||||
// Close closes the write-side of the connection.
|
|
||||||
Close() error
|
|
||||||
}
|
|
||||||
|
|
||||||
// transport is the keyingTransport that implements the SSH packet
|
|
||||||
// protocol.
|
|
||||||
type transport struct {
|
|
||||||
reader connectionState
|
|
||||||
writer connectionState
|
|
||||||
|
|
||||||
bufReader *bufio.Reader
|
|
||||||
bufWriter *bufio.Writer
|
|
||||||
rand io.Reader
|
|
||||||
|
|
||||||
io.Closer
|
|
||||||
|
|
||||||
// Initial H used for the session ID. Once assigned this does
|
|
||||||
// not change, even during subsequent key exchanges.
|
|
||||||
sessionID []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *transport) getSessionID() []byte {
|
|
||||||
if t.sessionID == nil {
|
|
||||||
panic("session ID not set yet")
|
|
||||||
}
|
|
||||||
s := make([]byte, len(t.sessionID))
|
|
||||||
copy(s, t.sessionID)
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// packetCipher represents a combination of SSH encryption/MAC
|
|
||||||
// protocol. A single instance should be used for one direction only.
|
|
||||||
type packetCipher interface {
|
|
||||||
// writePacket encrypts the packet and writes it to w. The
|
|
||||||
// contents of the packet are generally scrambled.
|
|
||||||
writePacket(seqnum uint32, w io.Writer, rand io.Reader, packet []byte) error
|
|
||||||
|
|
||||||
// readPacket reads and decrypts a packet of data. The
|
|
||||||
// returned packet may be overwritten by future calls of
|
|
||||||
// readPacket.
|
|
||||||
readPacket(seqnum uint32, r io.Reader) ([]byte, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// connectionState represents one side (read or write) of the
|
|
||||||
// connection. This is necessary because each direction has its own
|
|
||||||
// keys, and can even have its own algorithms
|
|
||||||
type connectionState struct {
|
|
||||||
packetCipher
|
|
||||||
seqNum uint32
|
|
||||||
dir direction
|
|
||||||
pendingKeyChange chan packetCipher
|
|
||||||
}
|
|
||||||
|
|
||||||
// prepareKeyChange sets up key material for a keychange. The key changes in
|
|
||||||
// both directions are triggered by reading and writing a msgNewKey packet
|
|
||||||
// respectively.
|
|
||||||
func (t *transport) prepareKeyChange(algs *algorithms, kexResult *kexResult) error {
|
|
||||||
if t.sessionID == nil {
|
|
||||||
t.sessionID = kexResult.H
|
|
||||||
}
|
|
||||||
|
|
||||||
kexResult.SessionID = t.sessionID
|
|
||||||
|
|
||||||
if ciph, err := newPacketCipher(t.reader.dir, algs.r, kexResult); err != nil {
|
|
||||||
return err
|
|
||||||
} else {
|
|
||||||
t.reader.pendingKeyChange <- ciph
|
|
||||||
}
|
|
||||||
|
|
||||||
if ciph, err := newPacketCipher(t.writer.dir, algs.w, kexResult); err != nil {
|
|
||||||
return err
|
|
||||||
} else {
|
|
||||||
t.writer.pendingKeyChange <- ciph
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read and decrypt next packet.
|
|
||||||
func (t *transport) readPacket() ([]byte, error) {
|
|
||||||
return t.reader.readPacket(t.bufReader)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *connectionState) readPacket(r *bufio.Reader) ([]byte, error) {
|
|
||||||
packet, err := s.packetCipher.readPacket(s.seqNum, r)
|
|
||||||
s.seqNum++
|
|
||||||
if err == nil && len(packet) == 0 {
|
|
||||||
err = errors.New("ssh: zero length packet")
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(packet) > 0 && packet[0] == msgNewKeys {
|
|
||||||
select {
|
|
||||||
case cipher := <-s.pendingKeyChange:
|
|
||||||
s.packetCipher = cipher
|
|
||||||
default:
|
|
||||||
return nil, errors.New("ssh: got bogus newkeys message.")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// The packet may point to an internal buffer, so copy the
|
|
||||||
// packet out here.
|
|
||||||
fresh := make([]byte, len(packet))
|
|
||||||
copy(fresh, packet)
|
|
||||||
|
|
||||||
return fresh, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *transport) writePacket(packet []byte) error {
|
|
||||||
return t.writer.writePacket(t.bufWriter, t.rand, packet)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *connectionState) writePacket(w *bufio.Writer, rand io.Reader, packet []byte) error {
|
|
||||||
changeKeys := len(packet) > 0 && packet[0] == msgNewKeys
|
|
||||||
|
|
||||||
err := s.packetCipher.writePacket(s.seqNum, w, rand, packet)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err = w.Flush(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
s.seqNum++
|
|
||||||
if changeKeys {
|
|
||||||
select {
|
|
||||||
case cipher := <-s.pendingKeyChange:
|
|
||||||
s.packetCipher = cipher
|
|
||||||
default:
|
|
||||||
panic("ssh: no key material for msgNewKeys")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func newTransport(rwc io.ReadWriteCloser, rand io.Reader, isClient bool) *transport {
|
|
||||||
t := &transport{
|
|
||||||
bufReader: bufio.NewReader(rwc),
|
|
||||||
bufWriter: bufio.NewWriter(rwc),
|
|
||||||
rand: rand,
|
|
||||||
reader: connectionState{
|
|
||||||
packetCipher: &streamPacketCipher{cipher: noneCipher{}},
|
|
||||||
pendingKeyChange: make(chan packetCipher, 1),
|
|
||||||
},
|
|
||||||
writer: connectionState{
|
|
||||||
packetCipher: &streamPacketCipher{cipher: noneCipher{}},
|
|
||||||
pendingKeyChange: make(chan packetCipher, 1),
|
|
||||||
},
|
|
||||||
Closer: rwc,
|
|
||||||
}
|
|
||||||
if isClient {
|
|
||||||
t.reader.dir = serverKeys
|
|
||||||
t.writer.dir = clientKeys
|
|
||||||
} else {
|
|
||||||
t.reader.dir = clientKeys
|
|
||||||
t.writer.dir = serverKeys
|
|
||||||
}
|
|
||||||
|
|
||||||
return t
|
|
||||||
}
|
|
||||||
|
|
||||||
type direction struct {
|
|
||||||
ivTag []byte
|
|
||||||
keyTag []byte
|
|
||||||
macKeyTag []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
serverKeys = direction{[]byte{'B'}, []byte{'D'}, []byte{'F'}}
|
|
||||||
clientKeys = direction{[]byte{'A'}, []byte{'C'}, []byte{'E'}}
|
|
||||||
)
|
|
||||||
|
|
||||||
// generateKeys generates key material for IV, MAC and encryption.
|
|
||||||
func generateKeys(d direction, algs directionAlgorithms, kex *kexResult) (iv, key, macKey []byte) {
|
|
||||||
cipherMode := cipherModes[algs.Cipher]
|
|
||||||
macMode := macModes[algs.MAC]
|
|
||||||
|
|
||||||
iv = make([]byte, cipherMode.ivSize)
|
|
||||||
key = make([]byte, cipherMode.keySize)
|
|
||||||
macKey = make([]byte, macMode.keySize)
|
|
||||||
|
|
||||||
generateKeyMaterial(iv, d.ivTag, kex)
|
|
||||||
generateKeyMaterial(key, d.keyTag, kex)
|
|
||||||
generateKeyMaterial(macKey, d.macKeyTag, kex)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// setupKeys sets the cipher and MAC keys from kex.K, kex.H and sessionId, as
|
|
||||||
// described in RFC 4253, section 6.4. direction should either be serverKeys
|
|
||||||
// (to setup server->client keys) or clientKeys (for client->server keys).
|
|
||||||
func newPacketCipher(d direction, algs directionAlgorithms, kex *kexResult) (packetCipher, error) {
|
|
||||||
iv, key, macKey := generateKeys(d, algs, kex)
|
|
||||||
|
|
||||||
if algs.Cipher == gcmCipherID {
|
|
||||||
return newGCMCipher(iv, key, macKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
c := &streamPacketCipher{
|
|
||||||
mac: macModes[algs.MAC].new(macKey),
|
|
||||||
}
|
|
||||||
c.macResult = make([]byte, c.mac.Size())
|
|
||||||
|
|
||||||
var err error
|
|
||||||
c.cipher, err = cipherModes[algs.Cipher].createStream(key, iv)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return c, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// generateKeyMaterial fills out with key material generated from tag, K, H
|
|
||||||
// and sessionId, as specified in RFC 4253, section 7.2.
|
|
||||||
func generateKeyMaterial(out, tag []byte, r *kexResult) {
|
|
||||||
var digestsSoFar []byte
|
|
||||||
|
|
||||||
h := r.Hash.New()
|
|
||||||
for len(out) > 0 {
|
|
||||||
h.Reset()
|
|
||||||
h.Write(r.K)
|
|
||||||
h.Write(r.H)
|
|
||||||
|
|
||||||
if len(digestsSoFar) == 0 {
|
|
||||||
h.Write(tag)
|
|
||||||
h.Write(r.SessionID)
|
|
||||||
} else {
|
|
||||||
h.Write(digestsSoFar)
|
|
||||||
}
|
|
||||||
|
|
||||||
digest := h.Sum(nil)
|
|
||||||
n := copy(out, digest)
|
|
||||||
out = out[n:]
|
|
||||||
if len(out) > 0 {
|
|
||||||
digestsSoFar = append(digestsSoFar, digest...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const packageVersion = "SSH-2.0-Go"
|
|
||||||
|
|
||||||
// Sends and receives a version line. The versionLine string should
|
|
||||||
// be US ASCII, start with "SSH-2.0-", and should not include a
|
|
||||||
// newline. exchangeVersions returns the other side's version line.
|
|
||||||
func exchangeVersions(rw io.ReadWriter, versionLine []byte) (them []byte, err error) {
|
|
||||||
// Contrary to the RFC, we do not ignore lines that don't
|
|
||||||
// start with "SSH-2.0-" to make the library usable with
|
|
||||||
// nonconforming servers.
|
|
||||||
for _, c := range versionLine {
|
|
||||||
// The spec disallows non US-ASCII chars, and
|
|
||||||
// specifically forbids null chars.
|
|
||||||
if c < 32 {
|
|
||||||
return nil, errors.New("ssh: junk character in version line")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if _, err = rw.Write(append(versionLine, '\r', '\n')); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
them, err = readVersion(rw)
|
|
||||||
return them, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// maxVersionStringBytes is the maximum number of bytes that we'll
|
|
||||||
// accept as a version string. RFC 4253 section 4.2 limits this at 255
|
|
||||||
// chars
|
|
||||||
const maxVersionStringBytes = 255
|
|
||||||
|
|
||||||
// Read version string as specified by RFC 4253, section 4.2.
|
|
||||||
func readVersion(r io.Reader) ([]byte, error) {
|
|
||||||
versionString := make([]byte, 0, 64)
|
|
||||||
var ok bool
|
|
||||||
var buf [1]byte
|
|
||||||
|
|
||||||
for len(versionString) < maxVersionStringBytes {
|
|
||||||
_, err := io.ReadFull(r, buf[:])
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
// The RFC says that the version should be terminated with \r\n
|
|
||||||
// but several SSH servers actually only send a \n.
|
|
||||||
if buf[0] == '\n' {
|
|
||||||
ok = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
// non ASCII chars are disallowed, but we are lenient,
|
|
||||||
// since Go doesn't use null-terminated strings.
|
|
||||||
|
|
||||||
// The RFC allows a comment after a space, however,
|
|
||||||
// all of it (version and comments) goes into the
|
|
||||||
// session hash.
|
|
||||||
versionString = append(versionString, buf[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
if !ok {
|
|
||||||
return nil, errors.New("ssh: overflow reading version string")
|
|
||||||
}
|
|
||||||
|
|
||||||
// There might be a '\r' on the end which we should remove.
|
|
||||||
if len(versionString) > 0 && versionString[len(versionString)-1] == '\r' {
|
|
||||||
versionString = versionString[:len(versionString)-1]
|
|
||||||
}
|
|
||||||
return versionString, nil
|
|
||||||
}
|
|
|
@ -1,109 +0,0 @@
|
||||||
// Copyright 2011 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package ssh
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"crypto/rand"
|
|
||||||
"encoding/binary"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestReadVersion(t *testing.T) {
|
|
||||||
longversion := strings.Repeat("SSH-2.0-bla", 50)[:253]
|
|
||||||
cases := map[string]string{
|
|
||||||
"SSH-2.0-bla\r\n": "SSH-2.0-bla",
|
|
||||||
"SSH-2.0-bla\n": "SSH-2.0-bla",
|
|
||||||
longversion + "\r\n": longversion,
|
|
||||||
}
|
|
||||||
|
|
||||||
for in, want := range cases {
|
|
||||||
result, err := readVersion(bytes.NewBufferString(in))
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("readVersion(%q): %s", in, err)
|
|
||||||
}
|
|
||||||
got := string(result)
|
|
||||||
if got != want {
|
|
||||||
t.Errorf("got %q, want %q", got, want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestReadVersionError(t *testing.T) {
|
|
||||||
longversion := strings.Repeat("SSH-2.0-bla", 50)[:253]
|
|
||||||
cases := []string{
|
|
||||||
longversion + "too-long\r\n",
|
|
||||||
}
|
|
||||||
for _, in := range cases {
|
|
||||||
if _, err := readVersion(bytes.NewBufferString(in)); err == nil {
|
|
||||||
t.Errorf("readVersion(%q) should have failed", in)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestExchangeVersionsBasic(t *testing.T) {
|
|
||||||
v := "SSH-2.0-bla"
|
|
||||||
buf := bytes.NewBufferString(v + "\r\n")
|
|
||||||
them, err := exchangeVersions(buf, []byte("xyz"))
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("exchangeVersions: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if want := "SSH-2.0-bla"; string(them) != want {
|
|
||||||
t.Errorf("got %q want %q for our version", them, want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestExchangeVersions(t *testing.T) {
|
|
||||||
cases := []string{
|
|
||||||
"not\x000allowed",
|
|
||||||
"not allowed\n",
|
|
||||||
}
|
|
||||||
for _, c := range cases {
|
|
||||||
buf := bytes.NewBufferString("SSH-2.0-bla\r\n")
|
|
||||||
if _, err := exchangeVersions(buf, []byte(c)); err == nil {
|
|
||||||
t.Errorf("exchangeVersions(%q): should have failed", c)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type closerBuffer struct {
|
|
||||||
bytes.Buffer
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *closerBuffer) Close() error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTransportMaxPacketWrite(t *testing.T) {
|
|
||||||
buf := &closerBuffer{}
|
|
||||||
tr := newTransport(buf, rand.Reader, true)
|
|
||||||
huge := make([]byte, maxPacket+1)
|
|
||||||
err := tr.writePacket(huge)
|
|
||||||
if err == nil {
|
|
||||||
t.Errorf("transport accepted write for a huge packet.")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTransportMaxPacketReader(t *testing.T) {
|
|
||||||
var header [5]byte
|
|
||||||
huge := make([]byte, maxPacket+128)
|
|
||||||
binary.BigEndian.PutUint32(header[0:], uint32(len(huge)))
|
|
||||||
// padding.
|
|
||||||
header[4] = 0
|
|
||||||
|
|
||||||
buf := &closerBuffer{}
|
|
||||||
buf.Write(header[:])
|
|
||||||
buf.Write(huge)
|
|
||||||
|
|
||||||
tr := newTransport(buf, rand.Reader, true)
|
|
||||||
_, err := tr.readPacket()
|
|
||||||
if err == nil {
|
|
||||||
t.Errorf("transport succeeded reading huge packet.")
|
|
||||||
} else if !strings.Contains(err.Error(), "large") {
|
|
||||||
t.Errorf("got %q, should mention %q", err.Error(), "large")
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -0,0 +1,447 @@
|
||||||
|
// Copyright 2014 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
// Package context defines the Context type, which carries deadlines,
|
||||||
|
// cancelation signals, and other request-scoped values across API boundaries
|
||||||
|
// and between processes.
|
||||||
|
//
|
||||||
|
// Incoming requests to a server should create a Context, and outgoing calls to
|
||||||
|
// servers should accept a Context. The chain of function calls between must
|
||||||
|
// propagate the Context, optionally replacing it with a modified copy created
|
||||||
|
// using WithDeadline, WithTimeout, WithCancel, or WithValue.
|
||||||
|
//
|
||||||
|
// Programs that use Contexts should follow these rules to keep interfaces
|
||||||
|
// consistent across packages and enable static analysis tools to check context
|
||||||
|
// propagation:
|
||||||
|
//
|
||||||
|
// Do not store Contexts inside a struct type; instead, pass a Context
|
||||||
|
// explicitly to each function that needs it. The Context should be the first
|
||||||
|
// parameter, typically named ctx:
|
||||||
|
//
|
||||||
|
// func DoSomething(ctx context.Context, arg Arg) error {
|
||||||
|
// // ... use ctx ...
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// Do not pass a nil Context, even if a function permits it. Pass context.TODO
|
||||||
|
// if you are unsure about which Context to use.
|
||||||
|
//
|
||||||
|
// Use context Values only for request-scoped data that transits processes and
|
||||||
|
// APIs, not for passing optional parameters to functions.
|
||||||
|
//
|
||||||
|
// The same Context may be passed to functions running in different goroutines;
|
||||||
|
// Contexts are safe for simultaneous use by multiple goroutines.
|
||||||
|
//
|
||||||
|
// See http://blog.golang.org/context for example code for a server that uses
|
||||||
|
// Contexts.
|
||||||
|
package context
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// A Context carries a deadline, a cancelation signal, and other values across
|
||||||
|
// API boundaries.
|
||||||
|
//
|
||||||
|
// Context's methods may be called by multiple goroutines simultaneously.
|
||||||
|
type Context interface {
|
||||||
|
// Deadline returns the time when work done on behalf of this context
|
||||||
|
// should be canceled. Deadline returns ok==false when no deadline is
|
||||||
|
// set. Successive calls to Deadline return the same results.
|
||||||
|
Deadline() (deadline time.Time, ok bool)
|
||||||
|
|
||||||
|
// Done returns a channel that's closed when work done on behalf of this
|
||||||
|
// context should be canceled. Done may return nil if this context can
|
||||||
|
// never be canceled. Successive calls to Done return the same value.
|
||||||
|
//
|
||||||
|
// WithCancel arranges for Done to be closed when cancel is called;
|
||||||
|
// WithDeadline arranges for Done to be closed when the deadline
|
||||||
|
// expires; WithTimeout arranges for Done to be closed when the timeout
|
||||||
|
// elapses.
|
||||||
|
//
|
||||||
|
// Done is provided for use in select statements:
|
||||||
|
//
|
||||||
|
// // Stream generates values with DoSomething and sends them to out
|
||||||
|
// // until DoSomething returns an error or ctx.Done is closed.
|
||||||
|
// func Stream(ctx context.Context, out <-chan Value) error {
|
||||||
|
// for {
|
||||||
|
// v, err := DoSomething(ctx)
|
||||||
|
// if err != nil {
|
||||||
|
// return err
|
||||||
|
// }
|
||||||
|
// select {
|
||||||
|
// case <-ctx.Done():
|
||||||
|
// return ctx.Err()
|
||||||
|
// case out <- v:
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// See http://blog.golang.org/pipelines for more examples of how to use
|
||||||
|
// a Done channel for cancelation.
|
||||||
|
Done() <-chan struct{}
|
||||||
|
|
||||||
|
// Err returns a non-nil error value after Done is closed. Err returns
|
||||||
|
// Canceled if the context was canceled or DeadlineExceeded if the
|
||||||
|
// context's deadline passed. No other values for Err are defined.
|
||||||
|
// After Done is closed, successive calls to Err return the same value.
|
||||||
|
Err() error
|
||||||
|
|
||||||
|
// Value returns the value associated with this context for key, or nil
|
||||||
|
// if no value is associated with key. Successive calls to Value with
|
||||||
|
// the same key returns the same result.
|
||||||
|
//
|
||||||
|
// Use context values only for request-scoped data that transits
|
||||||
|
// processes and API boundaries, not for passing optional parameters to
|
||||||
|
// functions.
|
||||||
|
//
|
||||||
|
// A key identifies a specific value in a Context. Functions that wish
|
||||||
|
// to store values in Context typically allocate a key in a global
|
||||||
|
// variable then use that key as the argument to context.WithValue and
|
||||||
|
// Context.Value. A key can be any type that supports equality;
|
||||||
|
// packages should define keys as an unexported type to avoid
|
||||||
|
// collisions.
|
||||||
|
//
|
||||||
|
// Packages that define a Context key should provide type-safe accessors
|
||||||
|
// for the values stores using that key:
|
||||||
|
//
|
||||||
|
// // Package user defines a User type that's stored in Contexts.
|
||||||
|
// package user
|
||||||
|
//
|
||||||
|
// import "golang.org/x/net/context"
|
||||||
|
//
|
||||||
|
// // User is the type of value stored in the Contexts.
|
||||||
|
// type User struct {...}
|
||||||
|
//
|
||||||
|
// // key is an unexported type for keys defined in this package.
|
||||||
|
// // This prevents collisions with keys defined in other packages.
|
||||||
|
// type key int
|
||||||
|
//
|
||||||
|
// // userKey is the key for user.User values in Contexts. It is
|
||||||
|
// // unexported; clients use user.NewContext and user.FromContext
|
||||||
|
// // instead of using this key directly.
|
||||||
|
// var userKey key = 0
|
||||||
|
//
|
||||||
|
// // NewContext returns a new Context that carries value u.
|
||||||
|
// func NewContext(ctx context.Context, u *User) context.Context {
|
||||||
|
// return context.WithValue(ctx, userKey, u)
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// // FromContext returns the User value stored in ctx, if any.
|
||||||
|
// func FromContext(ctx context.Context) (*User, bool) {
|
||||||
|
// u, ok := ctx.Value(userKey).(*User)
|
||||||
|
// return u, ok
|
||||||
|
// }
|
||||||
|
Value(key interface{}) interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Canceled is the error returned by Context.Err when the context is canceled.
|
||||||
|
var Canceled = errors.New("context canceled")
|
||||||
|
|
||||||
|
// DeadlineExceeded is the error returned by Context.Err when the context's
|
||||||
|
// deadline passes.
|
||||||
|
var DeadlineExceeded = errors.New("context deadline exceeded")
|
||||||
|
|
||||||
|
// An emptyCtx is never canceled, has no values, and has no deadline. It is not
|
||||||
|
// struct{}, since vars of this type must have distinct addresses.
|
||||||
|
type emptyCtx int
|
||||||
|
|
||||||
|
func (*emptyCtx) Deadline() (deadline time.Time, ok bool) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*emptyCtx) Done() <-chan struct{} {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*emptyCtx) Err() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*emptyCtx) Value(key interface{}) interface{} {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *emptyCtx) String() string {
|
||||||
|
switch e {
|
||||||
|
case background:
|
||||||
|
return "context.Background"
|
||||||
|
case todo:
|
||||||
|
return "context.TODO"
|
||||||
|
}
|
||||||
|
return "unknown empty Context"
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
background = new(emptyCtx)
|
||||||
|
todo = new(emptyCtx)
|
||||||
|
)
|
||||||
|
|
||||||
|
// Background returns a non-nil, empty Context. It is never canceled, has no
|
||||||
|
// values, and has no deadline. It is typically used by the main function,
|
||||||
|
// initialization, and tests, and as the top-level Context for incoming
|
||||||
|
// requests.
|
||||||
|
func Background() Context {
|
||||||
|
return background
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO returns a non-nil, empty Context. Code should use context.TODO when
|
||||||
|
// it's unclear which Context to use or it's is not yet available (because the
|
||||||
|
// surrounding function has not yet been extended to accept a Context
|
||||||
|
// parameter). TODO is recognized by static analysis tools that determine
|
||||||
|
// whether Contexts are propagated correctly in a program.
|
||||||
|
func TODO() Context {
|
||||||
|
return todo
|
||||||
|
}
|
||||||
|
|
||||||
|
// A CancelFunc tells an operation to abandon its work.
|
||||||
|
// A CancelFunc does not wait for the work to stop.
|
||||||
|
// After the first call, subsequent calls to a CancelFunc do nothing.
|
||||||
|
type CancelFunc func()
|
||||||
|
|
||||||
|
// WithCancel returns a copy of parent with a new Done channel. The returned
|
||||||
|
// context's Done channel is closed when the returned cancel function is called
|
||||||
|
// or when the parent context's Done channel is closed, whichever happens first.
|
||||||
|
//
|
||||||
|
// Canceling this context releases resources associated with it, so code should
|
||||||
|
// call cancel as soon as the operations running in this Context complete.
|
||||||
|
func WithCancel(parent Context) (ctx Context, cancel CancelFunc) {
|
||||||
|
c := newCancelCtx(parent)
|
||||||
|
propagateCancel(parent, &c)
|
||||||
|
return &c, func() { c.cancel(true, Canceled) }
|
||||||
|
}
|
||||||
|
|
||||||
|
// newCancelCtx returns an initialized cancelCtx.
|
||||||
|
func newCancelCtx(parent Context) cancelCtx {
|
||||||
|
return cancelCtx{
|
||||||
|
Context: parent,
|
||||||
|
done: make(chan struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// propagateCancel arranges for child to be canceled when parent is.
|
||||||
|
func propagateCancel(parent Context, child canceler) {
|
||||||
|
if parent.Done() == nil {
|
||||||
|
return // parent is never canceled
|
||||||
|
}
|
||||||
|
if p, ok := parentCancelCtx(parent); ok {
|
||||||
|
p.mu.Lock()
|
||||||
|
if p.err != nil {
|
||||||
|
// parent has already been canceled
|
||||||
|
child.cancel(false, p.err)
|
||||||
|
} else {
|
||||||
|
if p.children == nil {
|
||||||
|
p.children = make(map[canceler]bool)
|
||||||
|
}
|
||||||
|
p.children[child] = true
|
||||||
|
}
|
||||||
|
p.mu.Unlock()
|
||||||
|
} else {
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-parent.Done():
|
||||||
|
child.cancel(false, parent.Err())
|
||||||
|
case <-child.Done():
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// parentCancelCtx follows a chain of parent references until it finds a
|
||||||
|
// *cancelCtx. This function understands how each of the concrete types in this
|
||||||
|
// package represents its parent.
|
||||||
|
func parentCancelCtx(parent Context) (*cancelCtx, bool) {
|
||||||
|
for {
|
||||||
|
switch c := parent.(type) {
|
||||||
|
case *cancelCtx:
|
||||||
|
return c, true
|
||||||
|
case *timerCtx:
|
||||||
|
return &c.cancelCtx, true
|
||||||
|
case *valueCtx:
|
||||||
|
parent = c.Context
|
||||||
|
default:
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeChild removes a context from its parent.
|
||||||
|
func removeChild(parent Context, child canceler) {
|
||||||
|
p, ok := parentCancelCtx(parent)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p.mu.Lock()
|
||||||
|
if p.children != nil {
|
||||||
|
delete(p.children, child)
|
||||||
|
}
|
||||||
|
p.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// A canceler is a context type that can be canceled directly. The
|
||||||
|
// implementations are *cancelCtx and *timerCtx.
|
||||||
|
type canceler interface {
|
||||||
|
cancel(removeFromParent bool, err error)
|
||||||
|
Done() <-chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// A cancelCtx can be canceled. When canceled, it also cancels any children
|
||||||
|
// that implement canceler.
|
||||||
|
type cancelCtx struct {
|
||||||
|
Context
|
||||||
|
|
||||||
|
done chan struct{} // closed by the first cancel call.
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
children map[canceler]bool // set to nil by the first cancel call
|
||||||
|
err error // set to non-nil by the first cancel call
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *cancelCtx) Done() <-chan struct{} {
|
||||||
|
return c.done
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *cancelCtx) Err() error {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
return c.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *cancelCtx) String() string {
|
||||||
|
return fmt.Sprintf("%v.WithCancel", c.Context)
|
||||||
|
}
|
||||||
|
|
||||||
|
// cancel closes c.done, cancels each of c's children, and, if
|
||||||
|
// removeFromParent is true, removes c from its parent's children.
|
||||||
|
func (c *cancelCtx) cancel(removeFromParent bool, err error) {
|
||||||
|
if err == nil {
|
||||||
|
panic("context: internal error: missing cancel error")
|
||||||
|
}
|
||||||
|
c.mu.Lock()
|
||||||
|
if c.err != nil {
|
||||||
|
c.mu.Unlock()
|
||||||
|
return // already canceled
|
||||||
|
}
|
||||||
|
c.err = err
|
||||||
|
close(c.done)
|
||||||
|
for child := range c.children {
|
||||||
|
// NOTE: acquiring the child's lock while holding parent's lock.
|
||||||
|
child.cancel(false, err)
|
||||||
|
}
|
||||||
|
c.children = nil
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
if removeFromParent {
|
||||||
|
removeChild(c.Context, c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithDeadline returns a copy of the parent context with the deadline adjusted
|
||||||
|
// to be no later than d. If the parent's deadline is already earlier than d,
|
||||||
|
// WithDeadline(parent, d) is semantically equivalent to parent. The returned
|
||||||
|
// context's Done channel is closed when the deadline expires, when the returned
|
||||||
|
// cancel function is called, or when the parent context's Done channel is
|
||||||
|
// closed, whichever happens first.
|
||||||
|
//
|
||||||
|
// Canceling this context releases resources associated with it, so code should
|
||||||
|
// call cancel as soon as the operations running in this Context complete.
|
||||||
|
func WithDeadline(parent Context, deadline time.Time) (Context, CancelFunc) {
|
||||||
|
if cur, ok := parent.Deadline(); ok && cur.Before(deadline) {
|
||||||
|
// The current deadline is already sooner than the new one.
|
||||||
|
return WithCancel(parent)
|
||||||
|
}
|
||||||
|
c := &timerCtx{
|
||||||
|
cancelCtx: newCancelCtx(parent),
|
||||||
|
deadline: deadline,
|
||||||
|
}
|
||||||
|
propagateCancel(parent, c)
|
||||||
|
d := deadline.Sub(time.Now())
|
||||||
|
if d <= 0 {
|
||||||
|
c.cancel(true, DeadlineExceeded) // deadline has already passed
|
||||||
|
return c, func() { c.cancel(true, Canceled) }
|
||||||
|
}
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
if c.err == nil {
|
||||||
|
c.timer = time.AfterFunc(d, func() {
|
||||||
|
c.cancel(true, DeadlineExceeded)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return c, func() { c.cancel(true, Canceled) }
|
||||||
|
}
|
||||||
|
|
||||||
|
// A timerCtx carries a timer and a deadline. It embeds a cancelCtx to
|
||||||
|
// implement Done and Err. It implements cancel by stopping its timer then
|
||||||
|
// delegating to cancelCtx.cancel.
|
||||||
|
type timerCtx struct {
|
||||||
|
cancelCtx
|
||||||
|
timer *time.Timer // Under cancelCtx.mu.
|
||||||
|
|
||||||
|
deadline time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *timerCtx) Deadline() (deadline time.Time, ok bool) {
|
||||||
|
return c.deadline, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *timerCtx) String() string {
|
||||||
|
return fmt.Sprintf("%v.WithDeadline(%s [%s])", c.cancelCtx.Context, c.deadline, c.deadline.Sub(time.Now()))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *timerCtx) cancel(removeFromParent bool, err error) {
|
||||||
|
c.cancelCtx.cancel(false, err)
|
||||||
|
if removeFromParent {
|
||||||
|
// Remove this timerCtx from its parent cancelCtx's children.
|
||||||
|
removeChild(c.cancelCtx.Context, c)
|
||||||
|
}
|
||||||
|
c.mu.Lock()
|
||||||
|
if c.timer != nil {
|
||||||
|
c.timer.Stop()
|
||||||
|
c.timer = nil
|
||||||
|
}
|
||||||
|
c.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithTimeout returns WithDeadline(parent, time.Now().Add(timeout)).
|
||||||
|
//
|
||||||
|
// Canceling this context releases resources associated with it, so code should
|
||||||
|
// call cancel as soon as the operations running in this Context complete:
|
||||||
|
//
|
||||||
|
// func slowOperationWithTimeout(ctx context.Context) (Result, error) {
|
||||||
|
// ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond)
|
||||||
|
// defer cancel() // releases resources if slowOperation completes before timeout elapses
|
||||||
|
// return slowOperation(ctx)
|
||||||
|
// }
|
||||||
|
func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc) {
|
||||||
|
return WithDeadline(parent, time.Now().Add(timeout))
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithValue returns a copy of parent in which the value associated with key is
|
||||||
|
// val.
|
||||||
|
//
|
||||||
|
// Use context Values only for request-scoped data that transits processes and
|
||||||
|
// APIs, not for passing optional parameters to functions.
|
||||||
|
func WithValue(parent Context, key interface{}, val interface{}) Context {
|
||||||
|
return &valueCtx{parent, key, val}
|
||||||
|
}
|
||||||
|
|
||||||
|
// A valueCtx carries a key-value pair. It implements Value for that key and
|
||||||
|
// delegates all other calls to the embedded Context.
|
||||||
|
type valueCtx struct {
|
||||||
|
Context
|
||||||
|
key, val interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *valueCtx) String() string {
|
||||||
|
return fmt.Sprintf("%v.WithValue(%#v, %#v)", c.Context, c.key, c.val)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *valueCtx) Value(key interface{}) interface{} {
|
||||||
|
if c.key == key {
|
||||||
|
return c.val
|
||||||
|
}
|
||||||
|
return c.Context.Value(key)
|
||||||
|
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue