Conversation
JackCaoG
approved these changes
Jun 9, 2023
Collaborator
JackCaoG
left a comment
There was a problem hiding this comment.
LGTM! Let's add autocast test to TPUCI in https://github.com/pytorch/xla/blob/master/test/tpu/xla_test_job.yaml as well, so we can test the TPU amp.
JackCaoG
added a commit
that referenced
this pull request
Jun 9, 2023
pytorchmergebot
pushed a commit
to pytorch/pytorch
that referenced
this pull request
Jun 23, 2023
With pytorch/xla#5148, pytorch/xla#4740 With these changes XLA:GPU users should use `torch.cuda.amp.autocast()` for AMP with float16 XLA:TPU users should use `torch.amp.autocast('xla')` for AMP with bfloat16 Pull Request resolved: #96370 Approved by: https://github.com/bdhirsh, https://github.com/malfet
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Retrying PR #4740
With pytorch/pytorch#96370
Enable AMP on TPUs with bfloat16.
Currently, torch_xla.amp.autocast(args...) aliases torch.cuda.amp.autocast(args...)
This change proposes updating this to torch_xla.amp.autocast(device, args...) which will call the appropriate autocast depending on the XLA device.
Alternatively, users can call the appropriate autocast directly with torch.cuda.amp.autocast(args...) for XLA:GPU devices and torch.amp.autocast('xla', args...) for XLA:TPU devices.
HLO dump of MNIST model from test/test_train_mp_mnist_amp.py
ENTRY %IrToHlo.149 (p0.1: f32[10], p1.2: f32[10,50], p2.4: f32[50], p3.5: f32[50,320], p4.7: f32[20], p5.8: f32[20], p6.9: f32[20], p7.10: f32[20], p8.11: f32[20], p9.12: f32[20,10,5,5], p10.13: f32[10], p11.14: f32[10], p12.15: f32[10], p13.16: f32[10], p14.17: f32[10], p15.18: f32[10,1,5,5], p16.19: f32[128,1,28,28]) -> (f32[128,10]) {
...
%p16.19 = f32[128,1,28,28]{0,3,2,1} parameter(16)
%p15.18 = f32[10,1,5,5]{0,3,2,1} parameter(15)
%convolution.20 = f32[128,10,24,24]{3,2,1,0} convolution(f32[128,1,28,28]{0,3,2,1} %p16.19, f32[10,1,5,5]{0,3,2,1} %p15.18), window={size=5x5}, dim_labels=bf01_oi01->bf01
%p14.17 = f32[10]{0} parameter(14)
%broadcast.21 = f32[128,24,24,10]{3,2,1,0} broadcast(f32[10]{0} %p14.17), dimensions={3}
%transpose.22 = f32[128,10,24,24]{1,3,2,0} transpose(f32[128,24,24,10]{3,2,1,0} %broadcast.21), dimensions={0,3,1,2}
%add.23 = f32[128,10,24,24]{3,2,1,0} add(f32[128,10,24,24]{3,2,1,0} %convolution.20, f32[128,10,24,24]{1,3,2,0} %transpose.22)
...
%broadcast.136 = f32[128,10]{1,0} broadcast(f32[128]{0} %reduce.135), dimensions={0}
%subtract.137 = f32[128,10]{1,0} subtract(f32[128,10]{1,0} %add.129, f32[128,10]{1,0} %broadcast.136)
%exponential.138 = f32[128,10]{1,0} exponential(f32[128,10]{1,0} %subtract.137)
%constant.139 = f32[] constant(0)
%reduce.144 = f32[128]{0} reduce(f32[128,10]{1,0} %exponential.138, f32[] %constant.139), dimensions={1}, to_apply=%AddComputation.140
%log.145 = f32[128]{0} log(f32[128]{0} %reduce.144)
%broadcast.146 = f32[128,10]{1,0} broadcast(f32[128]{0} %log.145), dimensions={0}
%subtract.147 = f32[128,10]{1,0} subtract(f32[128,10]{1,0} %subtract.137, f32[128,10]{1,0} %broadcast.146)
ROOT %tuple.148 = (f32[128,10]{1,0}) tuple(f32[128,10]{1,0} %subtract.147)
}
}
HLO dump with AMP
ENTRY %IrToHlo.162 (p0.1: f32[10], p1.3: f32[10,50], p2.6: f32[50], p3.8: f32[50,320], p4.11: f32[20], p5.12: f32[20], p6.13: f32[20], p7.14: f32[20], p8.15: f32[20], p9.17: f32[20,10,5,5], p10.19: f32[10], p11.20: f32[10], p12.21: f32[10], p13.22: f32[10], p14.23: f32[10], p15.25: f32[10,1,5,5], p16.27: f32[128,1,28,28]) -> (bf16[128,10]) {
...
%p16.27 = f32[128,1,28,28]{0,3,2,1} parameter(16)
%convert.28 = bf16[128,1,28,28]{0,3,2,1} convert(f32[128,1,28,28]{0,3,2,1} %p16.27)
%p15.25 = f32[10,1,5,5]{0,3,2,1} parameter(15)
%convert.26 = bf16[10,1,5,5]{0,3,2,1} convert(f32[10,1,5,5]{0,3,2,1} %p15.25)
%convolution.29 = bf16[128,10,24,24]{3,2,1,0} convolution(bf16[128,1,28,28]{0,3,2,1} %convert.28, bf16[10,1,5,5]{0,3,2,1} %convert.26), window={size=5x5}, dim_labels=bf01_oi01->bf01
%p14.23 = f32[10]{0} parameter(14)
%convert.24 = bf16[10]{0} convert(f32[10]{0} %p14.23)
%broadcast.30 = bf16[128,24,24,10]{3,2,1,0} broadcast(bf16[10]{0} %convert.24), dimensions={3}
%transpose.31 = bf16[128,10,24,24]{1,3,2,0} transpose(bf16[128,24,24,10]{3,2,1,0} %broadcast.30), dimensions={0,3,1,2}
%add.32 = bf16[128,10,24,24]{3,2,1,0} add(bf16[128,10,24,24]{3,2,1,0} %convolution.29, bf16[128,10,24,24]{1,3,2,0} %transpose.31)
...
%broadcast.149 = bf16[128,10]{1,0} broadcast(bf16[128]{0} %reduce.148), dimensions={0}
%subtract.150 = bf16[128,10]{1,0} subtract(bf16[128,10]{1,0} %add.142, bf16[128,10]{1,0} %broadcast.149)
%exponential.151 = bf16[128,10]{1,0} exponential(bf16[128,10]{1,0} %subtract.150)
%constant.152 = bf16[] constant(0)
%reduce.157 = bf16[128]{0} reduce(bf16[128,10]{1,0} %exponential.151, bf16[] %constant.152), dimensions={1}, to_apply=%AddComputation.153
%log.158 = bf16[128]{0} log(bf16[128]{0} %reduce.157)
%broadcast.159 = bf16[128,10]{1,0} broadcast(bf16[128]{0} %log.158), dimensions={0}
%subtract.160 = bf16[128,10]{1,0} subtract(bf16[128,10]{1,0} %subtract.150, bf16[128,10]{1,0} %broadcast.159)
ROOT %tuple.161 = (bf16[128,10]{1,0}) tuple(bf16[128,10]{1,0} %subtract.160)
}