Skip to content

Commit 581b665

Browse files
committed
Resolve race condition in codespaces connection
1 parent ede1705 commit 581b665

10 files changed

Lines changed: 67 additions & 20 deletions

File tree

internal/codespaces/connection/connection.go

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,15 @@ const (
1616
clientName = "gh"
1717
)
1818

19+
type TunnelClient struct {
20+
*tunnels.Client
21+
connected bool
22+
}
23+
1924
type CodespaceConnection struct {
2025
tunnelProperties api.TunnelProperties
2126
TunnelManager *tunnels.Manager
22-
TunnelClient *tunnels.Client
27+
TunnelClient *TunnelClient
2328
Options *tunnels.TunnelRequestOptions
2429
Tunnel *tunnels.Tunnel
2530
AllowedPortPrivacySettings []string
@@ -74,6 +79,38 @@ func NewCodespaceConnection(ctx context.Context, codespace *api.Codespace, httpC
7479
}, nil
7580
}
7681

82+
// Connect connects the client to the tunnel.
83+
func (c *CodespaceConnection) Connect(ctx context.Context) error {
84+
// If already connected, return
85+
if c.TunnelClient.connected {
86+
return nil
87+
}
88+
89+
// Connect to the tunnel
90+
if err := c.TunnelClient.Client.Connect(ctx, ""); err != nil {
91+
return fmt.Errorf("error connecting to tunnel: %w", err)
92+
}
93+
94+
// Set the connected flag so we know we're connected
95+
c.TunnelClient.connected = true
96+
97+
return nil
98+
}
99+
100+
// Close closes the underlying tunnel client SSH connection.
101+
func (c *CodespaceConnection) Close() error {
102+
// Don't close if we're not connected
103+
if c.TunnelClient != nil && c.TunnelClient.connected {
104+
if err := c.TunnelClient.Close(); err != nil {
105+
return fmt.Errorf("failed to close tunnel client connection: %w", err)
106+
}
107+
108+
c.TunnelClient.connected = false
109+
}
110+
111+
return nil
112+
}
113+
77114
// getTunnelManager creates a tunnel manager for the given codespace.
78115
// The tunnel manager is used to get the tunnel hosted in the codespace that we
79116
// want to connect to and perform operations on ports (add, remove, list, etc.).
@@ -96,7 +133,7 @@ func getTunnelManager(tunnelProperties api.TunnelProperties, httpClient *http.Cl
96133
// getTunnelClient creates a tunnel client for the given tunnel.
97134
// The tunnel client is used to connect to the the tunnel and allows
98135
// for ports to be forwarded locally.
99-
func getTunnelClient(ctx context.Context, tunnelManager *tunnels.Manager, tunnel *tunnels.Tunnel, options *tunnels.TunnelRequestOptions) (tunnelClient *tunnels.Client, err error) {
136+
func getTunnelClient(ctx context.Context, tunnelManager *tunnels.Manager, tunnel *tunnels.Tunnel, options *tunnels.TunnelRequestOptions) (tunnelClient *TunnelClient, err error) {
100137
// Get the tunnel that we want to connect to
101138
codespaceTunnel, err := tunnelManager.GetTunnel(ctx, tunnel, options)
102139
if err != nil {
@@ -107,10 +144,15 @@ func getTunnelClient(ctx context.Context, tunnelManager *tunnels.Manager, tunnel
107144
codespaceTunnel.AccessTokens = tunnel.AccessTokens
108145

109146
// We need to pass false for accept local connections because we don't want to automatically connect to all forwarded ports
110-
tunnelClient, err = tunnels.NewClient(log.New(io.Discard, "", log.LstdFlags), codespaceTunnel, false)
147+
client, err := tunnels.NewClient(log.New(io.Discard, "", log.LstdFlags), codespaceTunnel, false)
111148
if err != nil {
112149
return nil, fmt.Errorf("error creating tunnel client: %w", err)
113150
}
114151

152+
tunnelClient = &TunnelClient{
153+
Client: client,
154+
connected: false,
155+
}
156+
115157
return tunnelClient, nil
116158
}

internal/codespaces/connection/connection_test.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,12 @@ func TestNewCodespaceConnection(t *testing.T) {
4141
t.Fatalf("NewCodespaceConnection returned an error: %v", err)
4242
}
4343

44+
// Verify closing before connected doesn't throw
45+
err = conn.Close()
46+
if err != nil {
47+
t.Fatalf("Close returned an error: %v", err)
48+
}
49+
4450
// Check that the connection was created successfully
4551
if conn == nil {
4652
t.Fatal("NewCodespaceConnection returned nil")

internal/codespaces/portforwarder/port_forwarder.go

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ type PortForwarder interface {
4848
UpdatePortVisibility(ctx context.Context, remotePort int, visibility string) error
4949
KeepAlive(reason string)
5050
GetKeepAliveReason() string
51-
CloseSSHConnection()
51+
Close() error
5252
}
5353

5454
// NewPortForwarder returns a new PortForwarder for the specified codespace.
@@ -66,9 +66,6 @@ func (fwd *CodespacesPortForwarder) ForwardPortToListener(ctx context.Context, o
6666
return fmt.Errorf("error forwarding port: %w", err)
6767
}
6868

69-
// Close the SSH connection when we're done
70-
defer fwd.CloseSSHConnection()
71-
7269
done := make(chan error)
7370
go func() {
7471
// Convert the port number to a uint16
@@ -151,15 +148,14 @@ func (fwd *CodespacesPortForwarder) ForwardPort(ctx context.Context, opts Forwar
151148
}
152149

153150
// Connect to the tunnel
154-
err = fwd.connection.TunnelClient.Connect(ctx, "")
151+
err = fwd.connection.Connect(ctx)
155152
if err != nil {
156153
return fmt.Errorf("connect failed: %v", err)
157154
}
158155

159156
// Inform the host that we've forwarded the port locally
160157
err = fwd.connection.TunnelClient.RefreshPorts(ctx)
161158
if err != nil {
162-
fwd.CloseSSHConnection()
163159
return fmt.Errorf("refresh ports failed: %v", err)
164160
}
165161

@@ -257,15 +253,12 @@ func (fwd *CodespacesPortForwarder) UpdatePortVisibility(ctx context.Context, re
257253
done := make(chan error)
258254
go func() {
259255
// Connect to the tunnel
260-
err = fwd.connection.TunnelClient.Connect(ctx, "")
256+
err = fwd.connection.Connect(ctx)
261257
if err != nil {
262258
done <- fmt.Errorf("connect failed: %v", err)
263259
return
264260
}
265261

266-
// Close the SSH connection when we're done
267-
defer fwd.CloseSSHConnection()
268-
269262
// Inform the host that we've deleted the port
270263
err = fwd.connection.TunnelClient.RefreshPorts(ctx)
271264
if err != nil {
@@ -316,8 +309,8 @@ func (fwd *CodespacesPortForwarder) GetKeepAliveReason() string {
316309
}
317310

318311
// Close closes the port forwarder's tunnel client connection.
319-
func (fwd *CodespacesPortForwarder) CloseSSHConnection() {
320-
_ = fwd.connection.TunnelClient.Close()
312+
func (fwd *CodespacesPortForwarder) Close() error {
313+
return fwd.connection.Close()
321314
}
322315

323316
// AccessControlEntriesToVisibility converts the access control entries used by Dev Tunnels to a friendly visibility value.

internal/codespaces/rpc/test/port_forwarder.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ import (
1313
type PortForwarder struct{}
1414

1515
// Close implements portforwarder.PortForwarder.
16-
func (PortForwarder) CloseSSHConnection() {
17-
panic("unimplemented")
16+
func (PortForwarder) Close() error {
17+
return nil
1818
}
1919

2020
// ConnectToForwardedPort implements portforwarder.PortForwarder.

internal/codespaces/states.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ func PollPostCreateStates(ctx context.Context, progress progressIndicator, apiCl
4747
if err != nil {
4848
return fmt.Errorf("failed to create port forwarder: %w", err)
4949
}
50+
defer safeClose(fwd, &err)
5051

5152
// Ensure local port is listening before client (getPostCreateOutput) connects.
5253
listen, localPort, err := ListenTCP(0, false)

pkg/cmd/codespace/jupyter.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ func (a *App) Jupyter(ctx context.Context, selector *CodespaceSelector) (err err
4848
if err != nil {
4949
return fmt.Errorf("failed to create port forwarder: %w", err)
5050
}
51+
defer safeClose(fwd, &err)
5152

5253
var (
5354
invoker rpc.Invoker

pkg/cmd/codespace/logs.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ func (a *App) Logs(ctx context.Context, selector *CodespaceSelector, follow bool
5151
if err != nil {
5252
return fmt.Errorf("failed to create port forwarder: %w", err)
5353
}
54+
defer safeClose(fwd, &err)
5455

5556
// Ensure local port is listening before client (getPostCreateOutput) connects.
5657
listen, localPort, err := codespaces.ListenTCP(0, false)

pkg/cmd/codespace/ports.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ func (a *App) ListPorts(ctx context.Context, selector *CodespaceSelector, export
6666
if err != nil {
6767
return fmt.Errorf("failed to create port forwarder: %w", err)
6868
}
69+
defer safeClose(fwd, &err)
6970

7071
var ports []*tunnels.TunnelPort
7172
err = a.RunWithProgress("Fetching ports", func() (err error) {
@@ -246,6 +247,7 @@ func (a *App) UpdatePortVisibility(ctx context.Context, selector *CodespaceSelec
246247
if err != nil {
247248
return fmt.Errorf("failed to create port forwarder: %w", err)
248249
}
250+
defer safeClose(fwd, &err)
249251

250252
// TODO: check if port visibility can be updated in parallel instead of sequentially
251253
for _, port := range ports {
@@ -337,6 +339,7 @@ func (a *App) ForwardPorts(ctx context.Context, selector *CodespaceSelector, por
337339
if err != nil {
338340
return fmt.Errorf("failed to create port forwarder: %w", err)
339341
}
342+
defer safeClose(fwd, &err)
340343

341344
opts := portforwarder.ForwardPortOpts{
342345
Port: pair.remote,

pkg/cmd/codespace/rebuild.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ func (a *App) Rebuild(ctx context.Context, selector *CodespaceSelector, full boo
6060
if err != nil {
6161
return fmt.Errorf("failed to create port forwarder: %w", err)
6262
}
63+
defer safeClose(fwd, &err)
6364

6465
invoker, err := rpc.CreateInvoker(ctx, fwd)
6566
if err != nil {

pkg/cmd/codespace/ssh.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e
202202
if err != nil {
203203
return fmt.Errorf("failed to create port forwarder: %w", err)
204204
}
205+
defer safeClose(fwd, &err)
205206

206207
var (
207208
invoker rpc.Invoker
@@ -238,9 +239,6 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e
238239
return fmt.Errorf("failed to forward port: %w", err)
239240
}
240241

241-
// Close the SSH connection when we're done
242-
defer fwd.CloseSSHConnection()
243-
244242
// Connect to the forwarded port
245243
err = fwd.ConnectToForwardedPort(ctx, stdio, opts)
246244
if err != nil {
@@ -584,6 +582,7 @@ func (a *App) printOpenSSHConfig(ctx context.Context, opts sshOptions) (err erro
584582
sshUsers <- result
585583
return
586584
}
585+
defer safeClose(fwd, &err)
587586

588587
invoker, err := rpc.CreateInvoker(ctx, fwd)
589588
if err != nil {

0 commit comments

Comments
 (0)