diff --git a/cli/command/builder/prune.go b/cli/command/builder/prune.go index 4b24957690..5466d30022 100644 --- a/cli/command/builder/prune.go +++ b/cli/command/builder/prune.go @@ -9,6 +9,7 @@ import ( "github.com/docker/cli/cli" "github.com/docker/cli/cli/command" "github.com/docker/cli/cli/command/completion" + "github.com/docker/cli/internal/prompt" "github.com/docker/cli/opts" "github.com/docker/docker/api/types" "github.com/docker/docker/errdefs" @@ -69,7 +70,7 @@ func runPrune(ctx context.Context, dockerCli command.Cli, options pruneOptions) warning = allCacheWarning } if !options.force { - r, err := command.PromptForConfirmation(ctx, dockerCli.In(), dockerCli.Out(), warning) + r, err := prompt.Confirm(ctx, dockerCli.In(), dockerCli.Out(), warning) if err != nil { return 0, "", err } diff --git a/cli/command/container/prune.go b/cli/command/container/prune.go index 1c879002a8..74f741a7a7 100644 --- a/cli/command/container/prune.go +++ b/cli/command/container/prune.go @@ -7,6 +7,7 @@ import ( "github.com/docker/cli/cli" "github.com/docker/cli/cli/command" "github.com/docker/cli/cli/command/completion" + "github.com/docker/cli/internal/prompt" "github.com/docker/cli/opts" "github.com/docker/docker/errdefs" units "github.com/docker/go-units" @@ -56,7 +57,7 @@ func runPrune(ctx context.Context, dockerCli command.Cli, options pruneOptions) pruneFilters := command.PruneFilters(dockerCli, options.filter.Value()) if !options.force { - r, err := command.PromptForConfirmation(ctx, dockerCli.In(), dockerCli.Out(), warning) + r, err := prompt.Confirm(ctx, dockerCli.In(), dockerCli.Out(), warning) if err != nil { return 0, "", err } diff --git a/cli/command/image/prune.go b/cli/command/image/prune.go index 7bdb24d8c5..89e84b41f0 100644 --- a/cli/command/image/prune.go +++ b/cli/command/image/prune.go @@ -9,6 +9,7 @@ import ( "github.com/docker/cli/cli" "github.com/docker/cli/cli/command" "github.com/docker/cli/cli/command/completion" + "github.com/docker/cli/internal/prompt" "github.com/docker/cli/opts" "github.com/docker/docker/errdefs" units "github.com/docker/go-units" @@ -70,7 +71,7 @@ func runPrune(ctx context.Context, dockerCli command.Cli, options pruneOptions) warning = allImageWarning } if !options.force { - r, err := command.PromptForConfirmation(ctx, dockerCli.In(), dockerCli.Out(), warning) + r, err := prompt.Confirm(ctx, dockerCli.In(), dockerCli.Out(), warning) if err != nil { return 0, "", err } diff --git a/cli/command/network/prune.go b/cli/command/network/prune.go index 8069d4c995..8eadee1d81 100644 --- a/cli/command/network/prune.go +++ b/cli/command/network/prune.go @@ -6,6 +6,7 @@ import ( "github.com/docker/cli/cli" "github.com/docker/cli/cli/command" + "github.com/docker/cli/internal/prompt" "github.com/docker/cli/opts" "github.com/docker/docker/errdefs" "github.com/pkg/errors" @@ -52,7 +53,7 @@ func runPrune(ctx context.Context, dockerCli command.Cli, options pruneOptions) pruneFilters := command.PruneFilters(dockerCli, options.filter.Value()) if !options.force { - r, err := command.PromptForConfirmation(ctx, dockerCli.In(), dockerCli.Out(), warning) + r, err := prompt.Confirm(ctx, dockerCli.In(), dockerCli.Out(), warning) if err != nil { return "", err } diff --git a/cli/command/network/remove.go b/cli/command/network/remove.go index 151c531f23..86f9ff5e2c 100644 --- a/cli/command/network/remove.go +++ b/cli/command/network/remove.go @@ -8,6 +8,7 @@ import ( "github.com/docker/cli/cli" "github.com/docker/cli/cli/command" "github.com/docker/cli/cli/command/completion" + "github.com/docker/cli/internal/prompt" "github.com/docker/docker/api/types/network" "github.com/docker/docker/errdefs" "github.com/spf13/cobra" @@ -49,7 +50,7 @@ func runRemove(ctx context.Context, dockerCLI command.Cli, networks []string, op for _, name := range networks { nw, _, err := apiClient.NetworkInspectWithRaw(ctx, name, network.InspectOptions{}) if err == nil && nw.Ingress { - r, err := command.PromptForConfirmation(ctx, dockerCLI.In(), dockerCLI.Out(), ingressWarning) + r, err := prompt.Confirm(ctx, dockerCLI.In(), dockerCLI.Out(), ingressWarning) if err != nil { return err } diff --git a/cli/command/plugin/install.go b/cli/command/plugin/install.go index 47b0c97ee9..9a7fe3ce57 100644 --- a/cli/command/plugin/install.go +++ b/cli/command/plugin/install.go @@ -10,6 +10,7 @@ import ( "github.com/docker/cli/cli/command" "github.com/docker/cli/cli/command/image" "github.com/docker/cli/cli/internal/jsonstream" + "github.com/docker/cli/internal/prompt" "github.com/docker/docker/api/types" registrytypes "github.com/docker/docker/api/types/registry" "github.com/docker/docker/registry" @@ -133,12 +134,12 @@ func runInstall(ctx context.Context, dockerCLI command.Cli, opts pluginOptions) return nil } -func acceptPrivileges(dockerCLI command.Cli, name string) func(ctx context.Context, privileges types.PluginPrivileges) (bool, error) { +func acceptPrivileges(dockerCLI command.Streams, name string) func(ctx context.Context, privileges types.PluginPrivileges) (bool, error) { return func(ctx context.Context, privileges types.PluginPrivileges) (bool, error) { _, _ = fmt.Fprintf(dockerCLI.Out(), "Plugin %q is requesting the following privileges:\n", name) for _, privilege := range privileges { _, _ = fmt.Fprintf(dockerCLI.Out(), " - %s: %v\n", privilege.Name, privilege.Value) } - return command.PromptForConfirmation(ctx, dockerCLI.In(), dockerCLI.Out(), "Do you grant the above permissions?") + return prompt.Confirm(ctx, dockerCLI.In(), dockerCLI.Out(), "Do you grant the above permissions?") } } diff --git a/cli/command/plugin/upgrade.go b/cli/command/plugin/upgrade.go index d83c75e317..d7b14988ad 100644 --- a/cli/command/plugin/upgrade.go +++ b/cli/command/plugin/upgrade.go @@ -9,6 +9,7 @@ import ( "github.com/docker/cli/cli" "github.com/docker/cli/cli/command" "github.com/docker/cli/cli/internal/jsonstream" + "github.com/docker/cli/internal/prompt" "github.com/docker/docker/errdefs" "github.com/pkg/errors" "github.com/spf13/cobra" @@ -64,7 +65,7 @@ func runUpgrade(ctx context.Context, dockerCLI command.Cli, opts pluginOptions) _, _ = fmt.Fprintf(dockerCLI.Out(), "Upgrading plugin %s from %s to %s\n", p.Name, reference.FamiliarString(old), reference.FamiliarString(remote)) if !opts.skipRemoteCheck && remote.String() != old.String() { - r, err := command.PromptForConfirmation(ctx, dockerCLI.In(), dockerCLI.Out(), "Plugin images do not match, are you sure?") + r, err := prompt.Confirm(ctx, dockerCLI.In(), dockerCLI.Out(), "Plugin images do not match, are you sure?") if err != nil { return err } diff --git a/cli/command/registry.go b/cli/command/registry.go index acf0c7e635..3169a25ed6 100644 --- a/cli/command/registry.go +++ b/cli/command/registry.go @@ -13,6 +13,7 @@ import ( configtypes "github.com/docker/cli/cli/config/types" "github.com/docker/cli/cli/hints" "github.com/docker/cli/cli/streams" + "github.com/docker/cli/internal/prompt" "github.com/docker/cli/internal/tui" registrytypes "github.com/docker/docker/api/types/registry" "github.com/morikuni/aec" @@ -148,16 +149,16 @@ func PromptUserForCredentials(ctx context.Context, cli Cli, argUser, argPassword } } - var prompt string + var msg string defaultUsername = strings.TrimSpace(defaultUsername) if defaultUsername == "" { - prompt = "Username: " + msg = "Username: " } else { - prompt = fmt.Sprintf("Username (%s): ", defaultUsername) + msg = fmt.Sprintf("Username (%s): ", defaultUsername) } var err error - argUser, err = PromptForInput(ctx, cli.In(), cli.Out(), prompt) + argUser, err = prompt.ReadInput(ctx, cli.In(), cli.Out(), msg) if err != nil { return registrytypes.AuthConfig{}, err } @@ -171,7 +172,7 @@ func PromptUserForCredentials(ctx context.Context, cli Cli, argUser, argPassword argPassword = strings.TrimSpace(argPassword) if argPassword == "" { - restoreInput, err := DisableInputEcho(cli.In()) + restoreInput, err := prompt.DisableInputEcho(cli.In()) if err != nil { return registrytypes.AuthConfig{}, err } @@ -188,7 +189,7 @@ func PromptUserForCredentials(ctx context.Context, cli Cli, argUser, argPassword out := tui.NewOutput(cli.Err()) out.PrintNote("A Personal Access Token (PAT) can be used instead.\n" + "To create a PAT, visit " + aec.Underline.Apply("https://app.docker.com/settings") + "\n\n") - argPassword, err = PromptForInput(ctx, cli.In(), cli.Out(), "Password: ") + argPassword, err = prompt.ReadInput(ctx, cli.In(), cli.Out(), "Password: ") if err != nil { return registrytypes.AuthConfig{}, err } diff --git a/cli/command/registry/login_test.go b/cli/command/registry/login_test.go index 75bd6a3164..b12b5b6623 100644 --- a/cli/command/registry/login_test.go +++ b/cli/command/registry/login_test.go @@ -9,9 +9,9 @@ import ( "time" "github.com/creack/pty" - "github.com/docker/cli/cli/command" configtypes "github.com/docker/cli/cli/config/types" "github.com/docker/cli/cli/streams" + "github.com/docker/cli/internal/prompt" "github.com/docker/cli/internal/test" registrytypes "github.com/docker/docker/api/types/registry" "github.com/docker/docker/api/types/system" @@ -492,7 +492,7 @@ func TestLoginTermination(t *testing.T) { case <-time.After(1 * time.Second): t.Fatal("timed out after 1 second. `runLogin` did not return") case err := <-runErr: - assert.ErrorIs(t, err, command.ErrPromptTerminated) + assert.ErrorIs(t, err, prompt.ErrTerminated) } } diff --git a/cli/command/system/prune.go b/cli/command/system/prune.go index 815efdef76..c1fa339a5b 100644 --- a/cli/command/system/prune.go +++ b/cli/command/system/prune.go @@ -15,6 +15,7 @@ import ( "github.com/docker/cli/cli/command/image" "github.com/docker/cli/cli/command/network" "github.com/docker/cli/cli/command/volume" + "github.com/docker/cli/internal/prompt" "github.com/docker/cli/opts" "github.com/docker/docker/api/types/versions" "github.com/docker/docker/errdefs" @@ -77,7 +78,7 @@ func runPrune(ctx context.Context, dockerCli command.Cli, options pruneOptions) return errors.New(`ERROR: The "until" filter is not supported with "--volumes"`) } if !options.force { - r, err := command.PromptForConfirmation(ctx, dockerCli.In(), dockerCli.Out(), confirmationMessage(dockerCli, options)) + r, err := prompt.Confirm(ctx, dockerCli.In(), dockerCli.Out(), confirmationMessage(dockerCli, options)) if err != nil { return err } diff --git a/cli/command/trust/revoke.go b/cli/command/trust/revoke.go index 7d1651679e..ec32c797f2 100644 --- a/cli/command/trust/revoke.go +++ b/cli/command/trust/revoke.go @@ -8,6 +8,7 @@ import ( "github.com/docker/cli/cli/command" "github.com/docker/cli/cli/command/image" "github.com/docker/cli/cli/trust" + "github.com/docker/cli/internal/prompt" "github.com/docker/docker/errdefs" "github.com/pkg/errors" "github.com/spf13/cobra" @@ -44,7 +45,7 @@ func revokeTrust(ctx context.Context, dockerCLI command.Cli, remote string, opti return errors.New("cannot use a digest reference for IMAGE:TAG") } if imgRefAndAuth.Tag() == "" && !options.forceYes { - deleteRemote, err := command.PromptForConfirmation(ctx, dockerCLI.In(), dockerCLI.Out(), fmt.Sprintf("Confirm you would like to delete all signature data for %s?", remote)) + deleteRemote, err := prompt.Confirm(ctx, dockerCLI.In(), dockerCLI.Out(), fmt.Sprintf("Confirm you would like to delete all signature data for %s?", remote)) if err != nil { return err } diff --git a/cli/command/trust/signer_remove.go b/cli/command/trust/signer_remove.go index ff6a29d5c5..10d2c2933b 100644 --- a/cli/command/trust/signer_remove.go +++ b/cli/command/trust/signer_remove.go @@ -9,6 +9,7 @@ import ( "github.com/docker/cli/cli/command" "github.com/docker/cli/cli/command/image" "github.com/docker/cli/cli/trust" + "github.com/docker/cli/internal/prompt" "github.com/pkg/errors" "github.com/spf13/cobra" "github.com/theupdateframework/notary/client" @@ -82,11 +83,7 @@ func maybePromptForSignerRemoval(ctx context.Context, dockerCLI command.Cli, rep "Are you sure you want to continue?", signerName, repoName, repoName, ) - removeSigner, err := command.PromptForConfirmation(ctx, dockerCLI.In(), dockerCLI.Out(), message) - if err != nil { - return false, err - } - return removeSigner, nil + return prompt.Confirm(ctx, dockerCLI.In(), dockerCLI.Out(), message) } return false, nil } diff --git a/cli/command/utils.go b/cli/command/utils.go index 60a508ccfd..373eaa095e 100644 --- a/cli/command/utils.go +++ b/cli/command/utils.go @@ -4,20 +4,17 @@ package command import ( - "bufio" "context" - "fmt" "io" "os" "path/filepath" - "runtime" "strings" "github.com/docker/cli/cli/config" "github.com/docker/cli/cli/streams" + "github.com/docker/cli/internal/prompt" "github.com/docker/docker/api/types/filters" "github.com/moby/sys/atomicwriter" - "github.com/moby/term" "github.com/pkg/errors" "github.com/spf13/pflag" ) @@ -35,29 +32,14 @@ func CopyToFile(outfile string, r io.Reader) error { return err } -const ErrPromptTerminated cancelledErr = "prompt terminated" - -type cancelledErr string - -func (e cancelledErr) Error() string { - return string(e) -} - -func (cancelledErr) Cancelled() {} +const ErrPromptTerminated = prompt.ErrTerminated // DisableInputEcho disables input echo on the provided streams.In. // This is useful when the user provides sensitive information like passwords. // The function returns a restore function that should be called to restore the // terminal state. func DisableInputEcho(ins *streams.In) (restore func() error, err error) { - oldState, err := term.SaveState(ins.FD()) - if err != nil { - return nil, err - } - restore = func() error { - return term.RestoreTerminal(ins.FD(), oldState) - } - return restore, term.DisableEcho(ins.FD(), oldState) + return prompt.DisableInputEcho(ins) } // PromptForInput requests input from the user. @@ -68,23 +50,7 @@ func DisableInputEcho(ins *streams.In) (restore func() error, err error) { // the stack and close the io.Reader used for the prompt which will prevent the // background goroutine from blocking indefinitely. func PromptForInput(ctx context.Context, in io.Reader, out io.Writer, message string) (string, error) { - _, _ = fmt.Fprint(out, message) - - result := make(chan string) - go func() { - scanner := bufio.NewScanner(in) - if scanner.Scan() { - result <- strings.TrimSpace(scanner.Text()) - } - }() - - select { - case <-ctx.Done(): - _, _ = fmt.Fprintln(out, "") - return "", ErrPromptTerminated - case r := <-result: - return r, nil - } + return prompt.ReadInput(ctx, in, out, message) } // PromptForConfirmation requests and checks confirmation from the user. @@ -98,39 +64,7 @@ func PromptForInput(ctx context.Context, in io.Reader, out io.Writer, message st // the stack and close the io.Reader used for the prompt which will prevent the // background goroutine from blocking indefinitely. func PromptForConfirmation(ctx context.Context, ins io.Reader, outs io.Writer, message string) (bool, error) { - if message == "" { - message = "Are you sure you want to proceed?" - } - message += " [y/N] " - - _, _ = fmt.Fprint(outs, message) - - // On Windows, force the use of the regular OS stdin stream. - if runtime.GOOS == "windows" { - ins = streams.NewIn(os.Stdin) - } - - result := make(chan bool) - - go func() { - var res bool - scanner := bufio.NewScanner(ins) - if scanner.Scan() { - answer := strings.TrimSpace(scanner.Text()) - if strings.EqualFold(answer, "y") { - res = true - } - } - result <- res - }() - - select { - case <-ctx.Done(): - _, _ = fmt.Fprintln(outs, "") - return false, ErrPromptTerminated - case r := <-result: - return r, nil - } + return prompt.Confirm(ctx, ins, outs, message) } // PruneFilters merges prune filters specified in config.json with those specified diff --git a/cli/command/utils_test.go b/cli/command/utils_test.go index 2cc0e2889b..6c8d484c6a 100644 --- a/cli/command/utils_test.go +++ b/cli/command/utils_test.go @@ -1,23 +1,12 @@ package command_test import ( - "bufio" - "bytes" - "context" "errors" - "fmt" - "io" "os" - "os/signal" "path/filepath" - "strings" - "syscall" "testing" - "time" "github.com/docker/cli/cli/command" - "github.com/docker/cli/cli/streams" - "github.com/docker/cli/internal/test" "gotest.tools/v3/assert" ) @@ -54,171 +43,3 @@ func TestValidateOutputPath(t *testing.T) { }) } } - -func TestPromptForInput(t *testing.T) { - t.Run("cancelling the context", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - t.Cleanup(cancel) - reader, _ := io.Pipe() - - buf := new(bytes.Buffer) - bufioWriter := bufio.NewWriter(buf) - - wroteHook := make(chan struct{}, 1) - promptOut := test.NewWriterWithHook(bufioWriter, func(p []byte) { - wroteHook <- struct{}{} - }) - - promptErr := make(chan error, 1) - go func() { - _, err := command.PromptForInput(ctx, streams.NewIn(reader), streams.NewOut(promptOut), "Enter something") - promptErr <- err - }() - - select { - case <-time.After(1 * time.Second): - t.Fatal("timeout waiting for prompt to write to buffer") - case <-wroteHook: - cancel() - } - - select { - case <-time.After(1 * time.Second): - t.Fatal("timeout waiting for prompt to be canceled") - case err := <-promptErr: - assert.ErrorIs(t, err, command.ErrPromptTerminated) - } - }) - - t.Run("user input should be properly trimmed", func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - t.Cleanup(cancel) - - reader, writer := io.Pipe() - - buf := new(bytes.Buffer) - bufioWriter := bufio.NewWriter(buf) - - wroteHook := make(chan struct{}, 1) - promptOut := test.NewWriterWithHook(bufioWriter, func(p []byte) { - wroteHook <- struct{}{} - }) - - go func() { - <-wroteHook - writer.Write([]byte(" foo \n")) - }() - - answer, err := command.PromptForInput(ctx, streams.NewIn(reader), streams.NewOut(promptOut), "Enter something") - assert.NilError(t, err) - assert.Equal(t, answer, "foo") - }) -} - -func TestPromptForConfirmation(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - t.Cleanup(cancel) - - type promptResult struct { - result bool - err error - } - - buf := new(bytes.Buffer) - bufioWriter := bufio.NewWriter(buf) - - var ( - promptWriter *io.PipeWriter - promptReader *io.PipeReader - ) - - defer func() { - if promptWriter != nil { - promptWriter.Close() - } - if promptReader != nil { - promptReader.Close() - } - }() - - for _, tc := range []struct { - desc string - f func() error - expected promptResult - }{ - {"SIGINT", func() error { - syscall.Kill(syscall.Getpid(), syscall.SIGINT) - return nil - }, promptResult{false, command.ErrPromptTerminated}}, - {"no", func() error { - _, err := fmt.Fprintln(promptWriter, "n") - return err - }, promptResult{false, nil}}, - {"yes", func() error { - _, err := fmt.Fprintln(promptWriter, "y") - return err - }, promptResult{true, nil}}, - {"any", func() error { - _, err := fmt.Fprintln(promptWriter, "a") - return err - }, promptResult{false, nil}}, - {"with space", func() error { - _, err := fmt.Fprintln(promptWriter, " y") - return err - }, promptResult{true, nil}}, - {"reader closed", func() error { - return promptReader.Close() - }, promptResult{false, nil}}, - } { - t.Run(tc.desc, func(t *testing.T) { - notifyCtx, notifyCancel := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM) - t.Cleanup(notifyCancel) - - buf.Reset() - promptReader, promptWriter = io.Pipe() - - wroteHook := make(chan struct{}, 1) - promptOut := test.NewWriterWithHook(bufioWriter, func(p []byte) { - wroteHook <- struct{}{} - }) - - result := make(chan promptResult, 1) - go func() { - r, err := command.PromptForConfirmation(notifyCtx, promptReader, promptOut, "") - result <- promptResult{r, err} - }() - - select { - case <-time.After(100 * time.Millisecond): - case <-wroteHook: - } - - assert.NilError(t, bufioWriter.Flush()) - assert.Equal(t, strings.TrimSpace(buf.String()), "Are you sure you want to proceed? [y/N]") - - // wait for the Prompt to write to the buffer - drainChannel(ctx, wroteHook) - - assert.NilError(t, tc.f()) - - select { - case <-time.After(500 * time.Millisecond): - t.Fatal("timeout waiting for prompt result") - case r := <-result: - assert.Equal(t, r, tc.expected) - } - }) - } -} - -func drainChannel(ctx context.Context, ch <-chan struct{}) { - go func() { - for { - select { - case <-ctx.Done(): - return - case <-ch: - } - } - }() -} diff --git a/cli/command/volume/prune.go b/cli/command/volume/prune.go index 9aa931186b..c366e347f6 100644 --- a/cli/command/volume/prune.go +++ b/cli/command/volume/prune.go @@ -7,6 +7,7 @@ import ( "github.com/docker/cli/cli" "github.com/docker/cli/cli/command" "github.com/docker/cli/cli/command/completion" + "github.com/docker/cli/internal/prompt" "github.com/docker/cli/opts" "github.com/docker/docker/api/types/versions" "github.com/docker/docker/errdefs" @@ -77,7 +78,7 @@ func runPrune(ctx context.Context, dockerCli command.Cli, options pruneOptions) warning = allVolumesWarning } if !options.force { - r, err := command.PromptForConfirmation(ctx, dockerCli.In(), dockerCli.Out(), warning) + r, err := prompt.Confirm(ctx, dockerCli.In(), dockerCli.Out(), warning) if err != nil { return 0, "", err } diff --git a/internal/prompt/prompt.go b/internal/prompt/prompt.go new file mode 100644 index 0000000000..cc054dfd84 --- /dev/null +++ b/internal/prompt/prompt.go @@ -0,0 +1,117 @@ +// Package prompt provides utilities to prompt the user for input. + +package prompt + +import ( + "bufio" + "context" + "fmt" + "io" + "os" + "runtime" + "strings" + + "github.com/docker/cli/cli/streams" + "github.com/moby/term" +) + +const ErrTerminated cancelledErr = "prompt terminated" + +type cancelledErr string + +func (e cancelledErr) Error() string { + return string(e) +} + +func (cancelledErr) Cancelled() {} + +// DisableInputEcho disables input echo on the provided streams.In. +// This is useful when the user provides sensitive information like passwords. +// The function returns a restore function that should be called to restore the +// terminal state. +// +// TODO(thaJeztah): implement without depending on streams? +func DisableInputEcho(ins *streams.In) (restore func() error, _ error) { + oldState, err := term.SaveState(ins.FD()) + if err != nil { + return nil, err + } + restore = func() error { + return term.RestoreTerminal(ins.FD(), oldState) + } + return restore, term.DisableEcho(ins.FD(), oldState) +} + +// ReadInput requests input from the user. +// +// It returns an empty string ("") with an [ErrTerminated] if the user terminates +// the CLI with SIGINT or SIGTERM while the prompt is active. If the prompt +// returns an error, the caller should close the [io.Reader] used for the prompt +// and propagate the error up the stack to prevent the background goroutine +// from blocking indefinitely. +func ReadInput(ctx context.Context, in io.Reader, out io.Writer, message string) (string, error) { + _, _ = fmt.Fprint(out, message) + + result := make(chan string) + go func() { + scanner := bufio.NewScanner(in) + if scanner.Scan() { + result <- strings.TrimSpace(scanner.Text()) + } + }() + + select { + case <-ctx.Done(): + _, _ = fmt.Fprintln(out, "") + return "", ErrTerminated + case r := <-result: + return r, nil + } +} + +// Confirm requests and checks confirmation from the user. +// +// It displays the provided message followed by "[y/N]". If the user +// input 'y' or 'Y' it returns true otherwise false. If no message is provided, +// "Are you sure you want to proceed? [y/N] " will be used instead. +// +// It returns false with an [ErrTerminated] if the user terminates +// the CLI with SIGINT or SIGTERM while the prompt is active. If the prompt +// returns an error, the caller should close the [io.Reader] used for the prompt +// and propagate the error up the stack to prevent the background goroutine +// from blocking indefinitely. +func Confirm(ctx context.Context, ins io.Reader, outs io.Writer, message string) (bool, error) { + if message == "" { + message = "Are you sure you want to proceed?" + } + message += " [y/N] " + + _, _ = fmt.Fprint(outs, message) + + // On Windows, force the use of the regular OS stdin stream. + if runtime.GOOS == "windows" { + ins = streams.NewIn(os.Stdin) + } + + result := make(chan bool) + + go func() { + var res bool + scanner := bufio.NewScanner(ins) + if scanner.Scan() { + answer := strings.TrimSpace(scanner.Text()) + if strings.EqualFold(answer, "y") { + res = true + } + } + result <- res + }() + + select { + case <-ctx.Done(): + _, _ = fmt.Fprintln(outs, "") + return false, ErrTerminated + case r := <-result: + return r, nil + } +} diff --git a/internal/prompt/prompt_test.go b/internal/prompt/prompt_test.go new file mode 100644 index 0000000000..18df7e237f --- /dev/null +++ b/internal/prompt/prompt_test.go @@ -0,0 +1,187 @@ +package prompt_test + +import ( + "bufio" + "bytes" + "context" + "fmt" + "io" + "os/signal" + "strings" + "syscall" + "testing" + "time" + + "github.com/docker/cli/cli/streams" + "github.com/docker/cli/internal/prompt" + "github.com/docker/cli/internal/test" + "gotest.tools/v3/assert" +) + +func TestReadInput(t *testing.T) { + t.Run("cancelling the context", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + reader, _ := io.Pipe() + + buf := new(bytes.Buffer) + bufioWriter := bufio.NewWriter(buf) + + wroteHook := make(chan struct{}, 1) + promptOut := test.NewWriterWithHook(bufioWriter, func(p []byte) { + wroteHook <- struct{}{} + }) + + promptErr := make(chan error, 1) + go func() { + _, err := prompt.ReadInput(ctx, streams.NewIn(reader), streams.NewOut(promptOut), "Enter something") + promptErr <- err + }() + + select { + case <-time.After(1 * time.Second): + t.Fatal("timeout waiting for prompt to write to buffer") + case <-wroteHook: + cancel() + } + + select { + case <-time.After(1 * time.Second): + t.Fatal("timeout waiting for prompt to be canceled") + case err := <-promptErr: + assert.ErrorIs(t, err, prompt.ErrTerminated) + } + }) + + t.Run("user input should be properly trimmed", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + t.Cleanup(cancel) + + reader, writer := io.Pipe() + + buf := new(bytes.Buffer) + bufioWriter := bufio.NewWriter(buf) + + wroteHook := make(chan struct{}, 1) + promptOut := test.NewWriterWithHook(bufioWriter, func(p []byte) { + wroteHook <- struct{}{} + }) + + go func() { + <-wroteHook + _, _ = writer.Write([]byte(" foo \n")) + }() + + answer, err := prompt.ReadInput(ctx, streams.NewIn(reader), streams.NewOut(promptOut), "Enter something") + assert.NilError(t, err) + assert.Equal(t, answer, "foo") + }) +} + +func TestConfirm(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + type promptResult struct { + result bool + err error + } + + buf := new(bytes.Buffer) + bufioWriter := bufio.NewWriter(buf) + + var ( + promptWriter *io.PipeWriter + promptReader *io.PipeReader + ) + + defer func() { + if promptWriter != nil { + _ = promptWriter.Close() + } + if promptReader != nil { + _ = promptReader.Close() + } + }() + + for _, tc := range []struct { + desc string + f func() error + expected promptResult + }{ + {"SIGINT", func() error { + _ = syscall.Kill(syscall.Getpid(), syscall.SIGINT) + return nil + }, promptResult{false, prompt.ErrTerminated}}, + {"no", func() error { + _, err := fmt.Fprintln(promptWriter, "n") + return err + }, promptResult{false, nil}}, + {"yes", func() error { + _, err := fmt.Fprintln(promptWriter, "y") + return err + }, promptResult{true, nil}}, + {"any", func() error { + _, err := fmt.Fprintln(promptWriter, "a") + return err + }, promptResult{false, nil}}, + {"with space", func() error { + _, err := fmt.Fprintln(promptWriter, " y") + return err + }, promptResult{true, nil}}, + {"reader closed", func() error { + return promptReader.Close() + }, promptResult{false, nil}}, + } { + t.Run(tc.desc, func(t *testing.T) { + notifyCtx, notifyCancel := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM) + t.Cleanup(notifyCancel) + + buf.Reset() + promptReader, promptWriter = io.Pipe() + + wroteHook := make(chan struct{}, 1) + promptOut := test.NewWriterWithHook(bufioWriter, func(p []byte) { + wroteHook <- struct{}{} + }) + + result := make(chan promptResult, 1) + go func() { + r, err := prompt.Confirm(notifyCtx, promptReader, promptOut, "") + result <- promptResult{r, err} + }() + + select { + case <-time.After(100 * time.Millisecond): + case <-wroteHook: + } + + assert.NilError(t, bufioWriter.Flush()) + assert.Equal(t, strings.TrimSpace(buf.String()), "Are you sure you want to proceed? [y/N]") + + // wait for the Prompt to write to the buffer + drainChannel(ctx, wroteHook) + + assert.NilError(t, tc.f()) + + select { + case <-time.After(500 * time.Millisecond): + t.Fatal("timeout waiting for prompt result") + case r := <-result: + assert.Equal(t, r, tc.expected) + } + }) + } +} + +func drainChannel(ctx context.Context, ch <-chan struct{}) { + go func() { + for { + select { + case <-ctx.Done(): + return + case <-ch: + } + } + }() +} diff --git a/internal/test/cmd.go b/internal/test/cmd.go index 52b44d66f1..8b98d2df62 100644 --- a/internal/test/cmd.go +++ b/internal/test/cmd.go @@ -6,8 +6,8 @@ import ( "testing" "time" - "github.com/docker/cli/cli/command" "github.com/docker/cli/cli/streams" + "github.com/docker/cli/internal/prompt" "github.com/spf13/cobra" "gotest.tools/v3/assert" ) @@ -76,6 +76,6 @@ func TerminatePrompt(ctx context.Context, t *testing.T, cmd *cobra.Command, cli t.Logf("command stderr:\n%s\n", cli.ErrBuffer().String()) t.Fatalf("command %s did not return after SIGINT", cmd.Name()) case err := <-errChan: - assert.ErrorIs(t, err, command.ErrPromptTerminated) + assert.ErrorIs(t, err, prompt.ErrTerminated) } }