Skip to content

Instantly share code, notes, and snippets.

@maravichandran
maravichandran / maya_notes_pytorch_distributed_model_training.md
Last active September 6, 2024 02:19
PyTorch distributed model training notes

Distributed model training in PyTorch

  • training works in PyTorch in 3 simple steps:
    1. compute loss during the forward pass
    2. compute the gradients during the backward pass
    3. update the model using the optimizer
    • in single-GPU training, all of these steps will take place on the same GPU
  • PyTorch DDP: distributed data parallel
    • used if you can fit the entire model on one GPU, but want to distribute training across GPUs
    • how it works:
  1. each GPU is initialized with the same initial model and optimizer (the entire model weights are stored on the GPU)