Skip to content

Feature Request: Cross-Validation Utilities #2

@noahgift

Description

@noahgift

Feature Request: Cross-Validation Utilities

Motivation

Machine learning models need validation to ensure they generalize well. Cross-validation is the gold standard for:

  1. Model evaluation: Assess performance on held-out data
  2. Hyperparameter tuning: Find optimal model parameters
  3. Overfitting detection: Identify when models memorize training data
  4. Scientific rigor: Industry best practice for ML pipelines

Proposed API

use aprender::prelude::*;
use aprender::model_selection::{cross_validate, KFold, StratifiedKFold};

// K-Fold Cross-Validation
let kfold = KFold::new(5); // 5-fold CV
let cv_results = cross_validate(
    &model,
    &x,
    &y,
    &kfold,
    &["accuracy", "precision", "recall"],
)?;

println!("Mean accuracy: {:.3} ± {:.3}", 
    cv_results.mean("accuracy"),
    cv_results.std("accuracy")
);

// Stratified K-Fold (balanced classes)
let stratified = StratifiedKFold::new(5);
let cv_results = cross_validate(&model, &x, &y, &stratified, &["f1_score"])?;

// Train/test split
let (x_train, x_test, y_train, y_test) = train_test_split(
    &x,
    &y,
    0.2,  // 20% test size
    Some(42),  // random state
)?;

Use Case: PMAT Mutation Testing

PMAT needs to validate mutant survivability predictions with cross-validation:

// From: server/src/services/mutation/cross_validation_test.rs
let mut predictor = SurvivabilityPredictor::new();

// Cross-validate prediction accuracy
let kfold = KFold::new(5);
let cv_scores = cross_validate(
    &predictor,
    &training_features,
    &labels,
    &kfold,
    &["accuracy", "precision", "recall"],
)?;

assert!(cv_scores.mean("accuracy") > 0.7, 
    "Model should achieve >70% accuracy on cross-validation"
);

Core Requirements

K-Fold Cross-Validation:

  • Split data into K folds
  • Train on K-1 folds, test on held-out fold
  • Return scores for each fold + mean/std

Stratified K-Fold:

  • Maintain class distribution in each fold
  • Critical for imbalanced datasets

Train/Test Split:

  • Simple random split
  • Stratified split (maintain class balance)
  • Reproducible with random seed

Scoring Metrics:

  • Support custom scoring functions
  • Built-in metrics: accuracy, precision, recall, F1, R²
  • Per-fold scores + aggregated statistics

Implementation Considerations

API Design:

  • Follow scikit-learn conventions (familiar API)
  • Work with aprender's Estimator trait
  • Support both supervised and unsupervised models
  • Return CrossValidationResult with scores + statistics

Performance:

  • Parallel fold execution (use rayon)
  • Memory efficient (don't copy full dataset K times)
  • Progress reporting for long-running CV

Edge Cases:

  • Handle class imbalance (stratification)
  • Small dataset warnings (K > n_samples)
  • Reproducible splits (seeded RNG)

Example Output

CrossValidationResult {
    scores: HashMap<String, Vec<f64>>,  // metric -> scores per fold
    train_time: Vec<Duration>,           // training time per fold
    test_time: Vec<Duration>,            // testing time per fold
}

// Helper methods
result.mean("accuracy");  // -> 0.85
result.std("accuracy");   // -> 0.03
result.scores("accuracy"); // -> [0.82, 0.87, 0.84, 0.86, 0.83]

Priority

High: Essential for validating ML models. Most ML practitioners expect cross-validation utilities out of the box.

References


Context: Filed as part of aprender integration planning for PMAT (Issue #1: Decision Trees)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions