@@ -22,15 +22,16 @@ def setUp(self):
2222 self .loop = asyncio .new_event_loop ()
2323 self .set_event_loop (self .loop )
2424
25- def ssl_protocol (self , waiter = None ):
25+ def ssl_protocol (self , * , waiter = None , proto = None ):
2626 sslcontext = test_utils .dummy_ssl_context ()
27- app_proto = asyncio .Protocol ()
28- proto = sslproto .SSLProtocol (self .loop , app_proto , sslcontext , waiter )
29- self .assertIs (proto ._app_transport .get_protocol (), app_proto )
30- self .addCleanup (proto ._app_transport .close )
31- return proto
32-
33- def connection_made (self , ssl_proto , do_handshake = None ):
27+ if proto is None : # app protocol
28+ proto = asyncio .Protocol ()
29+ ssl_proto = sslproto .SSLProtocol (self .loop , proto , sslcontext , waiter )
30+ self .assertIs (ssl_proto ._app_transport .get_protocol (), proto )
31+ self .addCleanup (ssl_proto ._app_transport .close )
32+ return ssl_proto
33+
34+ def connection_made (self , ssl_proto , * , do_handshake = None ):
3435 transport = mock .Mock ()
3536 sslpipe = mock .Mock ()
3637 sslpipe .shutdown .return_value = b''
@@ -48,7 +49,7 @@ def test_cancel_handshake(self):
4849 # Python issue #23197: cancelling a handshake must not raise an
4950 # exception or log an error, even if the handshake failed
5051 waiter = asyncio .Future (loop = self .loop )
51- ssl_proto = self .ssl_protocol (waiter )
52+ ssl_proto = self .ssl_protocol (waiter = waiter )
5253 handshake_fut = asyncio .Future (loop = self .loop )
5354
5455 def do_handshake (callback ):
@@ -58,14 +59,14 @@ def do_handshake(callback):
5859 return []
5960
6061 waiter .cancel ()
61- self .connection_made (ssl_proto , do_handshake )
62+ self .connection_made (ssl_proto , do_handshake = do_handshake )
6263
6364 with test_utils .disable_logger ():
6465 self .loop .run_until_complete (handshake_fut )
6566
6667 def test_eof_received_waiter (self ):
6768 waiter = asyncio .Future (loop = self .loop )
68- ssl_proto = self .ssl_protocol (waiter )
69+ ssl_proto = self .ssl_protocol (waiter = waiter )
6970 self .connection_made (ssl_proto )
7071 ssl_proto .eof_received ()
7172 test_utils .run_briefly (self .loop )
@@ -76,7 +77,7 @@ def test_fatal_error_no_name_error(self):
7677 # _fatal_error() generates a NameError if sslproto.py
7778 # does not import base_events.
7879 waiter = asyncio .Future (loop = self .loop )
79- ssl_proto = self .ssl_protocol (waiter )
80+ ssl_proto = self .ssl_protocol (waiter = waiter )
8081 # Temporarily turn off error logging so as not to spoil test output.
8182 log_level = log .logger .getEffectiveLevel ()
8283 log .logger .setLevel (logging .FATAL )
@@ -90,7 +91,7 @@ def test_connection_lost(self):
9091 # From issue #472.
9192 # yield from waiter hang if lost_connection was called.
9293 waiter = asyncio .Future (loop = self .loop )
93- ssl_proto = self .ssl_protocol (waiter )
94+ ssl_proto = self .ssl_protocol (waiter = waiter )
9495 self .connection_made (ssl_proto )
9596 ssl_proto .connection_lost (ConnectionAbortedError )
9697 test_utils .run_briefly (self .loop )
@@ -99,10 +100,7 @@ def test_connection_lost(self):
99100 def test_close_during_handshake (self ):
100101 # bpo-29743 Closing transport during handshake process leaks socket
101102 waiter = asyncio .Future (loop = self .loop )
102- ssl_proto = self .ssl_protocol (waiter )
103-
104- def do_handshake (callback ):
105- return []
103+ ssl_proto = self .ssl_protocol (waiter = waiter )
106104
107105 transport = self .connection_made (ssl_proto )
108106 test_utils .run_briefly (self .loop )
@@ -112,7 +110,7 @@ def do_handshake(callback):
112110
113111 def test_get_extra_info_on_closed_connection (self ):
114112 waiter = asyncio .Future (loop = self .loop )
115- ssl_proto = self .ssl_protocol (waiter )
113+ ssl_proto = self .ssl_protocol (waiter = waiter )
116114 self .assertIsNone (ssl_proto ._get_extra_info ('socket' ))
117115 default = object ()
118116 self .assertIs (ssl_proto ._get_extra_info ('socket' , default ), default )
@@ -123,12 +121,31 @@ def test_get_extra_info_on_closed_connection(self):
123121
124122 def test_set_new_app_protocol (self ):
125123 waiter = asyncio .Future (loop = self .loop )
126- ssl_proto = self .ssl_protocol (waiter )
124+ ssl_proto = self .ssl_protocol (waiter = waiter )
127125 new_app_proto = asyncio .Protocol ()
128126 ssl_proto ._app_transport .set_protocol (new_app_proto )
129127 self .assertIs (ssl_proto ._app_transport .get_protocol (), new_app_proto )
130128 self .assertIs (ssl_proto ._app_protocol , new_app_proto )
131129
130+ def test_data_received_after_closing (self ):
131+ ssl_proto = self .ssl_protocol ()
132+ self .connection_made (ssl_proto )
133+ transp = ssl_proto ._app_transport
134+
135+ transp .close ()
136+
137+ # should not raise
138+ self .assertIsNone (ssl_proto .data_received (b'data' ))
139+
140+ def test_write_after_closing (self ):
141+ ssl_proto = self .ssl_protocol ()
142+ self .connection_made (ssl_proto )
143+ transp = ssl_proto ._app_transport
144+ transp .close ()
145+
146+ # should not raise
147+ self .assertIsNone (transp .write (b'data' ))
148+
132149
133150if __name__ == '__main__' :
134151 unittest .main ()
0 commit comments