Skip to content

[Inductor] [CPU] Crash failure in torchbench model hf_BigBird #93460

@yudongsi

Description

@yudongsi

🐛 Describe the bug

This failure found in the latest TorchInductor CPU Performance Dashboard refresh test with below error log

SW information

SW Nightly commit Master/Main commit
Pytorch f8506fb 39449ea
Torchbench / 2e5d723
torchaudio c44b576 8ba323b
torchtext ebcfed5 b3390fb
torchvision d0f2888 5b4f79d

Error logs

cpu  eval  hf_BigBird                         ERROR:common:Failed for dynamo While executing return (permute_28, cat)
Original traceback:
None
Traceback (most recent call last):
  File "/workspace/pytorch/benchmarks/dynamo/common.py", line 1189, in warmup
    fn(model, example_inputs)
  File "/workspace/pytorch/torch/_dynamo/eval_frame.py", line 193, in _fn
    return fn(*args, **kwargs)
  File "benchmarks/dynamo/torchbench.py", line 377, in forward_pass
    return mod(*inputs)
  File "/workspace/pytorch/torch/nn/modules/module.py", line 1480, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/transformers/models/big_bird/modeling_big_bird.py", line 2462, in forward
    outputs = self.bert(
  File "/workspace/pytorch/torch/nn/modules/module.py", line 1480, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/transformers/models/big_bird/modeling_big_bird.py", line 2092, in forward
    ) = self._pad_to_block_size(
  File "/opt/conda/lib/python3.8/site-packages/transformers/models/big_bird/modeling_big_bird.py", line 2148, in <graph break in forward>
    encoder_outputs = self.encoder(
  File "/workspace/pytorch/torch/nn/modules/module.py", line 1480, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/transformers/models/big_bird/modeling_big_bird.py", line 1641, in forward
    layer_outputs = layer_module(
  File "/workspace/pytorch/torch/nn/modules/module.py", line 1480, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/transformers/models/big_bird/modeling_big_bird.py", line 1493, in forward
    self_attention_outputs = self.attention(
  File "/workspace/pytorch/torch/nn/modules/module.py", line 1480, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/transformers/models/big_bird/modeling_big_bird.py", line 1406, in forward
    self_outputs = self.self(
  File "/workspace/pytorch/torch/nn/modules/module.py", line 1480, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/transformers/models/big_bird/modeling_big_bird.py", line 475, in forward
    context_layer, attention_probs = self.bigbird_block_sparse_attention(
  File "/opt/conda/lib/python3.8/site-packages/transformers/models/big_bird/modeling_big_bird.py", line 573, in bigbird_block_sparse_attention
    np.random.seed(seed)
  File "/opt/conda/lib/python3.8/site-packages/transformers/models/big_bird/modeling_big_bird.py", line 587, in <graph break in bigbird_block_sparse_attention>
    rand_attn = self._bigbird_block_rand_mask_with_head(
  File "/opt/conda/lib/python3.8/site-packages/transformers/models/big_bird/modeling_big_bird.py", line 597, in <graph break in bigbird_block_sparse_attention>
    rand_attn = np.stack(rand_attn, axis=0)
  File "/opt/conda/lib/python3.8/site-packages/transformers/models/big_bird/modeling_big_bird.py", line 598, in <graph break in bigbird_block_sparse_attention>
    rand_attn = torch.tensor(rand_attn, device=query_layer.device, dtype=torch.long)
  File "/opt/conda/lib/python3.8/site-packages/transformers/models/big_bird/modeling_big_bird.py", line 598, in <graph break in bigbird_block_sparse_attention>
    rand_attn = torch.tensor(rand_attn, device=query_layer.device, dtype=torch.long)
  File "/workspace/pytorch/torch/_dynamo/eval_frame.py", line 193, in _fn
    return fn(*args, **kwargs)
  File "/workspace/pytorch/functorch/_src/aot_autograd.py", line 1800, in forward
    return compiled_f(
  File "/workspace/pytorch/functorch/_src/aot_autograd.py", line 1792, in compiled_f
    compiled_fn = create_aot_dispatcher_function(
  File "/workspace/pytorch/torch/_dynamo/utils.py", line 90, in time_wrapper
    r = func(*args, **kwargs)
  File "/workspace/pytorch/functorch/_src/aot_autograd.py", line 1518, in create_aot_dispatcher_function
    return aot_dispatch_base(flat_fn, fake_flat_tensor_args, aot_config)
  File "/workspace/pytorch/functorch/_src/aot_autograd.py", line 852, in aot_dispatch_base
    compiled_fw = aot_config.fw_compiler(fw_module, flat_args)
  File "/workspace/pytorch/torch/_dynamo/utils.py", line 90, in time_wrapper
    r = func(*args, **kwargs)
  File "/workspace/pytorch/torch/_inductor/compile_fx.py", line 371, in fw_compiler
    return inner_compile(
  File "/workspace/pytorch/torch/_dynamo/debug_utils.py", line 473, in debug_wrapper
    compiled_fn = compiler_fn(gm, example_inputs, **kwargs)
  File "/workspace/pytorch/torch/_inductor/debug.py", line 177, in inner
    return fn(*args, **kwargs)
  File "/opt/conda/lib/python3.8/contextlib.py", line 75, in inner
    return func(*args, **kwds)
  File "/workspace/pytorch/torch/_inductor/compile_fx.py", line 137, in compile_fx_inner
    graph.run(*example_inputs)
  File "/workspace/pytorch/torch/_dynamo/utils.py", line 90, in time_wrapper
    r = func(*args, **kwargs)
  File "/workspace/pytorch/torch/_inductor/graph.py", line 138, in run
    return super().run(*args)
  File "/workspace/pytorch/torch/fx/interpreter.py", line 130, in run
    self.env[node] = self.run_node(node)
  File "/workspace/pytorch/torch/_inductor/graph.py", line 322, in run_node
    result = super().run_node(n)
  File "/workspace/pytorch/torch/fx/interpreter.py", line 171, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/workspace/pytorch/torch/_inductor/graph.py", line 296, in output
    assert isinstance(value, ir.StorageBox)
AssertionError: While executing return (permute_28, cat)
Original traceback:
None

Minified repro

import torch._inductor.overrides

import torch
from torch import tensor, device
import torch.fx as fx
from torch._dynamo.testing import rand_strided
from math import inf
from torch.fx.experimental.proxy_tensor import make_fx

# REPLACEABLE COMMENT FOR TESTING PURPOSES

# torch version: 1.14.0a0+git76ba93c
# torch cuda version: None
# torch git version: 76ba93c1cb4e9584e749e0a51bdfbe7bf186df90


# torch.cuda.is_available()==False, no GPU info collected

from torch.nn import *
class Repro(torch.nn.Module):
    def __init__(self):
        super().__init__()

    
    
    def forward(self, arg0_1, cat):
        unsqueeze_ = torch.ops.aten.unsqueeze_.default(arg0_1, 0);  arg0_1 = None
        return (cat,)
        
args = [((1, 1, 1, 12, 11, 3), (396, 396, 396, 33, 3, 1), torch.int64, 'cpu'), ((1, 1, 1, 12, 11, 3), (396, 396, 396, 33, 3, 1), torch.int64, 'cpu')]
args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args]
mod = make_fx(Repro())(*args)

from torch._inductor.compile_fx import compile_fx_inner
from torch._dynamo.debug_utils import same_two_models

compiled = compile_fx_inner(mod, args)
compiled(args)

Metadata

Metadata

Assignees

Labels

triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

Status

Done

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions