Remove upstream sharding + misc loading fixes#1291
Merged
rltakashige merged 8 commits intomainfrom Jan 26, 2026
Merged
Conversation
If this doesn't, we should fall back to the other strategies we have
JakeHillion
approved these changes
Jan 26, 2026
Member
JakeHillion
left a comment
There was a problem hiding this comment.
Nice! Glad to get off of the mlx-lm fork.
Evanev7
reviewed
Jan 26, 2026
Comment on lines
+150
to
+152
| output = mx.distributed.all_gather(output, group=self.group)[ | ||
| -output.shape[0] : | ||
| ] # type :ignore |
Member
There was a problem hiding this comment.
is this intended to come back here? also the type ignore looks vaguely wrong
leocamello
added a commit
to leocamello/exo
that referenced
this pull request
Jan 26, 2026
This PR makes exo engine-agnostic by adding PyTorch as an inference backend, enabling Linux systems with NVIDIA GPUs to run inference. ## Architecture Changes - **Engine abstraction**: Created base_engine.py with Engine interface - **MlxEngine**: Moved all MLX-specific patches into MlxEngine.generate() - Includes KV prefix cache for faster prefill (upstream feature exo-explore#1262) - Properly passes kv_prefix_cache to mlx_generate() - **PytorchEngine**: New engine for HuggingFace transformers on NVIDIA GPUs - **Engine-agnostic runner**: runner.py no longer imports MLX at top level - **Conditional imports**: bootstrap.py selects engine based on instance type ## New Files - src/exo/worker/engines/base_engine.py - Abstract Engine interface - src/exo/worker/engines/pytorch/__init__.py - PyTorch engine implementation - src/exo/worker/engines/pytorch/auto_parallel.py - Pipeline parallelism for PyTorch - src/exo/worker/engines/mlx/patches.py - Extracted MLX-specific helpers - src/exo/utils/info_gatherer/linux_metrics.py - nvidia-smi GPU metrics ## Upstream Features Preserved - KV prefix cache (exo-explore#1262) - integrated into MlxEngine - Empty message fix (exo-explore#1292) - in utils_mlx.py - Model shard loading fix (exo-explore#1291) - in auto_parallel.py ## Dashboard Changes - Model-engine compatibility filtering - Reordered instance type buttons (MLX above PyTorch) - Fixed matchesSelectedRuntime() for PyTorch - Model dropdown reset on instance type change ## Bug Fixes - placement.py: Added missing logger import - api.py: Model/engine compatibility validation - api.py: Fixed tags hardcoding (tags=card.tags or []) - test_event_ordering.py: Updated to use MockEngine instead of MLX patches ## Testing - 20+ tests for nvidia-smi parsing edge cases - PyTorch engine tests - All existing tests pass (137 passed)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
Some models, on some configurations, would have several issues that caused the model to be stuck on loading.
Changes
Several loading issues were with upstream mlx lm shard loading for tensor parallel.
GLM 4.7 Flash now uses GLM 4.7 Lite.
A final portion of the issues were from mlx memory not being properly released before calling mx.eval(model), causing the system to run out of memory.
Test Plan
Manual Testing
Done a bunch (thanks @AlexCheema), hopefully exhaustive.
Automated Testing
A bunch of automated testing is imminent but not landed yet.