Skip to content

skip_memory_metrics=False breaks training loop when on MPS device #27181

@Datamance

Description

@Datamance

System Info

Here's my environment info:

  • transformers version: 4.34.0
  • Platform: macOS-14.0-arm64-arm-64bit
  • Python version: 3.11.5
  • Huggingface_hub version: 0.17.3
  • Safetensors version: 0.4.0
  • Accelerate version: 0.23.0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.1.0 (False)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): 0.7.4 (cpu)
  • Jax version: 0.4.18
  • JaxLib version: 0.4.18
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: N/A

Who can help?

@muellerzr @pacman100

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

  1. Create a Trainer, pass skip_memory_metrics=False
  2. Receive ValueError: No available GPU device found!

The problem comes from this block of code in TrainerMemoryTracker that doesn't check for torch.mps.

Expected behavior

Unless there's some special case for MPS or apple silicon (in which case, it should be documented), I'd like to be able to log/profile memory metrics with the tooling here.

Specifically what I am trying to do is track down what looks like a memory leak in _inner_training_loop. Even with a batch size of 1, a single gradient accumulation step, and no eval metrics, Activity Monitor shows the memory footprint of my training script growing by about ~10 MB every 30-40 seconds or so. This wouldn't be a big deal normally, but the assignment I'm working on wants us to only use 15 GB GPU memory. My total memory footprint starts at about 14.3GB and pretty quickly reaches 15 after a few iterations. You can see the way I construct the trainer here.

Also open to hearing workarounds for this - it sounds like TrainerCallback could be useful here, something like:

class MpsCacheClearCallback(transformers.TrainerCallback):
    def on_epoch_end(
        self,
        args: TrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        **kwargs,
    ):
        gc.collect()
        torch.mps.empty_cache()
        gc.collect()

But I'm not clear on when it's a good idea to clear the cache. Also, semi-related shouldn't accelerator.free_memory / release_memory clear the mps cache as well?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions