fix: Update lowering passes in aten tracer FX#1708
fix: Update lowering passes in aten tracer FX#1708gs-olive wants to merge 1 commit intopytorch:mainfrom
aten tracer FX#1708Conversation
- Enable translation to `reshape` from `view`, which was causing failures when compiling BERT model due to memory layout of Tensors - Default to `matmul` within `compose_bmm` lowering pass when the dimension of inputs exceeds 3
4990f6c to
a063082
Compare
| for n in module.graph.nodes: | ||
| if n.op == "call_function" and n.target in ( | ||
| torch.ops.aten._unsafe_view.default, | ||
| torch.ops.aten.view.default, |
There was a problem hiding this comment.
It is not necessary to remove aten.view since the reshape operation is decomposed into aten.view(which is safe) and we have converter to support aten.view.
There was a problem hiding this comment.
I see - thank you for the clarification on that. The reason I had removed the view operator was for cases like this:
def forward(self, x):
x = x.permute(0, 2, 1, 3).contiguous()
new_shape = x.size()[:-2] + (-1,)
return x.view(new_shape)These show up in the GPT2 code, and when using the aten tracer, they result in the following error (though they run fine in Torch):
File "~/TensorRT/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py", line 161, in opt_trace
fx_module(*args)
File "/usr/local/lib/python3.8/dist-packages/torch/fx/graph_module.py", line 662, in call_wrapped
return self._wrapped_call(self, *args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/fx/graph_module.py", line 281, in __call__
raise e
File "/usr/local/lib/python3.8/dist-packages/torch/fx/graph_module.py", line 271, in __call__
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "<eval_with_key>.15", line 9, in forward
File "/usr/local/lib/python3.8/dist-packages/torch/_ops.py", line 329, in __call__
return self._op(*args, **kwargs or {})
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.| if len(real_other.meta["val"].size()) == 3: | ||
| elif len(real_other.meta["val"].size()) == 3: | ||
| new_func = aten_compose_bmm_3d | ||
| else: |
There was a problem hiding this comment.
Not clear why we need this new_func = torch.ops.aten.matmul? Any example or unit test?
There was a problem hiding this comment.
This addition is related to an issue in the compose_bmm lowering pass. I noticed that input_n can have a different shape than real_input, which causes the batch matrix multiply to have 4 dimensions instead of 3, reaching this else statement. I don't yet have a minimal reproducing example yet, as #1789 would likely need to be addressed first.
Description
reshapefromview, which was causing failures when compiling BERT model due to memory layout of Tensorsmatmulwithincompose_bmmlowering pass when the dimension of inputs exceeds 3Error displayed prior to
remove_opsview fix (BERT model from Issue #1673):Error displayed prior to
compose_bmmfix:Note:
test_reshape_atenis currently failing since theaten.view.defaultops are being converted toaten.reshapeFixes #1673
Type of change
Please delete options that are not relevant and/or add your own.
Checklist: