Skip to content

Commit 5a4969a

Browse files
committed
refactor: improve threading support in usage recording
- Updated helper functions to clarify their threading suitability and lightweight operations. - Enhanced `_process_users_stats_response` to return invalid UIDs for logging outside the thread. - Switched from ProcessPoolExecutor to ThreadPoolExecutor for lightweight operations, reducing overhead. - Increased timeout for usage recording jobs from 25s to 30s to prevent scheduler backlog. These changes optimize performance and clarity in handling concurrent data processing tasks.
1 parent 3ad4f9b commit 5a4969a

File tree

1 file changed

+35
-21
lines changed

1 file changed

+35
-21
lines changed

app/jobs/record_usages.py

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,12 @@ async def _cleanup_thread_pool():
9494
logger.info("ThreadPoolExecutor shut down successfully")
9595

9696

97-
# Helper functions for multiprocessing (must be at module level for pickling)
97+
# Helper functions for threading (lightweight operations that release GIL)
9898
def _process_node_chunk(chunk_data: tuple) -> dict:
99-
"""Process a chunk of node data - CPU-bound operation."""
99+
"""
100+
Process a chunk of node data - lightweight CPU operation.
101+
Uses simple arithmetic and dict operations that release GIL, perfect for threads.
102+
"""
100103
node_id, params, coeff = chunk_data
101104
users_usage = defaultdict(int)
102105
for param in params:
@@ -107,7 +110,10 @@ def _process_node_chunk(chunk_data: tuple) -> dict:
107110

108111

109112
def _merge_usage_dicts(dicts: list[dict]) -> dict:
110-
"""Merge multiple usage dictionaries."""
113+
"""
114+
Merge multiple usage dictionaries.
115+
Dict operations release GIL, perfect for ThreadPoolExecutor.
116+
"""
111117
merged = defaultdict(int)
112118
for d in dicts:
113119
for uid, value in d.items():
@@ -445,25 +451,27 @@ async def _record_single_node(node_id: int, params: list[dict]):
445451

446452
def _process_users_stats_response(stats_response):
447453
"""
448-
Process stats response (CPU-bound operation) - can run in thread pool.
449-
Extracted to separate function for threading.
454+
Process stats response (CPU-bound operation) - runs in thread pool.
455+
Pure function designed for thread-safe execution.
456+
Returns tuple: (validated_params, invalid_uids) for logging outside thread.
450457
"""
451458
params = defaultdict(int)
452459
for stat in filter(attrgetter("value"), stats_response.stats):
453460
params[stat.name.split(".", 1)[0]] += stat.value
454461

455462
# Validate UIDs and filter out invalid ones
456463
validated_params = []
464+
invalid_uids = []
457465
for uid, value in params.items():
458466
try:
459467
uid_int = int(uid)
460468
validated_params.append({"uid": uid_int, "value": value})
461469
except (ValueError, TypeError):
462-
# Skip invalid UIDs that can't be converted to int
463-
logger.warning("Skipping invalid UID: %s", uid)
470+
# Collect invalid UIDs to log outside thread
471+
invalid_uids.append(uid)
464472
continue
465473

466-
return validated_params
474+
return validated_params, invalid_uids
467475

468476

469477
async def get_users_stats(node: PasarGuardNode):
@@ -479,10 +487,15 @@ async def get_users_stats(node: PasarGuardNode):
479487
# CPU-bound operation: process stats in thread pool to utilize multiple cores
480488
loop = asyncio.get_running_loop()
481489
thread_pool = await _get_thread_pool()
482-
validated_params = await loop.run_in_executor(
490+
validated_params, invalid_uids = await loop.run_in_executor(
483491
thread_pool, _process_users_stats_response, stats_response
484492
)
485493

494+
# Log invalid UIDs outside of thread (thread-safe logging)
495+
if invalid_uids:
496+
for uid in invalid_uids:
497+
logger.warning("Skipping invalid UID: %s", uid)
498+
486499
return validated_params
487500
except NodeAPIError as e:
488501
logger.error("Failed to get users stats, error: %s", e.detail)
@@ -557,7 +570,8 @@ async def calculate_admin_usage(users_usage: list) -> tuple[dict, set[int]]:
557570
async def calculate_users_usage(api_params: dict, usage_coefficient: dict) -> list:
558571
"""Calculate aggregated user usage across all nodes with coefficients applied.
559572
560-
Uses multiprocessing to parallelize CPU-bound operations across multiple cores.
573+
Uses ThreadPoolExecutor for lightweight operations (dict/arithmetic that release GIL).
574+
ThreadPoolExecutor is faster than ProcessPoolExecutor for these operations due to less overhead.
561575
"""
562576
if not api_params:
563577
return []
@@ -587,20 +601,20 @@ def _process_usage_sync(chunks_data: list[tuple[int, list[dict], float]]):
587601
if total_params < 1000:
588602
return _process_usage_sync(chunks)
589603

590-
# Large dataset - use multiprocessing
604+
# Large dataset - use ThreadPoolExecutor (faster for lightweight operations)
591605
loop = asyncio.get_running_loop()
592606
try:
593-
process_pool = await _get_process_pool()
607+
thread_pool = await _get_thread_pool()
594608
except Exception:
595-
logger.exception("Falling back to synchronous user usage calculation: failed to init process pool")
609+
logger.exception("Falling back to synchronous user usage calculation: failed to init thread pool")
596610
return _process_usage_sync(chunks)
597611

598612
try:
599-
# Process chunks in parallel
600-
tasks = [loop.run_in_executor(process_pool, _process_node_chunk, chunk) for chunk in chunks]
613+
# Process chunks in parallel using threads (less overhead than processes)
614+
tasks = [loop.run_in_executor(thread_pool, _process_node_chunk, chunk) for chunk in chunks]
601615
chunk_results = await asyncio.gather(*tasks)
602616

603-
# Merge results - this is also CPU-bound, so parallelize if many chunks
617+
# Merge results - also lightweight, use threads
604618
if len(chunk_results) > 4:
605619
# Split merge operation into smaller chunks
606620
chunk_size = max(1, len(chunk_results) // 4)
@@ -609,7 +623,7 @@ def _process_usage_sync(chunks_data: list[tuple[int, list[dict], float]]):
609623
for i in range(0, len(chunk_results), chunk_size)
610624
]
611625
merge_tasks = [
612-
loop.run_in_executor(process_pool, _merge_usage_dicts, merge_chunk)
626+
loop.run_in_executor(thread_pool, _merge_usage_dicts, merge_chunk)
613627
for merge_chunk in merge_chunks
614628
]
615629
partial_results = await asyncio.gather(*merge_tasks)
@@ -741,9 +755,9 @@ async def record_user_usages():
741755
# Hard timeout: prevent job from running longer than interval
742756
# This prevents scheduler backlog → spike → crash
743757
try:
744-
await asyncio.wait_for(_record_user_usages_impl(), timeout=25)
758+
await asyncio.wait_for(_record_user_usages_impl(), timeout=30)
745759
except asyncio.TimeoutError:
746-
logger.warning("record_user_usages timed out after 25s; skipping cycle to prevent backlog")
760+
logger.warning("record_user_usages timed out after 30s; skipping cycle to prevent backlog")
747761
finally:
748762
_running_jobs["record_user_usages"] = False
749763

@@ -847,9 +861,9 @@ async def record_node_usages():
847861
# Hard timeout: prevent job from running longer than interval
848862
# This prevents scheduler backlog → spike → crash
849863
try:
850-
await asyncio.wait_for(_record_node_usages_impl(), timeout=25)
864+
await asyncio.wait_for(_record_node_usages_impl(), timeout=30)
851865
except asyncio.TimeoutError:
852-
logger.warning("record_node_usages timed out after 25s; skipping cycle to prevent backlog")
866+
logger.warning("record_node_usages timed out after 30s; skipping cycle to prevent backlog")
853867
finally:
854868
_running_jobs["record_node_usages"] = False
855869

0 commit comments

Comments
 (0)