11import asyncio
2+ import multiprocessing
23import random
34from collections import defaultdict
5+ from concurrent .futures import ProcessPoolExecutor
46from datetime import datetime as dt , timedelta as td , timezone as tz
57from operator import attrgetter
68
1214from sqlalchemy .exc import DatabaseError , OperationalError
1315from sqlalchemy .sql .expression import Insert
1416
15- from app import scheduler
17+ from app import on_shutdown , scheduler
1618from app .db import GetDB
1719from app .db .base import engine
1820from app .db .models import Admin , Node , NodeUsage , NodeUserUsage , System , User
19- from app .node import node_manager as node_manager
21+ from app .node import node_manager
2022from app .utils .logger import get_logger
2123from config import (
2224 DISABLE_RECORDING_NODE_USAGE ,
2628
2729logger = get_logger ("record-usages" )
2830
31+ # Process pool executor for CPU-bound operations
32+ # Use number of CPU cores, but cap at reasonable limit to avoid overhead
33+ _process_pool = None
34+ _process_pool_lock = asyncio .Lock ()
35+
36+
37+ async def _get_process_pool ():
38+ """Get or create the process pool executor (thread-safe)."""
39+ global _process_pool
40+ async with _process_pool_lock :
41+ if _process_pool is None :
42+ num_workers = min (multiprocessing .cpu_count (), 8 ) # Cap at 8 workers
43+ _process_pool = ProcessPoolExecutor (max_workers = num_workers )
44+ logger .info (f"Initialized ProcessPoolExecutor with { num_workers } workers" )
45+ return _process_pool
46+
47+
48+ @on_shutdown
49+ async def _cleanup_process_pool ():
50+ """Cleanup process pool on shutdown (thread-safe)."""
51+ global _process_pool
52+ async with _process_pool_lock :
53+ if _process_pool is not None :
54+ logger .info ("Shutting down ProcessPoolExecutor..." )
55+ _process_pool .shutdown (wait = True )
56+ _process_pool = None
57+ logger .info ("ProcessPoolExecutor shut down successfully" )
58+
59+
60+ # Helper functions for multiprocessing (must be at module level for pickling)
61+ def _process_node_chunk (chunk_data : tuple ) -> dict :
62+ """Process a chunk of node data - CPU-bound operation."""
63+ node_id , params , coeff = chunk_data
64+ users_usage = defaultdict (int )
65+ for param in params :
66+ uid = int (param ["uid" ])
67+ value = int (param ["value" ] * coeff )
68+ users_usage [uid ] += value
69+ return dict (users_usage )
70+
71+
72+ def _merge_usage_dicts (dicts : list [dict ]) -> dict :
73+ """Merge multiple usage dictionaries."""
74+ merged = defaultdict (int )
75+ for d in dicts :
76+ for uid , value in d .items ():
77+ merged [uid ] += value
78+ return dict (merged )
79+
2980
3081async def get_dialect () -> str :
3182 """Get the database dialect name without holding the session open."""
@@ -244,14 +295,14 @@ async def safe_execute(stmt, params=None, max_retries: int = 5):
244295 raise
245296
246297
247- async def record_user_stats (params : list [dict ], node_id : int , usage_coefficient : int = 1 ):
298+ async def record_user_stats (params : list [dict ], node_id : int , usage_coefficient : float = 1.0 ):
248299 """
249300 Record user statistics for a specific node using UPSERT for efficiency.
250301
251302 Args:
252303 params (list[dict]): User statistic parameters
253304 node_id (int): Node identifier
254- usage_coefficient (int , optional): usage multiplier
305+ usage_coefficient (float , optional): Usage multiplier (default: 1.0)
255306 """
256307 if not params :
257308 return
@@ -313,9 +364,9 @@ async def record_node_stats(params: list[dict], node_id: int):
313364
314365async def get_users_stats (node : PasarGuardNode ):
315366 try :
316- stats_respons = await node .get_stats (stat_type = StatType .UsersStat , reset = True , timeout = 30 )
367+ stats_response = await node .get_stats (stat_type = StatType .UsersStat , reset = True , timeout = 30 )
317368 params = defaultdict (int )
318- for stat in filter (attrgetter ("value" ), stats_respons .stats ):
369+ for stat in filter (attrgetter ("value" ), stats_response .stats ):
319370 params [stat .name .split ("." , 1 )[0 ]] += stat .value
320371
321372 # Validate UIDs and filter out invalid ones
@@ -340,10 +391,10 @@ async def get_users_stats(node: PasarGuardNode):
340391
341392async def get_outbounds_stats (node : PasarGuardNode ):
342393 try :
343- stats_respons = await node .get_stats (stat_type = StatType .Outbounds , reset = True , timeout = 10 )
394+ stats_response = await node .get_stats (stat_type = StatType .Outbounds , reset = True , timeout = 10 )
344395 params = [
345396 {"up" : stat .value , "down" : 0 } if stat .type == "uplink" else {"up" : 0 , "down" : stat .value }
346- for stat in filter (attrgetter ("value" ), stats_respons .stats )
397+ for stat in filter (attrgetter ("value" ), stats_response .stats )
347398 ]
348399 return params
349400 except NodeAPIError as e :
@@ -379,30 +430,77 @@ async def calculate_admin_usage(users_usage: list) -> tuple[dict, set[int]]:
379430
380431
381432async def calculate_users_usage (api_params : dict , usage_coefficient : dict ) -> list :
382- """Calculate aggregated user usage across all nodes with coefficients applied"""
383- users_usage = defaultdict (int )
433+ """Calculate aggregated user usage across all nodes with coefficients applied.
434+
435+ Uses multiprocessing to parallelize CPU-bound operations across multiple cores.
436+ """
437+ if not api_params :
438+ return []
384439
385- # Process all node data in a single pass
386- for node_id , params in api_params .items ():
387- coeff = usage_coefficient .get (node_id , 1 )
388- # Use generator to avoid intermediate lists
389- node_usage = ((int (param ["uid" ]), int (param ["value" ] * coeff )) for param in params )
390- for uid , value in node_usage :
391- users_usage [uid ] += value
440+ # Prepare chunks for parallel processing
441+ chunks = [
442+ (node_id , params , usage_coefficient .get (node_id , 1 ))
443+ for node_id , params in api_params .items ()
444+ if params # Skip empty params
445+ ]
392446
393- return [{"uid" : uid , "value" : value } for uid , value in users_usage .items ()]
447+ if not chunks :
448+ return []
449+
450+ # For small datasets, process synchronously to avoid overhead
451+ total_params = sum (len (params ) for _ , params , _ in chunks )
452+ if total_params < 1000 :
453+ # Small dataset - process synchronously
454+ users_usage = defaultdict (int )
455+ for node_id , params , coeff in chunks :
456+ for param in params :
457+ uid = int (param ["uid" ])
458+ value = int (param ["value" ] * coeff )
459+ users_usage [uid ] += value
460+ return [{"uid" : uid , "value" : value } for uid , value in users_usage .items ()]
461+
462+ # Large dataset - use multiprocessing
463+ loop = asyncio .get_running_loop ()
464+ process_pool = await _get_process_pool ()
465+
466+ # Process chunks in parallel
467+ tasks = [
468+ loop .run_in_executor (process_pool , _process_node_chunk , chunk )
469+ for chunk in chunks
470+ ]
471+
472+ chunk_results = await asyncio .gather (* tasks )
473+
474+ # Merge results - this is also CPU-bound, so parallelize if many chunks
475+ if len (chunk_results ) > 4 :
476+ # Split merge operation into smaller chunks
477+ chunk_size = max (1 , len (chunk_results ) // 4 )
478+ merge_chunks = [
479+ chunk_results [i :i + chunk_size ]
480+ for i in range (0 , len (chunk_results ), chunk_size )
481+ ]
482+ merge_tasks = [
483+ loop .run_in_executor (process_pool , _merge_usage_dicts , merge_chunk )
484+ for merge_chunk in merge_chunks
485+ ]
486+ partial_results = await asyncio .gather (* merge_tasks )
487+ final_result = _merge_usage_dicts (partial_results )
488+ else :
489+ final_result = _merge_usage_dicts (chunk_results )
490+
491+ return [{"uid" : uid , "value" : value } for uid , value in final_result .items ()]
394492
395493
396494async def record_user_usages ():
397495 nodes : tuple [int , PasarGuardNode ] = await node_manager .get_healthy_nodes ()
398496
399- node_data = await asyncio .gather (* [asyncio .create_task (node .get_extra ()) for _ , node in nodes ])
497+ # Gather node extra data directly without unnecessary task creation
498+ node_data = await asyncio .gather (* [node .get_extra () for _ , node in nodes ])
400499 usage_coefficient = {node_id : data .get ("usage_coefficient" , 1 ) for (node_id , _ ), data in zip (nodes , node_data )}
401500
402- stats_tasks = [asyncio .create_task (get_users_stats (node )) for _ , node in nodes ]
403- await asyncio .gather (* stats_tasks )
404-
405- api_params = {nodes [i ][0 ]: task .result () for i , task in enumerate (stats_tasks )}
501+ # Gather stats directly - asyncio.gather accepts coroutines, no need for create_task
502+ stats_results = await asyncio .gather (* [get_users_stats (node ) for _ , node in nodes ])
503+ api_params = {nodes [i ][0 ]: result for i , result in enumerate (stats_results )}
406504
407505 users_usage = await calculate_users_usage (api_params , usage_coefficient )
408506 if not users_usage :
@@ -413,6 +511,7 @@ async def record_user_usages():
413511 logger .warning ("Skipping user usage recording; no matching users found for received stats" )
414512 return
415513
514+ # Filter valid users - simple operation, no need to parallelize
416515 valid_users_usage = [usage for usage in users_usage if int (usage ["uid" ]) in valid_user_ids ]
417516 if valid_users_usage :
418517 user_stmt = (
@@ -436,35 +535,28 @@ async def record_user_usages():
436535 if DISABLE_RECORDING_NODE_USAGE :
437536 return
438537
538+ # Create tasks only for nodes with valid filtered params
439539 record_tasks = []
440540 for node_id , params in api_params .items ():
441541 filtered_params = [param for param in params if int (param ["uid" ]) in valid_user_ids ]
442- if not filtered_params :
443- continue
444- record_tasks .append (
445- asyncio .create_task (
542+ if filtered_params :
543+ record_tasks .append (
446544 record_user_stats (
447545 params = filtered_params ,
448546 node_id = node_id ,
449- usage_coefficient = usage_coefficient [ node_id ] ,
547+ usage_coefficient = usage_coefficient . get ( node_id , 1.0 ) ,
450548 )
451549 )
452- )
453550
454551 if record_tasks :
455552 await asyncio .gather (* record_tasks )
456553
457554
458555async def record_node_usages ():
459- # Create tasks for all nodes
460- tasks = {
461- node_id : asyncio .create_task (get_outbounds_stats (node ))
462- for node_id , node in await node_manager .get_healthy_nodes ()
463- }
464-
465- await asyncio .gather (* tasks .values ())
466-
467- api_params = {node_id : task .result () for node_id , task in tasks .items ()}
556+ # Get healthy nodes and gather stats directly
557+ nodes = await node_manager .get_healthy_nodes ()
558+ stats_results = await asyncio .gather (* [get_outbounds_stats (node ) for _ , node in nodes ])
559+ api_params = {nodes [i ][0 ]: result for i , result in enumerate (stats_results )}
468560
469561 # Calculate per-node totals
470562 node_totals = {
@@ -505,8 +597,10 @@ async def record_node_usages():
505597 if DISABLE_RECORDING_NODE_USAGE :
506598 return
507599
508- record_tasks = [asyncio .create_task (record_node_stats (params , node_id )) for node_id , params in api_params .items ()]
509- await asyncio .gather (* record_tasks )
600+ # Gather record tasks directly without unnecessary task creation
601+ record_tasks = [record_node_stats (params , node_id ) for node_id , params in api_params .items ()]
602+ if record_tasks :
603+ await asyncio .gather (* record_tasks )
510604
511605
512606scheduler .add_job (
0 commit comments