166166
167167socket_error = OSError # keep that public name in module namespace
168168
169- if _ssl .HAS_TLS_UNIQUE :
170- CHANNEL_BINDING_TYPES = ['tls-unique' ]
171- else :
172- CHANNEL_BINDING_TYPES = []
169+ CHANNEL_BINDING_TYPES = ['tls-unique' ]
173170
174171HAS_NEVER_CHECK_COMMON_NAME = hasattr (_ssl , 'HOSTFLAG_NEVER_CHECK_SUBJECT' )
175172
@@ -407,11 +404,11 @@ def wrap_bio(self, incoming, outgoing, server_side=False,
407404 server_hostname = None , session = None ):
408405 # Need to encode server_hostname here because _wrap_bio() can only
409406 # handle ASCII str.
410- sslobj = self ._wrap_bio (
407+ return self .sslobject_class (
411408 incoming , outgoing , server_side = server_side ,
412- server_hostname = self ._encode_hostname (server_hostname )
409+ server_hostname = self ._encode_hostname (server_hostname ),
410+ session = session , _context = self ,
413411 )
414- return self .sslobject_class (sslobj , session = session )
415412
416413 def set_npn_protocols (self , npn_protocols ):
417414 protos = bytearray ()
@@ -616,12 +613,13 @@ class SSLObject:
616613 * The ``do_handshake_on_connect`` and ``suppress_ragged_eofs`` machinery.
617614 """
618615
619- def __init__ (self , sslobj , owner = None , session = None ):
620- self ._sslobj = sslobj
621- # Note: _sslobj takes a weak reference to owner
622- self ._sslobj .owner = owner or self
623- if session is not None :
624- self ._sslobj .session = session
616+ def __init__ (self , incoming , outgoing , server_side = False ,
617+ server_hostname = None , session = None , _context = None ):
618+ self ._sslobj = _context ._wrap_bio (
619+ incoming , outgoing , server_side = server_side ,
620+ server_hostname = server_hostname ,
621+ owner = self , session = session
622+ )
625623
626624 @property
627625 def context (self ):
@@ -684,7 +682,7 @@ def getpeercert(self, binary_form=False):
684682 Return None if no certificate was provided, {} if a certificate was
685683 provided, but not validated.
686684 """
687- return self ._sslobj .peer_certificate (binary_form )
685+ return self ._sslobj .getpeercert (binary_form )
688686
689687 def selected_npn_protocol (self ):
690688 """Return the currently selected NPN protocol as a string, or ``None``
@@ -732,13 +730,7 @@ def get_channel_binding(self, cb_type="tls-unique"):
732730 """Get channel binding data for current connection. Raise ValueError
733731 if the requested `cb_type` is not supported. Return bytes of the data
734732 or None if the data is not available (e.g. before the handshake)."""
735- if cb_type not in CHANNEL_BINDING_TYPES :
736- raise ValueError ("Unsupported channel binding type" )
737- if cb_type != "tls-unique" :
738- raise NotImplementedError (
739- "{0} channel binding type not implemented"
740- .format (cb_type ))
741- return self ._sslobj .tls_unique_cb ()
733+ return self ._sslobj .get_channel_binding (cb_type )
742734
743735 def version (self ):
744736 """Return a string identifying the protocol version used by the
@@ -832,10 +824,10 @@ def __init__(self, sock=None, keyfile=None, certfile=None,
832824 if connected :
833825 # create the SSL object
834826 try :
835- sslobj = self ._context ._wrap_socket (self , server_side ,
836- self .server_hostname )
837- self . _sslobj = SSLObject ( sslobj , owner = self ,
838- session = self . _session )
827+ self . _sslobj = self ._context ._wrap_socket (
828+ self , server_side , self .server_hostname ,
829+ owner = self , session = self . _session ,
830+ )
839831 if do_handshake_on_connect :
840832 timeout = self .gettimeout ()
841833 if timeout == 0.0 :
@@ -895,10 +887,13 @@ def read(self, len=1024, buffer=None):
895887 Return zero-length string on EOF."""
896888
897889 self ._checkClosed ()
898- if not self ._sslobj :
890+ if self ._sslobj is None :
899891 raise ValueError ("Read on closed or unwrapped SSL socket." )
900892 try :
901- return self ._sslobj .read (len , buffer )
893+ if buffer is not None :
894+ return self ._sslobj .read (len , buffer )
895+ else :
896+ return self ._sslobj .read (len )
902897 except SSLError as x :
903898 if x .args [0 ] == SSL_ERROR_EOF and self .suppress_ragged_eofs :
904899 if buffer is not None :
@@ -913,7 +908,7 @@ def write(self, data):
913908 number of bytes of DATA actually transmitted."""
914909
915910 self ._checkClosed ()
916- if not self ._sslobj :
911+ if self ._sslobj is None :
917912 raise ValueError ("Write on closed or unwrapped SSL socket." )
918913 return self ._sslobj .write (data )
919914
@@ -929,41 +924,42 @@ def getpeercert(self, binary_form=False):
929924
930925 def selected_npn_protocol (self ):
931926 self ._checkClosed ()
932- if not self ._sslobj or not _ssl .HAS_NPN :
927+ if self ._sslobj is None or not _ssl .HAS_NPN :
933928 return None
934929 else :
935930 return self ._sslobj .selected_npn_protocol ()
936931
937932 def selected_alpn_protocol (self ):
938933 self ._checkClosed ()
939- if not self ._sslobj or not _ssl .HAS_ALPN :
934+ if self ._sslobj is None or not _ssl .HAS_ALPN :
940935 return None
941936 else :
942937 return self ._sslobj .selected_alpn_protocol ()
943938
944939 def cipher (self ):
945940 self ._checkClosed ()
946- if not self ._sslobj :
941+ if self ._sslobj is None :
947942 return None
948943 else :
949944 return self ._sslobj .cipher ()
950945
951946 def shared_ciphers (self ):
952947 self ._checkClosed ()
953- if not self ._sslobj :
948+ if self ._sslobj is None :
954949 return None
955- return self ._sslobj .shared_ciphers ()
950+ else :
951+ return self ._sslobj .shared_ciphers ()
956952
957953 def compression (self ):
958954 self ._checkClosed ()
959- if not self ._sslobj :
955+ if self ._sslobj is None :
960956 return None
961957 else :
962958 return self ._sslobj .compression ()
963959
964960 def send (self , data , flags = 0 ):
965961 self ._checkClosed ()
966- if self ._sslobj :
962+ if self ._sslobj is not None :
967963 if flags != 0 :
968964 raise ValueError (
969965 "non-zero flags not allowed in calls to send() on %s" %
@@ -974,7 +970,7 @@ def send(self, data, flags=0):
974970
975971 def sendto (self , data , flags_or_addr , addr = None ):
976972 self ._checkClosed ()
977- if self ._sslobj :
973+ if self ._sslobj is not None :
978974 raise ValueError ("sendto not allowed on instances of %s" %
979975 self .__class__ )
980976 elif addr is None :
@@ -990,7 +986,7 @@ def sendmsg(self, *args, **kwargs):
990986
991987 def sendall (self , data , flags = 0 ):
992988 self ._checkClosed ()
993- if self ._sslobj :
989+ if self ._sslobj is not None :
994990 if flags != 0 :
995991 raise ValueError (
996992 "non-zero flags not allowed in calls to sendall() on %s" %
@@ -1008,15 +1004,15 @@ def sendfile(self, file, offset=0, count=None):
10081004 """Send a file, possibly by using os.sendfile() if this is a
10091005 clear-text socket. Return the total number of bytes sent.
10101006 """
1011- if self ._sslobj is None :
1007+ if self ._sslobj is not None :
1008+ return self ._sendfile_use_send (file , offset , count )
1009+ else :
10121010 # os.sendfile() works with plain sockets only
10131011 return super ().sendfile (file , offset , count )
1014- else :
1015- return self ._sendfile_use_send (file , offset , count )
10161012
10171013 def recv (self , buflen = 1024 , flags = 0 ):
10181014 self ._checkClosed ()
1019- if self ._sslobj :
1015+ if self ._sslobj is not None :
10201016 if flags != 0 :
10211017 raise ValueError (
10221018 "non-zero flags not allowed in calls to recv() on %s" %
@@ -1031,7 +1027,7 @@ def recv_into(self, buffer, nbytes=None, flags=0):
10311027 nbytes = len (buffer )
10321028 elif nbytes is None :
10331029 nbytes = 1024
1034- if self ._sslobj :
1030+ if self ._sslobj is not None :
10351031 if flags != 0 :
10361032 raise ValueError (
10371033 "non-zero flags not allowed in calls to recv_into() on %s" %
@@ -1042,15 +1038,15 @@ def recv_into(self, buffer, nbytes=None, flags=0):
10421038
10431039 def recvfrom (self , buflen = 1024 , flags = 0 ):
10441040 self ._checkClosed ()
1045- if self ._sslobj :
1041+ if self ._sslobj is not None :
10461042 raise ValueError ("recvfrom not allowed on instances of %s" %
10471043 self .__class__ )
10481044 else :
10491045 return super ().recvfrom (buflen , flags )
10501046
10511047 def recvfrom_into (self , buffer , nbytes = None , flags = 0 ):
10521048 self ._checkClosed ()
1053- if self ._sslobj :
1049+ if self ._sslobj is not None :
10541050 raise ValueError ("recvfrom_into not allowed on instances of %s" %
10551051 self .__class__ )
10561052 else :
@@ -1066,7 +1062,7 @@ def recvmsg_into(self, *args, **kwargs):
10661062
10671063 def pending (self ):
10681064 self ._checkClosed ()
1069- if self ._sslobj :
1065+ if self ._sslobj is not None :
10701066 return self ._sslobj .pending ()
10711067 else :
10721068 return 0
@@ -1078,7 +1074,7 @@ def shutdown(self, how):
10781074
10791075 def unwrap (self ):
10801076 if self ._sslobj :
1081- s = self ._sslobj .unwrap ()
1077+ s = self ._sslobj .shutdown ()
10821078 self ._sslobj = None
10831079 return s
10841080 else :
@@ -1096,6 +1092,11 @@ def do_handshake(self, block=False):
10961092 if timeout == 0.0 and block :
10971093 self .settimeout (None )
10981094 self ._sslobj .do_handshake ()
1095+ if self .context .check_hostname :
1096+ if not self .server_hostname :
1097+ raise ValueError ("check_hostname needs server_hostname "
1098+ "argument" )
1099+ match_hostname (self .getpeercert (), self .server_hostname )
10991100 finally :
11001101 self .settimeout (timeout )
11011102
@@ -1104,11 +1105,12 @@ def _real_connect(self, addr, connect_ex):
11041105 raise ValueError ("can't connect in server-side mode" )
11051106 # Here we assume that the socket is client-side, and not
11061107 # connected at the time of the call. We connect it, then wrap it.
1107- if self ._connected :
1108+ if self ._connected or self . _sslobj is not None :
11081109 raise ValueError ("attempt to connect already-connected SSLSocket!" )
1109- sslobj = self .context ._wrap_socket (self , False , self .server_hostname )
1110- self ._sslobj = SSLObject (sslobj , owner = self ,
1111- session = self ._session )
1110+ self ._sslobj = self .context ._wrap_socket (
1111+ self , False , self .server_hostname ,
1112+ owner = self , session = self ._session
1113+ )
11121114 try :
11131115 if connect_ex :
11141116 rc = super ().connect_ex (addr )
@@ -1151,18 +1153,24 @@ def get_channel_binding(self, cb_type="tls-unique"):
11511153 if the requested `cb_type` is not supported. Return bytes of the data
11521154 or None if the data is not available (e.g. before the handshake).
11531155 """
1154- if self ._sslobj is None :
1156+ if self ._sslobj is not None :
1157+ return self ._sslobj .get_channel_binding (cb_type )
1158+ else :
1159+ if cb_type not in CHANNEL_BINDING_TYPES :
1160+ raise ValueError (
1161+ "{0} channel binding type not implemented" .format (cb_type )
1162+ )
11551163 return None
1156- return self ._sslobj .get_channel_binding (cb_type )
11571164
11581165 def version (self ):
11591166 """
11601167 Return a string identifying the protocol version used by the
11611168 current SSL channel, or None if there is no established channel.
11621169 """
1163- if self ._sslobj is None :
1170+ if self ._sslobj is not None :
1171+ return self ._sslobj .version ()
1172+ else :
11641173 return None
1165- return self ._sslobj .version ()
11661174
11671175
11681176# Python does not support forward declaration of types.
0 commit comments