Skip to content

Commit e413afe

Browse files
committed
fix openssl
1 parent 07f4427 commit e413afe

File tree

1 file changed

+112
-33
lines changed

1 file changed

+112
-33
lines changed

crates/stdlib/src/openssl.rs

Lines changed: 112 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -706,32 +706,40 @@ mod _ssl {
706706
}
707707
}
708708

709+
// OpenSSL record type constants for msg_callback
710+
const SSL3_RT_CHANGE_CIPHER_SPEC: i32 = 20;
711+
const SSL3_RT_ALERT: i32 = 21;
712+
const SSL3_RT_HANDSHAKE: i32 = 22;
713+
const SSL3_RT_HEADER: i32 = 256;
714+
const SSL3_RT_INNER_CONTENT_TYPE: i32 = 257;
715+
// Special value for change cipher spec (CPython compatibility)
716+
const SSL3_MT_CHANGE_CIPHER_SPEC: i32 = 0x0101;
717+
709718
// Message callback function called by OpenSSL
710-
// NOTE: This callback is intentionally a no-op to avoid deadlocks.
711-
// The msg_callback can be called during various SSL operations (read, write, handshake),
712-
// and invoking Python code from within these operations can cause deadlocks
713-
// (see CPython bpo-43577). A proper implementation would require careful lock ordering.
719+
// Called during SSL operations to report protocol messages.
720+
// debughelpers.c:_PySSL_msg_callback
714721
unsafe extern "C" fn _msg_callback(
715-
_write_p: libc::c_int,
716-
_version: libc::c_int,
717-
_content_type: libc::c_int,
718-
_buf: *const libc::c_void,
719-
_len: usize,
720-
_ssl_ptr: *mut sys::SSL,
722+
write_p: libc::c_int,
723+
mut version: libc::c_int,
724+
content_type: libc::c_int,
725+
buf: *const libc::c_void,
726+
len: usize,
727+
ssl_ptr: *mut sys::SSL,
721728
_arg: *mut libc::c_void,
722729
) {
723730
if ssl_ptr.is_null() {
724731
return;
725732
}
726733

727734
unsafe {
728-
// Get SSL socket from SSL_get_app_data (index 0)
735+
// Get SSL socket from SSL_get_ex_data (index 0)
729736
let ssl_socket_ptr = sys::SSL_get_ex_data(ssl_ptr, 0);
730737
if ssl_socket_ptr.is_null() {
731738
return;
732739
}
733740

734-
let ssl_socket = &*(ssl_socket_ptr as *const PySslSocket);
741+
// ssl_socket_ptr is a pointer to Py<PySslSocket>, set in _wrap_socket/_wrap_bio
742+
let ssl_socket: &Py<PySslSocket> = &*(ssl_socket_ptr as *const Py<PySslSocket>);
735743

736744
// Get the callback from the context
737745
let callback_opt = ssl_socket.ctx.read().msg_callback.lock().clone();
@@ -761,16 +769,50 @@ mod _ssl {
761769
// Determine direction string
762770
let direction_str = if write_p != 0 { "write" } else { "read" };
763771

772+
// Calculate msg_type based on content_type (debughelpers.c behavior)
773+
let msg_type = match content_type {
774+
SSL3_RT_CHANGE_CIPHER_SPEC => SSL3_MT_CHANGE_CIPHER_SPEC,
775+
SSL3_RT_ALERT => {
776+
// byte 1 is alert type
777+
if len >= 2 { buf_slice[1] as i32 } else { -1 }
778+
}
779+
SSL3_RT_HANDSHAKE => {
780+
// byte 0 is handshake type
781+
if !buf_slice.is_empty() {
782+
buf_slice[0] as i32
783+
} else {
784+
-1
785+
}
786+
}
787+
SSL3_RT_HEADER => {
788+
// Frame header: version in bytes 1..2, type in byte 0
789+
if len >= 3 {
790+
version = ((buf_slice[1] as i32) << 8) | (buf_slice[2] as i32);
791+
buf_slice[0] as i32
792+
} else {
793+
-1
794+
}
795+
}
796+
SSL3_RT_INNER_CONTENT_TYPE => {
797+
// Inner content type in byte 0
798+
if !buf_slice.is_empty() {
799+
buf_slice[0] as i32
800+
} else {
801+
-1
802+
}
803+
}
804+
_ => -1,
805+
};
806+
764807
// Call the Python callback
765808
// Signature: callback(conn, direction, version, content_type, msg_type, data)
766-
// For simplicity, we'll pass msg_type as 0 (would need more parsing to get the actual type)
767809
match callback.call(
768810
(
769811
ssl_socket_obj,
770812
vm.ctx.new_str(direction_str),
771813
vm.ctx.new_int(version),
772814
vm.ctx.new_int(content_type),
773-
vm.ctx.new_int(0), // msg_type - would need parsing
815+
vm.ctx.new_int(msg_type),
774816
msg_bytes,
775817
),
776818
vm,
@@ -1300,7 +1342,7 @@ mod _ssl {
13001342
if let Some(cadata) = args.cadata {
13011343
let certs = match cadata {
13021344
Either::A(s) => {
1303-
if !s.is_ascii() {
1345+
if !s.as_str().is_ascii() {
13041346
return Err(invalid_cadata(vm));
13051347
}
13061348
X509::stack_from_pem(s.as_bytes())
@@ -1762,12 +1804,19 @@ mod _ssl {
17621804
// Check if SNI callback is configured (minimize lock time)
17631805
let has_sni_callback = zelf.sni_callback.lock().is_some();
17641806

1765-
// Set SNI callback data if needed (after releasing the lock)
1766-
if has_sni_callback {
1767-
let ssl_socket_weak = py_ref.as_object().downgrade(None, vm)?;
1768-
unsafe {
1769-
let ssl_ptr = py_ref.connection.read().ssl().as_ptr();
1807+
// Set up ex_data for callbacks
1808+
unsafe {
1809+
let ssl_ptr = py_ref.connection.read().ssl().as_ptr();
17701810

1811+
// Store ssl_socket pointer in index 0 for msg_callback (like CPython's SSL_set_app_data)
1812+
// This is safe because ssl_socket owns the SSL object and outlives it
1813+
// We store a pointer to Py<PySslSocket>, which msg_callback can dereference
1814+
let py_ptr: *const Py<PySslSocket> = &*py_ref;
1815+
sys::SSL_set_ex_data(ssl_ptr, 0, py_ptr as *mut _);
1816+
1817+
// Set SNI callback data if needed
1818+
if has_sni_callback {
1819+
let ssl_socket_weak = py_ref.as_object().downgrade(None, vm)?;
17711820
// Store callback data in SSL ex_data - use weak reference to avoid cycle
17721821
let callback_data = Box::new(SniCallbackData {
17731822
ssl_context: zelf.clone(),
@@ -1823,12 +1872,19 @@ mod _ssl {
18231872
// Check if SNI callback is configured (minimize lock time)
18241873
let has_sni_callback = zelf.sni_callback.lock().is_some();
18251874

1826-
// Set SNI callback data if needed (after releasing the lock)
1827-
if has_sni_callback {
1828-
let ssl_socket_weak = py_ref.as_object().downgrade(None, vm)?;
1829-
unsafe {
1830-
let ssl_ptr = py_ref.connection.read().ssl().as_ptr();
1875+
// Set up ex_data for callbacks
1876+
unsafe {
1877+
let ssl_ptr = py_ref.connection.read().ssl().as_ptr();
1878+
1879+
// Store ssl_socket pointer in index 0 for msg_callback (like CPython's SSL_set_app_data)
1880+
// This is safe because ssl_socket owns the SSL object and outlives it
1881+
// We store a pointer to Py<PySslSocket>, which msg_callback can dereference
1882+
let py_ptr: *const Py<PySslSocket> = &*py_ref;
1883+
sys::SSL_set_ex_data(ssl_ptr, 0, py_ptr as *mut _);
18311884

1885+
// Set SNI callback data if needed
1886+
if has_sni_callback {
1887+
let ssl_socket_weak = py_ref.as_object().downgrade(None, vm)?;
18321888
// Store callback data in SSL ex_data - use weak reference to avoid cycle
18331889
let callback_data = Box::new(SniCallbackData {
18341890
ssl_context: zelf.clone(),
@@ -1924,12 +1980,14 @@ mod _ssl {
19241980
Some(s) => s,
19251981
None => return SelectRet::Closed,
19261982
};
1927-
let deadline = match &deadline {
1983+
// For blocking sockets without timeout, call sock_select with None timeout
1984+
// to actually block waiting for data instead of busy-looping
1985+
let timeout = match &deadline {
19281986
Ok(deadline) => match deadline.checked_duration_since(Instant::now()) {
1929-
Some(deadline) => deadline,
1987+
Some(d) => Some(d),
19301988
None => return SelectRet::TimedOut,
19311989
},
1932-
Err(true) => return SelectRet::IsBlocking,
1990+
Err(true) => None, // Blocking: no timeout, wait indefinitely
19331991
Err(false) => return SelectRet::Nonblocking,
19341992
};
19351993
let res = socket::sock_select(
@@ -1938,7 +1996,7 @@ mod _ssl {
19381996
SslNeeds::Read => socket::SelectKind::Read,
19391997
SslNeeds::Write => socket::SelectKind::Write,
19401998
},
1941-
Some(deadline),
1999+
timeout,
19422000
);
19432001
match res {
19442002
Ok(true) => SelectRet::TimedOut,
@@ -2672,12 +2730,33 @@ mod _ssl {
26722730
#[pymethod]
26732731
fn read(
26742732
&self,
2675-
n: usize,
2733+
n: isize,
26762734
buffer: OptionalArg<ArgMemoryBuffer>,
26772735
vm: &VirtualMachine,
26782736
) -> PyResult {
2737+
// Handle negative n:
2738+
// - If buffer is None and n < 0: raise ValueError
2739+
// - If buffer is present and n <= 0: use buffer length
2740+
// This matches _ssl__SSLSocket_read_impl in CPython
2741+
let read_len: usize = match &buffer {
2742+
OptionalArg::Present(buf) => {
2743+
let buf_len = buf.borrow_buf_mut().len();
2744+
if n <= 0 || (n as usize) > buf_len {
2745+
buf_len
2746+
} else {
2747+
n as usize
2748+
}
2749+
}
2750+
OptionalArg::Missing => {
2751+
if n < 0 {
2752+
return Err(vm.new_value_error("size should not be negative".to_owned()));
2753+
}
2754+
n as usize
2755+
}
2756+
};
2757+
26792758
// Special case: reading 0 bytes should return empty bytes immediately
2680-
if n == 0 {
2759+
if read_len == 0 {
26812760
return if buffer.is_present() {
26822761
Ok(vm.ctx.new_int(0).into())
26832762
} else {
@@ -2689,13 +2768,13 @@ mod _ssl {
26892768
let mut inner_buffer = if let OptionalArg::Present(buffer) = &buffer {
26902769
Either::A(buffer.borrow_buf_mut())
26912770
} else {
2692-
Either::B(vec![0u8; n])
2771+
Either::B(vec![0u8; read_len])
26932772
};
26942773
let buf = match &mut inner_buffer {
26952774
Either::A(b) => &mut **b,
26962775
Either::B(b) => b.as_mut_slice(),
26972776
};
2698-
let buf = match buf.get_mut(..n) {
2777+
let buf = match buf.get_mut(..read_len) {
26992778
Some(b) => b,
27002779
None => buf,
27012780
};

0 commit comments

Comments
 (0)