AMP for TPUs#4740
Conversation
|
Do you have a corresponding upstream pr? I saw you use |
| loss_fn = nn.NLLLoss() | ||
| scaler = GradScaler(use_zero_grad=FLAGS.use_zero_grad) | ||
| if device_hw == 'TPU': | ||
| autocast = torch.xla.amp.autocast |
There was a problem hiding this comment.
Question: can you point me to the mapping the explains what op is autocasted and what op is not?
There was a problem hiding this comment.
The op list is registered on the pytorch side here: https://github.com/pytorch/pytorch/blob/8455aac198e4ee68003f2b38a5e15631fb82690c/aten/src/ATen/autocast_mode.cpp#L520
|
@chandrasekhard2 @cowanmeg who is exploring the ResNet model test? is this part of this PR or a separate PR? |
I will send a PR today. Thanks! |
|
@chandrasekhard2 @cowanmeg can you please share a few words on how the AMP perf looks like at the moment? it's fine if we need more work to improve numbers. |
|
nit: I assume this PR needs your PyTorch changes -- you can pin the PyTorch PR in this PR so it builds with your PyTorch PR. Example of PyTorch pinning here. |
|
@cowanmeg can we do a operator by operator mapping of the current model vs. what we see in HLO to verify what we added to the upstream pytorch for AMP maps to what emerges on the HLO side? |
| elif device_hw == 'GPU': | ||
| autocast = torch.cuda.amp.autocast | ||
| # GradScaler only used for GPU | ||
| scaler = GradScaler(use_zero_grad=FLAGS.use_zero_grad) |
There was a problem hiding this comment.
Why don't we use a scaler for TPU?
There was a problem hiding this comment.
That comment was a bit misleading. GradScaler is necessary for float16 loss but not bfloat16. float16 has less dynamic range than float32 since it uses only 5 exponent bits vs 8 exponent bits, so the loss needs to be scaled to prevent small values from disappearing. bfloat16 has the same dynamic range as float32 so doesn't have this problem.
There was a problem hiding this comment.
That means TPU uses bfloat16 by default?
|
|
|
||
| // CastPolicy::lower_precision_fp General_DeviceType | ||
| template<DeviceType device_type, class Redispatch, Redispatch* F, class Ret, class... Args> | ||
| struct WrapFunction_<CastPolicy::lower_precision_fp, device_type, Redispatch, F, Ret, guts::typelist::typelist<Args...>> { |
There was a problem hiding this comment.
wondering how this file is being used in the PR
There was a problem hiding this comment.
These are templates that help dispatch an op to its lower precision/fp32/promoted version.
There was a problem hiding this comment.
Is the proper place for this file?
|
Could you add a few unit tests to the C++ or python side to exercise amp for a simple case, e.g. a single operation? |
| if device_hw == 'TPU': | ||
| scaler = None | ||
| elif device_hw == 'GPU': | ||
| scaler = GradScaler(use_zero_grad=FLAGS.use_zero_grad) |
There was a problem hiding this comment.
Now we have a full AMP story, I think we should add a doc under https://github.com/pytorch/xla/tree/master/docs to explain how to use it on TPU and GPU. We should also mentioned tricks like GradScaler here. WDYT?
I had a https://github.com/pytorch/xla/blob/master/docs/gpu.md#amp-automatic-mixed-precision here but I think it is better to have a standalone doc we can refer to during release.
There was a problem hiding this comment.
I agree, will add some documentation.
| enabled=enabled, | ||
| dtype=torch.float16, | ||
| cache_enabled=cache_enabled) | ||
| else: |
There was a problem hiding this comment.
hmm so CPU and TPU shared the same autocast rule here?
JackCaoG
left a comment
There was a problem hiding this comment.
Mostly LGTM, this is exciting!
I have one ux problem, if user do torch.amp.autocast('xla') on GPU device, is there anyway for us to throw a warning and get them to either use torch_xla's autocast or use cuda directly?
That code will be in Pytorch, so I don't think we can generate a warning since there is no way distinguish XLA:GPU vs. XLA:TPU that I am aware of. |
JackCaoG
left a comment
There was a problem hiding this comment.
LGTM! Let's follow up with the doc update in next pr.
With pytorch/xla#5148, pytorch/xla#4740 With these changes XLA:GPU users should use `torch.cuda.amp.autocast()` for AMP with float16 XLA:TPU users should use `torch.amp.autocast('xla')` for AMP with bfloat16 Pull Request resolved: #96370 Approved by: https://github.com/bdhirsh, https://github.com/malfet
With pytorch/pytorch#96370
Enable AMP on TPUs with bfloat16.
Currently,
torch_xla.amp.autocast(args...)aliasestorch.cuda.amp.autocast(args...)This change proposes updating this to
torch_xla.amp.autocast(device, args...)which will call the appropriate autocast depending on the XLA device.Alternatively, users can call the appropriate autocast directly with
torch.cuda.amp.autocast(args...)for XLA:GPU devices andtorch.amp.autocast('xla', args...)for XLA:TPU devices.HLO dump of MNIST model from test/test_train_mp_mnist_amp.py
HLO dump with AMP