[fix] Support MPS in the cached losses' RandContext#3812
Conversation
|
Hello! Thanks for opening this. I totally missed the original issue as well, that's on me. Your changes here look pretty sound, but I just can't really reproduce them as I don't have any MPS devices available. @Incheonkirin Have you been able to reproduce this yourself?
|
There was a problem hiding this comment.
Pull request overview
Fixes MPS (Apple Silicon) crashes in cached-loss RNG replay by teaching the cached-loss RandContext to snapshot/restore MPS RNG state without calling torch.utils.checkpoint.get_device_states() on MPS tensors.
Changes:
- Update
RandContextin cached loss implementations to capture MPS RNG state viatorch.mps.get_rng_state()and exclude MPS tensors fromget_device_states(). - Restore MPS RNG state on
__enter__for deterministic replay, and restore the surrounding MPS RNG state on__exit__to avoid leaking RNG changes. - Add an MPS-only regression test that validates both “replay determinism” and “no RNG leakage” across both
RandContextcopies.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated no comments.
| File | Description |
|---|---|
| tests/sentence_transformer/losses/test_cmnrl.py | Adds an MPS regression test covering both cached-loss RandContext implementations, including replay + no-leak assertions. |
| sentence_transformers/sentence_transformer/losses/cached_multiple_negatives_ranking.py | Captures/restores MPS RNG state and filters MPS tensors out of get_device_states() to prevent the torch.mps.device crash while preserving replay determinism. |
| sentence_transformers/sentence_transformer/losses/cached_gist_embed.py | Applies the same MPS-safe RNG snapshot/restore behavior to the GIST cached loss RandContext. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
RandContext backs up the device RNG state via torch.utils.checkpoint.get_device_states(), which raises "module 'torch.mps' has no attribute 'device'" as soon as it sees an MPS tensor. Capture the MPS RNG state for top-level MPS tensor arguments and pass only the non-MPS arguments to get_device_states(). __enter__ restores the snapshot so the cached second forward replays the same randomness; __exit__ restores the surrounding MPS state, since the existing fork_rng() call defaults to device_type="cuda". Covers CachedMultipleNegativesRankingLoss, CachedGISTEmbedLoss, and sparse CachedSpladeLoss (which imports the CachedMultipleNegativesRankingLoss RandContext). Fixes huggingface#3564
|
Yes — I originally hit this crash on my own machine (Apple M4 Max), and I re-verified it today against current Repro (MPS): import torch
from sentence_transformers import SentenceTransformer
from sentence_transformers.sentence_transformer.losses import CachedMultipleNegativesRankingLoss
m = SentenceTransformer("intfloat/multilingual-e5-small", device="mps")
loss = CachedMultipleNegativesRankingLoss(m, mini_batch_size=2)
mv = lambda f: {k: (v.to("mps") if torch.is_tensor(v) else v) for k, v in f.items()}
feats = [mv(m.tokenize(["query: a", "query: b", "query: c", "query: d"])), mv(m.tokenize(["passage: w", "passage: x", "passage: y", "passage: z"]))]
out = loss(feats, torch.zeros(4, dtype=torch.long).to("mps"))
out.backward()
print("NO CRASH — loss:", round(out.item(), 4))On On this branch the same snippet prints I also rebased the branch onto current |
680755c to
fe38c7c
Compare
tomaarsen
left a comment
There was a problem hiding this comment.
That's excellent news, thank you! Then I'd be happy to support merging this. I'll await the tests.
- Tom Aarsen
|
Thanks! I appreciate you checking this even without an MPS device on hand. I re-tested it on current main after the rebase, so I'm glad this can now cover the cached-loss MPS path. |
Fixes #3564.
CachedMultipleNegativesRankingLossandCachedGISTEmbedLosscrash on MPS (Apple Silicon):RandContextbacks up the device RNG state withtorch.utils.checkpoint.get_device_states(*tensors), which calls the non-existenttorch.mps.device()as soon as it sees an MPS tensor.RandContextis there so GradCache's second (cached) forward replays the same randomness, so just skipping the device state on MPS would stop the crash but break that replay — dropout would differ between the two forwards. Instead I capture the MPS RNG state withtorch.mps.get_rng_state()when an input is on MPS, pass only the non-MPS tensors toget_device_states(), restore the snapshot in__enter__, and restore the surrounding MPS state in__exit__(the existingfork_rng()call defaults todevice_type="cuda", so it does not cover MPS).Both SentenceTransformer
RandContextcopies get the fix, so this also covers sparseCachedSpladeLoss, which imports theCachedMultipleNegativesRankingLossone. The CrossEncoder cached loss callsRandContext(pairs)with raw text rather than tensors, so it never hit this and is left alone.Added an MPS regression test (skipped without MPS) over both copies — it raises the
AttributeErroron main and passes here, and checks both the replay and that the outer RNG state is not leaked.