@@ -1014,7 +1014,7 @@ impl ModelClientSession {
10141014 pub async fn preconnect_websocket (
10151015 & mut self ,
10161016 session_telemetry : & SessionTelemetry ,
1017- _model_info : & ModelInfo ,
1017+ model_info : & ModelInfo ,
10181018 ) -> std:: result:: Result < ( ) , ApiError > {
10191019 if !self . client . responses_websocket_enabled ( ) {
10201020 return Ok ( ( ) ) ;
@@ -1033,7 +1033,7 @@ impl ModelClientSession {
10331033 client_setup. api_auth . as_ref ( ) ,
10341034 PendingUnauthorizedRetry :: default ( ) ,
10351035 ) ;
1036- let connection = self
1036+ let connection = match self
10371037 . client
10381038 . connect_websocket (
10391039 session_telemetry,
@@ -1044,7 +1044,15 @@ impl ModelClientSession {
10441044 auth_context,
10451045 RequestRouteTelemetry :: for_endpoint ( RESPONSES_ENDPOINT ) ,
10461046 )
1047- . await ?;
1047+ . await
1048+ {
1049+ Ok ( connection) => connection,
1050+ Err ( err) if should_fallback_to_http_after_websocket_connect_error ( & err) => {
1051+ self . try_switch_fallback_transport ( session_telemetry, model_info) ;
1052+ return Ok ( ( ) ) ;
1053+ }
1054+ Err ( err) => return Err ( err) ,
1055+ } ;
10481056 self . websocket_session . connection = Some ( connection) ;
10491057 self . websocket_session
10501058 . set_connection_reused ( /*connection_reused*/ false ) ;
@@ -1335,6 +1343,10 @@ impl ModelClientSession {
13351343 {
13361344 return Ok ( WebsocketStreamOutcome :: FallbackToHttp ) ;
13371345 }
1346+ Err ( err) if should_fallback_to_http_after_websocket_connect_error ( & err) => {
1347+ self . reset_websocket_session ( ) ;
1348+ return Ok ( WebsocketStreamOutcome :: FallbackToHttp ) ;
1349+ }
13381350 Err ( ApiError :: Transport (
13391351 unauthorized_transport @ TransportError :: Http { status, .. } ,
13401352 ) ) if status == StatusCode :: UNAUTHORIZED => {
@@ -1905,6 +1917,13 @@ fn api_error_http_status(error: &ApiError) -> Option<u16> {
19051917 }
19061918}
19071919
1920+ fn should_fallback_to_http_after_websocket_connect_error ( error : & ApiError ) -> bool {
1921+ matches ! (
1922+ error,
1923+ ApiError :: Transport ( TransportError :: Timeout | TransportError :: Network ( _) )
1924+ )
1925+ }
1926+
19081927struct ApiTelemetry {
19091928 session_telemetry : SessionTelemetry ,
19101929 auth_context : AuthRequestTelemetryContext ,
0 commit comments