🐛 Describe the bug
We should make sure the following works:
import torch
dtype = torch.float8_e8m0fnu
device = "cuda"
def foo(x0):
x1 = x0 + 1
x2 = x1.view(dtype)
return x2
x0 = torch.randint(0, 255, (16, 16), device=device, dtype=torch.uint8)
foo_c = torch.compile(foo, backend="inductor", fullgraph=True)
with torch.no_grad():
y_c = foo_c(x0)
This is important for the PT2 support of MX workflows (tracked in pytorch/ao#556). Specifically, once this functionality exists, a user would be able to write a scaling+casting kernel for MX and output the scales directly in the e8m0 dtype, instead of having to output in uint8 and view as e8m0 afterwards.
Versions
main branch
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov
🐛 Describe the bug
We should make sure the following works:
This is important for the PT2 support of MX workflows (tracked in pytorch/ao#556). Specifically, once this functionality exists, a user would be able to write a scaling+casting kernel for MX and output the scales directly in the e8m0 dtype, instead of having to output in uint8 and view as e8m0 afterwards.
Versions
main branch
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov