🐛 Describe the bug
Summary:
aten.baddbmm with fp16 inputs are decomposed to aten.bmm in fp32.
This resulted in performance regression.
For more details, see https://docs.google.com/document/d/1OOV1UaiwcQH58aGzvJByUn29i1wLphey4mXqcvS60cM/edit?usp=sharing
For following model
import torch
import torch.export._trace
from torch._inductor.decomposition import decompositions, get_decompositions
from torch.fx.experimental.proxy_tensor import make_fx
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.weight = torch.nn.Parameter(torch.randn(64, 64, 192, dtype=torch.float16))
self.bias = torch.nn.Parameter(torch.randn(64, 1, 192, dtype=torch.float16))
def forward(self, x):
return torch.ops.aten.baddbmm.default(self.bias, x, self.weight)
x = torch.randn(64, 2048, 64, dtype=torch.float16, requires_grad=False)
inputs = (x,)
m = M()
gm = make_fx(m, pre_dispatch=False, decomposition_table=decompositions)(*inputs)
gm.print_readable(print_output=True)
Before the fix:
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f16[64, 2048, 64]"):
# No stacktrace found for following nodes
_param_constant0 = self._param_constant0
convert_element_type: "f32[64, 1, 192]" = torch.ops.prims.convert_element_type.default(_param_constant0, torch.float32); _param_constant0 = None
convert_element_type_1: "f32[64, 2048, 64]" = torch.ops.prims.convert_element_type.default(arg0_1, torch.float32); arg0_1 = None
_param_constant1 = self._param_constant1
convert_element_type_2: "f32[64, 64, 192]" = torch.ops.prims.convert_element_type.default(_param_constant1, torch.float32); _param_constant1 = None
bmm: "f32[64, 2048, 192]" = torch.ops.aten.bmm.default(convert_element_type_1, convert_element_type_2); convert_element_type_1 = convert_element_type_2 = None
add: "f32[64, 2048, 192]" = torch.ops.aten.add.Tensor(convert_element_type, bmm); convert_element_type = bmm = None
convert_element_type_3: "f16[64, 2048, 192]" = torch.ops.prims.convert_element_type.default(add, torch.float16); add = None
return convert_element_type_3
see attempted fix at #137671
Error logs
No response
Minified repro
No response
Versions
main
cc @ezyang @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire
🐛 Describe the bug
Summary:
aten.baddbmm with fp16 inputs are decomposed to aten.bmm in fp32.
This resulted in performance regression.
For more details, see https://docs.google.com/document/d/1OOV1UaiwcQH58aGzvJByUn29i1wLphey4mXqcvS60cM/edit?usp=sharing
For following model
Before the fix:
see attempted fix at #137671
Error logs
No response
Minified repro
No response
Versions
main
cc @ezyang @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire