Skip to content

Task 003: Migrate Existing Policies to New Trait #7536

@slin1237

Description

@slin1237

Task 003: Migrate Existing Policies to New Trait

Summary

Migrate all existing routing policies (Random, RoundRobin, CacheAware) to implement the new RoutingPolicy trait, ensuring they work seamlessly in both regular and PD routing modes.

Motivation

  • Eliminate code duplication between regular and PD routers
  • Enable all policies to work in both routing modes
  • Standardize policy implementation patterns
  • Make policies truly pluggable

Implementation Plan

1. Migrate Random Policy

// src/routing/policies/random.rs
use crate::core::Worker;
use crate::routing::policies::{RoutingPolicy, RoutingError};

pub struct RandomPolicy {
    rng: Arc<Mutex<rand::rngs::ThreadRng>>,
}

impl RandomPolicy {
    pub fn new() -> Self {
        Self {
            rng: Arc::new(Mutex::new(rand::thread_rng())),
        }
    }
}

impl RoutingPolicy for RandomPolicy {
    // Implementation from Task 002
}

2. Migrate RoundRobin Policy

// src/routing/policies/round_robin.rs
pub struct RoundRobinPolicy {
    regular_counter: Arc<AtomicUsize>,
    prefill_counter: Arc<AtomicUsize>,
    decode_counter: Arc<AtomicUsize>,
}

impl RoutingPolicy for RoundRobinPolicy {
    async fn select_single(&self, workers: &[Arc<dyn Worker>], _: &serde_json::Value) 
        -> Result<Arc<dyn Worker>, RoutingError> {
        let healthy_workers: Vec<_> = workers.iter()
            .filter(|w| w.is_healthy())
            .collect();
            
        if healthy_workers.is_empty() {
            return Err(RoutingError::NoHealthyWorkers);
        }
        
        let idx = self.regular_counter.fetch_add(1, Ordering::Relaxed) % healthy_workers.len();
        Ok(healthy_workers[idx].clone())
    }
    
    async fn select_pair(&self, prefill: &[Arc<dyn Worker>], decode: &[Arc<dyn Worker>], _: &serde_json::Value) 
        -> Result<(Arc<dyn Worker>, Arc<dyn Worker>), RoutingError> {
        // Use separate counters for prefill and decode
        let p_idx = self.prefill_counter.fetch_add(1, Ordering::Relaxed) % prefill.len();
        let d_idx = self.decode_counter.fetch_add(1, Ordering::Relaxed) % decode.len();
        Ok((prefill[p_idx].clone(), decode[d_idx].clone()))
    }
}

3. Migrate CacheAware Policy

// src/routing/policies/cache_aware.rs
pub struct CacheAwarePolicy {
    trees: Arc<DashMap<String, Tree>>,
    load_tracker: Arc<DashMap<String, AtomicUsize>>,
    config: CacheAwareConfig,
    eviction_handle: Option<JoinHandle<()>>,
}

impl CacheAwarePolicy {
    fn is_load_balanced(&self, workers: &[Arc<dyn Worker>]) -> bool {
        let loads: Vec<usize> = workers.iter()
            .map(|w| w.load().load(Ordering::Relaxed))
            .collect();
            
        let min = loads.iter().min().unwrap_or(&0);
        let max = loads.iter().max().unwrap_or(&0);
        
        let abs_diff = max - min;
        let rel_ratio = if *min > 0 { *max as f32 / *min as f32 } else { f32::MAX };
        
        abs_diff <= self.config.balance_abs_threshold && 
        rel_ratio <= self.config.balance_rel_threshold
    }
    
    fn select_by_cache_affinity(&self, workers: &[Arc<dyn Worker>], text: &str) 
        -> Result<Arc<dyn Worker>, RoutingError> {
        let mut best_worker = None;
        let mut best_match_rate = 0.0;
        let mut smallest_tree_size = usize::MAX;
        
        for worker in workers.iter().filter(|w| w.is_healthy()) {
            let tree = self.trees.entry(worker.url().to_string())
                .or_insert_with(|| Tree::new());
                
            let (match_len, tree_size) = {
                let tree_guard = tree.value();
                (tree_guard.prefix_match(text), tree_guard.size())
            };
            
            let match_rate = match_len as f32 / text.len() as f32;
            
            if match_rate > best_match_rate {
                best_match_rate = match_rate;
                best_worker = Some(worker);
            } else if match_rate < self.config.cache_threshold && tree_size < smallest_tree_size {
                smallest_tree_size = tree_size;
                best_worker = Some(worker);
            }
        }
        
        best_worker.cloned().ok_or(RoutingError::NoHealthyWorkers)
    }
}

4. Add PowerOfTwo Policy (New)

// src/routing/policies/power_of_two.rs
pub struct PowerOfTwoPolicy {
    load_monitor: Arc<LoadMonitor>,
}

impl RoutingPolicy for PowerOfTwoPolicy {
    async fn select_single(&self, workers: &[Arc<dyn Worker>], _: &serde_json::Value) 
        -> Result<Arc<dyn Worker>, RoutingError> {
        let healthy_workers: Vec<_> = workers.iter()
            .filter(|w| w.is_healthy())
            .collect();
            
        if healthy_workers.is_empty() {
            return Err(RoutingError::NoHealthyWorkers);
        }
        
        // Sample two random workers
        let mut rng = rand::thread_rng();
        let idx1 = rng.gen_range(0..healthy_workers.len());
        let idx2 = rng.gen_range(0..healthy_workers.len());
        
        // Return the one with lower load
        let load1 = healthy_workers[idx1].load().load(Ordering::Relaxed);
        let load2 = healthy_workers[idx2].load().load(Ordering::Relaxed);
        
        Ok(if load1 <= load2 {
            healthy_workers[idx1].clone()
        } else {
            healthy_workers[idx2].clone()
        })
    }
}

5. Update Router Usage

// Update routers to use policies uniformly
impl RegularRouter {
    pub async fn route(&self, request: &serde_json::Value) -> Result<String, RoutingError> {
        let workers = self.workers.read().await;
        let worker = self.policy.select_single(&workers, request).await?;
        Ok(worker.url().to_string())
    }
}

impl PdRouter {
    pub async fn route(&self, request: &serde_json::Value) -> Result<(String, String), RoutingError> {
        let prefill = self.prefill_workers.read().await;
        let decode = self.decode_workers.read().await;
        let (p, d) = self.policy.select_pair(&prefill, &decode, request).await?;
        Ok((p.url().to_string(), d.url().to_string()))
    }
}

Acceptance Criteria

  1. Policy Migration

    • Random policy migrated and tested
    • RoundRobin policy migrated and tested
    • CacheAware policy migrated and tested
    • PowerOfTwo policy implemented and tested
  2. Feature Parity

    • All policies work in regular mode
    • All policies work in PD mode
    • No regression in routing behavior
  3. Code Quality

    • No code duplication between policies
    • Consistent error handling
    • Proper resource cleanup (eviction threads)
  4. Performance

    • No performance regression
    • Load tracking accurate
    • Cache eviction working correctly
  5. Tests

    • Unit tests for each policy
    • Integration tests for policy switching
    • Load tests for cache-aware policy

Dependencies

  • Task 001: Worker Abstraction
  • Task 002: RoutingPolicy Trait

Estimated Effort

  • Implementation: 3 days
  • Testing: 2 days
  • Performance validation: 1 day
  • Total: 6 days

Risks

  • Risk: CacheAware policy behavior changes
    • Mitigation: Careful testing of cache hit rates
  • Risk: Memory leaks from trees
    • Mitigation: Ensure eviction thread cleanup

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions