Functionalities like #9177 are immensely helpful to load a checkpoint in say, torch.float8_e5m2, perform computation in say, torch.float16, and then keep the result in torch.float8_e5m2 again.
Even though this feature isn't immediately compatible with torch.compile() and we're unsure of its repercussions, we think it's still better to just have them as experimental APIs because the memory benefits are significant.
Cc: @vladmandic as you expressed interest for this.
Cc: @a-r-r-o-w @SunMarc as we discussed it in-person.
Cc: @DN6 because #9177 is his brainchild.