Skip to content

RFC: enable bfloat16 x float16 -> fp32 type promotion #43049

@mruberry

Description

@mruberry

Recently @xuhdev enabled bfloat16 type promotion in PyTorch with the exception of bfloat16 x float16 promotion, which I asked him to delay implementing so we could discuss it.

I propose we be consistent with JAX and promote bfloat16 x float16 to fp32:

import jax.numpy as jnp
a = jnp.array(1, dtype=jnp.float16)
b = jnp.array(1, dtype=jnp.bfloat16)
a + b
: DeviceArray(2., dtype=float32)

Other reasonable options for this choice are:

  • continue to not support bfloat16 x float16 type promotion and tell users to solve this problem for themselves
  • promote to fp16, which is the more precise of the two dtypes but has less dynamic range than bfloat16
  • promote to bfloat16, which has the greater dynamic range but is less precise

Being consistent with JAX while not sacrificing the precision of fp16 or the dynamic range of bfloat16 seems like the best option to me.

cc @xuhdev @gchanan @nairbv

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNot as big of a feature, but technically not a bug. Should be easy to fixmodule: bfloat16module: halfRelated to float16 half-precision floatsmodule: type promotionRelated to semantics of type promotiontriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions