Skip to content

Commit 3ad4f9b

Browse files
committed
feat: enhance concurrency in usage recording with thread pool support
- Introduced a ThreadPoolExecutor for handling I/O-bound node API calls, improving data processing efficiency. - Added thread pool initialization and cleanup functions to manage resources safely. - Refactored user and outbounds stats retrieval to utilize the thread pool for CPU-bound processing, distributing workload across multiple cores. - Extracted processing logic into separate functions for better modularity and clarity. These changes aim to optimize performance during high-load scenarios by leveraging concurrent execution for data processing tasks.
1 parent 4fdf7fc commit 3ad4f9b

File tree

1 file changed

+111
-38
lines changed

1 file changed

+111
-38
lines changed

app/jobs/record_usages.py

Lines changed: 111 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import random
44
import time
55
from collections import defaultdict
6-
from concurrent.futures import ProcessPoolExecutor
6+
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
77
from datetime import datetime as dt, timedelta as td, timezone as tz
88
from operator import attrgetter
99

@@ -41,6 +41,11 @@
4141
_process_pool = None
4242
_process_pool_lock = asyncio.Lock()
4343

44+
# Thread pool executor for I/O-bound node API calls
45+
# Distributes workload across threads/cores for data collection
46+
_thread_pool = None
47+
_thread_pool_lock = asyncio.Lock()
48+
4449

4550
async def _get_process_pool():
4651
"""Get or create the process pool executor (thread-safe)."""
@@ -53,6 +58,18 @@ async def _get_process_pool():
5358
return _process_pool
5459

5560

61+
async def _get_thread_pool():
62+
"""Get or create the thread pool executor (thread-safe)."""
63+
global _thread_pool
64+
async with _thread_pool_lock:
65+
if _thread_pool is None:
66+
# Use more threads for I/O-bound operations (2x CPU cores, cap at 16)
67+
num_workers = min(multiprocessing.cpu_count() * 2, 16)
68+
_thread_pool = ThreadPoolExecutor(max_workers=num_workers)
69+
logger.info(f"Initialized ThreadPoolExecutor with {num_workers} workers")
70+
return _thread_pool
71+
72+
5673
@on_shutdown
5774
async def _cleanup_process_pool():
5875
"""Cleanup process pool on shutdown (thread-safe)."""
@@ -65,6 +82,18 @@ async def _cleanup_process_pool():
6582
logger.info("ProcessPoolExecutor shut down successfully")
6683

6784

85+
@on_shutdown
86+
async def _cleanup_thread_pool():
87+
"""Cleanup thread pool on shutdown (thread-safe)."""
88+
global _thread_pool
89+
async with _thread_pool_lock:
90+
if _thread_pool is not None:
91+
logger.info("Shutting down ThreadPoolExecutor...")
92+
_thread_pool.shutdown(wait=True)
93+
_thread_pool = None
94+
logger.info("ThreadPoolExecutor shut down successfully")
95+
96+
6897
# Helper functions for multiprocessing (must be at module level for pickling)
6998
def _process_node_chunk(chunk_data: tuple) -> dict:
7099
"""Process a chunk of node data - CPU-bound operation."""
@@ -414,47 +443,91 @@ async def _record_single_node(node_id: int, params: list[dict]):
414443
await asyncio.gather(*tasks, return_exceptions=True)
415444

416445

446+
def _process_users_stats_response(stats_response):
447+
"""
448+
Process stats response (CPU-bound operation) - can run in thread pool.
449+
Extracted to separate function for threading.
450+
"""
451+
params = defaultdict(int)
452+
for stat in filter(attrgetter("value"), stats_response.stats):
453+
params[stat.name.split(".", 1)[0]] += stat.value
454+
455+
# Validate UIDs and filter out invalid ones
456+
validated_params = []
457+
for uid, value in params.items():
458+
try:
459+
uid_int = int(uid)
460+
validated_params.append({"uid": uid_int, "value": value})
461+
except (ValueError, TypeError):
462+
# Skip invalid UIDs that can't be converted to int
463+
logger.warning("Skipping invalid UID: %s", uid)
464+
continue
465+
466+
return validated_params
467+
468+
417469
async def get_users_stats(node: PasarGuardNode):
418-
try:
419-
stats_response = await node.get_stats(stat_type=StatType.UsersStat, reset=True, timeout=30)
420-
params = defaultdict(int)
421-
for stat in filter(attrgetter("value"), stats_response.stats):
422-
params[stat.name.split(".", 1)[0]] += stat.value
423-
424-
# Validate UIDs and filter out invalid ones
425-
validated_params = []
426-
for uid, value in params.items():
427-
try:
428-
uid_int = int(uid)
429-
validated_params.append({"uid": uid_int, "value": value})
430-
except (ValueError, TypeError):
431-
# Skip invalid UIDs that can't be converted to int
432-
logger.warning("Skipping invalid UID: %s", uid)
433-
continue
434-
435-
return validated_params
436-
except NodeAPIError as e:
437-
logger.error("Failed to get users stats, error: %s", e.detail)
438-
return []
439-
except Exception as e:
440-
logger.error("Failed to get users stats, unknown error: %s", e)
441-
return []
470+
"""
471+
Get user stats from node using thread pool for CPU-bound processing.
472+
This distributes the heavy data processing workload across cores.
473+
"""
474+
async with JOB_SEM:
475+
try:
476+
# I/O operation: fetch stats from node (async, non-blocking)
477+
stats_response = await node.get_stats(stat_type=StatType.UsersStat, reset=True, timeout=30)
478+
479+
# CPU-bound operation: process stats in thread pool to utilize multiple cores
480+
loop = asyncio.get_running_loop()
481+
thread_pool = await _get_thread_pool()
482+
validated_params = await loop.run_in_executor(
483+
thread_pool, _process_users_stats_response, stats_response
484+
)
485+
486+
return validated_params
487+
except NodeAPIError as e:
488+
logger.error("Failed to get users stats, error: %s", e.detail)
489+
return []
490+
except Exception as e:
491+
logger.error("Failed to get users stats, unknown error: %s", e)
492+
return []
493+
494+
495+
def _process_outbounds_stats_response(stats_response):
496+
"""
497+
Process outbounds stats response (CPU-bound operation) - can run in thread pool.
498+
Extracted to separate function for threading.
499+
"""
500+
params = [
501+
{"up": stat.value, "down": 0} if stat.type == "uplink" else {"up": 0, "down": stat.value}
502+
for stat in filter(attrgetter("value"), stats_response.stats)
503+
]
504+
return params
442505

443506

444507
async def get_outbounds_stats(node: PasarGuardNode):
445-
try:
446-
stats_response = await node.get_stats(stat_type=StatType.Outbounds, reset=True, timeout=10)
447-
params = [
448-
{"up": stat.value, "down": 0} if stat.type == "uplink" else {"up": 0, "down": stat.value}
449-
for stat in filter(attrgetter("value"), stats_response.stats)
450-
]
451-
return params
452-
except NodeAPIError as e:
453-
logger.error("Failed to get outbounds stats, error: %s", e.detail)
454-
return []
455-
except Exception as e:
456-
logger.error("Failed to get outbounds stats, unknown error: %s", e)
457-
return []
508+
"""
509+
Get outbounds stats from node using thread pool for CPU-bound processing.
510+
This distributes the heavy data processing workload across cores.
511+
"""
512+
async with JOB_SEM:
513+
try:
514+
# I/O operation: fetch stats from node (async, non-blocking)
515+
stats_response = await node.get_stats(stat_type=StatType.Outbounds, reset=True, timeout=10)
516+
517+
# CPU-bound operation: process stats in thread pool to utilize multiple cores
518+
loop = asyncio.get_running_loop()
519+
thread_pool = await _get_thread_pool()
520+
params = await loop.run_in_executor(
521+
thread_pool, _process_outbounds_stats_response, stats_response
522+
)
523+
524+
return params
525+
except NodeAPIError as e:
526+
logger.error("Failed to get outbounds stats, error: %s", e.detail)
527+
return []
528+
except Exception as e:
529+
logger.error("Failed to get outbounds stats, unknown error: %s", e)
530+
return []
458531

459532

460533
async def calculate_admin_usage(users_usage: list) -> tuple[dict, set[int]]:

0 commit comments

Comments
 (0)