The array API specification stipulates clear type promotion rules, that are independent of the array size and values:

PyTorch mostly adheres to this with one exception: Within a dtype category (integral, floating, complex) 0d-tensors do not participate in type promotion:
import torch
dtype_categories = (
(torch.int8, torch.uint8, torch.int32, torch.int64),
(torch.float16, torch.bfloat16, torch.float32, torch.float64),
(torch.complex32, torch.complex64, torch.complex128),
)
for dtypes in dtype_categories:
for idx in range(len(dtypes) - 1):
dtype_nd = dtypes[idx]
for dtype_0d in dtypes[idx + 1:]:
a = torch.empty((1,), dtype=dtype_nd)
b = torch.empty((), dtype=dtype_0d)
print(f"{a.dtype} + {b.dtype} = {torch.result_type(a, b)}")
torch.int8 + torch.uint8 = torch.int8
torch.int8 + torch.int32 = torch.int8
torch.int8 + torch.int64 = torch.int8
torch.uint8 + torch.int32 = torch.uint8
torch.uint8 + torch.int64 = torch.uint8
torch.int32 + torch.int64 = torch.int32
torch.float16 + torch.bfloat16 = torch.float16
torch.float16 + torch.float32 = torch.float16
torch.float16 + torch.float64 = torch.float16
torch.bfloat16 + torch.float32 = torch.bfloat16
torch.bfloat16 + torch.float64 = torch.bfloat16
torch.float32 + torch.float64 = torch.float32
torch.complex32 + torch.complex64 = torch.complex32
torch.complex32 + torch.complex128 = torch.complex32
torch.complex64 + torch.complex128 = torch.complex64
This is not documented well(see #58489), but seems to be intentional.
cc @nairbv @mruberry
The array API specification stipulates clear type promotion rules, that are independent of the array size and values:
PyTorch mostly adheres to this with one exception: Within a dtype category (integral, floating, complex) 0d-tensors do not participate in type promotion:
This is not documented well(see #58489), but seems to be intentional.
cc @nairbv @mruberry