Skip to content

Resolved: Only add type promotion support to unary pwise, binary pwise, and reduction operations #56356

@mruberry

Description

@mruberry

This issue suggests that support for type promotion should only be added to unary pointwise, binary pointwise and reduction operations for now. Note that it does not suggest removing type promotion support from other operations, just that PyTorch should stop accepting hypothetical PR's like "Adds type promotion to torch.matmul" until some challenges with type promotion can be resolved. The motivation for this suggestion is:

  • PyTorch's current type promotion support is scattershot, and it's hard for users to know which operations do or don't support type promotion
  • while type promotion is natural for some operations, the value of adding type promotion to all PyTorch operations seems very small
  • type promotion can be confusing to users and a source of bugs, especially when writing performance-sensitive neural networks
  • type promotion is often nontrivial to implement, may regress performance, and requires additional testing

The rest of this issue briefly discusses what type promotion is, elaborates on each of these four points, and then summarizes.

What is Type Promotion?

"Type promotion" occurs when an operation converts one or more of its inputs from its original {d}type to a "promoted" {d}type before performing its computation. Type promotion is often intuitive and regularly occurs in programming languages like C++ (see here); for example when adding a float and a double the float is "promoted" to double before the computation occurs in double precision. C++ also has a separate concept of "type conversion," but for purposes of PyTorch UX we can think of all dtype promotions/conversions as "type promotion."

Type promotion is often as natural for a tensor library as it is for a programming language. NumPy consistently implements type promotion for its operators, and this allows arrays of different dtypes to be added, subtracted, multiplied, or even matrix multiplied together. JAX also implements NumPy-like type promotion for operations in jax.numpy.

PyTorch's Type Promotion Behavior is Scattershot

Many operations in PyTorch, like torch.add, implement type promotion. But while add supports tensors with different dtypes torch.matmul, another binary operation, doesn't.

a = torch.randn((2, 2,))
b = torch.randn((2, 2,), dtype=torch.float16)

a @ b
: RuntimeError: expected scalar type Float but found Half

a.numpy() @ b.numpy()
: array([[ 0.01677812, -0.15937534], [ 0.73216796,  1.5073942 ]], dtype=float32)

Even when looking at extremely similar operations, like inplace addition and torch.Tensor.index_put_, one often supports type promotion while the other doesn't:

# inplace addition supports type promotion
a += b

# torch.Tensor.index_put_, conceptually an almost identical operation with accumulate=True, doesn't
a.index_put_((torch.tensor((0, 1)),), b, accumulate=True)
: RuntimeError: Index put requires the source and destination dtypes match, got Float for the destination and Half for the source.

There's no public principle for why PyTorch supports type promotion on some operations and not others, and the lack of clear, easy to understand UX principles can make PyTorch confusing and frustrating to use. We should endeavor to have a UX that is familiar, consistent, and empowering, and this "sometimes we support type promotion and sometimes we don't" state is not consistent.

Adding Universal Type Promotion Support isn't that Interesting

One option to make PyTorch's UX more consistent and more compatible with NumPy would be to universally support type promotion. Universal type promotion, however, is not very interesting. Consider the value of type promotion for binary pointwise operations, one of its most natural use cases. At its best, type promotion saves users from having to cast one or more input tensors to a different dtype:

a = torch.randn((2, 2,))
b = torch.randn((2, 2,), dtype=torch.float16)

# your type promotion tax dollars at work
a + b
: tensor([[-0.8938,  2.0655], [-1.2176,  0.8563]])

# the simple type promotion workaround
a + b.float()
: tensor([[-0.8938,  2.0655], [-1.2176,  0.8563]])

# a programmatic workaround
compute_type = torch.result_type(a, b)
a.to(compute_type) + b.to(compute_type)

For simple binary operations, like the binary pointwise operations, type promotion is convenient. But, as the next two sections discuss, the type promotion tax can be significant, and this convenience is not compelling enough to pay it.

Type Promotion can be a Source of Confusion and Bugs

Supporting type promotion for relatively simple mathematical operations, like add, is consistent with languages like C++. Supporting it for operations like matmul, however, might be consistent with NumPy and JAX but a possible source of confusion and bugs for PyTorch users developing neural networks.

If a PyTorch user is using mixed precision, for example, then accidentally multiplying a float16 matrix and a float32 matrix together is a real possibility when type fidelity is desired. The ability to multiply float16 matrices by float32 matrices may also suggest to some users that PyTorch is using special kernels that support the operation, and not simply upcasting the float16 matrix to float32 before performing the operation.

Even on operations like torch.cat and torch.stack, type promotion support means fewer error checks. Maybe the user did want to concatenate a float32 and a float16 tensor and create a float32 result, or maybe they accidentally included a single float32 tensor in a torch.cat call and now they have to debug why they're no longer running in float16. Throwing an error at the source of the type mismatch would have prevented this issue.

Maybe in the future PyTorch will support a "performance mode" or a "strict typing" mode, analogous to its deterministic algorithms mode that will better support developing performance-critical programs, but today we don't have such a mode.

Type promotion is often nontrivial to implement, may regress performance, and requires additional testing

Let's consider implementing type promotion for a hypothetical ternary operation, like torch.lerp. We can look at TensorIterator's compute_types()

void TensorIteratorBase::compute_types(const TensorIteratorConfig& config) {

and compute_common_dtype()

ScalarType TensorIteratorBase::compute_common_dtype() {

for an idea of the work required and the impact of that work.

Basically we'd do the following:

  • enumerate the tensor inputs and verify they all have the same dtype (fast path) or have different dtypes and require type promotion
  • wrap the type promotion functions like compute_common_dtype() in TensorIterator does
  • acquire a common dtype
  • cast the input tensors to that common dtype

This is not the end of the world, but it's not nothing, either, and doing it for every operator in PyTorch is daunting. Plus, there are additional hidden costs:

  • a test for the ternary operations type promotion must be added
  • downstream projects, like the JIT and NNC, may need to replicate the type promotion logic

So while we've saved the hypothetical user who wants to mix and match their torch.lerp dtypes the pain of having to cast their tensors, we've incurred some nontrivial engineering work and maintenance.

In the future PyTorch may have better type promotion architecture that readily and performantly supports type promotion on every operation, but this feature is nontrivial to build. Some operations have finicky type promotion rules, like torch.Tensor.index_put_ (above), which wants to promote one tensor to the dtype of another, but preserve the dtype of an index tensor. Either that logic needs to be handwritten, a generic type promotion that most ops can use developed, or improved compiler technology that supports writing more operations performantly as composites deployed.

Summary

Type promotion can be a convenience win, and in some cases it's natural and expected. PyTorch's support for type promotion is "scattershot," however, and an inconsistent UX is confusing and frustrating for users. One option to address this inconsistency would be to pursue universal type promotion support, but this is a considerable engineering effort with, at best, marginal benefits, and some real UX risks (do we seriously want to approve PRs adding type promotion support for operations like matmul?). Another option would be to clarify which type of operations support type promotion.

I propose we support type promotion on three types of operations: unary pointwise, binary pointwise, and reductions. This is a good set because:

  • type promotion is natural for and most intuitive with operations in these classes, so there's little UX risk
  • users can easily understand these types of operations; that is, they can tell if an operation is a "binary pointwise" operation or not
  • these operations are so uniform that either we already have tools to help implement type promotion for them or such tools are easy to develop
  • PyTorch already implements type promotion for most, if not all, unary and binary pointwise functions, and adding support for reductions is straightforward
  • PyTorch already has automated testing for unary pointwise operations, and it plans to add automated testing for binary pointwise operations and reductions, so type promotion can be tested systematically with no additional cost to developers

Second, I propose we put a moratorium on adding custom type promotion support to other operators. This moratorium can be revisited after reviewing the UX, engineering, and performance issues associated with broader type promotion support. Note that this moratorium does mean removing type promotion support from any operators that currently support it.

cc @nairbv @mruberry

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: type promotionRelated to semantics of type promotionmodule: uxtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions