Skip to content

Commit c15ade0

Browse files
committed
fix: ctx cancellation on login prompt
Signed-off-by: Alano Terblanche <18033717+Benehiko@users.noreply.github.com>
1 parent 01dd6ab commit c15ade0

5 files changed

Lines changed: 160 additions & 34 deletions

File tree

cli/command/registry.go

Lines changed: 14 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
package command
22

33
import (
4-
"bufio"
54
"context"
65
"fmt"
7-
"io"
86
"os"
97
"runtime"
108
"strings"
@@ -18,7 +16,6 @@ import (
1816
"github.com/docker/docker/api/types"
1917
registrytypes "github.com/docker/docker/api/types/registry"
2018
"github.com/docker/docker/registry"
21-
"github.com/moby/term"
2219
"github.com/pkg/errors"
2320
)
2421

@@ -44,7 +41,7 @@ func RegistryAuthenticationPrivilegedFunc(cli Cli, index *registrytypes.IndexInf
4441
default:
4542
}
4643

47-
err = ConfigureAuth(cli, "", "", &authConfig, isDefaultRegistry)
44+
err = ConfigureAuth(ctx, cli, "", "", &authConfig, isDefaultRegistry)
4845
if err != nil {
4946
return "", err
5047
}
@@ -90,7 +87,7 @@ func GetDefaultAuthConfig(cfg *configfile.ConfigFile, checkCredStore bool, serve
9087
}
9188

9289
// ConfigureAuth handles prompting of user's username and password if needed
93-
func ConfigureAuth(cli Cli, flUser, flPassword string, authconfig *registrytypes.AuthConfig, isDefaultRegistry bool) error {
90+
func ConfigureAuth(ctx context.Context, cli Cli, flUser, flPassword string, authconfig *registrytypes.AuthConfig, isDefaultRegistry bool) error {
9491
// On Windows, force the use of the regular OS stdin stream.
9592
//
9693
// See:
@@ -125,9 +122,15 @@ func ConfigureAuth(cli Cli, flUser, flPassword string, authconfig *registrytypes
125122
fmt.Fprintln(cli.Out())
126123
}
127124
}
128-
promptWithDefault(cli.Out(), "Username", authconfig.Username)
125+
126+
var prompt string
127+
if authconfig.Username == "" {
128+
prompt = "Username: "
129+
} else {
130+
prompt = fmt.Sprintf("Username (%s): ", authconfig.Username)
131+
}
129132
var err error
130-
flUser, err = readInput(cli.In())
133+
flUser, err = PromptForInput(ctx, cli.In(), cli.Out(), prompt)
131134
if err != nil {
132135
return err
133136
}
@@ -139,16 +142,13 @@ func ConfigureAuth(cli Cli, flUser, flPassword string, authconfig *registrytypes
139142
return errors.Errorf("Error: Non-null Username Required")
140143
}
141144
if flPassword == "" {
142-
oldState, err := term.SaveState(cli.In().FD())
145+
restoreInput, err := DisableInputEcho(cli.In())
143146
if err != nil {
144147
return err
145148
}
146-
fmt.Fprintf(cli.Out(), "Password: ")
147-
_ = term.DisableEcho(cli.In().FD(), oldState)
148-
defer func() {
149-
_ = term.RestoreTerminal(cli.In().FD(), oldState)
150-
}()
151-
flPassword, err = readInput(cli.In())
149+
defer restoreInput()
150+
151+
flPassword, err = PromptForInput(ctx, cli.In(), cli.Out(), "Password: ")
152152
if err != nil {
153153
return err
154154
}
@@ -164,25 +164,6 @@ func ConfigureAuth(cli Cli, flUser, flPassword string, authconfig *registrytypes
164164
return nil
165165
}
166166

167-
// readInput reads, and returns user input from in. It tries to return a
168-
// single line, not including the end-of-line bytes, and trims leading
169-
// and trailing whitespace.
170-
func readInput(in io.Reader) (string, error) {
171-
line, _, err := bufio.NewReader(in).ReadLine()
172-
if err != nil {
173-
return "", errors.Wrap(err, "error while reading input")
174-
}
175-
return strings.TrimSpace(string(line)), nil
176-
}
177-
178-
func promptWithDefault(out io.Writer, prompt string, configDefault string) {
179-
if configDefault == "" {
180-
fmt.Fprintf(out, "%s: ", prompt)
181-
} else {
182-
fmt.Fprintf(out, "%s (%s): ", prompt, configDefault)
183-
}
184-
}
185-
186167
// RetrieveAuthTokenFromImage retrieves an encoded auth token given a complete
187168
// image. The auth configuration is serialized as a base64url encoded RFC4648,
188169
// section 5) JSON string for sending through the X-Registry-Auth header.

cli/command/registry/login.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ func runLogin(ctx context.Context, dockerCli command.Cli, opts loginOptions) err
121121
response, err = loginWithCredStoreCreds(ctx, dockerCli, &authConfig)
122122
}
123123
if err != nil || authConfig.Username == "" || authConfig.Password == "" {
124-
err = command.ConfigureAuth(dockerCli, opts.user, opts.password, &authConfig, isDefaultRegistry)
124+
err = command.ConfigureAuth(ctx, dockerCli, opts.user, opts.password, &authConfig, isDefaultRegistry)
125125
if err != nil {
126126
return err
127127
}

cli/command/registry/login_test.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@ import (
66
"errors"
77
"fmt"
88
"testing"
9+
"time"
910

11+
"github.com/creack/pty"
12+
"github.com/docker/cli/cli/command"
1013
configtypes "github.com/docker/cli/cli/config/types"
1114
"github.com/docker/cli/cli/streams"
1215
"github.com/docker/cli/internal/test"
@@ -185,3 +188,41 @@ func TestRunLogin(t *testing.T) {
185188
})
186189
}
187190
}
191+
192+
func TestLoginTermination(t *testing.T) {
193+
p, tty, err := pty.Open()
194+
assert.NilError(t, err)
195+
196+
t.Cleanup(func() {
197+
_ = tty.Close()
198+
_ = p.Close()
199+
})
200+
201+
cli := test.NewFakeCli(&fakeClient{}, func(fc *test.FakeCli) {
202+
fc.SetOut(streams.NewOut(tty))
203+
fc.SetIn(streams.NewIn(tty))
204+
})
205+
tmpFile := fs.NewFile(t, "test-login-termination")
206+
defer tmpFile.Remove()
207+
208+
configFile := cli.ConfigFile()
209+
configFile.Filename = tmpFile.Path()
210+
211+
ctx, cancel := context.WithCancel(context.Background())
212+
t.Cleanup(cancel)
213+
214+
runErr := make(chan error)
215+
go func() {
216+
runErr <- runLogin(ctx, cli, loginOptions{})
217+
}()
218+
219+
// Let the prompt get canceled by the context
220+
cancel()
221+
222+
select {
223+
case <-time.After(1 * time.Second):
224+
t.Fatal("timed out after 1 second. `runLogin` did not return")
225+
case err := <-runErr:
226+
assert.ErrorIs(t, err, command.ErrPromptTerminated)
227+
}
228+
}

cli/command/utils.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"github.com/docker/docker/api/types/versions"
2020
"github.com/docker/docker/errdefs"
2121
"github.com/moby/sys/sequential"
22+
"github.com/moby/term"
2223
"github.com/pkg/errors"
2324
"github.com/spf13/pflag"
2425
)
@@ -76,6 +77,48 @@ func PrettyPrint(i any) string {
7677

7778
var ErrPromptTerminated = errdefs.Cancelled(errors.New("prompt terminated"))
7879

80+
// DisableInputEcho disables input echo on the provided streams.In.
81+
// This is useful when the user provides sensitive information like passwords.
82+
// The function returns a restore function that should be called to restore the
83+
// terminal state.
84+
func DisableInputEcho(ins *streams.In) (restore func() error, err error) {
85+
oldState, err := term.SaveState(ins.FD())
86+
if err != nil {
87+
return nil, err
88+
}
89+
restore = func() error {
90+
return term.RestoreTerminal(ins.FD(), oldState)
91+
}
92+
return restore, term.DisableEcho(ins.FD(), oldState)
93+
}
94+
95+
// PromptForInput requests input from the user.
96+
//
97+
// If the user terminates the CLI with SIGINT or SIGTERM while the prompt is
98+
// active, the prompt will return an empty string ("") with an ErrPromptTerminated error.
99+
// When the prompt returns an error, the caller should propagate the error up
100+
// the stack and close the io.Reader used for the prompt which will prevent the
101+
// background goroutine from blocking indefinitely.
102+
func PromptForInput(ctx context.Context, in io.Reader, out io.Writer, message string) (string, error) {
103+
_, _ = fmt.Fprint(out, message)
104+
105+
result := make(chan string)
106+
go func() {
107+
scanner := bufio.NewScanner(in)
108+
if scanner.Scan() {
109+
result <- strings.TrimSpace(scanner.Text())
110+
}
111+
}()
112+
113+
select {
114+
case <-ctx.Done():
115+
_, _ = fmt.Fprintln(out, "")
116+
return "", ErrPromptTerminated
117+
case r := <-result:
118+
return r, nil
119+
}
120+
}
121+
79122
// PromptForConfirmation requests and checks confirmation from the user.
80123
// This will display the provided message followed by ' [y/N] '. If the user
81124
// input 'y' or 'Y' it returns true otherwise false. If no message is provided,

cli/command/utils_test.go

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"time"
1616

1717
"github.com/docker/cli/cli/command"
18+
"github.com/docker/cli/cli/streams"
1819
"github.com/docker/cli/internal/test"
1920
"github.com/pkg/errors"
2021
"gotest.tools/v3/assert"
@@ -80,6 +81,66 @@ func TestValidateOutputPath(t *testing.T) {
8081
}
8182
}
8283

84+
func TestPromptForInput(t *testing.T) {
85+
t.Run("case=cancelling the context", func(t *testing.T) {
86+
ctx, cancel := context.WithCancel(context.Background())
87+
t.Cleanup(cancel)
88+
reader, _ := io.Pipe()
89+
90+
buf := new(bytes.Buffer)
91+
bufioWriter := bufio.NewWriter(buf)
92+
93+
wroteHook := make(chan struct{}, 1)
94+
promptOut := test.NewWriterWithHook(bufioWriter, func(p []byte) {
95+
wroteHook <- struct{}{}
96+
})
97+
98+
promptErr := make(chan error, 1)
99+
go func() {
100+
_, err := command.PromptForInput(ctx, streams.NewIn(reader), streams.NewOut(promptOut), "Enter something")
101+
promptErr <- err
102+
}()
103+
104+
select {
105+
case <-time.After(1 * time.Second):
106+
t.Fatal("timeout waiting for prompt to write to buffer")
107+
case <-wroteHook:
108+
cancel()
109+
}
110+
111+
select {
112+
case <-time.After(1 * time.Second):
113+
t.Fatal("timeout waiting for prompt to be canceled")
114+
case err := <-promptErr:
115+
assert.ErrorIs(t, err, command.ErrPromptTerminated)
116+
}
117+
})
118+
119+
t.Run("case=user input should be properly trimmed", func(t *testing.T) {
120+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
121+
t.Cleanup(cancel)
122+
123+
reader, writer := io.Pipe()
124+
125+
buf := new(bytes.Buffer)
126+
bufioWriter := bufio.NewWriter(buf)
127+
128+
wroteHook := make(chan struct{}, 1)
129+
promptOut := test.NewWriterWithHook(bufioWriter, func(p []byte) {
130+
wroteHook <- struct{}{}
131+
})
132+
133+
go func() {
134+
<-wroteHook
135+
writer.Write([]byte(" foo \n"))
136+
}()
137+
138+
answer, err := command.PromptForInput(ctx, streams.NewIn(reader), streams.NewOut(promptOut), "Enter something")
139+
assert.NilError(t, err)
140+
assert.Equal(t, answer, "foo")
141+
})
142+
}
143+
83144
func TestPromptForConfirmation(t *testing.T) {
84145
ctx, cancel := context.WithCancel(context.Background())
85146
t.Cleanup(cancel)

0 commit comments

Comments
 (0)