Task 005: Router Interface and Factory
Summary
Define a Router trait (interface) that abstracts routing operations and create a RouterFactory to centralize router creation logic. This enables clean separation between the server and router implementations, replacing the current enum-based approach.
Motivation
Current issues:
- Router is an enum with all logic embedded in match statements
- Server code directly depends on specific router implementations
- Difficult to extend with new router types
- No unified interface for routing operations
- Policy creation logic scattered
- No clear initialization pipeline
Implementation Plan
1. Define Router Trait
// src/router/mod.rs
use crate::core::Worker;
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest};
use actix_web::{HttpRequest, HttpResponse};
#[async_trait]
pub trait Router: Send + Sync {
/// Route a chat completion request
async fn route_chat_completion(
&self,
req: HttpRequest,
body: ChatCompletionRequest,
) -> HttpResponse;
/// Route a text completion request
async fn route_completion(
&self,
req: HttpRequest,
body: CompletionRequest,
) -> HttpResponse;
/// Route a generate request (SGLang specific)
async fn route_generate(
&self,
req: HttpRequest,
body: serde_json::Value,
) -> HttpResponse;
/// Add a worker dynamically
async fn add_worker(&self, worker: Arc<dyn Worker>) -> Result<(), RouterError>;
/// Remove a worker by URL
async fn remove_worker(&self, url: &str) -> Result<(), RouterError>;
/// List current worker URLs
async fn list_workers(&self) -> Vec<WorkerInfo>;
/// Get current load information
async fn get_loads(&self) -> LoadInfo;
/// Get router type for metrics/debugging
fn router_type(&self) -> &'static str;
/// Apply service discovery update
fn apply_discovery_update(&self, update: DiscoveryUpdate);
}
#[derive(Debug, Serialize)]
pub struct WorkerInfo {
pub url: String,
pub worker_type: WorkerType,
pub healthy: bool,
pub load: usize,
}
#[derive(Debug, Serialize)]
pub struct LoadInfo {
pub router_type: String,
pub total_workers: usize,
pub healthy_workers: usize,
pub total_load: usize,
pub worker_loads: Vec<(String, usize)>,
}
2. Create Router Factory
// src/router/factory.rs
use crate::config::{RouterConfig, PolicyConfig};
use crate::routing::{Router, RegularRouter, PdRouter};
use crate::routing::policies::{PolicyFactory, RoutingPolicy};
pub struct RouterFactory {
http_client: reqwest::Client,
worker_factory: WorkerFactory,
}
impl RouterFactory {
pub fn new() -> Self {
Self {
http_client: reqwest::Client::builder()
.timeout(Duration::from_secs(600))
.pool_max_idle_per_host(100)
.build()
.expect("Failed to create HTTP client"),
worker_factory: WorkerFactory::new(),
}
}
pub async fn create_router(
&self,
config: &RouterConfig,
) -> Result<Arc<dyn Router>, RouterError> {
// Create routing policy
let policy = PolicyFactory::create(&config.policy)?;
// Create router based on mode
match &config.mode {
RoutingMode::Regular { worker_urls } => {
self.create_regular_router(worker_urls, policy).await
}
RoutingMode::PrefillDecode { prefill_urls, decode_urls } => {
self.create_pd_router(prefill_urls, decode_urls, policy).await
}
}
}
async fn create_regular_router(
&self,
worker_urls: &[String],
policy: Arc<dyn RoutingPolicy>,
) -> Result<Arc<dyn Router>, RouterError> {
// Create workers with health checking
let mut workers = Vec::new();
for url in worker_urls {
let worker = self.worker_factory.create_regular(url.clone());
// Initial health check with timeout
match timeout(Duration::from_secs(30), worker.check_health()).await {
Ok(Ok(())) => {
info!("Worker {} is healthy", url);
workers.push(worker);
}
_ => {
warn!("Worker {} failed initial health check", url);
if workers.is_empty() && url == worker_urls.last().unwrap() {
return Err(RouterError::NoHealthyWorkers);
}
}
}
}
Ok(Arc::new(RegularRouter::new(
workers,
policy,
self.http_client.clone(),
)))
}
async fn create_pd_router(
&self,
prefill_urls: &[(String, Option<u16>)],
decode_urls: &[String],
policy: Arc<dyn RoutingPolicy>,
) -> Result<Arc<dyn Router>, RouterError> {
// Create prefill workers
let mut prefill_workers = Vec::new();
for (url, bootstrap_port) in prefill_urls {
let worker = self.worker_factory.create_prefill(url.clone(), *bootstrap_port);
if worker.check_health().await.is_ok() {
prefill_workers.push(worker);
}
}
// Create decode workers
let mut decode_workers = Vec::new();
for url in decode_urls {
let worker = self.worker_factory.create_decode(url.clone());
if worker.check_health().await.is_ok() {
decode_workers.push(worker);
}
}
if prefill_workers.is_empty() || decode_workers.is_empty() {
return Err(RouterError::NoHealthyWorkers);
}
Ok(Arc::new(PdRouter::new(
prefill_workers,
decode_workers,
policy,
self.http_client.clone(),
)))
}
}
3. Implement Router Trait for RegularRouter
// src/router/router.rs
pub struct RegularRouter {
workers: Arc<RwLock<Vec<Arc<dyn Worker>>>>,
policy: Arc<dyn RoutingPolicy>,
http_client: reqwest::Client,
}
impl RegularRouter {
pub fn new(
workers: Vec<Arc<dyn Worker>>,
policy: Arc<dyn RoutingPolicy>,
http_client: reqwest::Client,
) -> Self {
Self {
workers: Arc::new(RwLock::new(workers)),
policy,
http_client,
}
}
}
#[async_trait]
impl Router for RegularRouter {
async fn route_chat_completion(
&self,
req: HttpRequest,
body: ChatCompletionRequest,
) -> HttpResponse {
let start = Instant::now();
let route = "/v1/chat/completions";
// Convert to JSON for policy selection
let json_body = serde_json::to_value(&body).unwrap();
// Select worker using policy
let worker = {
let workers = self.workers.read().await;
match self.policy.select_single(&workers, &json_body).await {
Ok(w) => w,
Err(e) => {
RouterMetrics::record_routing_error(route, &e.to_string());
return HttpResponse::ServiceUnavailable()
.json(json!({ "error": e.to_string() }));
}
}
};
// Update load
worker.load().fetch_add(1, Ordering::Relaxed);
RouterMetrics::set_worker_load(worker.url(), worker.load().load(Ordering::Relaxed));
// Forward request
let response = self.forward_request(req, json_body, worker.url(), route).await;
// Update load and metrics
worker.load().fetch_sub(1, Ordering::Relaxed);
RouterMetrics::set_worker_load(worker.url(), worker.load().load(Ordering::Relaxed));
RouterMetrics::record_request_duration(route, start.elapsed());
response
}
async fn route_completion(
&self,
req: HttpRequest,
body: CompletionRequest,
) -> HttpResponse {
// Similar implementation for text completions
let json_body = serde_json::to_value(&body).unwrap();
self.route_internal(req, json_body, "/v1/completions").await
}
async fn route_generate(
&self,
req: HttpRequest,
body: serde_json::Value,
) -> HttpResponse {
self.route_internal(req, body, "/generate").await
}
async fn add_worker(&self, worker: Arc<dyn Worker>) -> Result<(), RouterError> {
if worker.worker_type() != WorkerType::Regular {
return Err(RouterError::InvalidWorkerType);
}
let mut workers = self.workers.write().await;
workers.push(worker);
Ok(())
}
async fn remove_worker(&self, url: &str) -> Result<(), RouterError> {
let mut workers = self.workers.write().await;
let initial_len = workers.len();
workers.retain(|w| w.url() != url);
if workers.len() == initial_len {
Err(RouterError::WorkerNotFound)
} else {
Ok(())
}
}
fn router_type(&self) -> &'static str {
"regular"
}
}
4. Implement Router Trait for PdRouter
// src/router/pd_router.rs
pub struct PdRouter {
prefill_workers: Arc<RwLock<Vec<Arc<dyn Worker>>>>,
decode_workers: Arc<RwLock<Vec<Arc<dyn Worker>>>>,
policy: Arc<dyn RoutingPolicy>,
http_client: reqwest::Client,
}
#[async_trait]
impl Router for PdRouter {
async fn route_chat_completion(
&self,
req: HttpRequest,
body: ChatCompletionRequest,
) -> HttpResponse {
let route = "/v1/chat/completions";
let mut json_body = serde_json::to_value(&body).unwrap();
// Select workers using policy
let (prefill_worker, decode_worker) = {
let prefill = self.prefill_workers.read().await;
let decode = self.decode_workers.read().await;
match self.policy.select_pair(&prefill, &decode, &json_body).await {
Ok((p, d)) => (p, d),
Err(e) => {
return HttpResponse::ServiceUnavailable()
.json(json!({ "error": e.to_string() }));
}
}
};
// Inject bootstrap content if needed
if let WorkerType::Prefill { bootstrap_port: Some(port) } = prefill_worker.worker_type() {
self.inject_bootstrap_content(&mut json_body, prefill_worker.url(), port).await;
}
// Forward to both workers
let prefill_future = self.forward_request(&req, &json_body, prefill_worker.url(), route);
let decode_future = self.forward_request(&req, &json_body, decode_worker.url(), route);
// Wait for both responses
let (prefill_result, decode_result) = join!(prefill_future, decode_future);
// Merge responses (especially for logprobs)
self.merge_responses(prefill_result, decode_result).await
}
async fn route_completion(
&self,
req: HttpRequest,
body: CompletionRequest,
) -> HttpResponse {
// Similar implementation
let json_body = serde_json::to_value(&body).unwrap();
self.route_pd_internal(req, json_body, "/v1/completions").await
}
async fn route_generate(
&self,
req: HttpRequest,
body: serde_json::Value,
) -> HttpResponse {
self.route_pd_internal(req, body, "/generate").await
}
async fn add_worker(&self, worker: Arc<dyn Worker>) -> Result<(), RouterError> {
match worker.worker_type() {
WorkerType::Prefill { .. } => {
let mut workers = self.prefill_workers.write().await;
workers.push(worker);
Ok(())
}
WorkerType::Decode => {
let mut workers = self.decode_workers.write().await;
workers.push(worker);
Ok(())
}
_ => Err(RouterError::InvalidWorkerType),
}
}
fn router_type(&self) -> &'static str {
"prefill_decode"
}
}
5. Update Server Handlers to Use Router Interface
// src/server.rs
/// Updated handler that uses the Router trait
pub async fn chat_completions_handler(
req: HttpRequest,
body: web::Json<ChatCompletionRequest>,
data: web::Data<AppState>,
) -> HttpResponse {
// The server no longer knows which router implementation is being used
data.router.route_chat_completion(req, body.into_inner()).await
}
pub async fn completions_handler(
req: HttpRequest,
body: web::Json<CompletionRequest>,
data: web::Data<AppState>,
) -> HttpResponse {
data.router.route_completion(req, body.into_inner()).await
}
pub async fn generate_handler(
req: HttpRequest,
body: web::Json<serde_json::Value>,
data: web::Data<AppState>,
) -> HttpResponse {
data.router.route_generate(req, body.into_inner()).await
}
pub async fn get_loads_handler(
data: web::Data<AppState>,
) -> HttpResponse {
let load_info = data.router.get_loads().await;
HttpResponse::Ok().json(load_info)
}
6. Update Server Initialization
// src/server.rs
pub async fn startup(config: RouterConfig) -> Result<(), ServerError> {
// Initialize observability
init_observability(config.observability.clone())?;
// Create router
let router_factory = RouterFactory::new();
let router = router_factory.create_router(&config).await?;
// Setup service discovery if enabled
let discovery_handle = if let Some(discovery_config) = config.discovery {
let manager = ServiceDiscoveryManager::new(
discovery_config,
Arc::new(move |update| router.apply_discovery_update(update)),
).await?;
Some(manager.start().await)
} else {
None
};
// Create app state
let app_state = AppState {
router,
config: config.clone(),
};
// Start HTTP server
HttpServer::new(move || {
App::new()
.app_data(web::Data::new(app_state.clone()))
.wrap(metrics_middleware)
.configure(configure_routes)
})
.bind((config.host, config.port))?
.run()
.await
}
Benefits
- Separation of Concerns: Server no longer knows about router internals
- Extensibility: Easy to add new router types without changing server code
- Type Safety: Each endpoint gets proper request types
- Testability: Can mock Router trait for server testing
- Future Ready: Supports the long-term vision of different router modes
Acceptance Criteria
-
Router Trait
-
Router Factory
-
Router Implementations
-
Server Integration
-
Testing
Dependencies
- Task 001: Worker Abstraction
- Task 002: RoutingPolicy Trait
- Task 003: Migrate Policies
Estimated Effort
- Implementation: 3 days
- Refactoring: 2 days
- Testing: 2 days
- Total: 7 days
Risks
- Risk: Breaking existing router behavior
- Mitigation: Extensive testing, gradual rollout
- Risk: Performance regression
- Mitigation: Benchmark before/after, optimize hot paths
Task 005: Router Interface and Factory
Summary
Define a Router trait (interface) that abstracts routing operations and create a RouterFactory to centralize router creation logic. This enables clean separation between the server and router implementations, replacing the current enum-based approach.
Motivation
Current issues:
Implementation Plan
1. Define Router Trait
2. Create Router Factory
3. Implement Router Trait for RegularRouter
4. Implement Router Trait for PdRouter
5. Update Server Handlers to Use Router Interface
6. Update Server Initialization
Benefits
Acceptance Criteria
Router Trait
Router Factory
Router Implementations
Server Integration
router: Arc<dyn Router>Testing
Dependencies
Estimated Effort
Risks