Skip to content

[fix] Support MPS in the cached losses' RandContext#3812

Merged
tomaarsen merged 1 commit into
huggingface:mainfrom
Incheonkirin:support-mps-cached-randcontext
Jun 12, 2026
Merged

[fix] Support MPS in the cached losses' RandContext#3812
tomaarsen merged 1 commit into
huggingface:mainfrom
Incheonkirin:support-mps-cached-randcontext

Conversation

@Incheonkirin

Copy link
Copy Markdown
Contributor

Fixes #3564.

CachedMultipleNegativesRankingLoss and CachedGISTEmbedLoss crash on MPS (Apple Silicon):

AttributeError: module 'torch.mps' has no attribute 'device'

RandContext backs up the device RNG state with torch.utils.checkpoint.get_device_states(*tensors), which calls the non-existent torch.mps.device() as soon as it sees an MPS tensor.

RandContext is there so GradCache's second (cached) forward replays the same randomness, so just skipping the device state on MPS would stop the crash but break that replay — dropout would differ between the two forwards. Instead I capture the MPS RNG state with torch.mps.get_rng_state() when an input is on MPS, pass only the non-MPS tensors to get_device_states(), restore the snapshot in __enter__, and restore the surrounding MPS state in __exit__ (the existing fork_rng() call defaults to device_type="cuda", so it does not cover MPS).

Both SentenceTransformer RandContext copies get the fix, so this also covers sparse CachedSpladeLoss, which imports the CachedMultipleNegativesRankingLoss one. The CrossEncoder cached loss calls RandContext(pairs) with raw text rather than tensors, so it never hit this and is left alone.

Added an MPS regression test (skipped without MPS) over both copies — it raises the AttributeError on main and passes here, and checks both the replay and that the outer RNG state is not leaked.

@tomaarsen

tomaarsen commented Jun 12, 2026

Copy link
Copy Markdown
Member

Hello!

Thanks for opening this. I totally missed the original issue as well, that's on me. Your changes here look pretty sound, but I just can't really reproduce them as I don't have any MPS devices available. @Incheonkirin Have you been able to reproduce this yourself?

  • Tom Aarsen

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Fixes MPS (Apple Silicon) crashes in cached-loss RNG replay by teaching the cached-loss RandContext to snapshot/restore MPS RNG state without calling torch.utils.checkpoint.get_device_states() on MPS tensors.

Changes:

  • Update RandContext in cached loss implementations to capture MPS RNG state via torch.mps.get_rng_state() and exclude MPS tensors from get_device_states().
  • Restore MPS RNG state on __enter__ for deterministic replay, and restore the surrounding MPS RNG state on __exit__ to avoid leaking RNG changes.
  • Add an MPS-only regression test that validates both “replay determinism” and “no RNG leakage” across both RandContext copies.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated no comments.

File Description
tests/sentence_transformer/losses/test_cmnrl.py Adds an MPS regression test covering both cached-loss RandContext implementations, including replay + no-leak assertions.
sentence_transformers/sentence_transformer/losses/cached_multiple_negatives_ranking.py Captures/restores MPS RNG state and filters MPS tensors out of get_device_states() to prevent the torch.mps.device crash while preserving replay determinism.
sentence_transformers/sentence_transformer/losses/cached_gist_embed.py Applies the same MPS-safe RNG snapshot/restore behavior to the GIST cached loss RandContext.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

RandContext backs up the device RNG state via torch.utils.checkpoint.get_device_states(),
which raises "module 'torch.mps' has no attribute 'device'" as soon as it sees an MPS tensor.

Capture the MPS RNG state for top-level MPS tensor arguments and pass only the non-MPS
arguments to get_device_states(). __enter__ restores the snapshot so the cached second
forward replays the same randomness; __exit__ restores the surrounding MPS state, since the
existing fork_rng() call defaults to device_type="cuda".

Covers CachedMultipleNegativesRankingLoss, CachedGISTEmbedLoss, and sparse CachedSpladeLoss
(which imports the CachedMultipleNegativesRankingLoss RandContext).

Fixes huggingface#3564
@Incheonkirin

Copy link
Copy Markdown
Contributor Author

Yes — I originally hit this crash on my own machine (Apple M4 Max), and I re-verified it today against current main with torch 2.12.0.

Repro (MPS):

import torch
from sentence_transformers import SentenceTransformer
from sentence_transformers.sentence_transformer.losses import CachedMultipleNegativesRankingLoss

m = SentenceTransformer("intfloat/multilingual-e5-small", device="mps")
loss = CachedMultipleNegativesRankingLoss(m, mini_batch_size=2)
mv = lambda f: {k: (v.to("mps") if torch.is_tensor(v) else v) for k, v in f.items()}
feats = [mv(m.tokenize(["query: a", "query: b", "query: c", "query: d"])), mv(m.tokenize(["passage: w", "passage: x", "passage: y", "passage: z"]))]
out = loss(feats, torch.zeros(4, dtype=torch.long).to("mps"))
out.backward()
print("NO CRASH — loss:", round(out.item(), 4))

On main this still fails:

Traceback (most recent call last):
  File "<stdin>", line 9, in <module>
  File "/private/tmp/st-3164/.venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1778, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/private/tmp/st-3164/.venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1789, in _call_impl
    return forward_call(*args, **kwargs)
  File "/private/tmp/st-3164/sentence_transformers/sentence_transformer/losses/cached_multiple_negatives_ranking.py", line 562, in forward
    for reps_mb, random_state in self.embed_minibatch_iter(
                                 ~~~~~~~~~~~~~~~~~~~~~~~~~^
        sentence_feature=sentence_feature,
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        with_grad=False,
        ^^^^^^^^^^^^^^^^
        copy_random_state=True,
        ^^^^^^^^^^^^^^^^^^^^^^^
    ):
    ^
  File "/private/tmp/st-3164/sentence_transformers/sentence_transformer/losses/cached_multiple_negatives_ranking.py", line 408, in embed_minibatch_iter
    reps, random_state = self.embed_minibatch(
                         ~~~~~~~~~~~~~~~~~~~~^
        sentence_feature=sentence_feature,
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    ...<4 lines>...
        random_state=None if random_states is None else random_states[i],
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/private/tmp/st-3164/sentence_transformers/sentence_transformer/losses/cached_multiple_negatives_ranking.py", line 385, in embed_minibatch
    random_state = RandContext(*sentence_feature_minibatch.values()) if copy_random_state else None
                   ~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/private/tmp/st-3164/sentence_transformers/sentence_transformer/losses/cached_multiple_negatives_ranking.py", line 32, in __init__
    self.fwd_gpu_devices, self.fwd_gpu_states = get_device_states(*tensors)
                                                ~~~~~~~~~~~~~~~~~^^^^^^^^^^
  File "/private/tmp/st-3164/.venv/lib/python3.13/site-packages/torch/utils/checkpoint.py", line 189, in get_device_states
    with device_module.device(device_id):
         ^^^^^^^^^^^^^^^^^^^^
AttributeError: module 'torch.mps' has no attribute 'device'

On this branch the same snippet prints NO CRASH — loss: 1.3717, and the MPS regression tests pass on the actual hardware:

$ pytest tests/sentence_transformer/losses/test_cmnrl.py -k "rand_context or mps"
4 passed, 10 deselected, 1 warning in 0.06s
$ pytest tests/sentence_transformer/losses/test_cmnrl.py
14 passed, 1 warning in 2.48s

I also rebased the branch onto current main (no conflicts after #3816/#3817/#3818) and re-ran everything above on the rebased state.

@Incheonkirin Incheonkirin force-pushed the support-mps-cached-randcontext branch from 680755c to fe38c7c Compare June 12, 2026 10:47

@tomaarsen tomaarsen left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's excellent news, thank you! Then I'd be happy to support merging this. I'll await the tests.

  • Tom Aarsen

@tomaarsen tomaarsen merged commit 429cf5d into huggingface:main Jun 12, 2026
15 of 18 checks passed
@Incheonkirin

Copy link
Copy Markdown
Contributor Author

Thanks! I appreciate you checking this even without an MPS device on hand. I re-tested it on current main after the rebase, so I'm glad this can now cover the cached-loss MPS path.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

CachedMultipleNegativesRankingLoss + MPS is broken

3 participants