Skip to content

Commit 0f5ffbd

Browse files
committed
minimize ssl lock
1 parent db95946 commit 0f5ffbd

File tree

1 file changed

+117
-93
lines changed

1 file changed

+117
-93
lines changed

crates/stdlib/src/ssl.rs

Lines changed: 117 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -1357,47 +1357,69 @@ mod _ssl {
13571357
);
13581358
}
13591359

1360-
// Get mutable references to store and ca_certs_der
1360+
// Parse arguments BEFORE acquiring locks to reduce lock scope
1361+
let cafile_path = if let OptionalArg::Present(Some(ref cafile_obj)) = args.cafile {
1362+
Some(Self::parse_path_arg(cafile_obj, vm)?)
1363+
} else {
1364+
None
1365+
};
1366+
1367+
let capath_dir = if let OptionalArg::Present(Some(ref capath_obj)) = args.capath {
1368+
Some(Self::parse_path_arg(capath_obj, vm)?)
1369+
} else {
1370+
None
1371+
};
1372+
1373+
let cadata_parsed = if let OptionalArg::Present(ref cadata_obj) = args.cadata
1374+
&& !vm.is_none(cadata_obj)
1375+
{
1376+
let is_string = PyStrRef::try_from_object(vm, cadata_obj.clone()).is_ok();
1377+
let data_vec = self.parse_cadata_arg(cadata_obj, vm)?;
1378+
Some((data_vec, is_string))
1379+
} else {
1380+
None
1381+
};
1382+
1383+
// Check for CRL before acquiring main locks
1384+
let (crl_opt, cafile_is_crl) = if let Some(ref path) = cafile_path {
1385+
let crl = self.load_crl_from_file(path, vm)?;
1386+
let is_crl = crl.is_some();
1387+
(crl, is_crl)
1388+
} else {
1389+
(None, false)
1390+
};
1391+
1392+
// If it's a CRL, just add it (separate lock, no conflict with root_store)
1393+
if let Some(crl) = crl_opt {
1394+
self.crls.write().push(crl);
1395+
}
1396+
1397+
// Now acquire write locks for certificate loading
13611398
let mut root_store = self.root_certs.write();
13621399
let mut ca_certs_der = self.ca_certs_der.write();
13631400

1364-
// Load from file
1365-
if let OptionalArg::Present(Some(ref cafile_obj)) = args.cafile {
1366-
let path = Self::parse_path_arg(cafile_obj, vm)?;
1367-
1368-
// Try to load as CRL first
1369-
if let Some(crl) = self.load_crl_from_file(&path, vm)? {
1370-
self.crls.write().push(crl);
1371-
} else {
1372-
// Not a CRL, load as certificate
1373-
let stats = self.load_certs_from_file_helper(
1374-
&mut root_store,
1375-
&mut ca_certs_der,
1376-
&path,
1377-
vm,
1378-
)?;
1379-
self.update_cert_stats(stats);
1380-
}
1401+
// Load from file (if not CRL)
1402+
if let Some(ref path) = cafile_path
1403+
&& !cafile_is_crl
1404+
{
1405+
// Not a CRL, load as certificate
1406+
let stats =
1407+
self.load_certs_from_file_helper(&mut root_store, &mut ca_certs_der, path, vm)?;
1408+
self.update_cert_stats(stats);
13811409
}
13821410

13831411
// Load from directory (don't add to ca_certs_der)
1384-
if let OptionalArg::Present(Some(ref capath_obj)) = args.capath {
1385-
let dir_path = Self::parse_path_arg(capath_obj, vm)?;
1386-
let stats = self.load_certs_from_dir_helper(&mut root_store, &dir_path, vm)?;
1412+
if let Some(ref dir_path) = capath_dir {
1413+
let stats = self.load_certs_from_dir_helper(&mut root_store, dir_path, vm)?;
13871414
self.update_cert_stats(stats);
13881415
}
13891416

13901417
// Load from bytes or str
1391-
if let OptionalArg::Present(cadata_obj) = args.cadata
1392-
&& !vm.is_none(&cadata_obj)
1393-
{
1394-
// Check if input is string or bytes
1395-
let is_string = PyStrRef::try_from_object(vm, cadata_obj.clone()).is_ok();
1396-
let data_vec = self.parse_cadata_arg(&cadata_obj, vm)?;
1418+
if let Some((ref data_vec, is_string)) = cadata_parsed {
13971419
let stats = self.load_certs_from_bytes_helper(
13981420
&mut root_store,
13991421
&mut ca_certs_der,
1400-
&data_vec,
1422+
data_vec,
14011423
is_string, // PEM only for strings
14021424
vm,
14031425
)?;
@@ -2547,48 +2569,51 @@ mod _ssl {
25472569
/// This simulates lazy loading behavior: capath certificates
25482570
/// are only added to get_ca_certs() after they're actually used in a handshake.
25492571
fn track_used_ca_from_capath(&self) -> Result<(), String> {
2550-
let context = self.context.read();
2551-
let capath_certs = context.capath_certs_der.read();
2552-
2553-
// No capath certs to track
2554-
if capath_certs.is_empty() {
2555-
return Ok(());
2556-
}
2557-
2558-
// Get peer certificate chain
2559-
let conn_guard = self.connection.lock();
2560-
let conn = conn_guard.as_ref().ok_or("No connection")?;
2561-
2562-
let peer_certs = conn.peer_certificates().ok_or("No peer certificates")?;
2572+
// Extract capath_certs, releasing context lock quickly
2573+
let capath_certs = {
2574+
let context = self.context.read();
2575+
let certs = context.capath_certs_der.read();
2576+
if certs.is_empty() {
2577+
return Ok(());
2578+
}
2579+
certs.clone()
2580+
};
25632581

2564-
if peer_certs.is_empty() {
2565-
return Ok(());
2566-
}
2582+
// Extract peer certificates, releasing connection lock quickly
2583+
let top_cert_der = {
2584+
let conn_guard = self.connection.lock();
2585+
let conn = conn_guard.as_ref().ok_or("No connection")?;
2586+
let peer_certs = conn.peer_certificates().ok_or("No peer certificates")?;
2587+
if peer_certs.is_empty() {
2588+
return Ok(());
2589+
}
2590+
peer_certs
2591+
.iter()
2592+
.map(|c| c.as_ref().to_vec())
2593+
.next_back()
2594+
.expect("is_empty checked above")
2595+
};
25672596

25682597
// Get the top certificate in the chain (closest to root)
25692598
// Note: Server usually doesn't send the root CA, so we check the last cert's issuer
2570-
let top_cert_der = peer_certs.last().unwrap();
2571-
let (_, top_cert) = x509_parser::parse_x509_certificate(top_cert_der)
2599+
let (_, top_cert) = x509_parser::parse_x509_certificate(&top_cert_der)
25722600
.map_err(|e| format!("Failed to parse top cert: {e}"))?;
25732601

25742602
let top_issuer = top_cert.issuer();
25752603

2576-
// Find matching CA in capath certs
2577-
for ca_der in capath_certs.iter() {
2578-
let (_, ca) = x509_parser::parse_x509_certificate(ca_der)
2579-
.map_err(|e| format!("Failed to parse CA: {e}"))?;
2604+
// Find matching CA in capath certs (skip unparseable certificates)
2605+
let matching_ca = capath_certs.iter().find_map(|ca_der| {
2606+
let (_, ca) = x509_parser::parse_x509_certificate(ca_der).ok()?;
2607+
// Check if this CA is self-signed (root CA) and matches the issuer
2608+
(ca.subject() == ca.issuer() && ca.subject() == top_issuer).then(|| ca_der.clone())
2609+
});
25802610

2581-
// Check if this CA is self-signed and matches the issuer
2582-
if ca.subject() == ca.issuer() // Self-signed (root CA)
2583-
&& ca.subject() == top_issuer
2584-
// Matches top cert's issuer
2585-
{
2586-
// Check if not already in ca_certs_der
2587-
let mut ca_certs_der = context.ca_certs_der.write();
2588-
if !ca_certs_der.iter().any(|c| c == ca_der) {
2589-
ca_certs_der.push(ca_der.clone());
2590-
}
2591-
break;
2611+
// Update ca_certs_der if we found a match
2612+
if let Some(ca_der) = matching_ca {
2613+
let context = self.context.read();
2614+
let mut ca_certs_der = context.ca_certs_der.write();
2615+
if !ca_certs_der.iter().any(|c| c == &ca_der) {
2616+
ca_certs_der.push(ca_der);
25922617
}
25932618
}
25942619

@@ -2675,6 +2700,7 @@ mod _ssl {
26752700

26762701
/// Check if SNI callback is configured
26772702
pub(crate) fn has_sni_callback(&self) -> bool {
2703+
// Nested read locks are safe
26782704
self.context.read().sni_callback.read().is_some()
26792705
}
26802706

@@ -2685,10 +2711,9 @@ mod _ssl {
26852711

26862712
/// Get the extracted SNI name from resolver
26872713
pub(crate) fn get_extracted_sni_name(&self) -> Option<String> {
2688-
self.sni_state
2689-
.read()
2690-
.as_ref()
2691-
.and_then(|arc| arc.lock().1.clone())
2714+
// Clone the Arc option to avoid nested lock (sni_state.read -> arc.lock)
2715+
let sni_state_opt = self.sni_state.read().clone();
2716+
sni_state_opt.as_ref().and_then(|arc| arc.lock().1.clone())
26922717
}
26932718

26942719
/// Invoke the Python SNI callback
@@ -3516,27 +3541,24 @@ mod _ssl {
35163541
return Err(vm.new_value_error("handshake not done yet"));
35173542
}
35183543

3519-
// Get peer certificates from TLS connection
3520-
let conn_guard = self.connection.lock();
3521-
let conn = conn_guard
3522-
.as_ref()
3523-
.ok_or_else(|| vm.new_value_error("No TLS connection established"))?;
3524-
3525-
let certs = conn.peer_certificates();
3544+
// Extract DER bytes from connection, releasing lock quickly
3545+
let der_bytes = {
3546+
let conn_guard = self.connection.lock();
3547+
let conn = conn_guard
3548+
.as_ref()
3549+
.ok_or_else(|| vm.new_value_error("No TLS connection established"))?;
35263550

3527-
// Return None if no peer certificate
3528-
let Some(certs) = certs else {
3529-
return Ok(None);
3551+
let Some(peer_certificates) = conn.peer_certificates() else {
3552+
return Ok(None);
3553+
};
3554+
let cert = peer_certificates
3555+
.first()
3556+
.ok_or_else(|| vm.new_value_error("No peer certificate available"))?;
3557+
cert.as_ref().to_vec()
35303558
};
35313559

3532-
// Get first certificate (peer's certificate)
3533-
let cert_der = certs
3534-
.first()
3535-
.ok_or_else(|| vm.new_value_error("No peer certificate available"))?;
3536-
35373560
if binary {
35383561
// Return DER-encoded certificate as bytes
3539-
let der_bytes = cert_der.as_ref().to_vec();
35403562
return Ok(Some(vm.ctx.new_bytes(der_bytes).into()));
35413563
}
35423564

@@ -3548,22 +3570,22 @@ mod _ssl {
35483570
return Ok(Some(vm.ctx.new_dict().into()));
35493571
}
35503572

3551-
// Parse DER certificate and convert to dict
3552-
let der_bytes = cert_der.as_ref();
3553-
let (_, cert) = x509_parser::parse_x509_certificate(der_bytes)
3573+
// Parse DER certificate and convert to dict (outside lock)
3574+
let (_, cert) = x509_parser::parse_x509_certificate(&der_bytes)
35543575
.map_err(|e| vm.new_value_error(format!("Failed to parse certificate: {e}")))?;
35553576

35563577
cert::cert_to_dict(vm, &cert).map(Some)
35573578
}
35583579

35593580
#[pymethod]
35603581
fn cipher(&self) -> Option<(String, String, i32)> {
3561-
let conn_guard = self.connection.lock();
3562-
let conn = conn_guard.as_ref()?;
3563-
3564-
let suite = conn.negotiated_cipher_suite()?;
3582+
// Extract cipher suite, releasing lock quickly
3583+
let suite = {
3584+
let conn_guard = self.connection.lock();
3585+
conn_guard.as_ref()?.negotiated_cipher_suite()?
3586+
};
35653587

3566-
// Extract cipher information using unified helper
3588+
// Extract cipher information outside the lock
35673589
let cipher_info = extract_cipher_info(&suite);
35683590

35693591
// Note: returns a 3-tuple (name, protocol_version, bits)
@@ -3577,11 +3599,13 @@ mod _ssl {
35773599

35783600
#[pymethod]
35793601
fn version(&self) -> Option<String> {
3580-
let conn_guard = self.connection.lock();
3581-
let conn = conn_guard.as_ref()?;
3582-
3583-
let suite = conn.negotiated_cipher_suite()?;
3602+
// Extract cipher suite, releasing lock quickly
3603+
let suite = {
3604+
let conn_guard = self.connection.lock();
3605+
conn_guard.as_ref()?.negotiated_cipher_suite()?
3606+
};
35843607

3608+
// Convert to string outside the lock
35853609
let version_str = match suite.version().version {
35863610
rustls::ProtocolVersion::TLSv1_2 => "TLSv1.2",
35873611
rustls::ProtocolVersion::TLSv1_3 => "TLSv1.3",

0 commit comments

Comments
 (0)