❓ Questions and Help
Hi folks, I have a question about the XLA mul op.
When both inputs are bf16, the generated graph converts to f32, performs the multiply, then converts back to bf16. Two questions:
In this case, is the op math type effectively f32 (not bf16)?
If this upcast exists primarily for TPU accuracy/stability, would it be acceptable to gate it behind a flag (e.g., env option) so we can treat that path as a no-op and keep the op in native bf16 when desired?
Reference code path:
https://github.com/pytorch/xla/blob/master/torch_xla/csrc/aten_xla_type.cpp#L187-L211
If there’s a better approach please let me know. Thanks!
❓ Questions and Help
Hi folks, I have a question about the XLA mul op.
When both inputs are bf16, the generated graph converts to f32, performs the multiply, then converts back to bf16. Two questions:
In this case, is the op math type effectively f32 (not bf16)?
If this upcast exists primarily for TPU accuracy/stability, would it be acceptable to gate it behind a flag (e.g., env option) so we can treat that path as a no-op and keep the op in native bf16 when desired?
Reference code path:
https://github.com/pytorch/xla/blob/master/torch_xla/csrc/aten_xla_type.cpp#L187-L211
If there’s a better approach please let me know. Thanks!