Skip to content

Partitioner stores fp8 copy of all weights between fwd and bwd, causing OOM #141881

@lw

Description

@lw

🐛 Describe the bug

We have some code to convert linear layers to fp8. The weights are still stored in high precision, but we have an autograd.Function which converts them to fp8 in the forward and then again in the backward, in two slightly different ways. The autograd.Function does not save the weight for the backward.

However, when we compile our code, the partitioner ends up choosing to fuse the two conversions into a single one, and save one of its results for the backward. Concretely, this means that the partitioner is choosing to store an additional copy of the entire model in fp8 between forward and backward! This amounts to multiple GBs of extra memory occupied, and is preventing training large models.

I don't have a strong opinion on how this should be fixed. I do not think that the partitioner should be constrained to honor exactly what the autograd.Functions choose to keep/drop, but I do believe that the partitioner should take into account the amount of memory used by eager as an upper bound.

Versions

Installed from the pytorch-nightly conda channel, v2.6.0.dev20241107, build py3.12_cuda12.4_cudnn9.1.0_0.

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @chauhang @penguinwu

Metadata

Metadata

Assignees

No one assigned

    Labels

    high priorityoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions