Skip to content

multi-tensor repos fail to download #127

@alfredodeza

Description

@alfredodeza

When using apr import with a repo like hf://microsoft/phi-4 which has 6 tensor files the import fails because there is no support for multi-tensor files. I tried adding multi-tensor downloading and merging but quickly ran into issues like OOM and disk usage.

Here is an executive summary that captures some of the changes I tried and some challenges:

Large models (14B+ parameters) are distributed as sharded SafeTensors files to work around file size limits. Example: microsoft/phi-4 has 6 shards totaling ~56GB. The original import implementation loaded all shards into memory simultaneously, causing >100GB memory usage and crashes.

Examples of changes:

1. Zero-Copy Streaming Merge (src/format/converter.rs:747-883)

Before:

// ❌ Loaded all 6 shards into memory at once
for shard in shards {
    let tensors = extract_all_tensors(shard)?;  // 56GB in RAM
    all_tensors.extend(tensors);
}
save_safetensors(&all_tensors)?;  // Another 56GB for output buffer

After:

// ✅ Stream tensors one at a time with smart caching
// Phase 1: Read index.json to map tensors → shards
let metadata_index = build_tensor_index(shard_paths)?;

// Phase 2: Stream tensors in alphabetical order (SafeTensors requirement)
let mut shard_cache = HashMap::new();  // LRU cache for 1-2 shards
for (tensor_name, (shard_idx, offsets)) in metadata_index {
    let raw_data = shard_cache.get_or_load(shard_idx)?;
    let tensor_bytes = &raw_data[offsets.start..offsets.end];
    writer.write_all(tensor_bytes)?;  // Direct write, no RAM accumulation
}

Memory reduction: 100GB → 2-5GB (95% reduction)

2. Parse model.safetensors.index.json

The index file maps tensors to their shard files:

{
  "weight_map": {
    "model.layers.0.weight": "model-00001-of-00006.safetensors",
    "model.layers.38.weight": "model-00006-of-00006.safetensors"
  }
}

Used to:

  • Discover shard count and filenames
  • Auto-detect sharded models (fallback when model.safetensors missing)
  • Know which tensors come from which shard

3. Fixed Offset Calculation Bug

Problem: Calculated tensor size from shape (elements × 4 bytes) didn't match SafeTensors actual size (has padding/alignment).

Fix: Use actual byte size from original shard metadata:

// ✅ Correct
let [start, end] = original_meta.data_offsets;
let actual_size = end - start;

4. HF Cache Location Instead of /tmp

Merged file now written to HuggingFace cache directory (where shards are) instead of /tmp, avoiding "Disk quota exceeded" errors.

Challenges & Limitations

1. APR Format Not Implemented

  • Currently writes SafeTensors format, not APR
  • Files have .apr extension but SafeTensors content
  • Causes "Invalid magic bytes" errors in CLI tools
  • TODO: Implement proper APR header with APRN/APR1 magic bytes

2. Sequential Shard Downloads

  • Downloads 6 shards one-by-one
  • Opportunity: Use tokio for parallel downloads (6x faster)

3. Tensor Ordering Fragility

  • SafeTensors requires alphabetical tensor order
  • Current impl relies on BTreeMap iteration guarantees
  • If tensor names interleave across shards → multiple shards in cache → higher memory
  • Opportunity: Group by shard first, reorder metadata

4. Index File Underutilized

  • We parse index.json but then reload all shard metadata anyway
  • Opportunity: Use index weight_map directly, skip metadata loading

5. No Resume on Failure

  • If merge fails at 90%, must restart from scratch
  • Opportunity: Implement checkpointing

Testing

# Verify streaming merge
cargo test --features hf-hub-integration merge_safetensors_files

# Test with real model
apr import hf://microsoft/phi-4 --output phi-4.apr --force
apr validate phi-4.apr

Files Changed

  • src/format/converter.rs - Streaming merge implementation (lines 747-883)
  • src/serialization/safetensors.rs - Streaming writer (lines 48-178)

Next Steps

  1. Implement proper APR format (magic bytes, compression)
  2. Add parallel shard downloads
  3. Implement merge checkpointing (resume on failure)
  4. Utilize index.json weight_map directly
  5. Add progress bars (indicatif)

Related

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions