Skip to content

Guided Diffusion fails guards with runtime errors. #93405

@0x6b64

Description

@0x6b64

🐛 Describe the bug

Hi,

I'm working with Jan 28 nightly build of PyTorch (nightly branch)
5d6a4f6

I've updated the training script to use torch.compile
Ref: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/train_util.py#L93

model = th.compile(model)

This is the model run script.

mpirun -np 1 --allow-run-as-root python guided-diffusion/scripts/image_train.py \
--data_dir "datasets/cifar" \
--attention_resolutions '32,16,8' \
--class_cond True \
--diffusion_steps 1000 \
--image_size 128 \
--channel_mult '1,1.5,2,4,5' \
--learn_sigma True \
--noise_schedule 'linear' \
--num_channels 256 \
--num_heads 4 \
--num_res_blocks 3 \
--resblock_updown True \
--use_fp16 True \
--use_scale_shift_norm True \
--lr 1e-4 \
--weight_decay 1e-4 \
--batch_size  8 \
--ema_rate  '0.9999' \
--log_interval  10 \
--save_interval  100000

Any pointers will be very helpful!

Error logs

Traceback (most recent call last):
  File "/opt/conda/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 363, in _compile
    check_fn = CheckFunctionManager(
  File "/opt/conda/lib/python3.9/site-packages/torch/_dynamo/guards.py", line 548, in __init__
    guard.create(local_builder, global_builder)
  File "/opt/conda/lib/python3.9/site-packages/torch/_guards.py", line 163, in create
    return self.create_fn(self.source.select(local_builder, global_builder), self)
  File "/opt/conda/lib/python3.9/site-packages/torch/_dynamo/guards.py", line 321, in LIST_LENGTH
    code.append(f"len({ref}) == {len(value)}")
TypeError: object of type 'tuple_iterator' has no len()

Minified repro

This is the repro after dynamo (doesn't get to AOT). This doesn't reproduce the issue. I'm beginning to doubt how useful these minifiers truly are, none of the ones I've generated recapture the issue encountered with full training.

import os
from math import inf
import torch
from torch import tensor, device
import torch.fx as fx
import functools
import torch._dynamo
from torch._dynamo.debug_utils import run_fwd_maybe_bwd
from torch._dynamo.optimizations.backends import BACKENDS
from torch._dynamo.testing import rand_strided

import torch._dynamo.config
import torch._inductor.config
torch._dynamo.config.load_config(b'\x80\x04\x95\xa9\x07\x00\x00\x00\x00\x00\x00}\x94(\x8c\x08__name__\x94\x8c\x14torch._dynamo.config\x94\x8c\x07__doc__\x94N\x8c\x0b__package__\x94\x8c\rtorch._dynamo\x94\x8c\n__loader__\x94\x8c\x1a_frozen_importlib_external\x94\x8c\x10SourceFileLoader\x94\x93\x94)\x81\x94}\x94(\x8c\x04name\x94h\x02\x8c\x04path\x94\x8c>/opt/conda/lib/python3.9/site-packages/torch/_dynamo/config.py\x94ub\x8c\x08__spec__\x94\x8c\x11_frozen_importlib\x94\x8c\nModuleSpec\x94\x93\x94)\x81\x94}\x94(h\x0ch\x02\x8c\x06loader\x94h\n\x8c\x06origin\x94h\x0e\x8c\x0cloader_state\x94N\x8c\x1asubmodule_search_locations\x94N\x8c\r_set_fileattr\x94\x88\x8c\x07_cached\x94\x8cV/opt/conda/lib/python3.9/site-packages/torch/_dynamo/__pycache__/config.cpython-39.pyc\x94\x8c\r_initializing\x94\x89ub\x8c\x08__file__\x94h\x0e\x8c\n__cached__\x94h\x1b\x8c\x07abspath\x94\x8c\tposixpath\x94h\x1f\x93\x94\x8c\x07dirname\x94h h"\x93\x94\x8c\x0eHAS_REFS_PRIMS\x94\x88\x8c\tlog_level\x94K\x1e\x8c\x0boutput_code\x94\x89\x8c\rlog_file_name\x94N\x8c\x07verbose\x94\x89\x8c\x11output_graph_code\x94\x89\x8c\x12verify_correctness\x94\x89\x8c\x12minimum_call_count\x94K\x01\x8c\x15dead_code_elimination\x94\x88\x8c\x10cache_size_limit\x94K@\x8c\x14specialize_int_float\x94\x88\x8c\x0edynamic_shapes\x94\x89\x8c\x10guard_nn_modules\x94\x89\x8c\x0cnormalize_ir\x94\x89\x8c\x1btraceable_tensor_subclasses\x94\x8f\x94\x8c\x0fsuppress_errors\x94\x89\x8c\x15replay_record_enabled\x94\x89\x8c rewrite_assert_with_torch_assert\x94\x88\x8c\x12print_graph_breaks\x94\x89\x8c\x07disable\x94\x89\x8c*allowed_functions_module_string_ignorelist\x94\x8f\x94(\x8c\x13torch.distributions\x94\x8c\rtorch.testing\x94\x8c\rtorch._decomp\x94\x8c\x0btorch._refs\x94\x8c\x0ctorch._prims\x94\x90\x8c\x16capture_scalar_outputs\x94\x89\x8c\x19enforce_cond_guards_match\x94\x88\x8c\x0coptimize_ddp\x94\x88\x8c\x1araise_on_ctx_manager_usage\x94\x88\x8c\x1craise_on_unsafe_aot_autograd\x94\x89\x8c\rdynamo_import\x94\x8c\rtorch._dynamo\x94\x8c\x0finductor_import\x94\x8c\x0ftorch._inductor\x94\x8c\x18error_on_nested_fx_trace\x94\x88\x8c\tallow_rnn\x94\x89\x8c\x08base_dir\x94\x8c&/opt/conda/lib/python3.9/site-packages\x94\x8c\x0edebug_dir_root\x94\x8c>/fsx/karan/pt2_benchmarks/guided_diffusion/torch_compile_debug\x94\x8c)DO_NOT_USE_legacy_non_fake_example_inputs\x94\x89\x8c\x15_AccessLimitingConfig\x94}\x94(\x8c\n__module__\x94h\x02\x8c\x0b__setattr__\x94h\x02\x8c!_AccessLimitingConfig.__setattr__\x94\x93\x94h\x03Nu\x8c\x15_allowed_config_names\x94\x8f\x94(\x8c\x0brepro_level\x94\x8c\x02os\x94\x8c\x07logging\x94hCh5h\x1dh\'h\x1fh8h9h\x1eh/hDhP\x8c\nModuleType\x94h\x04hEh"h\x03h$hMh%h,hAhOh*h(h\x06h)hJh1h\x0fh2hG\x8c\x12constant_functions\x94h\x01\x8c!skipfiles_inline_module_allowlist\x94h+hKh0h6h7\x8c\x0c__builtins__\x94h.\x8c\x0brepro_after\x94hB\x8c\x05torch\x94\x8c\x03sys\x94h-hIh&\x8c\x0eexternal_utils\x94h4h@\x90\x8c\x1cget_config_serialization_fns\x94\x8c\x1atorch._dynamo.config_utils\x94hc\x93\x94u.')
torch._inductor.config.load_config(b'\x80\x04\x95\xd3\x08\x00\x00\x00\x00\x00\x00}\x94(\x8c\x08__name__\x94\x8c\x16torch._inductor.config\x94\x8c\x07__doc__\x94N\x8c\x0b__package__\x94\x8c\x0ftorch._inductor\x94\x8c\n__loader__\x94\x8c\x1a_frozen_importlib_external\x94\x8c\x10SourceFileLoader\x94\x93\x94)\x81\x94}\x94(\x8c\x04name\x94h\x02\x8c\x04path\x94\x8c@/opt/conda/lib/python3.9/site-packages/torch/_inductor/config.py\x94ub\x8c\x08__spec__\x94\x8c\x11_frozen_importlib\x94\x8c\nModuleSpec\x94\x93\x94)\x81\x94}\x94(h\x0ch\x02\x8c\x06loader\x94h\n\x8c\x06origin\x94h\x0e\x8c\x0cloader_state\x94N\x8c\x1asubmodule_search_locations\x94N\x8c\r_set_fileattr\x94\x88\x8c\x07_cached\x94\x8cX/opt/conda/lib/python3.9/site-packages/torch/_inductor/__pycache__/config.cpython-39.pyc\x94\x8c\r_initializing\x94\x89ub\x8c\x08__file__\x94h\x0e\x8c\n__cached__\x94h\x1b\x8c\x05debug\x94\x89\x8c\x10disable_progress\x94\x88\x8c\x10verbose_progress\x94\x89\x8c\x0bcpp_wrapper\x94\x89\x8c\x03dce\x94\x89\x8c\x14static_weight_shapes\x94\x88\x8c\x0csize_asserts\x94\x88\x8c\x10pick_loop_orders\x94\x88\x8c\x0finplace_buffers\x94\x88\x8c\x11benchmark_harness\x94\x88\x8c\x0fepilogue_fusion\x94\x89\x8c\x15epilogue_fusion_first\x94\x89\x8c\x0cmax_autotune\x94\x89\x8c\x17realize_reads_threshold\x94K\x04\x8c\x17realize_bytes_threshold\x94M\xd0\x07\x8c\x1brealize_acc_reads_threshold\x94K\x08\x8c\x0ffallback_random\x94\x89\x8c\x12implicit_fallbacks\x94\x88\x8c\rprefuse_nodes\x94\x88\x8c\x0btune_layout\x94\x89\x8c\x11aggressive_fusion\x94\x89\x8c\x0fmax_fusion_size\x94K@\x8c\x1bunroll_reductions_threshold\x94K\x08\x8c\x0ecomment_origin\x94\x89\x8c\tis_fbcode\x94h\x02h7\x93\x94\x8c\x0fcompile_threads\x94K \x8c\x13kernel_name_max_ops\x94K\n\x8c\x0finductor_import\x94\x8c\x0ftorch._inductor\x94\x8c\rshape_padding\x94\x89\x8c\x0epermute_fusion\x94\x89\x8c\x1aprofiler_mark_wrapper_call\x94\x89\x8c\x03cpp\x94}\x94(\x8c\n__module__\x94h\x02\x8c\x07threads\x94J\xff\xff\xff\xff\x8c\x0fdynamic_threads\x94\x89\x8c\x07simdlen\x94N\x8c\x0emin_chunk_size\x94M\x00\x10\x8c\x03cxx\x94N\x8c\x03g++\x94\x86\x94\x8c\x15enable_kernel_profile\x94\x89h\x03Nu\x8c\x06triton\x94}\x94(hBh\x02\x8c\ncudagraphs\x94\x88\x8c\x10debug_sync_graph\x94\x89\x8c\x11debug_sync_kernel\x94\x89\x8c\x0bconvolution\x94\x8c\x04aten\x94\x8c\x0edense_indexing\x94\x89\x8c\tmax_tiles\x94K\x02\x8c\x12autotune_pointwise\x94\x88\x8c tiling_prevents_pointwise_fusion\x94\x88\x8c tiling_prevents_reduction_fusion\x94\x88\x8c\x14ordered_kernel_names\x94\x89\x8c\x18descriptive_kernel_names\x94\x89h\x03Nu\x8c\x05trace\x94}\x94(hBh\x02\x8c\x07enabled\x94\x89\x8c\tdebug_log\x94\x88\x8c\x08info_log\x94\x89\x8c\x08fx_graph\x94\x88\x8c\rir_pre_fusion\x94\x88\x8c\x0eir_post_fusion\x94\x88\x8c\x0boutput_code\x94\x88\x8c\rgraph_diagram\x94\x89\x8c\x0fcompile_profile\x94\x89\x8c\nupload_tar\x94Nh\x03Nu\x8c\x15InductorConfigContext\x94}\x94(hBh\x02\x8c\x0f__annotations__\x94}\x94(\x8c\rstatic_memory\x94\x8c\x08builtins\x94\x8c\x04bool\x94\x93\x94\x8c\x0ematmul_padding\x94hlh+hl\x8c\x12triton_convolution\x94hj\x8c\x03str\x94\x93\x94\x8c\x17rematerialize_threshold\x94hj\x8c\x03int\x94\x93\x94\x8c\x1brematerialize_acc_threshold\x94hsu\x8c\x05_save\x94h\x02\x8c\x1bInductorConfigContext._save\x94\x93\x94\x8c\x06_apply\x94h\x02\x8c\x1cInductorConfigContext._apply\x94\x93\x94\x8c\x08__init__\x94h\x02\x8c\x1eInductorConfigContext.__init__\x94\x93\x94\x8c\t__enter__\x94h\x02\x8c\x1fInductorConfigContext.__enter__\x94\x93\x94\x8c\x08__exit__\x94h\x02\x8c\x1eInductorConfigContext.__exit__\x94\x93\x94h\x03Nu\x8c\x1cget_config_serialization_fns\x94\x8c\x1atorch._dynamo.config_utils\x94h\x84\x93\x94u.')


# REPLACEABLE COMMENT FOR TESTING PURPOSES


args = [((8, 256), (256, 1), torch.float32, 'cuda', False)]
args = [rand_strided(sh, st, dt, dev).requires_grad_(rg) for (sh, st, dt, dev, rg) in args]


from torch.nn import *
class Repro(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.self_0 = Linear(in_features=256, out_features=1024, bias=True).cuda()
        self.self_1 = SiLU()



    def forward(self, input_1 : torch.Tensor):
        self_0 = self.self_0(input_1);  input_1 = None
        self_1 = self.self_1(self_0);  self_0 = None
        return (self_1,)


mod = Repro()

# Setup debug minifier compiler
torch._dynamo.debug_utils.MINIFIER_SPAWNED = True
compiler_fn = BACKENDS["dynamo_minifier_backend"]

dynamo_minifier_backend = functools.partial(
    compiler_fn,
    compiler_name="inductor",
)
opt_mod = torch._dynamo.optimize(dynamo_minifier_backend)(mod)

with torch.cuda.amp.autocast(enabled=False):
    opt_mod(*args)

Versions

Collecting environment information...
PyTorch version: 2.0.0a0+git5876d91
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: Could not collect
CMake version: version 3.24.3
Libc version: glibc-2.31

Python version: 3.9.13 | packaged by conda-forge | (main, May 27 2022, 16:58:50)  [GCC 10.3.0] (64-bit runtime)
Python platform: Linux-4.14.296-222.539.amzn2.x86_64-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 11.7.99
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: NVIDIA A100-SXM4-40GB 

Nvidia driver version: 470.103.01
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.5.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.5.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.5.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.5.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.5.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.5.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.5.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] bert-pytorch==0.0.1a4
[pip3] clip-anytorch==2.5.0
[pip3] CoCa-pytorch==0.0.7
[pip3] dalle2-pytorch==1.10.5
[pip3] ema-pytorch==0.1.4
[pip3] functorch==1.14.0a0+408bcf1
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.23.5
[pip3] pytorch-transformers==1.2.0
[pip3] pytorch-warmup==0.1.1
[pip3] rotary-embedding-torch==0.2.1
[pip3] torch==2.0.0a0+git5876d91
[pip3] torch-fidelity==0.3.0
[pip3] torch-struct==0.5
[pip3] torchaudio==2.0.0a0+4699ef2
[pip3] torchdata==0.6.0a0+a1612ee
[pip3] torchmetrics==0.11.0
[pip3] torchrec-nightly==2023.1.29
[pip3] torchtext==0.15.0a0+f653dac
[pip3] torchvision==0.15.0a0+c35e8d5
[pip3] vector-quantize-pytorch==0.10.15
[conda] bert-pytorch              0.0.1a4                   dev_0    <develop>
[conda] clip-anytorch             2.5.0                    pypi_0    pypi
[conda] coca-pytorch              0.0.7                    pypi_0    pypi
[conda] dalle2-pytorch            1.10.5                   pypi_0    pypi
[conda] ema-pytorch               0.1.4                    pypi_0    pypi
[conda] functorch                 1.14.0a0+408bcf1          pypi_0    pypi
[conda] magma-cuda117             2.6.1                         1    pytorch
[conda] mkl                       2022.2.1         h84fe81f_16997    conda-forge
[conda] mkl-include               2023.0.0         h84fe81f_25396    conda-forge
[conda] numpy                     1.23.5                   pypi_0    pypi
[conda] pytorch-transformers      1.2.0                    pypi_0    pypi
[conda] pytorch-warmup            0.1.1                    pypi_0    pypi
[conda] rotary-embedding-torch    0.2.1                    pypi_0    pypi
[conda] torch                     2.0.0a0+git5876d91          pypi_0    pypi
[conda] torch-fidelity            0.3.0                    pypi_0    pypi
[conda] torch-struct              0.5                      pypi_0    pypi
[conda] torchaudio                2.0.0a0+4699ef2          pypi_0    pypi
[conda] torchdata                 0.6.0a0+a1612ee          pypi_0    pypi
[conda] torchmetrics              0.11.0                   pypi_0    pypi
[conda] torchrec-nightly          2023.1.29                pypi_0    pypi
[conda] torchtext                 0.15.0a0+f653dac          pypi_0    pypi
[conda] torchvision               0.15.0a0+c35e8d5          pypi_0    pypi
[conda] vector-quantize-pytorch   0.10.15                  pypi_0    pypi

cc @ezyang @soumith @msaroufim @wconstab @ngimel @bdhirsh

Metadata

Metadata

Assignees

Labels

oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions