Skip to content

Add torch.set_default_int_dtype() / extend set_default_dtype() to allow setting default signed integer dtype #141994

@rpsilva-aws

Description

@rpsilva-aws

🚀 The feature, motivation and pitch

The proposal is adding torch.set_default_int_dtype() to PyTorch, similar to the existing torch.set_default_dtype() for floating-point types [1]. This function would allow users to set the default integer dtype used by PyTorch.

When working with certain datasets or models, customers may not require the full range of 64-bit integers, using 32-bit integers can lead to significant memory savings and potential performance improvements. Most importantly, certain compilers do not support 64-bit data types to begin with, as is the case with Neuron. At the moment, PyTorch currently defaults to 64-bit integers (torch.int64) for many operations. In some cases, this limits and complicates enforcing this limitation, as is the case with TorchXLA - particularly because some tensor operations (e.g. Cast) requires validating that the raw underlying type can be converted between a source and target XLA type. Hence, any type downcasting on XLA is inherently limited by PyTorch.

The proposed torch.set_default_int_dtype() would allow users to easily switch to 32-bit integers (or other integer dtypes) as the default, without having to explicitly specify the dtype in every tensor creation or operation.

The function could work similarly to torch.set_default_dtype():

import torch

# Check current default integer dtype
print(torch.tensor([1, 2, 3]).dtype)  # Output: torch.int64

# Set new default integer dtype
torch.set_default_int_dtype(torch.int32)

# Verify the change
print(torch.tensor([1, 2, 3]).dtype)  # Output: torch.int32

This would enhance PyTorch's flexibility and allow users/components to more easily optimize their code for specific use cases.

In this case, the scope if only for signed integers and not complex (which uses floats).

[1] https://pytorch.org/docs/stable/generated/torch.set_default_dtype.html#torch.set_default_dtype

Alternatives

No response

Additional context

Draft documentation for torch.set_default_int_dtype:

torch.set_default_int_dtype(d, /)

    Sets the default integer dtype to d. Supports integer dtype as inputs. Other dtypes will cause torch to raise an exception.

    When PyTorch is initialized its default integer dtype is torch.int64 (long). The intent of set_default_int_dtype(torch.int32) is to facilitate using 32-bit integers as the default. The default integer dtype is used to:

        1. Determine the dtype for tensors constructed using Python integers. See examples below.
        2. Determine the result of type promotion between bool tensors and Python integers.
        3. Infer the dtype for integer tensors created without an explicit dtype specified.

    Parameters:
        d (torch.dtype) – the integer point dtype to make the default. Must be one of torch.int8, torch.int16, torch.int32, or torch.int64.

    Example::

        >>> torch.tensor([1, 2, 3]).dtype
        torch.int64
        >>> torch.set_default_int_dtype(torch.int32)
        >>> torch.tensor([1, 2, 3]).dtype
        torch.int32

    Warning:
        This function will affect the behavior of all modules and tensors created after it's called. It should be used with caution, preferably at the beginning of a script or program.

    Note:
        This does not affect the default dtype of floating point tensors, which remains controlled by torch.set_default_dtype().

cc @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki

Metadata

Metadata

Assignees

No one assigned

    Labels

    featureA request for a proper, new feature.module: python frontendFor issues relating to PyTorch's Python frontendneeds designWe want to add this feature but we need to figure out how firsttriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions