Skip to content

Remove upstream sharding + misc loading fixes#1291

Merged
rltakashige merged 8 commits intomainfrom
leo/fix-basic-model-shard
Jan 26, 2026
Merged

Remove upstream sharding + misc loading fixes#1291
rltakashige merged 8 commits intomainfrom
leo/fix-basic-model-shard

Conversation

@rltakashige
Copy link
Collaborator

Motivation

Some models, on some configurations, would have several issues that caused the model to be stuck on loading.

Changes

Several loading issues were with upstream mlx lm shard loading for tensor parallel.
GLM 4.7 Flash now uses GLM 4.7 Lite.
A final portion of the issues were from mlx memory not being properly released before calling mx.eval(model), causing the system to run out of memory.

Test Plan

Manual Testing

Done a bunch (thanks @AlexCheema), hopefully exhaustive.

Automated Testing

A bunch of automated testing is imminent but not landed yet.

@rltakashige rltakashige enabled auto-merge (squash) January 26, 2026 17:42
@rltakashige rltakashige changed the title Leo/fix basic model shard Remove upstream sharding + misc loading fixes Jan 26, 2026
Copy link
Member

@JakeHillion JakeHillion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Glad to get off of the mlx-lm fork.

@rltakashige rltakashige merged commit 9968abe into main Jan 26, 2026
8 checks passed
@rltakashige rltakashige deleted the leo/fix-basic-model-shard branch January 26, 2026 17:49
Comment on lines +150 to +152
output = mx.distributed.all_gather(output, group=self.group)[
-output.shape[0] :
] # type :ignore
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this intended to come back here? also the type ignore looks vaguely wrong

leocamello added a commit to leocamello/exo that referenced this pull request Jan 26, 2026
This PR makes exo engine-agnostic by adding PyTorch as an inference backend,
enabling Linux systems with NVIDIA GPUs to run inference.

## Architecture Changes

- **Engine abstraction**: Created base_engine.py with Engine interface
- **MlxEngine**: Moved all MLX-specific patches into MlxEngine.generate()
  - Includes KV prefix cache for faster prefill (upstream feature exo-explore#1262)
  - Properly passes kv_prefix_cache to mlx_generate()
- **PytorchEngine**: New engine for HuggingFace transformers on NVIDIA GPUs
- **Engine-agnostic runner**: runner.py no longer imports MLX at top level
- **Conditional imports**: bootstrap.py selects engine based on instance type

## New Files

- src/exo/worker/engines/base_engine.py - Abstract Engine interface
- src/exo/worker/engines/pytorch/__init__.py - PyTorch engine implementation
- src/exo/worker/engines/pytorch/auto_parallel.py - Pipeline parallelism for PyTorch
- src/exo/worker/engines/mlx/patches.py - Extracted MLX-specific helpers
- src/exo/utils/info_gatherer/linux_metrics.py - nvidia-smi GPU metrics

## Upstream Features Preserved

- KV prefix cache (exo-explore#1262) - integrated into MlxEngine
- Empty message fix (exo-explore#1292) - in utils_mlx.py
- Model shard loading fix (exo-explore#1291) - in auto_parallel.py

## Dashboard Changes

- Model-engine compatibility filtering
- Reordered instance type buttons (MLX above PyTorch)
- Fixed matchesSelectedRuntime() for PyTorch
- Model dropdown reset on instance type change

## Bug Fixes

- placement.py: Added missing logger import
- api.py: Model/engine compatibility validation
- api.py: Fixed tags hardcoding (tags=card.tags or [])
- test_event_ordering.py: Updated to use MockEngine instead of MLX patches

## Testing

- 20+ tests for nvidia-smi parsing edge cases
- PyTorch engine tests
- All existing tests pass (137 passed)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants