@@ -302,6 +302,7 @@ def test_start_tls_client_buf_proto_1(self):
302302
303303 server_context = test_utils .simple_server_sslcontext ()
304304 client_context = test_utils .simple_client_sslcontext ()
305+ client_con_made_calls = 0
305306
306307 def serve (sock ):
307308 sock .settimeout (self .TIMEOUT )
@@ -315,20 +316,21 @@ def serve(sock):
315316 data = sock .recv_all (len (HELLO_MSG ))
316317 self .assertEqual (len (data ), len (HELLO_MSG ))
317318
319+ sock .sendall (b'2' )
320+ data = sock .recv_all (len (HELLO_MSG ))
321+ self .assertEqual (len (data ), len (HELLO_MSG ))
322+
318323 sock .shutdown (socket .SHUT_RDWR )
319324 sock .close ()
320325
321- class ClientProto (asyncio .BufferedProtocol ):
322- def __init__ (self , on_data , on_eof ):
326+ class ClientProtoFirst (asyncio .BufferedProtocol ):
327+ def __init__ (self , on_data ):
323328 self .on_data = on_data
324- self .on_eof = on_eof
325- self .con_made_cnt = 0
326329 self .buf = bytearray (1 )
327330
328- def connection_made (proto , tr ):
329- proto .con_made_cnt += 1
330- # Ensure connection_made gets called only once.
331- self .assertEqual (proto .con_made_cnt , 1 )
331+ def connection_made (self , tr ):
332+ nonlocal client_con_made_calls
333+ client_con_made_calls += 1
332334
333335 def get_buffer (self , sizehint ):
334336 return self .buf
@@ -337,27 +339,50 @@ def buffer_updated(self, nsize):
337339 assert nsize == 1
338340 self .on_data .set_result (bytes (self .buf [:nsize ]))
339341
342+ class ClientProtoSecond (asyncio .Protocol ):
343+ def __init__ (self , on_data , on_eof ):
344+ self .on_data = on_data
345+ self .on_eof = on_eof
346+ self .con_made_cnt = 0
347+
348+ def connection_made (self , tr ):
349+ nonlocal client_con_made_calls
350+ client_con_made_calls += 1
351+
352+ def data_received (self , data ):
353+ self .on_data .set_result (data )
354+
340355 def eof_received (self ):
341356 self .on_eof .set_result (True )
342357
343358 async def client (addr ):
344359 await asyncio .sleep (0.5 , loop = self .loop )
345360
346- on_data = self .loop .create_future ()
361+ on_data1 = self .loop .create_future ()
362+ on_data2 = self .loop .create_future ()
347363 on_eof = self .loop .create_future ()
348364
349365 tr , proto = await self .loop .create_connection (
350- lambda : ClientProto ( on_data , on_eof ), * addr )
366+ lambda : ClientProtoFirst ( on_data1 ), * addr )
351367
352368 tr .write (HELLO_MSG )
353369 new_tr = await self .loop .start_tls (tr , proto , client_context )
354370
355- self .assertEqual (await on_data , b'O' )
371+ self .assertEqual (await on_data1 , b'O' )
372+ new_tr .write (HELLO_MSG )
373+
374+ new_tr .set_protocol (ClientProtoSecond (on_data2 , on_eof ))
375+ self .assertEqual (await on_data2 , b'2' )
356376 new_tr .write (HELLO_MSG )
357377 await on_eof
358378
359379 new_tr .close ()
360380
381+ # connection_made() should be called only once -- when
382+ # we establish connection for the first time. Start TLS
383+ # doesn't call connection_made() on application protocols.
384+ self .assertEqual (client_con_made_calls , 1 )
385+
361386 with self .tcp_server (serve , timeout = self .TIMEOUT ) as srv :
362387 self .loop .run_until_complete (
363388 asyncio .wait_for (client (srv .addr ),
0 commit comments