torchdynamo and xla integration#87741
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/87741
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 FailuresAs of commit 2390063: The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
|
681485d to
b985542
Compare
ef82c2a to
e2deb29
Compare
wconstab
left a comment
There was a problem hiding this comment.
looks good, but i do think it is worth reconsidering how the UX of moving tensors to XLA device is implemented. I'm ok landing it and then updating if that unblocks XLA team in the mean-time.
| eager_dev = inputs[0].device | ||
| # We assume the passed in mod is already on xla device | ||
| # so we dont need: xla_mod = copy.deepcopy(mod).to(device=xla_dev) | ||
| xla_mod = mod |
There was a problem hiding this comment.
maybe rename 'mod' above to xla_mod, and instead of the comment here you could do an assert?
| # We assume the passed in mod is already on xla device | ||
| # so we dont need: xla_mod = copy.deepcopy(mod).to(device=xla_dev) | ||
| xla_mod = mod | ||
| xla_inputs = tree_map(lambda x: x.to(device=xla_dev), inputs) |
There was a problem hiding this comment.
is 'wrapper' going to stand in for model_iter_fn later in the timing script? so that means we will include .to/from(xla) transfers in our timing?
There was a problem hiding this comment.
yes, inputs.to(lazy_device) and outputs.to(eager_device) are included in the timing for baseline. Those overhead are included in the timing for torchxla_trace_once backend as well.
|
|
||
| def xla_model_wrapper(*inputs): | ||
| orig_device = inputs[0].device if len(inputs) > 0 else "cpu" | ||
| xla_inputs = tuple(inp.to(device=xla_dev) for inp in inputs) |
There was a problem hiding this comment.
same question as above;
i'm wondering if the best UX is to handle to/from XLA device transparently like this, or to require users up front pick the XLA device.
The tradeoff would be that the user has to pick XLA up front, but then they would be more in control of when the device copy happens.
There was a problem hiding this comment.
So basically you mean wrap the model_iter_fn once (to handle the cross device movements) so we don't need to do it in any other places?
I actually tried this idea since I agree that makes the code cleaner. But the problem I encountered is, the wrapped model_iter_fn fail to be compiled by dynamo. I guess there is some incompatibility if we apply XLA and dynamo (graph capture part rather than the backend) together.
| def __call__(self, args): | ||
| real_input = [] | ||
| for tensor_id, traced_ivalue in zip( | ||
| self.graph_input_tensor_ids, self.graph_input_ivalues |
There was a problem hiding this comment.
hmm is '_ivalues' here also an overloaded term to mean 'xla_values' in the xla case?
| ) | ||
| xla_args_need_update = [] | ||
| arg_index_to_need_update_index = {} | ||
| for i, nede_update in enumerate(xla_args_need_update_bool): |
|
|
||
| ( | ||
| graph_input_tensor_ids, | ||
| graph_input_ivalues, |
There was a problem hiding this comment.
yea i think better to rename 'ivalues' to 'backend_handles' or something more generic?
There was a problem hiding this comment.
haha, call it ivalue since we return IValue in the pybind API:
m.def("_get_tensors_xla_device_data_node",
[](const std::vector<at::Tensor>& tensors)
-> std::pair<std::vector<int64_t>, std::vector<at::IValue>> {
There was a problem hiding this comment.
backend_handles would be more suitable for some opaque lookup keys which will be interpreted somewhere else (i.e. in XLA)?
Just double check. I'm ok to rename though.
| tensor_id_to_arg_idx, graph_input_tensor_ids, graph_input_ivalues | ||
| ) | ||
|
|
||
| # sync xla tensors |
There was a problem hiding this comment.
might be worth changing comment here to explain to non-xla familiar folks:
compiles+runs graph rooted at tensors in 'args_and_out'
|
TPU v4-8 single device speed up is
on bert and vit the speed up is very significant. It is puzzling why alexnet and squeezenet1_1 is slower with dynamo. |
| @@ -0,0 +1,165 @@ | |||
| import copy | |||
There was a problem hiding this comment.
I'd suggest just putting this file under torch/_dynamo/optimizations/torchxla_integration.py, unless you expect it to grow large enough to be many files.
There was a problem hiding this comment.
I can move it. The file size should still be manageable
| @@ -0,0 +1,127 @@ | |||
| import copy | |||
There was a problem hiding this comment.
Let's put test files in the top level test/dynamo directory, with the right skips if XLA isn't installed.
Try |
e2deb29 to
975cafb
Compare
50cc856 to
9b0ad23
Compare
|
@pytorchbot merge -f "CI clean except some CLA authorization issue" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: The following mandatory check(s) failed (Rule Dig deeper by viewing the failures on hud Details for Dev Infra teamRaised by workflow job |
|
/easycla |
284973e to
6b485c6
Compare
This PR needs a labelIf your changes are user facing and intended to be a part of release notes, please use a label starting with If not, please add the For more information, see https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work. |
Example command ``` GPU_NUM_DEVICES=1 python benchmarks/dynamo/torchbench.py --randomize-input --performance --use-xla-baseline --only resnet18 --backend=torchxla_trace_once ```
6b485c6 to
2390063
Compare
|
The failed job ' Looks like some TORCH_CHECK fail for quantization and cause the testing process abort. I think it's not related to the PR. |
|
@pytorchbot merge -f "The only failed check is not relevant to the PR. Check the previous comment" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
# Motivation - torchdynamo and torchxla uses different strategies to be a sound graph capture technique. The former relies on guards; the latter relies on retracing - guard system is quite low overhead but torchxla tracing overhead is quite high The main idea is to leverage guard system in torchdynamo to avoid retracing in torchxla so that - we can integration torchdynamo with XLA - we reduce or even completely avoid tracing overhead of torchxla # Technique details ## XLA baseline We found that different frameworks do not generate numerically identical results for the SAME model with the SAME input. By default, torchdynamo uses eager as baseline so the model will run with PyTorch. It would be tricky to compare a model running on XLA with this baseline: it's hard to check correctness. To make the comparison easier, we add a flag `--use-xla-baseline`. When it's enabled, the baseline will be run on XLA. ## New dynamo backends added We add 2 new dynamo backends torchxla_trivial and trochxla_trace_once to control the optimization targets. torchxla_trivial simply moves inputs/model parameters to XLA and run the model on XLA. There is tracing overhead for each run. We should expect that result to be mostly neutral compared to the XLA baseline. torchxla_trace_once only traces once during AOT compiling time. Here are the steps: 1. dynamo capture guards and the subgraph 2. torchxla_trace_once backend trace the graph with torchxla, lowering the graph and record a hash of the graph for later lookup 3. at inference time, the hash is used directly to lookup the optimized graph and run it. # Limitations We can not handle LTC/torchxla fall back right now. If a op misses LTC kernel, we raise and exception and that will results in dynamo fallback (or try another compiler). People have brainstormed the idea of graph breaking and stitching the subgraphs together. But maybe it's easier to add those missing LTC kernels for those models. # Results The models we tested are those not causing LTC fallback. We run the tests on **GPU**. We see **1.38x** geomean speedup for trochxla_trace_once and torchxla_trivial is mostly neutral as expected. ``` | Model | XLA (trace once) | XLA (trace everytime) | +=========================+====================+=========================+ | resnet18 | 1.346 | 1.045 | +-------------------------+--------------------+-------------------------+ | resnet50 | 1.153 | 1.007 | +-------------------------+--------------------+-------------------------+ | resnext50_32x4d | 1.381 | 1.039 | +-------------------------+--------------------+-------------------------+ | alexnet | 1.045 | 1.018 | +-------------------------+--------------------+-------------------------+ | mobilenet_v2 | 1.562 | 1.021 | +-------------------------+--------------------+-------------------------+ | mnasnet1_0 | 1.303 | 1.069 | +-------------------------+--------------------+-------------------------+ | squeezenet1_1 | 1.278 | 1.025 | +-------------------------+--------------------+-------------------------+ | vgg16 | 1.076 | 1.008 | +-------------------------+--------------------+-------------------------+ | BERT_pytorch | 2.224 | 0.978 | +-------------------------+--------------------+-------------------------+ | timm_vision_transformer | 1.81 | 1.025 | +-------------------------+--------------------+-------------------------+ | geomean | 1.38101 | 1.02324 | +-------------------------+--------------------+-------------------------+ ``` The speedup is similar to what we see from previous work for LTC's TorchScript backend (we see 1.40 geomean speedup there): https://docs.google.com/presentation/d/1G09X8v41u_cLKLtSdf7v6R8G19-iZTPcW_VAdOnvYBI/edit#slide=id.g11bf989cb6b_1_5 # Next steps - Use AOT autograd to enable training - Share results on XLA devices - Do more extensive tests on torchbench models Example command ``` GPU_NUM_DEVICES=1 python benchmarks/dynamo/torchbench.py --randomize-input --performance --use-xla-baseline --only resnet18 --backend=torchxla_trace_once ``` Thanks @JackCaoG from torchxla team to help debugging various perf issues and merging the torchxla PR! That's super critical for us to get the results above. torchxla side PR: pytorch/xla#4119 topic: not user facing cc @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @jansel Pull Request resolved: pytorch#87741 Approved by: https://github.com/wconstab
In #87741 we added the inference support for dynamo/torchxla integration. Later on in #88449 we attempt to add the training support. That attempt is not smooth because - we try 2 things together 1. let dynamo trace the model on xla rather than eager 2. enable training - It turns out neither of these two tasks are trivial enough. Furthermore, item 2 (enable training) depends on item 1 (tracing on xla). We enable training via AOTAutograd. AOTAutograd lift all model parameters/buffers as graph inputs. Without item 1 being done, we would need copy all graph inputs (including model parameters/buffers) from eager device to xla devices. That hurts performance a lot. Have a cache to map eager parameter to XLA parameter does not solve the problem since the update on either will not sync automatically to the other. They will easily go out of sync. This PR let dynamo trace the model on XLA rather than eager. This is a preparation step to enabling training. Also, tracing on XLA makes the data movement more efficient. We see 1.5x geomean speedup compared to previous 1.38x. ``` +-------------------------+--------------------+-------------------------+ | Model | XLA (trace once) | XLA (trace everytime) | +=========================+====================+=========================+ | resnet18 | 1.38 | 1.008 | +-------------------------+--------------------+-------------------------+ | resnet50 | 1.227 | 0.998 | +-------------------------+--------------------+-------------------------+ | resnext50_32x4d | 1.544 | 1.008 | +-------------------------+--------------------+-------------------------+ | alexnet | 1.085 | 1.045 | +-------------------------+--------------------+-------------------------+ | mobilenet_v2 | 2.028 | 1.013 | +-------------------------+--------------------+-------------------------+ | mnasnet1_0 | 1.516 | 0.995 | +-------------------------+--------------------+-------------------------+ | squeezenet1_1 | 0.868 | 1.01 | +-------------------------+--------------------+-------------------------+ | vgg16 | 1.099 | 1.008 | +-------------------------+--------------------+-------------------------+ | BERT_pytorch | 3.26 | 1.027 | +-------------------------+--------------------+-------------------------+ | timm_vision_transformer | 2.182 | 1.015 | +-------------------------+--------------------+-------------------------+ | geomean | 1.50389 | 1.01261 | +-------------------------+--------------------+-------------------------+ ``` Example command ``` GPU_NUM_DEVICES=1 python benchmarks/dynamo/torchbench.py --randomize-input --performance --trace-on-xla --only resnet18 --backend=torchxla_trace_once ``` Pull Request resolved: #88904 Approved by: https://github.com/wconstab, https://github.com/JackCaoG, https://github.com/jansel
# Motivation - torchdynamo and torchxla uses different strategies to be a sound graph capture technique. The former relies on guards; the latter relies on retracing - guard system is quite low overhead but torchxla tracing overhead is quite high The main idea is to leverage guard system in torchdynamo to avoid retracing in torchxla so that - we can integration torchdynamo with XLA - we reduce or even completely avoid tracing overhead of torchxla # Technique details ## XLA baseline We found that different frameworks do not generate numerically identical results for the SAME model with the SAME input. By default, torchdynamo uses eager as baseline so the model will run with PyTorch. It would be tricky to compare a model running on XLA with this baseline: it's hard to check correctness. To make the comparison easier, we add a flag `--use-xla-baseline`. When it's enabled, the baseline will be run on XLA. ## New dynamo backends added We add 2 new dynamo backends torchxla_trivial and trochxla_trace_once to control the optimization targets. torchxla_trivial simply moves inputs/model parameters to XLA and run the model on XLA. There is tracing overhead for each run. We should expect that result to be mostly neutral compared to the XLA baseline. torchxla_trace_once only traces once during AOT compiling time. Here are the steps: 1. dynamo capture guards and the subgraph 2. torchxla_trace_once backend trace the graph with torchxla, lowering the graph and record a hash of the graph for later lookup 3. at inference time, the hash is used directly to lookup the optimized graph and run it. # Limitations We can not handle LTC/torchxla fall back right now. If a op misses LTC kernel, we raise and exception and that will results in dynamo fallback (or try another compiler). People have brainstormed the idea of graph breaking and stitching the subgraphs together. But maybe it's easier to add those missing LTC kernels for those models. # Results The models we tested are those not causing LTC fallback. We run the tests on **GPU**. We see **1.38x** geomean speedup for trochxla_trace_once and torchxla_trivial is mostly neutral as expected. ``` | Model | XLA (trace once) | XLA (trace everytime) | +=========================+====================+=========================+ | resnet18 | 1.346 | 1.045 | +-------------------------+--------------------+-------------------------+ | resnet50 | 1.153 | 1.007 | +-------------------------+--------------------+-------------------------+ | resnext50_32x4d | 1.381 | 1.039 | +-------------------------+--------------------+-------------------------+ | alexnet | 1.045 | 1.018 | +-------------------------+--------------------+-------------------------+ | mobilenet_v2 | 1.562 | 1.021 | +-------------------------+--------------------+-------------------------+ | mnasnet1_0 | 1.303 | 1.069 | +-------------------------+--------------------+-------------------------+ | squeezenet1_1 | 1.278 | 1.025 | +-------------------------+--------------------+-------------------------+ | vgg16 | 1.076 | 1.008 | +-------------------------+--------------------+-------------------------+ | BERT_pytorch | 2.224 | 0.978 | +-------------------------+--------------------+-------------------------+ | timm_vision_transformer | 1.81 | 1.025 | +-------------------------+--------------------+-------------------------+ | geomean | 1.38101 | 1.02324 | +-------------------------+--------------------+-------------------------+ ``` The speedup is similar to what we see from previous work for LTC's TorchScript backend (we see 1.40 geomean speedup there): https://docs.google.com/presentation/d/1G09X8v41u_cLKLtSdf7v6R8G19-iZTPcW_VAdOnvYBI/edit#slide=id.g11bf989cb6b_1_5 # Next steps - Use AOT autograd to enable training - Share results on XLA devices - Do more extensive tests on torchbench models Example command ``` GPU_NUM_DEVICES=1 python benchmarks/dynamo/torchbench.py --randomize-input --performance --use-xla-baseline --only resnet18 --backend=torchxla_trace_once ``` Thanks @JackCaoG from torchxla team to help debugging various perf issues and merging the torchxla PR! That's super critical for us to get the results above. torchxla side PR: pytorch/xla#4119 topic: not user facing cc @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @jansel Pull Request resolved: pytorch#87741 Approved by: https://github.com/wconstab
…88904) In pytorch#87741 we added the inference support for dynamo/torchxla integration. Later on in pytorch#88449 we attempt to add the training support. That attempt is not smooth because - we try 2 things together 1. let dynamo trace the model on xla rather than eager 2. enable training - It turns out neither of these two tasks are trivial enough. Furthermore, item 2 (enable training) depends on item 1 (tracing on xla). We enable training via AOTAutograd. AOTAutograd lift all model parameters/buffers as graph inputs. Without item 1 being done, we would need copy all graph inputs (including model parameters/buffers) from eager device to xla devices. That hurts performance a lot. Have a cache to map eager parameter to XLA parameter does not solve the problem since the update on either will not sync automatically to the other. They will easily go out of sync. This PR let dynamo trace the model on XLA rather than eager. This is a preparation step to enabling training. Also, tracing on XLA makes the data movement more efficient. We see 1.5x geomean speedup compared to previous 1.38x. ``` +-------------------------+--------------------+-------------------------+ | Model | XLA (trace once) | XLA (trace everytime) | +=========================+====================+=========================+ | resnet18 | 1.38 | 1.008 | +-------------------------+--------------------+-------------------------+ | resnet50 | 1.227 | 0.998 | +-------------------------+--------------------+-------------------------+ | resnext50_32x4d | 1.544 | 1.008 | +-------------------------+--------------------+-------------------------+ | alexnet | 1.085 | 1.045 | +-------------------------+--------------------+-------------------------+ | mobilenet_v2 | 2.028 | 1.013 | +-------------------------+--------------------+-------------------------+ | mnasnet1_0 | 1.516 | 0.995 | +-------------------------+--------------------+-------------------------+ | squeezenet1_1 | 0.868 | 1.01 | +-------------------------+--------------------+-------------------------+ | vgg16 | 1.099 | 1.008 | +-------------------------+--------------------+-------------------------+ | BERT_pytorch | 3.26 | 1.027 | +-------------------------+--------------------+-------------------------+ | timm_vision_transformer | 2.182 | 1.015 | +-------------------------+--------------------+-------------------------+ | geomean | 1.50389 | 1.01261 | +-------------------------+--------------------+-------------------------+ ``` Example command ``` GPU_NUM_DEVICES=1 python benchmarks/dynamo/torchbench.py --randomize-input --performance --trace-on-xla --only resnet18 --backend=torchxla_trace_once ``` Pull Request resolved: pytorch#88904 Approved by: https://github.com/wconstab, https://github.com/JackCaoG, https://github.com/jansel
# Motivation - torchdynamo and torchxla uses different strategies to be a sound graph capture technique. The former relies on guards; the latter relies on retracing - guard system is quite low overhead but torchxla tracing overhead is quite high The main idea is to leverage guard system in torchdynamo to avoid retracing in torchxla so that - we can integration torchdynamo with XLA - we reduce or even completely avoid tracing overhead of torchxla # Technique details ## XLA baseline We found that different frameworks do not generate numerically identical results for the SAME model with the SAME input. By default, torchdynamo uses eager as baseline so the model will run with PyTorch. It would be tricky to compare a model running on XLA with this baseline: it's hard to check correctness. To make the comparison easier, we add a flag `--use-xla-baseline`. When it's enabled, the baseline will be run on XLA. ## New dynamo backends added We add 2 new dynamo backends torchxla_trivial and trochxla_trace_once to control the optimization targets. torchxla_trivial simply moves inputs/model parameters to XLA and run the model on XLA. There is tracing overhead for each run. We should expect that result to be mostly neutral compared to the XLA baseline. torchxla_trace_once only traces once during AOT compiling time. Here are the steps: 1. dynamo capture guards and the subgraph 2. torchxla_trace_once backend trace the graph with torchxla, lowering the graph and record a hash of the graph for later lookup 3. at inference time, the hash is used directly to lookup the optimized graph and run it. # Limitations We can not handle LTC/torchxla fall back right now. If a op misses LTC kernel, we raise and exception and that will results in dynamo fallback (or try another compiler). People have brainstormed the idea of graph breaking and stitching the subgraphs together. But maybe it's easier to add those missing LTC kernels for those models. # Results The models we tested are those not causing LTC fallback. We run the tests on **GPU**. We see **1.38x** geomean speedup for trochxla_trace_once and torchxla_trivial is mostly neutral as expected. ``` | Model | XLA (trace once) | XLA (trace everytime) | +=========================+====================+=========================+ | resnet18 | 1.346 | 1.045 | +-------------------------+--------------------+-------------------------+ | resnet50 | 1.153 | 1.007 | +-------------------------+--------------------+-------------------------+ | resnext50_32x4d | 1.381 | 1.039 | +-------------------------+--------------------+-------------------------+ | alexnet | 1.045 | 1.018 | +-------------------------+--------------------+-------------------------+ | mobilenet_v2 | 1.562 | 1.021 | +-------------------------+--------------------+-------------------------+ | mnasnet1_0 | 1.303 | 1.069 | +-------------------------+--------------------+-------------------------+ | squeezenet1_1 | 1.278 | 1.025 | +-------------------------+--------------------+-------------------------+ | vgg16 | 1.076 | 1.008 | +-------------------------+--------------------+-------------------------+ | BERT_pytorch | 2.224 | 0.978 | +-------------------------+--------------------+-------------------------+ | timm_vision_transformer | 1.81 | 1.025 | +-------------------------+--------------------+-------------------------+ | geomean | 1.38101 | 1.02324 | +-------------------------+--------------------+-------------------------+ ``` The speedup is similar to what we see from previous work for LTC's TorchScript backend (we see 1.40 geomean speedup there): https://docs.google.com/presentation/d/1G09X8v41u_cLKLtSdf7v6R8G19-iZTPcW_VAdOnvYBI/edit#slide=id.g11bf989cb6b_1_5 # Next steps - Use AOT autograd to enable training - Share results on XLA devices - Do more extensive tests on torchbench models Example command ``` GPU_NUM_DEVICES=1 python benchmarks/dynamo/torchbench.py --randomize-input --performance --use-xla-baseline --only resnet18 --backend=torchxla_trace_once ``` Thanks @JackCaoG from torchxla team to help debugging various perf issues and merging the torchxla PR! That's super critical for us to get the results above. torchxla side PR: pytorch/xla#4119 topic: not user facing cc @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @jansel Pull Request resolved: pytorch#87741 Approved by: https://github.com/wconstab
…88904) In pytorch#87741 we added the inference support for dynamo/torchxla integration. Later on in pytorch#88449 we attempt to add the training support. That attempt is not smooth because - we try 2 things together 1. let dynamo trace the model on xla rather than eager 2. enable training - It turns out neither of these two tasks are trivial enough. Furthermore, item 2 (enable training) depends on item 1 (tracing on xla). We enable training via AOTAutograd. AOTAutograd lift all model parameters/buffers as graph inputs. Without item 1 being done, we would need copy all graph inputs (including model parameters/buffers) from eager device to xla devices. That hurts performance a lot. Have a cache to map eager parameter to XLA parameter does not solve the problem since the update on either will not sync automatically to the other. They will easily go out of sync. This PR let dynamo trace the model on XLA rather than eager. This is a preparation step to enabling training. Also, tracing on XLA makes the data movement more efficient. We see 1.5x geomean speedup compared to previous 1.38x. ``` +-------------------------+--------------------+-------------------------+ | Model | XLA (trace once) | XLA (trace everytime) | +=========================+====================+=========================+ | resnet18 | 1.38 | 1.008 | +-------------------------+--------------------+-------------------------+ | resnet50 | 1.227 | 0.998 | +-------------------------+--------------------+-------------------------+ | resnext50_32x4d | 1.544 | 1.008 | +-------------------------+--------------------+-------------------------+ | alexnet | 1.085 | 1.045 | +-------------------------+--------------------+-------------------------+ | mobilenet_v2 | 2.028 | 1.013 | +-------------------------+--------------------+-------------------------+ | mnasnet1_0 | 1.516 | 0.995 | +-------------------------+--------------------+-------------------------+ | squeezenet1_1 | 0.868 | 1.01 | +-------------------------+--------------------+-------------------------+ | vgg16 | 1.099 | 1.008 | +-------------------------+--------------------+-------------------------+ | BERT_pytorch | 3.26 | 1.027 | +-------------------------+--------------------+-------------------------+ | timm_vision_transformer | 2.182 | 1.015 | +-------------------------+--------------------+-------------------------+ | geomean | 1.50389 | 1.01261 | +-------------------------+--------------------+-------------------------+ ``` Example command ``` GPU_NUM_DEVICES=1 python benchmarks/dynamo/torchbench.py --randomize-input --performance --trace-on-xla --only resnet18 --backend=torchxla_trace_once ``` Pull Request resolved: pytorch#88904 Approved by: https://github.com/wconstab, https://github.com/JackCaoG, https://github.com/jansel
Motivation
The main idea is to leverage guard system in torchdynamo to avoid retracing in torchxla so that
Technique details
XLA baseline
We found that different frameworks do not generate numerically identical results for the SAME model with the SAME input. By default, torchdynamo uses eager as baseline so the model will run with PyTorch. It would be tricky to compare a model running on XLA with this baseline: it's hard to check correctness. To make the comparison easier, we add a flag
--use-xla-baseline. When it's enabled, the baseline will be run on XLA.New dynamo backends added
We add 2 new dynamo backends torchxla_trivial and trochxla_trace_once to control the optimization targets.
torchxla_trivial simply moves inputs/model parameters to XLA and run the model on XLA. There is tracing overhead for each run. We should expect that result to be mostly neutral compared to the XLA baseline.
torchxla_trace_once only traces once during AOT compiling time. Here are the steps:
Limitations
We can not handle LTC/torchxla fall back right now. If a op misses LTC kernel, we raise and exception and that will results in dynamo fallback (or try another compiler). People have brainstormed the idea of graph breaking and stitching the subgraphs together. But maybe it's easier to add those missing LTC kernels for those models.
Results
The models we tested are those not causing LTC fallback. We run the tests on GPU. We see 1.38x geomean speedup for trochxla_trace_once and torchxla_trivial is mostly neutral as expected.
The speedup is similar to what we see from previous work for LTC's TorchScript backend (we see 1.40 geomean speedup there):
https://docs.google.com/presentation/d/1G09X8v41u_cLKLtSdf7v6R8G19-iZTPcW_VAdOnvYBI/edit#slide=id.g11bf989cb6b_1_5
Next steps
Example command
Thanks @JackCaoG from torchxla team to help debugging various perf issues and merging the torchxla PR! That's super critical for us to get the results above. torchxla side PR: pytorch/xla#4119
topic: not user facing
cc @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @jansel