@@ -37,6 +37,10 @@ def _parse_connect_response(sock: Socket) -> Tuple[Optional[int], str]:
3737 return status , "\n " .join (lines )
3838
3939
40+ def _use_or_create_ssl_context (ssl_context : Optional [ssl .SSLContext ] = None ):
41+ return ssl_context if ssl_context is not None else ssl .create_default_context ()
42+
43+
4044def _establish_new_socket_connection (
4145 session_id : str ,
4246 server_hostname : str ,
@@ -47,7 +51,11 @@ def _establish_new_socket_connection(
4751 proxy : Optional [str ],
4852 proxy_headers : Optional [Dict [str , str ]],
4953 trace_enabled : bool ,
54+ ssl_context : Optional [ssl .SSLContext ] = None ,
5055) -> Union [ssl .SSLSocket , Socket ]:
56+
57+ ssl_context = _use_or_create_ssl_context (ssl_context )
58+
5159 if proxy is not None :
5260 parsed_proxy = urlparse (proxy )
5361 proxy_host , proxy_port = parsed_proxy .hostname , parsed_proxy .port or 80
@@ -83,7 +91,7 @@ def _establish_new_socket_connection(
8391 f"Failed to connect to the proxy (proxy: { proxy } , connect status code: { status } )"
8492 )
8593
86- sock = ssl . create_default_context () .wrap_socket (
94+ sock = ssl_context .wrap_socket (
8795 sock ,
8896 do_handshake_on_connect = True ,
8997 suppress_ragged_eofs = True ,
@@ -100,7 +108,7 @@ def _establish_new_socket_connection(
100108 return sock
101109
102110 sock = socket .create_connection ((server_hostname , server_port ), receive_timeout )
103- sock = ssl . create_default_context () .wrap_socket (
111+ sock = ssl_context .wrap_socket (
104112 sock ,
105113 do_handshake_on_connect = True ,
106114 suppress_ragged_eofs = True ,
0 commit comments