Bug description
Description
When using the SAC with the option selective_ac_option=op, torch.compile is not applied to FlexAttention. This occurs despite TorchTitan explicitly adding calls to torch.compile(flex_attention).
As a result, the following warning is emitted and results in a very low performance:
[rank0]:
[rank0]:SOLUTION: Use torch.compile(flex_attention)(...)
[rank0]:
[rank0]:If you want to debug your score_mod/mask_mod, you can set:
[rank0]:torch.nn.attention.flex_attention._FLEX_ATTENTION_DISABLE_COMPILE_DEBUG = True
[rank0]:
[rank0]:This will allow you to use print statements or breakpoints. Note: This doesn't work with the backwards pass and may produce incorrect results.
Steps to Reproduce
- Use the default DeepSeek 16B configuration, or
- Use the debug model with the flag
--model.flavor=debugmodel_flex_attn.
Additional Information
- This issue does not occur when:
- No Activation Checkpointing (AC) is used,
- Full Activation Checkpointing is used,
- SAC is used with
selective_ac_option=2.
Versions
Nightly
Bug description
Description
When using the SAC with the option
selective_ac_option=op,torch.compileis not applied toFlexAttention. This occurs despite TorchTitan explicitly adding calls totorch.compile(flex_attention).As a result, the following warning is emitted and results in a very low performance:
Steps to Reproduce
--model.flavor=debugmodel_flex_attn.Additional Information
selective_ac_option=2.Versions
Nightly