Skip to content

Minimizing recompiles when compiling multiple inlined nn.Modules #141589

@salmanmohammadi

Description

@salmanmohammadi

🐛 Describe the bug

I'm working on integrating torch.compile into our RLHF recipe for torchtune (meta-pytorch/torchtune#2066). I'm seeing a lot of recompiles - some of which I can definitely try address on our side. I'm hoping for some advice on how I might go about this.

Context:
Our RLHF recipe involves 4 separate models, all of which I'd like to speed-up with torch.compile:

  • A policy model which requires optimization
  • A value model which requires optimization
  • A reference policy model which is frozen and which has params set to requires_grad=False
  • A reward model which is frozen and which has params set to requires_grad=False
    All of these models share the same architecture and model definition (with the exception of the value and reward model which have different final output layers).

When we compile a model in torchtune, we compile each attention layer in the model individually. IIUC the entire block is inlined and the same graph is shared across all attention layers, with the layer weights being parameterized. One graph which is reused across all the attention layers, across all four models, right?

Please see this minimal repro for how I'm using these models in our recipe, and the tlparse output below:

import os
import torch
from torchtune import training
from torchtune.models.llama3_2 import llama3_2_1b

os.environ["TORCH_COMPILE_BACKEND"] = "eager"

torch._logging.set_logs(recompiles=True, graph_breaks=True, perf_hints=True)
torch.set_default_device(torch.device("mps"))
torch.set_default_dtype(torch.bfloat16)

policy_model = llama3_2_1b()
value_model = llama3_2_1b()
reference_policy_model = llama3_2_1b()
reward_model = llama3_2_1b()

training.compile_model(policy_model)
training.compile_model(value_model)
training.compile_model(reference_policy_model)
training.compile_model(reward_model)

reference_policy_model.eval()
for param in reference_policy_model.parameters():
    param.requires_grad = False

reward_model.eval()
for param in reward_model.parameters():
    param.requires_grad = False


inps_full_seq = torch.ones(1, 64, dtype=torch.long)
inps_single_token = torch.ones(1, 1, dtype=torch.long)
inps_longer_seq_len = torch.ones(1, 68, dtype=torch.long)
inps_medium_seq_len = torch.ones(1, 66, dtype=torch.long)

with torch.no_grad():
    policy_model(inps_full_seq)
    # recompile due to seq len 1
    policy_model(inps_single_token)
    policy_model(inps_single_token)
    policy_model(inps_single_token)
    policy_model(inps_single_token)

    print("Running reference model with longer sequence length")
    # recompile - now graph is dynamic on dim 1
    reference_policy_model(inps_longer_seq_len)

    print("Running value model with longer sequence length")
    # recompile - param grad changed!
    value_model(inps_longer_seq_len)

    print("Running reward model with medium sequence length")
    reward_model(inps_medium_seq_len)

print("Optimizing")
out1 = policy_model(inps_full_seq)
out2 = value_model(inps_longer_seq_len)
out1.mean().backward()
out2.mean().backward()

"""outputs

2024-11-26:11:19:09,639 INFO     [_compile.py:50] Compiling model layers with torch.compile...
2024-11-26:11:19:09,655 INFO     [_compile.py:50] Compiling model layers with torch.compile...
2024-11-26:11:19:09,656 INFO     [_compile.py:50] Compiling model layers with torch.compile...
2024-11-26:11:19:09,657 INFO     [_compile.py:50] Compiling model layers with torch.compile...
V1126 11:19:10.402000 89573 torch/_dynamo/guards.py:2760] [0/1] [__recompiles] Recompiling function forward in /Users/salmanmohammadi/projects/torchtune/torchtune/modules/transformer.py:82
V1126 11:19:10.402000 89573 torch/_dynamo/guards.py:2760] [0/1] [__recompiles]     triggered by the following guard failure(s):
V1126 11:19:10.402000 89573 torch/_dynamo/guards.py:2760] [0/1] [__recompiles]     - 0/0: tensor 'L['x']' size mismatch at index 1. expected 64, actual 1
Running reference model with longer sequence length
V1126 11:19:11.178000 89573 torch/_dynamo/guards.py:2760] [0/2] [__recompiles] Recompiling function forward in /Users/salmanmohammadi/projects/torchtune/torchtune/modules/transformer.py:82
V1126 11:19:11.178000 89573 torch/_dynamo/guards.py:2760] [0/2] [__recompiles]     triggered by the following guard failure(s):
V1126 11:19:11.178000 89573 torch/_dynamo/guards.py:2760] [0/2] [__recompiles]     - 0/1: tensor 'L['x']' size mismatch at index 1. expected 1, actual 68
V1126 11:19:11.178000 89573 torch/_dynamo/guards.py:2760] [0/2] [__recompiles]     - 0/0: tensor 'L['x']' size mismatch at index 1. expected 64, actual 68
Running value model with longer sequence length
V1126 11:19:12.002000 89573 torch/_dynamo/guards.py:2760] [0/3] [__recompiles] Recompiling function forward in /Users/salmanmohammadi/projects/torchtune/torchtune/modules/transformer.py:82
V1126 11:19:12.002000 89573 torch/_dynamo/guards.py:2760] [0/3] [__recompiles]     triggered by the following guard failure(s):
V1126 11:19:12.002000 89573 torch/_dynamo/guards.py:2760] [0/3] [__recompiles]     - 0/2: tensor 'L['self']._modules['attn']._modules['q_proj']._parameters['weight']' requires_grad mismatch. expected requires_grad=0. Guard failed on a parameter, consider using torch._dynamo.config.force_parameter_static_shapes = False to allow dynamism on parameters.
V1126 11:19:12.002000 89573 torch/_dynamo/guards.py:2760] [0/3] [__recompiles]     - 0/1: tensor 'L['x']' size mismatch at index 1. expected 1, actual 68
V1126 11:19:12.002000 89573 torch/_dynamo/guards.py:2760] [0/3] [__recompiles]     - 0/0: tensor 'L['x']' size mismatch at index 1. expected 64, actual 68
Running reward model with medium sequence length
Optimizing
V1126 17:33:20.792000 95109 torch/_dynamo/guards.py:2760] [0/4] [__recompiles] Recompiling function forward in /Users/salmanmohammadi/projects/torchtune/torchtune/modules/transformer.py:82
V1126 17:33:20.792000 95109 torch/_dynamo/guards.py:2760] [0/4] [__recompiles]     triggered by the following guard failure(s):
V1126 17:33:20.792000 95109 torch/_dynamo/guards.py:2760] [0/4] [__recompiles]     - 0/2: GLOBAL_STATE changed: grad_mode 
V1126 17:33:20.792000 95109 torch/_dynamo/guards.py:2760] [0/4] [__recompiles]     - 0/3: GLOBAL_STATE changed: grad_mode 
V1126 17:33:20.792000 95109 torch/_dynamo/guards.py:2760] [0/4] [__recompiles]     - 0/1: GLOBAL_STATE changed: grad_mode 
V1126 17:33:20.792000 95109 torch/_dynamo/guards.py:2760] [0/4] [__recompiles]     - 0/0: GLOBAL_STATE changed: grad_mode 
"""

A few immediate questions I have:

  1. Can we somehow avoid the recompile on ._parameters["weight"].requires_grad? The relevant code is already being executed under no_grad. This seems like a relevant issue torch.compile should not recompiles when .requires_grad=True under torch.no_grad() context #131975. I suppose if these models are always run under no_grad I could avoid explicitly setting param.requires_grad=False? EDIT: I see a recompile on _module["attn"].training with this approach instead.
  2. I was thinking of trying to isolate some portions of the code - for example, we have a function which we know will always be calling policy_model on inputs of shape [bsz, 1, ...]. If I try move this into a function and compile it separately with dynamic=False, is there a way I can avoid re-using the graph generated when we do module.compile(), and use this very static function instead?
  3. Separately, if we divide the repro code above into their analagous counterparts in the recipe: a "generation" stage (under no_grad), and an "optimization" stage (the backwards after) - in the full recipe I only really see speedups for the generation stage. Infact, I pretty much always see slowdowns during the optimization stage vs. without compile. Could there be something immediately obvious behind this?

Thanks!

Error logs

dedicated_log_torch_trace_5kkmzqju.log

Versions

Collecting environment information...
PyTorch version: 2.6.0.dev20241121
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 14.4.1 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.3.9.4)
CMake version: Could not collect
Libc version: N/A

Python version: 3.12.4 (main, Jun  6 2024, 18:26:44) [Clang 15.0.0 (clang-1500.3.9.4)] (64-bit runtime)
Python platform: macOS-14.4.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M1 Max

Versions of relevant libraries:
[pip3] flake8==7.1.1
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.4
[pip3] pytorch_sphinx_theme==0.0.24
[pip3] torch==2.6.0.dev20241121
[pip3] torchao==0.5.0
[pip3] torchaudio==2.5.0.dev20241121
[pip3] torchtune==0.0.0
[pip3] torchvision==0.20.0.dev20241121
[conda] No relevant packages

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions