Skip to content

Feature Request: Model Serialization (Save/Load) #3

@noahgift

Description

@noahgift

Feature Request: Model Serialization (Save/Load)

Motivation

Production ML systems need to persist trained models to disk for:

  1. Deployment: Train once, serve many times
  2. Reproducibility: Save model state for later analysis
  3. Versioning: Track model versions across experiments
  4. Performance: Avoid re-training expensive models
  5. Sharing: Distribute pre-trained models

Proposed API

use aprender::prelude::*;
use std::path::Path;

// Train and save
let mut model = LinearRegression::new();
model.fit(&x_train, &y_train)?;

// Save to disk
model.save("model.bin")?;
// or
model.save_json("model.json")?;

// Load from disk
let loaded_model = LinearRegression::load("model.bin")?;
let predictions = loaded_model.predict(&x_test);

// Verify loaded model matches original
assert_eq!(model.predict(&x_test), loaded_model.predict(&x_test));

Use Case: PMAT Mutation Testing

PMAT trains mutant survivability predictors that should persist across runs:

// From: server/src/services/mutation/ml_predictor_tests.rs (test at line 173)
let mut predictor = SurvivabilityPredictor::new();
predictor.train(&training_data)?;

// Save model
let model_path = Path::new("/tmp/test_model.bin");
predictor.save(&model_path)?;

// Load model (potentially on different machine)
let loaded = SurvivabilityPredictor::load(&model_path)?;
assert!(loaded.is_trained());

// Make predictions with loaded model
let prediction = loaded.predict(&test_mutant)?;
assert!(prediction.kill_probability >= 0.0 && prediction.kill_probability <= 1.0);

Implementation Options

Option 1: Binary Serialization (bincode)

  • Pros: Fast, compact, zero-copy deserialization
  • Cons: Not human-readable, Rust-specific
  • Library: bincode crate
use bincode;
use serde::{Serialize, Deserialize};

#[derive(Serialize, Deserialize)]
pub struct LinearRegression {
    coefficients: Vector,
    intercept: f64,
}

impl LinearRegression {
    pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
        let bytes = bincode::serialize(self)?;
        std::fs::write(path, bytes)?;
        Ok(())
    }

    pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
        let bytes = std::fs::read(path)?;
        let model = bincode::deserialize(&bytes)?;
        Ok(model)
    }
}

Option 2: JSON Serialization (serde_json)

  • Pros: Human-readable, language-agnostic, debuggable
  • Cons: Larger file size, slower
  • Library: serde_json crate
impl LinearRegression {
    pub fn save_json<P: AsRef<Path>>(&self, path: P) -> Result<()> {
        let json = serde_json::to_string_pretty(self)?;
        std::fs::write(path, json)?;
        Ok(())
    }

    pub fn load_json<P: AsRef<Path>>(path: P) -> Result<Self> {
        let json = std::fs::read_to_string(path)?;
        let model = serde_json::from_str(&json)?;
        Ok(model)
    }
}

Recommendation: Support both formats

  • Binary (.bin) for production (fast, compact)
  • JSON (.json) for debugging and inspection

Core Requirements

Model Types to Support:

  • LinearRegression (coefficients + intercept)
  • KMeans (centroids + cluster assignments)
  • Future: DecisionTree, RandomForest

Metadata to Include:

  • Model version (for backward compatibility)
  • Training timestamp
  • Feature names (if available)
  • Hyperparameters used
  • Training metrics (R², accuracy, etc.)

Error Handling:

  • Version mismatch detection
  • Corrupted file detection
  • Clear error messages

Example Serialized Format (JSON)

{
  "model_type": "LinearRegression",
  "version": "0.1.0",
  "trained_at": "2025-01-18T12:34:56Z",
  "coefficients": [1.5, -0.3, 2.1],
  "intercept": 0.7,
  "metadata": {
    "n_features": 3,
    "n_samples_train": 1000,
    "r_squared": 0.87
  }
}

Dependencies

Both serde and bincode/serde_json are already in Rust ecosystem:

[dependencies]
serde = { version = "1.0", features = ["derive"] }
bincode = "1.3"  # For binary serialization
serde_json = "1.0"  # For JSON serialization

Edge Cases

Version Migration:

#[derive(Serialize, Deserialize)]
pub struct ModelMetadata {
    version: String,
}

impl LinearRegression {
    pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
        let bytes = std::fs::read(path)?;
        
        // Check version
        let metadata: ModelMetadata = bincode::deserialize(&bytes[..32])?;
        if metadata.version != env!("CARGO_PKG_VERSION") {
            eprintln!("Warning: Model version {} doesn't match library version {}", 
                metadata.version, env!("CARGO_PKG_VERSION"));
        }
        
        let model = bincode::deserialize(&bytes)?;
        Ok(model)
    }
}

Priority

High: Essential for production deployment. Nearly all ML libraries (scikit-learn, PyTorch, TensorFlow) provide model persistence.

References


Context: Filed as part of aprender integration for PMAT (Issues #1: Decision Trees, #2: Cross-Validation)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions