Skip to content

[Functionalization] Some ops need additional meta tensor support after functionalization #92916

@wonjoo-wj

Description

@wonjoo-wj

🐛 Describe the bug

Summary

With functionalization enabled, PyTorch/XLA saw new test failures due to ops requiring additional meta tensor support. The ops that we saw these errors are:

  • aten::_amp_foreach_non_finite_check_and_unscale_
  • aten::nan_to_num.out

The entire error logs for for one of these ops are:

C++ exception with description "Could not run 'aten::_amp_foreach_non_finite_check_and_unscale_' with arguments from the 'Meta' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::_amp_foreach_non_finite_check_and_unscale_' is only available for these backends: [XLA, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradHIP, AutogradXLA, AutogradMPS, AutogradIPU, AutogradXPU, AutogradHPU, AutogradVE, AutogradLazy, AutogradMeta, AutogradMTIA, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, AutogradNestedTensor, Tracer, AutocastCPU, AutocastCUDA, FuncTorchBatched, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PythonDispatcher].

Steps to reproduce

For aten::_amp_foreach_non_finite_check_and_unscale_:

import torch
import functorch
def test():
    self_tensor = torch.tensor([1, 2, 3, 4])
    found_inf = torch.tensor(0)
    inv_scale = torch.tensor(0.2)
    print(torch._amp_foreach_non_finite_check_and_unscale_([self_tensor], found_inf, inv_scale))

functorch.functionalize(test)()

Output:

/opt/conda/lib/python3.8/site-packages/torch/_functorch/deprecated.py:93: UserWarning: We've integrated functorch into PyTorch. As the final step of the integration, functorch.functionalize is deprecated as of PyTorch 2.0 and will be deleted in a future version of PyTorch >= 2.3. Please use torch.func.functionalize instead; see the PyTorch 2.0 release notes and/or the torch.func migration guide for more details https://pytorch.org/docs/master/func.migrating.html
  warn_deprecated('functionalize')
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/opt/conda/lib/python3.8/site-packages/torch/_functorch/vmap.py", line 39, in fn
    return f(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/_functorch/eager_transforms.py", line 1582, in wrapped
    func_outputs = func(*func_args, **func_kwargs)
  File "<stdin>", line 5, in test
NotImplementedError: Could not run 'aten::_amp_foreach_non_finite_check_and_unscale_' with arguments from the 'Meta' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::_amp_foreach_non_finite_check_and_unscale_' is only available for these backends: [XLA, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradHIP, AutogradXLA, AutogradMPS, AutogradIPU, AutogradXPU, AutogradHPU, AutogradVE, AutogradLazy, AutogradMeta, AutogradMTIA, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, AutogradNestedTensor, Tracer, AutocastCPU, AutocastCUDA, FuncTorchBatched, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PythonDispatcher].

Versions

Nightly

cc @bdhirsh @ezyang @eellison @soumith

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: functionalizationused for issues that are specific to functionalization (AOTAutograd bugs should start w aotdispatch)module: meta tensorsmodule: xlaRelated to XLA supporttriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions