Skip to content

aten.baddbmm with fp16 inputs are ran in fp32, causing perf issue #137897

@SherlockNoMad

Description

@SherlockNoMad

🐛 Describe the bug

Summary:
aten.baddbmm with fp16 inputs are decomposed to aten.bmm in fp32.
This resulted in performance regression.

For more details, see https://docs.google.com/document/d/1OOV1UaiwcQH58aGzvJByUn29i1wLphey4mXqcvS60cM/edit?usp=sharing

For following model

import torch
import torch.export._trace
from torch._inductor.decomposition import decompositions, get_decompositions
from torch.fx.experimental.proxy_tensor import make_fx


class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.weight = torch.nn.Parameter(torch.randn(64, 64, 192, dtype=torch.float16))
        self.bias = torch.nn.Parameter(torch.randn(64, 1, 192, dtype=torch.float16))

    def forward(self, x):
        return torch.ops.aten.baddbmm.default(self.bias, x, self.weight)

x = torch.randn(64, 2048, 64, dtype=torch.float16, requires_grad=False)
inputs = (x,)

m = M()
gm = make_fx(m, pre_dispatch=False, decomposition_table=decompositions)(*inputs)
gm.print_readable(print_output=True)

Before the fix:

class <lambda>(torch.nn.Module):
    def forward(self, arg0_1: "f16[64, 2048, 64]"):
        # No stacktrace found for following nodes
        _param_constant0 = self._param_constant0
        convert_element_type: "f32[64, 1, 192]" = torch.ops.prims.convert_element_type.default(_param_constant0, torch.float32);  _param_constant0 = None
        convert_element_type_1: "f32[64, 2048, 64]" = torch.ops.prims.convert_element_type.default(arg0_1, torch.float32);  arg0_1 = None
        _param_constant1 = self._param_constant1
        convert_element_type_2: "f32[64, 64, 192]" = torch.ops.prims.convert_element_type.default(_param_constant1, torch.float32);  _param_constant1 = None
        bmm: "f32[64, 2048, 192]" = torch.ops.aten.bmm.default(convert_element_type_1, convert_element_type_2);  convert_element_type_1 = convert_element_type_2 = None
        add: "f32[64, 2048, 192]" = torch.ops.aten.add.Tensor(convert_element_type, bmm);  convert_element_type = bmm = None
        convert_element_type_3: "f16[64, 2048, 192]" = torch.ops.prims.convert_element_type.default(add, torch.float16);  add = None
        return convert_element_type_3

see attempted fix at #137671

Error logs

No response

Minified repro

No response

Versions

main

cc @ezyang @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire

Metadata

Metadata

Assignees

Labels

module: inductoroncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions