Skip to content

[Q][GPU][BF16] torch.mul is lowered to HLO as an f32 multiply #8545

@apivovarov

Description

@apivovarov

❓ Questions and Help

torch 2.5.1
torch_xla 2.5.1
cuda 12.4
GPU NVIDIA L4

The following example uses torch.mul where both operands are bf16, but in the HLO graph, I see an f32 multiply operation.

export XLA_FLAGS="--xla_dump_to=/tmp/dump --xla_dump_hlo_as_text --xla_dump_hlo_pass_re=.*"
import torch
import torch_xla as xla

device = xla.device(0)

def foo(a, b):
  y = torch.mul(a, b)
  return y

a = torch.ones([5, 9216, 64], dtype=torch.bfloat16, device=device)
b = torch.ones([5, 9216, 64], dtype=torch.bfloat16, device=device)

y = foo(a, b)
print(y)

hlo: module_0000.SyncTensorsGraph.16.before_optimizations.txt

HloModule SyncTensorsGraph.16, entry_computation_layout={()->(bf16[5,9216,64]{2,1,0})}

ENTRY SyncTensorsGraph.16 {
  constant.7 = bf16[] constant(1)
  reshape.8 = bf16[1,1,1]{2,1,0} reshape(constant.7)
  broadcast.9 = bf16[1,1,1]{2,1,0} broadcast(reshape.8), dimensions={0,1,2}
  reshape.10 = bf16[] reshape(broadcast.9)
  broadcast.11 = bf16[5,9216,64]{2,1,0} broadcast(reshape.10), dimensions={}
  convert.12 = f32[5,9216,64]{2,1,0} convert(broadcast.11)
  constant.1 = bf16[] constant(1)
  reshape.2 = bf16[1,1,1]{2,1,0} reshape(constant.1)
  broadcast.3 = bf16[1,1,1]{2,1,0} broadcast(reshape.2), dimensions={0,1,2}
  reshape.4 = bf16[] reshape(broadcast.3)
  broadcast.5 = bf16[5,9216,64]{2,1,0} broadcast(reshape.4), dimensions={}
  convert.6 = f32[5,9216,64]{2,1,0} convert(broadcast.5)
  multiply.13 = f32[5,9216,64]{2,1,0} multiply(convert.12, convert.6)
  convert.14 = bf16[5,9216,64]{2,1,0} convert(multiply.13)
  ROOT tuple.15 = (bf16[5,9216,64]{2,1,0}) tuple(convert.14)
} // SyncTensorsGraph.16

I was able to achieve bf16 multiplication by setting export XLA_USE_BF16=1, but I received the following warning

XLA_USE_BF16 will be deprecated after the 2.5 release, please convert your model to bf16 directly

I'm not sure how I can enable bf16 multiplication in HLO (High-Level Optimizer) in the correct way, without using the deprecated flag.

Metadata

Metadata

Labels

loweringATen Operation lowering

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions