Skip to content

Commit 2284b7e

Browse files
committed
Fallback to HTTP on websocket connect failures
1 parent 87bc724 commit 2284b7e

2 files changed

Lines changed: 73 additions & 3 deletions

File tree

codex-rs/core/src/client.rs

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1014,7 +1014,7 @@ impl ModelClientSession {
10141014
pub async fn preconnect_websocket(
10151015
&mut self,
10161016
session_telemetry: &SessionTelemetry,
1017-
_model_info: &ModelInfo,
1017+
model_info: &ModelInfo,
10181018
) -> std::result::Result<(), ApiError> {
10191019
if !self.client.responses_websocket_enabled() {
10201020
return Ok(());
@@ -1033,7 +1033,7 @@ impl ModelClientSession {
10331033
client_setup.api_auth.as_ref(),
10341034
PendingUnauthorizedRetry::default(),
10351035
);
1036-
let connection = self
1036+
let connection = match self
10371037
.client
10381038
.connect_websocket(
10391039
session_telemetry,
@@ -1044,7 +1044,15 @@ impl ModelClientSession {
10441044
auth_context,
10451045
RequestRouteTelemetry::for_endpoint(RESPONSES_ENDPOINT),
10461046
)
1047-
.await?;
1047+
.await
1048+
{
1049+
Ok(connection) => connection,
1050+
Err(err) if should_fallback_to_http_after_websocket_connect_error(&err) => {
1051+
self.try_switch_fallback_transport(session_telemetry, model_info);
1052+
return Ok(());
1053+
}
1054+
Err(err) => return Err(err),
1055+
};
10481056
self.websocket_session.connection = Some(connection);
10491057
self.websocket_session
10501058
.set_connection_reused(/*connection_reused*/ false);
@@ -1335,6 +1343,10 @@ impl ModelClientSession {
13351343
{
13361344
return Ok(WebsocketStreamOutcome::FallbackToHttp);
13371345
}
1346+
Err(err) if should_fallback_to_http_after_websocket_connect_error(&err) => {
1347+
self.reset_websocket_session();
1348+
return Ok(WebsocketStreamOutcome::FallbackToHttp);
1349+
}
13381350
Err(ApiError::Transport(
13391351
unauthorized_transport @ TransportError::Http { status, .. },
13401352
)) if status == StatusCode::UNAUTHORIZED => {
@@ -1905,6 +1917,13 @@ fn api_error_http_status(error: &ApiError) -> Option<u16> {
19051917
}
19061918
}
19071919

1920+
fn should_fallback_to_http_after_websocket_connect_error(error: &ApiError) -> bool {
1921+
matches!(
1922+
error,
1923+
ApiError::Transport(TransportError::Timeout | TransportError::Network(_))
1924+
)
1925+
}
1926+
19081927
struct ApiTelemetry {
19091928
session_telemetry: SessionTelemetry,
19101929
auth_context: AuthRequestTelemetryContext,

codex-rs/core/tests/suite/websocket_fallback.rs

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,57 @@ async fn websocket_fallback_switches_to_http_on_upgrade_required_connect() -> Re
7575
Ok(())
7676
}
7777

78+
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
79+
async fn websocket_fallback_switches_to_http_on_connect_timeout() -> Result<()> {
80+
skip_if_no_network!(Ok(()));
81+
82+
let server = responses::start_mock_server().await;
83+
Mock::given(method("GET"))
84+
.and(path_regex(".*/responses$"))
85+
.respond_with(ResponseTemplate::new(426).set_delay(Duration::from_millis(500)))
86+
.mount(&server)
87+
.await;
88+
89+
let response_mock = mount_sse_once(
90+
&server,
91+
sse(vec![ev_response_created("resp-1"), ev_completed("resp-1")]),
92+
)
93+
.await;
94+
95+
let mut builder = test_codex().with_config({
96+
let base_url = format!("{}/v1", server.uri());
97+
move |config| {
98+
config.model_provider.base_url = Some(base_url);
99+
config.model_provider.wire_api = WireApi::Responses;
100+
config.model_provider.supports_websockets = true;
101+
config.model_provider.stream_max_retries = Some(5);
102+
config.model_provider.request_max_retries = Some(0);
103+
config.model_provider.websocket_connect_timeout_ms = Some(50);
104+
}
105+
});
106+
let test = builder.build(&server).await?;
107+
108+
test.submit_turn("hello").await?;
109+
110+
let requests = server.received_requests().await.unwrap_or_default();
111+
let websocket_attempts = requests
112+
.iter()
113+
.filter(|req| req.method == Method::GET && req.url.path().ends_with("/responses"))
114+
.count();
115+
let http_attempts = requests
116+
.iter()
117+
.filter(|req| req.method == Method::POST && req.url.path().ends_with("/responses"))
118+
.count();
119+
120+
// Timeout during the WebSocket handshake is a strong signal that the local transport path
121+
// cannot carry WebSockets, so fallback should activate before consuming stream retries.
122+
assert_eq!(websocket_attempts, 1);
123+
assert_eq!(http_attempts, 1);
124+
assert_eq!(response_mock.requests().len(), 1);
125+
126+
Ok(())
127+
}
128+
78129
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
79130
async fn websocket_fallback_switches_to_http_after_retries_exhausted() -> Result<()> {
80131
skip_if_no_network!(Ok(()));

0 commit comments

Comments
 (0)