Skip to content

Commit d92204a

Browse files
committed
feat: enhance user usage calculation with threading support
- Introduced `calculate_users_usage` function to aggregate user usage across nodes using ThreadPoolExecutor for improved performance. - Added helper functions `_process_node_chunk` and `_merge_usage_dicts` for efficient processing and merging of usage data. - Implemented synchronous fallback for small datasets to optimize performance and reduce overhead. - Refactored `_record_user_usages_impl` to utilize the new calculation method, improving clarity and efficiency in usage recording. These changes optimize the handling of user usage data, leveraging concurrency for better performance in high-load scenarios.
1 parent 439796f commit d92204a

File tree

1 file changed

+95
-12
lines changed

1 file changed

+95
-12
lines changed

app/jobs/record_usages.py

Lines changed: 95 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,32 @@ async def _cleanup_thread_pool():
9393
_thread_pool = None
9494
logger.info("ThreadPoolExecutor shut down successfully")
9595

96+
# Helper functions for threading (lightweight operations that release GIL)
97+
def _process_node_chunk(chunk_data: tuple) -> dict:
98+
"""
99+
Process a chunk of node data - lightweight CPU operation.
100+
Uses simple arithmetic and dict operations that release GIL, perfect for threads.
101+
"""
102+
node_id, params, coeff = chunk_data
103+
users_usage = defaultdict(int)
104+
for param in params:
105+
uid = int(param["uid"])
106+
value = int(param["value"] * coeff)
107+
users_usage[uid] += value
108+
return dict(users_usage)
96109

97110

111+
def _merge_usage_dicts(dicts: list[dict]) -> dict:
112+
"""
113+
Merge multiple usage dictionaries.
114+
Dict operations release GIL, perfect for ThreadPoolExecutor.
115+
"""
116+
merged = defaultdict(int)
117+
for d in dicts:
118+
for uid, value in d.items():
119+
merged[uid] += value
120+
return dict(merged)
121+
98122

99123
async def get_dialect() -> str:
100124
"""Get the database dialect name without holding the session open."""
@@ -542,6 +566,76 @@ async def calculate_admin_usage(users_usage: list) -> tuple[dict, set[int]]:
542566
return admin_usage, set(user_admin_map.keys())
543567

544568

569+
async def calculate_users_usage(api_params: dict, usage_coefficient: dict) -> list:
570+
"""Calculate aggregated user usage across all nodes with coefficients applied.
571+
572+
Uses ThreadPoolExecutor for lightweight operations (dict/arithmetic that release GIL).
573+
ThreadPoolExecutor is faster than ProcessPoolExecutor for these operations due to less overhead.
574+
"""
575+
if not api_params:
576+
return []
577+
578+
def _process_usage_sync(chunks_data: list[tuple[int, list[dict], float]]):
579+
"""Synchronous fallback used for small batches or on executor failures."""
580+
users_usage = defaultdict(int)
581+
for _, params, coeff in chunks_data:
582+
for param in params:
583+
uid = int(param["uid"])
584+
value = int(param["value"] * coeff)
585+
users_usage[uid] += value
586+
return [{"uid": uid, "value": value} for uid, value in users_usage.items()]
587+
588+
# Prepare chunks for parallel processing
589+
chunks = [
590+
(node_id, params, usage_coefficient.get(node_id, 1))
591+
for node_id, params in api_params.items()
592+
if params # Skip empty params
593+
]
594+
595+
if not chunks:
596+
return []
597+
598+
# For small datasets, process synchronously to avoid overhead
599+
total_params = sum(len(params) for _, params, _ in chunks)
600+
if total_params < 1000:
601+
return _process_usage_sync(chunks)
602+
603+
# Large dataset - use ThreadPoolExecutor (faster for lightweight operations)
604+
loop = asyncio.get_running_loop()
605+
try:
606+
thread_pool = await _get_thread_pool()
607+
except Exception:
608+
logger.exception("Falling back to synchronous user usage calculation: failed to init thread pool")
609+
return _process_usage_sync(chunks)
610+
611+
try:
612+
# Process chunks in parallel using threads (less overhead than processes)
613+
tasks = [loop.run_in_executor(thread_pool, _process_node_chunk, chunk) for chunk in chunks]
614+
chunk_results = await asyncio.gather(*tasks)
615+
616+
# Merge results - also lightweight, use threads
617+
if len(chunk_results) > 4:
618+
# Split merge operation into smaller chunks
619+
chunk_size = max(1, len(chunk_results) // 4)
620+
merge_chunks = [
621+
chunk_results[i : i + chunk_size]
622+
for i in range(0, len(chunk_results), chunk_size)
623+
]
624+
merge_tasks = [
625+
loop.run_in_executor(thread_pool, _merge_usage_dicts, merge_chunk)
626+
for merge_chunk in merge_chunks
627+
]
628+
partial_results = await asyncio.gather(*merge_tasks)
629+
final_result = _merge_usage_dicts(partial_results)
630+
else:
631+
final_result = _merge_usage_dicts(chunk_results)
632+
633+
return [{"uid": uid, "value": value} for uid, value in final_result.items()]
634+
except Exception:
635+
logger.exception("Falling back to synchronous user usage calculation: executor merge failed")
636+
return _process_usage_sync(chunks)
637+
638+
545639
async def _record_user_usages_impl():
546640
"""
547641
Internal implementation of record_user_usages.
@@ -578,18 +672,7 @@ async def _record_user_usages_impl():
578672
else:
579673
api_params[node_id] = result
580674

581-
# Aggregate user usage across all nodes with coefficients applied
582-
users_usage_dict = defaultdict(int)
583-
for node_id, params in api_params.items():
584-
if not params:
585-
continue
586-
coeff = usage_coefficient.get(node_id, 1.0)
587-
for param in params:
588-
uid = int(param["uid"])
589-
value = int(param["value"] * coeff)
590-
users_usage_dict[uid] += value
591-
592-
users_usage = [{"uid": uid, "value": value} for uid, value in users_usage_dict.items()]
675+
users_usage = await calculate_users_usage(api_params, usage_coefficient)
593676
if not users_usage:
594677
logger.debug("No user usage to record")
595678
return

0 commit comments

Comments
 (0)