Skip to content

Commit d514dc7

Browse files
committed
minimize ssl lock
1 parent db95946 commit d514dc7

File tree

1 file changed

+113
-85
lines changed

1 file changed

+113
-85
lines changed

crates/stdlib/src/ssl.rs

Lines changed: 113 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1357,47 +1357,73 @@ mod _ssl {
13571357
);
13581358
}
13591359

1360-
// Get mutable references to store and ca_certs_der
1360+
// Parse arguments BEFORE acquiring locks to reduce lock scope
1361+
let cafile_path = if let OptionalArg::Present(Some(ref cafile_obj)) = args.cafile {
1362+
Some(Self::parse_path_arg(cafile_obj, vm)?)
1363+
} else {
1364+
None
1365+
};
1366+
1367+
let capath_dir = if let OptionalArg::Present(Some(ref capath_obj)) = args.capath {
1368+
Some(Self::parse_path_arg(capath_obj, vm)?)
1369+
} else {
1370+
None
1371+
};
1372+
1373+
let cadata_parsed = if let OptionalArg::Present(ref cadata_obj) = args.cadata
1374+
&& !vm.is_none(cadata_obj)
1375+
{
1376+
let is_string = PyStrRef::try_from_object(vm, cadata_obj.clone()).is_ok();
1377+
let data_vec = self.parse_cadata_arg(cadata_obj, vm)?;
1378+
Some((data_vec, is_string))
1379+
} else {
1380+
None
1381+
};
1382+
1383+
// Check for CRL before acquiring main locks
1384+
let (crl_opt, cafile_is_crl) = if let Some(ref path) = cafile_path {
1385+
let crl = self.load_crl_from_file(path, vm)?;
1386+
let is_crl = crl.is_some();
1387+
(crl, is_crl)
1388+
} else {
1389+
(None, false)
1390+
};
1391+
1392+
// If it's a CRL, just add it (separate lock, no conflict with root_store)
1393+
if let Some(crl) = crl_opt {
1394+
self.crls.write().push(crl);
1395+
}
1396+
1397+
// Now acquire write locks for certificate loading
13611398
let mut root_store = self.root_certs.write();
13621399
let mut ca_certs_der = self.ca_certs_der.write();
13631400

1364-
// Load from file
1365-
if let OptionalArg::Present(Some(ref cafile_obj)) = args.cafile {
1366-
let path = Self::parse_path_arg(cafile_obj, vm)?;
1367-
1368-
// Try to load as CRL first
1369-
if let Some(crl) = self.load_crl_from_file(&path, vm)? {
1370-
self.crls.write().push(crl);
1371-
} else {
1401+
// Load from file (if not CRL)
1402+
if let Some(ref path) = cafile_path {
1403+
if !cafile_is_crl {
13721404
// Not a CRL, load as certificate
13731405
let stats = self.load_certs_from_file_helper(
13741406
&mut root_store,
13751407
&mut ca_certs_der,
1376-
&path,
1408+
path,
13771409
vm,
13781410
)?;
13791411
self.update_cert_stats(stats);
13801412
}
13811413
}
13821414

13831415
// Load from directory (don't add to ca_certs_der)
1384-
if let OptionalArg::Present(Some(ref capath_obj)) = args.capath {
1385-
let dir_path = Self::parse_path_arg(capath_obj, vm)?;
1386-
let stats = self.load_certs_from_dir_helper(&mut root_store, &dir_path, vm)?;
1416+
if let Some(ref dir_path) = capath_dir {
1417+
let stats = self.load_certs_from_dir_helper(&mut root_store, dir_path, vm)?;
13871418
self.update_cert_stats(stats);
13881419
}
13891420

13901421
// Load from bytes or str
1391-
if let OptionalArg::Present(cadata_obj) = args.cadata
1392-
&& !vm.is_none(&cadata_obj)
1393-
{
1394-
// Check if input is string or bytes
1395-
let is_string = PyStrRef::try_from_object(vm, cadata_obj.clone()).is_ok();
1396-
let data_vec = self.parse_cadata_arg(&cadata_obj, vm)?;
1422+
if let Some((ref data_vec, is_string)) = cadata_parsed {
13971423
let stats = self.load_certs_from_bytes_helper(
13981424
&mut root_store,
13991425
&mut ca_certs_der,
1400-
&data_vec,
1426+
data_vec,
14011427
is_string, // PEM only for strings
14021428
vm,
14031429
)?;
@@ -2547,48 +2573,51 @@ mod _ssl {
25472573
/// This simulates lazy loading behavior: capath certificates
25482574
/// are only added to get_ca_certs() after they're actually used in a handshake.
25492575
fn track_used_ca_from_capath(&self) -> Result<(), String> {
2550-
let context = self.context.read();
2551-
let capath_certs = context.capath_certs_der.read();
2552-
2553-
// No capath certs to track
2554-
if capath_certs.is_empty() {
2555-
return Ok(());
2556-
}
2557-
2558-
// Get peer certificate chain
2559-
let conn_guard = self.connection.lock();
2560-
let conn = conn_guard.as_ref().ok_or("No connection")?;
2561-
2562-
let peer_certs = conn.peer_certificates().ok_or("No peer certificates")?;
2576+
// Extract capath_certs, releasing context lock quickly
2577+
let capath_certs = {
2578+
let context = self.context.read();
2579+
let certs = context.capath_certs_der.read();
2580+
if certs.is_empty() {
2581+
return Ok(());
2582+
}
2583+
certs.clone()
2584+
};
25632585

2564-
if peer_certs.is_empty() {
2565-
return Ok(());
2566-
}
2586+
// Extract peer certificates, releasing connection lock quickly
2587+
let top_cert_der = {
2588+
let conn_guard = self.connection.lock();
2589+
let conn = conn_guard.as_ref().ok_or("No connection")?;
2590+
let peer_certs = conn.peer_certificates().ok_or("No peer certificates")?;
2591+
if peer_certs.is_empty() {
2592+
return Ok(());
2593+
}
2594+
peer_certs
2595+
.iter()
2596+
.map(|c| c.as_ref().to_vec())
2597+
.last()
2598+
.expect("is_empty checked above")
2599+
};
25672600

25682601
// Get the top certificate in the chain (closest to root)
25692602
// Note: Server usually doesn't send the root CA, so we check the last cert's issuer
2570-
let top_cert_der = peer_certs.last().unwrap();
2571-
let (_, top_cert) = x509_parser::parse_x509_certificate(top_cert_der)
2603+
let (_, top_cert) = x509_parser::parse_x509_certificate(&top_cert_der)
25722604
.map_err(|e| format!("Failed to parse top cert: {e}"))?;
25732605

25742606
let top_issuer = top_cert.issuer();
25752607

2576-
// Find matching CA in capath certs
2577-
for ca_der in capath_certs.iter() {
2578-
let (_, ca) = x509_parser::parse_x509_certificate(ca_der)
2579-
.map_err(|e| format!("Failed to parse CA: {e}"))?;
2608+
// Find matching CA in capath certs (skip unparseable certificates)
2609+
let matching_ca = capath_certs.iter().find_map(|ca_der| {
2610+
let (_, ca) = x509_parser::parse_x509_certificate(ca_der).ok()?;
2611+
// Check if this CA is self-signed (root CA) and matches the issuer
2612+
(ca.subject() == ca.issuer() && ca.subject() == top_issuer).then(|| ca_der.clone())
2613+
});
25802614

2581-
// Check if this CA is self-signed and matches the issuer
2582-
if ca.subject() == ca.issuer() // Self-signed (root CA)
2583-
&& ca.subject() == top_issuer
2584-
// Matches top cert's issuer
2585-
{
2586-
// Check if not already in ca_certs_der
2587-
let mut ca_certs_der = context.ca_certs_der.write();
2588-
if !ca_certs_der.iter().any(|c| c == ca_der) {
2589-
ca_certs_der.push(ca_der.clone());
2590-
}
2591-
break;
2615+
// Update ca_certs_der if we found a match
2616+
if let Some(ca_der) = matching_ca {
2617+
let context = self.context.read();
2618+
let mut ca_certs_der = context.ca_certs_der.write();
2619+
if !ca_certs_der.iter().any(|c| c == &ca_der) {
2620+
ca_certs_der.push(ca_der);
25922621
}
25932622
}
25942623

@@ -2675,6 +2704,7 @@ mod _ssl {
26752704

26762705
/// Check if SNI callback is configured
26772706
pub(crate) fn has_sni_callback(&self) -> bool {
2707+
// Nested read locks are safe
26782708
self.context.read().sni_callback.read().is_some()
26792709
}
26802710

@@ -2685,10 +2715,9 @@ mod _ssl {
26852715

26862716
/// Get the extracted SNI name from resolver
26872717
pub(crate) fn get_extracted_sni_name(&self) -> Option<String> {
2688-
self.sni_state
2689-
.read()
2690-
.as_ref()
2691-
.and_then(|arc| arc.lock().1.clone())
2718+
// Clone the Arc option to avoid nested lock (sni_state.read -> arc.lock)
2719+
let sni_state_opt = self.sni_state.read().clone();
2720+
sni_state_opt.as_ref().and_then(|arc| arc.lock().1.clone())
26922721
}
26932722

26942723
/// Invoke the Python SNI callback
@@ -3516,27 +3545,24 @@ mod _ssl {
35163545
return Err(vm.new_value_error("handshake not done yet"));
35173546
}
35183547

3519-
// Get peer certificates from TLS connection
3520-
let conn_guard = self.connection.lock();
3521-
let conn = conn_guard
3522-
.as_ref()
3523-
.ok_or_else(|| vm.new_value_error("No TLS connection established"))?;
3524-
3525-
let certs = conn.peer_certificates();
3548+
// Extract DER bytes from connection, releasing lock quickly
3549+
let der_bytes = {
3550+
let conn_guard = self.connection.lock();
3551+
let conn = conn_guard
3552+
.as_ref()
3553+
.ok_or_else(|| vm.new_value_error("No TLS connection established"))?;
35263554

3527-
// Return None if no peer certificate
3528-
let Some(certs) = certs else {
3529-
return Ok(None);
3555+
let Some(peer_certificates) = conn.peer_certificates() else {
3556+
return Ok(None);
3557+
};
3558+
let cert = peer_certificates
3559+
.first()
3560+
.ok_or_else(|| vm.new_value_error("No peer certificate available"))?;
3561+
cert.as_ref().to_vec()
35303562
};
35313563

3532-
// Get first certificate (peer's certificate)
3533-
let cert_der = certs
3534-
.first()
3535-
.ok_or_else(|| vm.new_value_error("No peer certificate available"))?;
3536-
35373564
if binary {
35383565
// Return DER-encoded certificate as bytes
3539-
let der_bytes = cert_der.as_ref().to_vec();
35403566
return Ok(Some(vm.ctx.new_bytes(der_bytes).into()));
35413567
}
35423568

@@ -3548,22 +3574,22 @@ mod _ssl {
35483574
return Ok(Some(vm.ctx.new_dict().into()));
35493575
}
35503576

3551-
// Parse DER certificate and convert to dict
3552-
let der_bytes = cert_der.as_ref();
3553-
let (_, cert) = x509_parser::parse_x509_certificate(der_bytes)
3577+
// Parse DER certificate and convert to dict (outside lock)
3578+
let (_, cert) = x509_parser::parse_x509_certificate(&der_bytes)
35543579
.map_err(|e| vm.new_value_error(format!("Failed to parse certificate: {e}")))?;
35553580

35563581
cert::cert_to_dict(vm, &cert).map(Some)
35573582
}
35583583

35593584
#[pymethod]
35603585
fn cipher(&self) -> Option<(String, String, i32)> {
3561-
let conn_guard = self.connection.lock();
3562-
let conn = conn_guard.as_ref()?;
3563-
3564-
let suite = conn.negotiated_cipher_suite()?;
3586+
// Extract cipher suite, releasing lock quickly
3587+
let suite = {
3588+
let conn_guard = self.connection.lock();
3589+
conn_guard.as_ref()?.negotiated_cipher_suite()?
3590+
};
35653591

3566-
// Extract cipher information using unified helper
3592+
// Extract cipher information outside the lock
35673593
let cipher_info = extract_cipher_info(&suite);
35683594

35693595
// Note: returns a 3-tuple (name, protocol_version, bits)
@@ -3577,11 +3603,13 @@ mod _ssl {
35773603

35783604
#[pymethod]
35793605
fn version(&self) -> Option<String> {
3580-
let conn_guard = self.connection.lock();
3581-
let conn = conn_guard.as_ref()?;
3582-
3583-
let suite = conn.negotiated_cipher_suite()?;
3606+
// Extract cipher suite, releasing lock quickly
3607+
let suite = {
3608+
let conn_guard = self.connection.lock();
3609+
conn_guard.as_ref()?.negotiated_cipher_suite()?
3610+
};
35843611

3612+
// Convert to string outside the lock
35853613
let version_str = match suite.version().version {
35863614
rustls::ProtocolVersion::TLSv1_2 => "TLSv1.2",
35873615
rustls::ProtocolVersion::TLSv1_3 => "TLSv1.3",

0 commit comments

Comments
 (0)