-
Notifications
You must be signed in to change notification settings - Fork 32.5k
Description
System Info
Here's my environment info:
transformersversion: 4.34.0- Platform: macOS-14.0-arm64-arm-64bit
- Python version: 3.11.5
- Huggingface_hub version: 0.17.3
- Safetensors version: 0.4.0
- Accelerate version: 0.23.0
- Accelerate config: not found
- PyTorch version (GPU?): 2.1.0 (False)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): 0.7.4 (cpu)
- Jax version: 0.4.18
- JaxLib version: 0.4.18
- Using GPU in script?: Yes
- Using distributed or parallel set-up in script?: N/A
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
- Create a Trainer, pass
skip_memory_metrics=False - Receive
ValueError: No available GPU device found!
The problem comes from this block of code in TrainerMemoryTracker that doesn't check for torch.mps.
Expected behavior
Unless there's some special case for MPS or apple silicon (in which case, it should be documented), I'd like to be able to log/profile memory metrics with the tooling here.
Specifically what I am trying to do is track down what looks like a memory leak in _inner_training_loop. Even with a batch size of 1, a single gradient accumulation step, and no eval metrics, Activity Monitor shows the memory footprint of my training script growing by about ~10 MB every 30-40 seconds or so. This wouldn't be a big deal normally, but the assignment I'm working on wants us to only use 15 GB GPU memory. My total memory footprint starts at about 14.3GB and pretty quickly reaches 15 after a few iterations. You can see the way I construct the trainer here.
Also open to hearing workarounds for this - it sounds like TrainerCallback could be useful here, something like:
class MpsCacheClearCallback(transformers.TrainerCallback):
def on_epoch_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
gc.collect()
torch.mps.empty_cache()
gc.collect()But I'm not clear on when it's a good idea to clear the cache. Also, semi-related shouldn't accelerator.free_memory / release_memory clear the mps cache as well?