Feature Request: Cross-Validation Utilities
Motivation
Machine learning models need validation to ensure they generalize well. Cross-validation is the gold standard for:
- Model evaluation: Assess performance on held-out data
- Hyperparameter tuning: Find optimal model parameters
- Overfitting detection: Identify when models memorize training data
- Scientific rigor: Industry best practice for ML pipelines
Proposed API
use aprender::prelude::*;
use aprender::model_selection::{cross_validate, KFold, StratifiedKFold};
// K-Fold Cross-Validation
let kfold = KFold::new(5); // 5-fold CV
let cv_results = cross_validate(
&model,
&x,
&y,
&kfold,
&["accuracy", "precision", "recall"],
)?;
println!("Mean accuracy: {:.3} ± {:.3}",
cv_results.mean("accuracy"),
cv_results.std("accuracy")
);
// Stratified K-Fold (balanced classes)
let stratified = StratifiedKFold::new(5);
let cv_results = cross_validate(&model, &x, &y, &stratified, &["f1_score"])?;
// Train/test split
let (x_train, x_test, y_train, y_test) = train_test_split(
&x,
&y,
0.2, // 20% test size
Some(42), // random state
)?;
Use Case: PMAT Mutation Testing
PMAT needs to validate mutant survivability predictions with cross-validation:
// From: server/src/services/mutation/cross_validation_test.rs
let mut predictor = SurvivabilityPredictor::new();
// Cross-validate prediction accuracy
let kfold = KFold::new(5);
let cv_scores = cross_validate(
&predictor,
&training_features,
&labels,
&kfold,
&["accuracy", "precision", "recall"],
)?;
assert!(cv_scores.mean("accuracy") > 0.7,
"Model should achieve >70% accuracy on cross-validation"
);
Core Requirements
K-Fold Cross-Validation:
- Split data into K folds
- Train on K-1 folds, test on held-out fold
- Return scores for each fold + mean/std
Stratified K-Fold:
- Maintain class distribution in each fold
- Critical for imbalanced datasets
Train/Test Split:
- Simple random split
- Stratified split (maintain class balance)
- Reproducible with random seed
Scoring Metrics:
- Support custom scoring functions
- Built-in metrics: accuracy, precision, recall, F1, R²
- Per-fold scores + aggregated statistics
Implementation Considerations
API Design:
- Follow scikit-learn conventions (familiar API)
- Work with aprender's
Estimator trait
- Support both supervised and unsupervised models
- Return
CrossValidationResult with scores + statistics
Performance:
- Parallel fold execution (use rayon)
- Memory efficient (don't copy full dataset K times)
- Progress reporting for long-running CV
Edge Cases:
- Handle class imbalance (stratification)
- Small dataset warnings (K > n_samples)
- Reproducible splits (seeded RNG)
Example Output
CrossValidationResult {
scores: HashMap<String, Vec<f64>>, // metric -> scores per fold
train_time: Vec<Duration>, // training time per fold
test_time: Vec<Duration>, // testing time per fold
}
// Helper methods
result.mean("accuracy"); // -> 0.85
result.std("accuracy"); // -> 0.03
result.scores("accuracy"); // -> [0.82, 0.87, 0.84, 0.86, 0.83]
Priority
High: Essential for validating ML models. Most ML practitioners expect cross-validation utilities out of the box.
References
Context: Filed as part of aprender integration planning for PMAT (Issue #1: Decision Trees)
Feature Request: Cross-Validation Utilities
Motivation
Machine learning models need validation to ensure they generalize well. Cross-validation is the gold standard for:
Proposed API
Use Case: PMAT Mutation Testing
PMAT needs to validate mutant survivability predictions with cross-validation:
Core Requirements
K-Fold Cross-Validation:
Stratified K-Fold:
Train/Test Split:
Scoring Metrics:
Implementation Considerations
API Design:
EstimatortraitCrossValidationResultwith scores + statisticsPerformance:
Edge Cases:
Example Output
Priority
High: Essential for validating ML models. Most ML practitioners expect cross-validation utilities out of the box.
References
server/src/services/mutation/cross_validation_test.rsContext: Filed as part of aprender integration planning for PMAT (Issue #1: Decision Trees)