Skip to content

returning tensors of dtype torch.float8_e8m0fnu should work with torchinductor #147873

@vkuzo

Description

@vkuzo

🐛 Describe the bug

We should make sure the following works:

        import torch

        dtype = torch.float8_e8m0fnu
        device = "cuda"

        def foo(x0):
            x1 = x0 + 1
            x2 = x1.view(dtype)
            return x2

        x0 = torch.randint(0, 255, (16, 16), device=device, dtype=torch.uint8)
        foo_c = torch.compile(foo, backend="inductor", fullgraph=True)

        with torch.no_grad():
            y_c = foo_c(x0)

This is important for the PT2 support of MX workflows (tracked in pytorch/ao#556). Specifically, once this functionality exists, a user would be able to write a scaling+casting kernel for MX and output the scales directly in the e8m0 dtype, instead of having to output in uint8 and view as e8m0 afterwards.

Versions

main branch

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov

Metadata

Metadata

Assignees

Labels

module: inductoroncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions