[benchmarks] Fix AMP setup for torchbench models.#7067
Conversation
|
Confirmed it also fixes #6833. |
|
Hmm according to https://github.com/pytorch/xla/blob/master/docs/amp.md we should be able to use |
|
That document is correct. Problem is that I didn't notice XLA:CUDA is supposed to run with CUDA autocast, i.e. |
| # https://github.com/pytorch/xla/issues/6511 | ||
| if self.is_accelerator_cuda(): | ||
| # For inductor and XLA:CUDA, we use CUDA autocast. | ||
| autocast = torch.cuda.amp.autocast |
There was a problem hiding this comment.
I guess torch.cuda.amp.autocast is the same as torch.amp.autocast("cuda")?
| # https://github.com/pytorch/xla/issues/6511 | ||
| if self.is_accelerator_cuda(): | ||
| # For inductor and XLA:CUDA, we use CUDA autocast. | ||
| autocast = torch.cuda.amp.autocast |
There was a problem hiding this comment.
do you need to set kwargs["device_type"] = "xla" for XLA:GPU case?
There was a problem hiding this comment.
Not really. torch.cuda.amp.autocast already does that.
Fix: #6556 (and, possibly #6833)
This PR fixes the benchmarks script when running with AMP. Previously, we were calling
torch.amp.autocast(..., device_type="xla")for both XLA:CUDA and XLA:TPU. However, we should be usingtorch.cuda.amp.autocastfor XLA:CUDA (see this for more details).Context: after #6518,
Super_Slomoinference started being run using AMP. However, due to #6511, that PR tried to mimictorch_xla.amp.autocastbehavior, usingtorch.amp.autocast.cc @miladm @JackCaoG @vanbasten23 @zpcore