Skip to content

Commit 3d6b8e8

Browse files
committed
feat: parallelize CPU-bound usage processing and optimize async code
- Add multiprocessing support for CPU-intensive usage calculations - Use ProcessPoolExecutor to distribute work across multiple CPU cores - Fix race conditions with proper async lock synchronization - Remove redundant asyncio.create_task() calls - use gather() directly - Replace deprecated get_event_loop() with get_running_loop() - Fix typos: stats_respons -> stats_response - Fix type hints: usage_coefficient int -> float - Remove redundant import alias - Optimize task creation and filtering logic - Add proper cleanup handler for process pool on shutdown Performance improvements: - Large datasets (>1000 params) now process in parallel across cores - Eliminates single-core CPU bottleneck - Reduces async overhead by removing unnecessary task wrapping
1 parent 33a40d7 commit 3d6b8e8

File tree

1 file changed

+134
-40
lines changed

1 file changed

+134
-40
lines changed

app/jobs/record_usages.py

Lines changed: 134 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import asyncio
2+
import multiprocessing
23
import random
34
from collections import defaultdict
5+
from concurrent.futures import ProcessPoolExecutor
46
from datetime import datetime as dt, timedelta as td, timezone as tz
57
from operator import attrgetter
68

@@ -12,11 +14,11 @@
1214
from sqlalchemy.exc import DatabaseError, OperationalError
1315
from sqlalchemy.sql.expression import Insert
1416

15-
from app import scheduler
17+
from app import on_shutdown, scheduler
1618
from app.db import GetDB
1719
from app.db.base import engine
1820
from app.db.models import Admin, Node, NodeUsage, NodeUserUsage, System, User
19-
from app.node import node_manager as node_manager
21+
from app.node import node_manager
2022
from app.utils.logger import get_logger
2123
from config import (
2224
DISABLE_RECORDING_NODE_USAGE,
@@ -26,6 +28,55 @@
2628

2729
logger = get_logger("record-usages")
2830

31+
# Process pool executor for CPU-bound operations
32+
# Use number of CPU cores, but cap at reasonable limit to avoid overhead
33+
_process_pool = None
34+
_process_pool_lock = asyncio.Lock()
35+
36+
37+
async def _get_process_pool():
38+
"""Get or create the process pool executor (thread-safe)."""
39+
global _process_pool
40+
async with _process_pool_lock:
41+
if _process_pool is None:
42+
num_workers = min(multiprocessing.cpu_count(), 8) # Cap at 8 workers
43+
_process_pool = ProcessPoolExecutor(max_workers=num_workers)
44+
logger.info(f"Initialized ProcessPoolExecutor with {num_workers} workers")
45+
return _process_pool
46+
47+
48+
@on_shutdown
49+
async def _cleanup_process_pool():
50+
"""Cleanup process pool on shutdown (thread-safe)."""
51+
global _process_pool
52+
async with _process_pool_lock:
53+
if _process_pool is not None:
54+
logger.info("Shutting down ProcessPoolExecutor...")
55+
_process_pool.shutdown(wait=True)
56+
_process_pool = None
57+
logger.info("ProcessPoolExecutor shut down successfully")
58+
59+
60+
# Helper functions for multiprocessing (must be at module level for pickling)
61+
def _process_node_chunk(chunk_data: tuple) -> dict:
62+
"""Process a chunk of node data - CPU-bound operation."""
63+
node_id, params, coeff = chunk_data
64+
users_usage = defaultdict(int)
65+
for param in params:
66+
uid = int(param["uid"])
67+
value = int(param["value"] * coeff)
68+
users_usage[uid] += value
69+
return dict(users_usage)
70+
71+
72+
def _merge_usage_dicts(dicts: list[dict]) -> dict:
73+
"""Merge multiple usage dictionaries."""
74+
merged = defaultdict(int)
75+
for d in dicts:
76+
for uid, value in d.items():
77+
merged[uid] += value
78+
return dict(merged)
79+
2980

3081
async def get_dialect() -> str:
3182
"""Get the database dialect name without holding the session open."""
@@ -244,14 +295,14 @@ async def safe_execute(stmt, params=None, max_retries: int = 5):
244295
raise
245296

246297

247-
async def record_user_stats(params: list[dict], node_id: int, usage_coefficient: int = 1):
298+
async def record_user_stats(params: list[dict], node_id: int, usage_coefficient: float = 1.0):
248299
"""
249300
Record user statistics for a specific node using UPSERT for efficiency.
250301
251302
Args:
252303
params (list[dict]): User statistic parameters
253304
node_id (int): Node identifier
254-
usage_coefficient (int, optional): usage multiplier
305+
usage_coefficient (float, optional): Usage multiplier (default: 1.0)
255306
"""
256307
if not params:
257308
return
@@ -313,9 +364,9 @@ async def record_node_stats(params: list[dict], node_id: int):
313364

314365
async def get_users_stats(node: PasarGuardNode):
315366
try:
316-
stats_respons = await node.get_stats(stat_type=StatType.UsersStat, reset=True, timeout=30)
367+
stats_response = await node.get_stats(stat_type=StatType.UsersStat, reset=True, timeout=30)
317368
params = defaultdict(int)
318-
for stat in filter(attrgetter("value"), stats_respons.stats):
369+
for stat in filter(attrgetter("value"), stats_response.stats):
319370
params[stat.name.split(".", 1)[0]] += stat.value
320371

321372
# Validate UIDs and filter out invalid ones
@@ -340,10 +391,10 @@ async def get_users_stats(node: PasarGuardNode):
340391

341392
async def get_outbounds_stats(node: PasarGuardNode):
342393
try:
343-
stats_respons = await node.get_stats(stat_type=StatType.Outbounds, reset=True, timeout=10)
394+
stats_response = await node.get_stats(stat_type=StatType.Outbounds, reset=True, timeout=10)
344395
params = [
345396
{"up": stat.value, "down": 0} if stat.type == "uplink" else {"up": 0, "down": stat.value}
346-
for stat in filter(attrgetter("value"), stats_respons.stats)
397+
for stat in filter(attrgetter("value"), stats_response.stats)
347398
]
348399
return params
349400
except NodeAPIError as e:
@@ -379,30 +430,77 @@ async def calculate_admin_usage(users_usage: list) -> tuple[dict, set[int]]:
379430

380431

381432
async def calculate_users_usage(api_params: dict, usage_coefficient: dict) -> list:
382-
"""Calculate aggregated user usage across all nodes with coefficients applied"""
383-
users_usage = defaultdict(int)
433+
"""Calculate aggregated user usage across all nodes with coefficients applied.
434+
435+
Uses multiprocessing to parallelize CPU-bound operations across multiple cores.
436+
"""
437+
if not api_params:
438+
return []
384439

385-
# Process all node data in a single pass
386-
for node_id, params in api_params.items():
387-
coeff = usage_coefficient.get(node_id, 1)
388-
# Use generator to avoid intermediate lists
389-
node_usage = ((int(param["uid"]), int(param["value"] * coeff)) for param in params)
390-
for uid, value in node_usage:
391-
users_usage[uid] += value
440+
# Prepare chunks for parallel processing
441+
chunks = [
442+
(node_id, params, usage_coefficient.get(node_id, 1))
443+
for node_id, params in api_params.items()
444+
if params # Skip empty params
445+
]
392446

393-
return [{"uid": uid, "value": value} for uid, value in users_usage.items()]
447+
if not chunks:
448+
return []
449+
450+
# For small datasets, process synchronously to avoid overhead
451+
total_params = sum(len(params) for _, params, _ in chunks)
452+
if total_params < 1000:
453+
# Small dataset - process synchronously
454+
users_usage = defaultdict(int)
455+
for node_id, params, coeff in chunks:
456+
for param in params:
457+
uid = int(param["uid"])
458+
value = int(param["value"] * coeff)
459+
users_usage[uid] += value
460+
return [{"uid": uid, "value": value} for uid, value in users_usage.items()]
461+
462+
# Large dataset - use multiprocessing
463+
loop = asyncio.get_running_loop()
464+
process_pool = await _get_process_pool()
465+
466+
# Process chunks in parallel
467+
tasks = [
468+
loop.run_in_executor(process_pool, _process_node_chunk, chunk)
469+
for chunk in chunks
470+
]
471+
472+
chunk_results = await asyncio.gather(*tasks)
473+
474+
# Merge results - this is also CPU-bound, so parallelize if many chunks
475+
if len(chunk_results) > 4:
476+
# Split merge operation into smaller chunks
477+
chunk_size = max(1, len(chunk_results) // 4)
478+
merge_chunks = [
479+
chunk_results[i:i + chunk_size]
480+
for i in range(0, len(chunk_results), chunk_size)
481+
]
482+
merge_tasks = [
483+
loop.run_in_executor(process_pool, _merge_usage_dicts, merge_chunk)
484+
for merge_chunk in merge_chunks
485+
]
486+
partial_results = await asyncio.gather(*merge_tasks)
487+
final_result = _merge_usage_dicts(partial_results)
488+
else:
489+
final_result = _merge_usage_dicts(chunk_results)
490+
491+
return [{"uid": uid, "value": value} for uid, value in final_result.items()]
394492

395493

396494
async def record_user_usages():
397495
nodes: tuple[int, PasarGuardNode] = await node_manager.get_healthy_nodes()
398496

399-
node_data = await asyncio.gather(*[asyncio.create_task(node.get_extra()) for _, node in nodes])
497+
# Gather node extra data directly without unnecessary task creation
498+
node_data = await asyncio.gather(*[node.get_extra() for _, node in nodes])
400499
usage_coefficient = {node_id: data.get("usage_coefficient", 1) for (node_id, _), data in zip(nodes, node_data)}
401500

402-
stats_tasks = [asyncio.create_task(get_users_stats(node)) for _, node in nodes]
403-
await asyncio.gather(*stats_tasks)
404-
405-
api_params = {nodes[i][0]: task.result() for i, task in enumerate(stats_tasks)}
501+
# Gather stats directly - asyncio.gather accepts coroutines, no need for create_task
502+
stats_results = await asyncio.gather(*[get_users_stats(node) for _, node in nodes])
503+
api_params = {nodes[i][0]: result for i, result in enumerate(stats_results)}
406504

407505
users_usage = await calculate_users_usage(api_params, usage_coefficient)
408506
if not users_usage:
@@ -413,6 +511,7 @@ async def record_user_usages():
413511
logger.warning("Skipping user usage recording; no matching users found for received stats")
414512
return
415513

514+
# Filter valid users - simple operation, no need to parallelize
416515
valid_users_usage = [usage for usage in users_usage if int(usage["uid"]) in valid_user_ids]
417516
if valid_users_usage:
418517
user_stmt = (
@@ -436,35 +535,28 @@ async def record_user_usages():
436535
if DISABLE_RECORDING_NODE_USAGE:
437536
return
438537

538+
# Create tasks only for nodes with valid filtered params
439539
record_tasks = []
440540
for node_id, params in api_params.items():
441541
filtered_params = [param for param in params if int(param["uid"]) in valid_user_ids]
442-
if not filtered_params:
443-
continue
444-
record_tasks.append(
445-
asyncio.create_task(
542+
if filtered_params:
543+
record_tasks.append(
446544
record_user_stats(
447545
params=filtered_params,
448546
node_id=node_id,
449-
usage_coefficient=usage_coefficient[node_id],
547+
usage_coefficient=usage_coefficient.get(node_id, 1.0),
450548
)
451549
)
452-
)
453550

454551
if record_tasks:
455552
await asyncio.gather(*record_tasks)
456553

457554

458555
async def record_node_usages():
459-
# Create tasks for all nodes
460-
tasks = {
461-
node_id: asyncio.create_task(get_outbounds_stats(node))
462-
for node_id, node in await node_manager.get_healthy_nodes()
463-
}
464-
465-
await asyncio.gather(*tasks.values())
466-
467-
api_params = {node_id: task.result() for node_id, task in tasks.items()}
556+
# Get healthy nodes and gather stats directly
557+
nodes = await node_manager.get_healthy_nodes()
558+
stats_results = await asyncio.gather(*[get_outbounds_stats(node) for _, node in nodes])
559+
api_params = {nodes[i][0]: result for i, result in enumerate(stats_results)}
468560

469561
# Calculate per-node totals
470562
node_totals = {
@@ -505,8 +597,10 @@ async def record_node_usages():
505597
if DISABLE_RECORDING_NODE_USAGE:
506598
return
507599

508-
record_tasks = [asyncio.create_task(record_node_stats(params, node_id)) for node_id, params in api_params.items()]
509-
await asyncio.gather(*record_tasks)
600+
# Gather record tasks directly without unnecessary task creation
601+
record_tasks = [record_node_stats(params, node_id) for node_id, params in api_params.items()]
602+
if record_tasks:
603+
await asyncio.gather(*record_tasks)
510604

511605

512606
scheduler.add_job(

0 commit comments

Comments
 (0)