@@ -16,10 +16,15 @@ const (
1616 clientName = "gh"
1717)
1818
19+ type TunnelClient struct {
20+ * tunnels.Client
21+ connected bool
22+ }
23+
1924type CodespaceConnection struct {
2025 tunnelProperties api.TunnelProperties
2126 TunnelManager * tunnels.Manager
22- TunnelClient * tunnels. Client
27+ TunnelClient * TunnelClient
2328 Options * tunnels.TunnelRequestOptions
2429 Tunnel * tunnels.Tunnel
2530 AllowedPortPrivacySettings []string
@@ -74,6 +79,38 @@ func NewCodespaceConnection(ctx context.Context, codespace *api.Codespace, httpC
7479 }, nil
7580}
7681
82+ // Connect connects the client to the tunnel.
83+ func (c * CodespaceConnection ) Connect (ctx context.Context ) error {
84+ // If already connected, return
85+ if c .TunnelClient .connected {
86+ return nil
87+ }
88+
89+ // Connect to the tunnel
90+ if err := c .TunnelClient .Client .Connect (ctx , "" ); err != nil {
91+ return fmt .Errorf ("error connecting to tunnel: %w" , err )
92+ }
93+
94+ // Set the connected flag so we know we're connected
95+ c .TunnelClient .connected = true
96+
97+ return nil
98+ }
99+
100+ // Close closes the underlying tunnel client SSH connection.
101+ func (c * CodespaceConnection ) Close () error {
102+ // Don't close if we're not connected
103+ if c .TunnelClient != nil && c .TunnelClient .connected {
104+ if err := c .TunnelClient .Close (); err != nil {
105+ return fmt .Errorf ("failed to close tunnel client connection: %w" , err )
106+ }
107+
108+ c .TunnelClient .connected = false
109+ }
110+
111+ return nil
112+ }
113+
77114// getTunnelManager creates a tunnel manager for the given codespace.
78115// The tunnel manager is used to get the tunnel hosted in the codespace that we
79116// want to connect to and perform operations on ports (add, remove, list, etc.).
@@ -96,7 +133,7 @@ func getTunnelManager(tunnelProperties api.TunnelProperties, httpClient *http.Cl
96133// getTunnelClient creates a tunnel client for the given tunnel.
97134// The tunnel client is used to connect to the the tunnel and allows
98135// for ports to be forwarded locally.
99- func getTunnelClient (ctx context.Context , tunnelManager * tunnels.Manager , tunnel * tunnels.Tunnel , options * tunnels.TunnelRequestOptions ) (tunnelClient * tunnels. Client , err error ) {
136+ func getTunnelClient (ctx context.Context , tunnelManager * tunnels.Manager , tunnel * tunnels.Tunnel , options * tunnels.TunnelRequestOptions ) (tunnelClient * TunnelClient , err error ) {
100137 // Get the tunnel that we want to connect to
101138 codespaceTunnel , err := tunnelManager .GetTunnel (ctx , tunnel , options )
102139 if err != nil {
@@ -107,10 +144,15 @@ func getTunnelClient(ctx context.Context, tunnelManager *tunnels.Manager, tunnel
107144 codespaceTunnel .AccessTokens = tunnel .AccessTokens
108145
109146 // We need to pass false for accept local connections because we don't want to automatically connect to all forwarded ports
110- tunnelClient , err = tunnels .NewClient (log .New (io .Discard , "" , log .LstdFlags ), codespaceTunnel , false )
147+ client , err : = tunnels .NewClient (log .New (io .Discard , "" , log .LstdFlags ), codespaceTunnel , false )
111148 if err != nil {
112149 return nil , fmt .Errorf ("error creating tunnel client: %w" , err )
113150 }
114151
152+ tunnelClient = & TunnelClient {
153+ Client : client ,
154+ connected : false ,
155+ }
156+
115157 return tunnelClient , nil
116158}
0 commit comments