Task 003: Migrate Existing Policies to New Trait
Summary
Migrate all existing routing policies (Random, RoundRobin, CacheAware) to implement the new RoutingPolicy trait, ensuring they work seamlessly in both regular and PD routing modes.
Motivation
- Eliminate code duplication between regular and PD routers
- Enable all policies to work in both routing modes
- Standardize policy implementation patterns
- Make policies truly pluggable
Implementation Plan
1. Migrate Random Policy
// src/routing/policies/random.rs
use crate::core::Worker;
use crate::routing::policies::{RoutingPolicy, RoutingError};
pub struct RandomPolicy {
rng: Arc<Mutex<rand::rngs::ThreadRng>>,
}
impl RandomPolicy {
pub fn new() -> Self {
Self {
rng: Arc::new(Mutex::new(rand::thread_rng())),
}
}
}
impl RoutingPolicy for RandomPolicy {
// Implementation from Task 002
}
2. Migrate RoundRobin Policy
// src/routing/policies/round_robin.rs
pub struct RoundRobinPolicy {
regular_counter: Arc<AtomicUsize>,
prefill_counter: Arc<AtomicUsize>,
decode_counter: Arc<AtomicUsize>,
}
impl RoutingPolicy for RoundRobinPolicy {
async fn select_single(&self, workers: &[Arc<dyn Worker>], _: &serde_json::Value)
-> Result<Arc<dyn Worker>, RoutingError> {
let healthy_workers: Vec<_> = workers.iter()
.filter(|w| w.is_healthy())
.collect();
if healthy_workers.is_empty() {
return Err(RoutingError::NoHealthyWorkers);
}
let idx = self.regular_counter.fetch_add(1, Ordering::Relaxed) % healthy_workers.len();
Ok(healthy_workers[idx].clone())
}
async fn select_pair(&self, prefill: &[Arc<dyn Worker>], decode: &[Arc<dyn Worker>], _: &serde_json::Value)
-> Result<(Arc<dyn Worker>, Arc<dyn Worker>), RoutingError> {
// Use separate counters for prefill and decode
let p_idx = self.prefill_counter.fetch_add(1, Ordering::Relaxed) % prefill.len();
let d_idx = self.decode_counter.fetch_add(1, Ordering::Relaxed) % decode.len();
Ok((prefill[p_idx].clone(), decode[d_idx].clone()))
}
}
3. Migrate CacheAware Policy
// src/routing/policies/cache_aware.rs
pub struct CacheAwarePolicy {
trees: Arc<DashMap<String, Tree>>,
load_tracker: Arc<DashMap<String, AtomicUsize>>,
config: CacheAwareConfig,
eviction_handle: Option<JoinHandle<()>>,
}
impl CacheAwarePolicy {
fn is_load_balanced(&self, workers: &[Arc<dyn Worker>]) -> bool {
let loads: Vec<usize> = workers.iter()
.map(|w| w.load().load(Ordering::Relaxed))
.collect();
let min = loads.iter().min().unwrap_or(&0);
let max = loads.iter().max().unwrap_or(&0);
let abs_diff = max - min;
let rel_ratio = if *min > 0 { *max as f32 / *min as f32 } else { f32::MAX };
abs_diff <= self.config.balance_abs_threshold &&
rel_ratio <= self.config.balance_rel_threshold
}
fn select_by_cache_affinity(&self, workers: &[Arc<dyn Worker>], text: &str)
-> Result<Arc<dyn Worker>, RoutingError> {
let mut best_worker = None;
let mut best_match_rate = 0.0;
let mut smallest_tree_size = usize::MAX;
for worker in workers.iter().filter(|w| w.is_healthy()) {
let tree = self.trees.entry(worker.url().to_string())
.or_insert_with(|| Tree::new());
let (match_len, tree_size) = {
let tree_guard = tree.value();
(tree_guard.prefix_match(text), tree_guard.size())
};
let match_rate = match_len as f32 / text.len() as f32;
if match_rate > best_match_rate {
best_match_rate = match_rate;
best_worker = Some(worker);
} else if match_rate < self.config.cache_threshold && tree_size < smallest_tree_size {
smallest_tree_size = tree_size;
best_worker = Some(worker);
}
}
best_worker.cloned().ok_or(RoutingError::NoHealthyWorkers)
}
}
4. Add PowerOfTwo Policy (New)
// src/routing/policies/power_of_two.rs
pub struct PowerOfTwoPolicy {
load_monitor: Arc<LoadMonitor>,
}
impl RoutingPolicy for PowerOfTwoPolicy {
async fn select_single(&self, workers: &[Arc<dyn Worker>], _: &serde_json::Value)
-> Result<Arc<dyn Worker>, RoutingError> {
let healthy_workers: Vec<_> = workers.iter()
.filter(|w| w.is_healthy())
.collect();
if healthy_workers.is_empty() {
return Err(RoutingError::NoHealthyWorkers);
}
// Sample two random workers
let mut rng = rand::thread_rng();
let idx1 = rng.gen_range(0..healthy_workers.len());
let idx2 = rng.gen_range(0..healthy_workers.len());
// Return the one with lower load
let load1 = healthy_workers[idx1].load().load(Ordering::Relaxed);
let load2 = healthy_workers[idx2].load().load(Ordering::Relaxed);
Ok(if load1 <= load2 {
healthy_workers[idx1].clone()
} else {
healthy_workers[idx2].clone()
})
}
}
5. Update Router Usage
// Update routers to use policies uniformly
impl RegularRouter {
pub async fn route(&self, request: &serde_json::Value) -> Result<String, RoutingError> {
let workers = self.workers.read().await;
let worker = self.policy.select_single(&workers, request).await?;
Ok(worker.url().to_string())
}
}
impl PdRouter {
pub async fn route(&self, request: &serde_json::Value) -> Result<(String, String), RoutingError> {
let prefill = self.prefill_workers.read().await;
let decode = self.decode_workers.read().await;
let (p, d) = self.policy.select_pair(&prefill, &decode, request).await?;
Ok((p.url().to_string(), d.url().to_string()))
}
}
Acceptance Criteria
-
Policy Migration
-
Feature Parity
-
Code Quality
-
Performance
-
Tests
Dependencies
- Task 001: Worker Abstraction
- Task 002: RoutingPolicy Trait
Estimated Effort
- Implementation: 3 days
- Testing: 2 days
- Performance validation: 1 day
- Total: 6 days
Risks
- Risk: CacheAware policy behavior changes
- Mitigation: Careful testing of cache hit rates
- Risk: Memory leaks from trees
- Mitigation: Ensure eviction thread cleanup
Task 003: Migrate Existing Policies to New Trait
Summary
Migrate all existing routing policies (Random, RoundRobin, CacheAware) to implement the new RoutingPolicy trait, ensuring they work seamlessly in both regular and PD routing modes.
Motivation
Implementation Plan
1. Migrate Random Policy
2. Migrate RoundRobin Policy
3. Migrate CacheAware Policy
4. Add PowerOfTwo Policy (New)
5. Update Router Usage
Acceptance Criteria
Policy Migration
Feature Parity
Code Quality
Performance
Tests
Dependencies
Estimated Effort
Risks