Skip to content

torchdynamo and xla integration#87741

Closed
shunting314 wants to merge 1 commit intomasterfrom
dynamo-torchxla-integration
Closed

torchdynamo and xla integration#87741
shunting314 wants to merge 1 commit intomasterfrom
dynamo-torchxla-integration

Conversation

@shunting314
Copy link
Copy Markdown
Contributor

@shunting314 shunting314 commented Oct 26, 2022

Motivation

  • torchdynamo and torchxla uses different strategies to be a sound graph capture technique. The former relies on guards; the latter relies on retracing
  • guard system is quite low overhead but torchxla tracing overhead is quite high

The main idea is to leverage guard system in torchdynamo to avoid retracing in torchxla so that

  • we can integration torchdynamo with XLA
  • we reduce or even completely avoid tracing overhead of torchxla

Technique details

XLA baseline

We found that different frameworks do not generate numerically identical results for the SAME model with the SAME input. By default, torchdynamo uses eager as baseline so the model will run with PyTorch. It would be tricky to compare a model running on XLA with this baseline: it's hard to check correctness. To make the comparison easier, we add a flag --use-xla-baseline. When it's enabled, the baseline will be run on XLA.

New dynamo backends added

We add 2 new dynamo backends torchxla_trivial and trochxla_trace_once to control the optimization targets.

torchxla_trivial simply moves inputs/model parameters to XLA and run the model on XLA. There is tracing overhead for each run. We should expect that result to be mostly neutral compared to the XLA baseline.

torchxla_trace_once only traces once during AOT compiling time. Here are the steps:

  1. dynamo capture guards and the subgraph
  2. torchxla_trace_once backend trace the graph with torchxla, lowering the graph and record a hash of the graph for later lookup
  3. at inference time, the hash is used directly to lookup the optimized graph and run it.

Limitations

We can not handle LTC/torchxla fall back right now. If a op misses LTC kernel, we raise and exception and that will results in dynamo fallback (or try another compiler). People have brainstormed the idea of graph breaking and stitching the subgraphs together. But maybe it's easier to add those missing LTC kernels for those models.

Results

The models we tested are those not causing LTC fallback. We run the tests on GPU. We see 1.38x geomean speedup for trochxla_trace_once and torchxla_trivial is mostly neutral as expected.

| Model                   |   XLA (trace once) |   XLA (trace everytime) |
+=========================+====================+=========================+
| resnet18                |            1.346   |                 1.045   |
+-------------------------+--------------------+-------------------------+
| resnet50                |            1.153   |                 1.007   |
+-------------------------+--------------------+-------------------------+
| resnext50_32x4d         |            1.381   |                 1.039   |
+-------------------------+--------------------+-------------------------+
| alexnet                 |            1.045   |                 1.018   |
+-------------------------+--------------------+-------------------------+
| mobilenet_v2            |            1.562   |                 1.021   |
+-------------------------+--------------------+-------------------------+
| mnasnet1_0              |            1.303   |                 1.069   |
+-------------------------+--------------------+-------------------------+
| squeezenet1_1           |            1.278   |                 1.025   |
+-------------------------+--------------------+-------------------------+
| vgg16                   |            1.076   |                 1.008   |
+-------------------------+--------------------+-------------------------+
| BERT_pytorch            |            2.224   |                 0.978   |
+-------------------------+--------------------+-------------------------+
| timm_vision_transformer |            1.81    |                 1.025   |
+-------------------------+--------------------+-------------------------+
| geomean                 |            1.38101 |                 1.02324 |
+-------------------------+--------------------+-------------------------+

The speedup is similar to what we see from previous work for LTC's TorchScript backend (we see 1.40 geomean speedup there):
https://docs.google.com/presentation/d/1G09X8v41u_cLKLtSdf7v6R8G19-iZTPcW_VAdOnvYBI/edit#slide=id.g11bf989cb6b_1_5

Next steps

  • Use AOT autograd to enable training
  • Share results on XLA devices
  • Do more extensive tests on torchbench models

Example command

GPU_NUM_DEVICES=1 python benchmarks/dynamo/torchbench.py --randomize-input --performance --use-xla-baseline --only resnet18 --backend=torchxla_trace_once

Thanks @JackCaoG from torchxla team to help debugging various perf issues and merging the torchxla PR! That's super critical for us to get the results above. torchxla side PR: pytorch/xla#4119

topic: not user facing

cc @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @jansel

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Oct 26, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/87741

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 Failures

As of commit 2390063:

The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@linux-foundation-easycla
Copy link
Copy Markdown

linux-foundation-easycla Bot commented Oct 26, 2022

CLA Signed

The committers listed above are authorized under a signed CLA.

@JackCaoG JackCaoG mentioned this pull request Oct 26, 2022
@shunting314 shunting314 force-pushed the dynamo-torchxla-integration branch from 681485d to b985542 Compare October 26, 2022 20:39
@shunting314 shunting314 requested a review from jansel October 26, 2022 20:41
@shunting314 shunting314 self-assigned this Oct 26, 2022
@shunting314 shunting314 requested a review from wconstab October 26, 2022 20:42
@shunting314 shunting314 force-pushed the dynamo-torchxla-integration branch 2 times, most recently from ef82c2a to e2deb29 Compare October 27, 2022 00:03
Copy link
Copy Markdown
Contributor

@wconstab wconstab left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good, but i do think it is worth reconsidering how the UX of moving tensors to XLA device is implemented. I'm ok landing it and then updating if that unblocks XLA team in the mean-time.

Comment thread benchmarks/dynamo/common.py Outdated
eager_dev = inputs[0].device
# We assume the passed in mod is already on xla device
# so we dont need: xla_mod = copy.deepcopy(mod).to(device=xla_dev)
xla_mod = mod
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe rename 'mod' above to xla_mod, and instead of the comment here you could do an assert?

# We assume the passed in mod is already on xla device
# so we dont need: xla_mod = copy.deepcopy(mod).to(device=xla_dev)
xla_mod = mod
xla_inputs = tree_map(lambda x: x.to(device=xla_dev), inputs)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is 'wrapper' going to stand in for model_iter_fn later in the timing script? so that means we will include .to/from(xla) transfers in our timing?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, inputs.to(lazy_device) and outputs.to(eager_device) are included in the timing for baseline. Those overhead are included in the timing for torchxla_trace_once backend as well.


def xla_model_wrapper(*inputs):
orig_device = inputs[0].device if len(inputs) > 0 else "cpu"
xla_inputs = tuple(inp.to(device=xla_dev) for inp in inputs)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same question as above;

i'm wondering if the best UX is to handle to/from XLA device transparently like this, or to require users up front pick the XLA device.

The tradeoff would be that the user has to pick XLA up front, but then they would be more in control of when the device copy happens.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So basically you mean wrap the model_iter_fn once (to handle the cross device movements) so we don't need to do it in any other places?

I actually tried this idea since I agree that makes the code cleaner. But the problem I encountered is, the wrapped model_iter_fn fail to be compiled by dynamo. I guess there is some incompatibility if we apply XLA and dynamo (graph capture part rather than the backend) together.

def __call__(self, args):
real_input = []
for tensor_id, traced_ivalue in zip(
self.graph_input_tensor_ids, self.graph_input_ivalues
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm is '_ivalues' here also an overloaded term to mean 'xla_values' in the xla case?

)
xla_args_need_update = []
arg_index_to_need_update_index = {}
for i, nede_update in enumerate(xla_args_need_update_bool):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo (need)


(
graph_input_tensor_ids,
graph_input_ivalues,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea i think better to rename 'ivalues' to 'backend_handles' or something more generic?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haha, call it ivalue since we return IValue in the pybind API:

  m.def("_get_tensors_xla_device_data_node",
        [](const std::vector<at::Tensor>& tensors)
            -> std::pair<std::vector<int64_t>, std::vector<at::IValue>> {

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

backend_handles would be more suitable for some opaque lookup keys which will be interpreted somewhere else (i.e. in XLA)?

Just double check. I'm ok to rename though.

tensor_id_to_arg_idx, graph_input_tensor_ids, graph_input_ivalues
)

# sync xla tensors
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be worth changing comment here to explain to non-xla familiar folks:

compiles+runs graph rooted at tensors in 'args_and_out'

@pytorch-bot pytorch-bot Bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 27, 2022
@JackCaoG
Copy link
Copy Markdown
Collaborator

TPU v4-8 single device speed up is

model name speed up lazy time dynamo time
resnet50 1.112 0.01545586 0.01389865
resnext50_32x4d 1.424 0.01714463 0.01204363
alexnet 0.697 0.0620243 0.08900219
mobilenet_v2 1.505 0.01111322 0.00738665
mnasnet1_0 1.114 0.0136932 0.01229267
squeezenet1_1 0.825 0.00689901 0.0083646
vgg16 0.994 0.00598695 0.00602498
BERT_pytorch 3.466 0.01493634 0.00430994
timm_vision_transformer 2.611 0.01517026 0.00580959

on bert and vit the speed up is very significant. It is puzzling why alexnet and squeezenet1_1 is slower with dynamo.

@@ -0,0 +1,165 @@
import copy
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd suggest just putting this file under torch/_dynamo/optimizations/torchxla_integration.py, unless you expect it to grow large enough to be many files.

Copy link
Copy Markdown
Contributor Author

@shunting314 shunting314 Oct 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can move it. The file size should still be manageable

@@ -0,0 +1,127 @@
import copy
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's put test files in the top level test/dynamo directory, with the right skips if XLA isn't installed.

@shunting314
Copy link
Copy Markdown
Contributor Author

TPU v4-8 single device speed up is

model name speed up lazy time dynamo time
resnet50 1.112 0.01545586 0.01389865
resnext50_32x4d 1.424 0.01714463 0.01204363
alexnet 0.697 0.0620243 0.08900219
mobilenet_v2 1.505 0.01111322 0.00738665
mnasnet1_0 1.114 0.0136932 0.01229267
squeezenet1_1 0.825 0.00689901 0.0083646
vgg16 0.994 0.00598695 0.00602498
BERT_pytorch 3.466 0.01493634 0.00430994
timm_vision_transformer 2.611 0.01517026 0.00580959
on bert and vit the speed up is very significant. It is puzzling why alexnet and squeezenet1_1 is slower with dynamo.

Try --torchxla-trivial on these 2 models? --torchxla-trivial enables dynamo but do no avoid retracing. That can be used to measure dynamo overhead here.

@shunting314 shunting314 force-pushed the dynamo-torchxla-integration branch from e2deb29 to 975cafb Compare October 27, 2022 21:36
@shunting314 shunting314 requested a review from jansel October 27, 2022 21:38
@shunting314
Copy link
Copy Markdown
Contributor Author

Resolve comments as much as I can. Some left ones need some clarification.

@jansel, @wconstab do you want to take another look

@shunting314 shunting314 force-pushed the dynamo-torchxla-integration branch 3 times, most recently from 50cc856 to 9b0ad23 Compare October 28, 2022 06:41
@shunting314
Copy link
Copy Markdown
Contributor Author

@pytorchbot merge -f "CI clean except some CLA authorization issue"

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge failed

Reason: The following mandatory check(s) failed (Rule superuser):

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

@malfet
Copy link
Copy Markdown
Contributor

malfet commented Oct 28, 2022

/easycla

@shunting314 shunting314 force-pushed the dynamo-torchxla-integration branch 2 times, most recently from 284973e to 6b485c6 Compare October 28, 2022 21:08
@github-actions
Copy link
Copy Markdown
Contributor

This PR needs a label

If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

For more information, see https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@shunting314 shunting314 added the topic: not user facing topic category label Oct 28, 2022
Example command
```
GPU_NUM_DEVICES=1 python benchmarks/dynamo/torchbench.py --randomize-input --performance --use-xla-baseline --only resnet18 --backend=torchxla_trace_once
```
@shunting314 shunting314 force-pushed the dynamo-torchxla-integration branch from 6b485c6 to 2390063 Compare October 28, 2022 21:45
@shunting314
Copy link
Copy Markdown
Contributor Author

The failed job '
trunk / linux-bionic-cuda11.7-py3.10-gcc7 / test (nogpu_NO_AVX2, 1, 1, linux.2xlarge) (push) ' has a very long log. Here is the most relevant error message snippet I can find:


2022-10-29T03:03:55.3797957Z �[0;32m[ RUN      ] �[mQuantization.QuantUpsampleNearst2dDequantUInt8
2022-10-29T03:03:55.4259856Z x:
2022-10-29T03:03:55.4261226Z (1,1,.,.) = 
2022-10-29T03:03:55.4261716Z   0.1066  0.6779  0.1332  0.2720
2022-10-29T03:03:55.4261979Z   0.7040  0.0439  0.8677  0.0901
2022-10-29T03:03:55.4262195Z   0.2853  0.6457  0.2500  0.6500
2022-10-29T03:03:55.4262449Z   0.7511  0.3352  0.9910  0.4533
2022-10-29T03:03:55.4262860Z [ CPUFloatType{1,1,4,4} ]
2022-10-29T03:03:55.4263085Z q:
2022-10-29T03:03:55.4263239Z (1,1,.,.) = 
2022-10-29T03:03:55.4263424Z   0.1000  0.7000  0.1000  0.3000
2022-10-29T03:03:55.4263582Z   0.7000  0.0000  0.9000  0.1000
2022-10-29T03:03:55.4263817Z   0.3000  0.6000  0.3000  0.6000
2022-10-29T03:03:55.4263996Z   0.8000  0.3000  1.0000  0.5000
2022-10-29T03:03:55.4264338Z [ QuantizedCPUQUInt8Type{1,1,4,4}, qscheme: per_tensor_affine, scale: 0.1000, zero_point: 13 ]
2022-10-29T03:03:55.4264632Z qu:
2022-10-29T03:03:55.4264786Z (1,1,.,.) = 
2022-10-29T03:03:55.4264966Z   0.1000  0.1000  0.7000  0.1000  0.1000  0.3000
2022-10-29T03:03:55.4265145Z   0.1000  0.1000  0.7000  0.1000  0.1000  0.3000
2022-10-29T03:03:55.4265389Z   0.7000  0.7000  0.0000  0.9000  0.9000  0.1000
2022-10-29T03:03:55.4265588Z   0.3000  0.3000  0.6000  0.3000  0.3000  0.6000
2022-10-29T03:03:55.4265764Z   0.3000  0.3000  0.6000  0.3000  0.3000  0.6000
2022-10-29T03:03:55.4265952Z   0.8000  0.8000  0.3000  1.0000  1.0000  0.5000
2022-10-29T03:03:55.4266283Z [ QuantizedCPUQUInt8Type{1,1,6,6}, qscheme: per_tensor_affine, scale: 0.1000, zero_point: 13 ]
2022-10-29T03:03:55.4266523Z y_expected:
2022-10-29T03:03:55.4266686Z (1,1,.,.) = 
2022-10-29T03:03:55.4266904Z   0.1000  0.1000  0.7000  0.1000  0.1000  0.3000
2022-10-29T03:03:55.4267090Z   0.1000  0.1000  0.7000  0.1000  0.1000  0.3000
2022-10-29T03:03:55.4267278Z   0.7000  0.7000  0.0000  0.9000  0.9000  0.1000
2022-10-29T03:03:55.4267467Z   0.3000  0.3000  0.6000  0.3000  0.3000  0.6000
2022-10-29T03:03:55.4267701Z   0.3000  0.3000  0.6000  0.3000  0.3000  0.6000
2022-10-29T03:03:55.4267883Z   0.8000  0.8000  0.3000  1.0000  1.0000  0.5000
2022-10-29T03:03:55.4268076Z [ CPUFloatType{1,1,6,6} ]
2022-10-29T03:03:55.4268253Z y:
2022-10-29T03:03:55.4268412Z (1,1,.,.) = 
2022-10-29T03:03:55.4268624Z   0.1000  0.1000  0.7000  0.1000  0.1000  0.3000
2022-10-29T03:03:55.4268812Z   0.1000  0.1000  0.7000  0.1000  0.1000  0.3000
2022-10-29T03:03:55.4268982Z   0.7000  0.7000  0.0000  0.9000  0.9000  0.1000
2022-10-29T03:03:55.4269186Z   0.3000  0.3000  0.6000  0.3000  0.3000  0.7000
2022-10-29T03:03:55.4269410Z   0.3000  0.3000  0.6000  0.3000  0.3000  0.7000
2022-10-29T03:03:55.4269582Z   0.8000  0.8000  0.3000  1.0000  1.0000  0.5000
2022-10-29T03:03:55.4269767Z [ CPUFloatType{1,1,6,6} ]
2022-10-29T03:03:55.4270073Z [F test_quantization.cpp:352] Check failed: check == 1 (0 vs. 1) 
2022-10-29T03:03:55.5846765Z .jenkins/pytorch/test.sh: line 355: 14643 Aborted                 (core dumped) "$TORCH_BIN_DIR"/test_tensorexpr --gtest_output=xml:$TEST_REPORTS_DIR/test_tensorexpr.xml 

Looks like some TORCH_CHECK fail for quantization and cause the testing process abort.

I think it's not related to the PR.

@shunting314
Copy link
Copy Markdown
Contributor Author

@pytorchbot merge -f "The only failed check is not relevant to the PR. Check the previous comment"

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@shunting314 shunting314 deleted the dynamo-torchxla-integration branch October 29, 2022 22:23
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Nov 5, 2022
# Motivation
- torchdynamo and torchxla uses different strategies to be a sound graph capture technique. The former relies on guards; the latter relies on retracing
- guard system is quite low overhead but torchxla tracing overhead is quite high

The main idea is to leverage guard system in torchdynamo to avoid retracing in torchxla so that
- we can integration torchdynamo with XLA
- we reduce or even completely avoid tracing overhead of torchxla

# Technique details
## XLA baseline
We found that different frameworks do not generate numerically identical results for the SAME model with the SAME input. By default, torchdynamo uses eager as baseline so the model will run with PyTorch. It would be tricky to compare a model running on XLA with this baseline: it's hard to check correctness. To make the comparison easier, we add a flag `--use-xla-baseline`. When it's enabled, the baseline will be run on XLA.

## New dynamo backends added
We add 2 new dynamo backends torchxla_trivial and trochxla_trace_once to control the optimization targets.

torchxla_trivial simply moves inputs/model parameters to XLA and run the model on XLA. There is tracing overhead for each run. We should expect that result to be mostly neutral compared to the XLA baseline.

torchxla_trace_once only traces once during AOT compiling time. Here are the steps:
1. dynamo capture guards and the subgraph
2. torchxla_trace_once backend trace the graph with torchxla, lowering the graph and record a hash of the graph for later lookup
3. at inference time, the hash is used directly to lookup the optimized graph and run it.

# Limitations
We can not handle LTC/torchxla fall back right now. If a op misses LTC kernel, we raise and exception and that will results in dynamo fallback (or try another compiler). People have brainstormed the idea of graph breaking and stitching the subgraphs together. But maybe it's easier to add those missing LTC kernels for those models.

# Results
The models we tested are those not causing LTC fallback. We run the tests on **GPU**. We see **1.38x** geomean speedup for trochxla_trace_once  and torchxla_trivial is mostly neutral as expected.
```
| Model                   |   XLA (trace once) |   XLA (trace everytime) |
+=========================+====================+=========================+
| resnet18                |            1.346   |                 1.045   |
+-------------------------+--------------------+-------------------------+
| resnet50                |            1.153   |                 1.007   |
+-------------------------+--------------------+-------------------------+
| resnext50_32x4d         |            1.381   |                 1.039   |
+-------------------------+--------------------+-------------------------+
| alexnet                 |            1.045   |                 1.018   |
+-------------------------+--------------------+-------------------------+
| mobilenet_v2            |            1.562   |                 1.021   |
+-------------------------+--------------------+-------------------------+
| mnasnet1_0              |            1.303   |                 1.069   |
+-------------------------+--------------------+-------------------------+
| squeezenet1_1           |            1.278   |                 1.025   |
+-------------------------+--------------------+-------------------------+
| vgg16                   |            1.076   |                 1.008   |
+-------------------------+--------------------+-------------------------+
| BERT_pytorch            |            2.224   |                 0.978   |
+-------------------------+--------------------+-------------------------+
| timm_vision_transformer |            1.81    |                 1.025   |
+-------------------------+--------------------+-------------------------+
| geomean                 |            1.38101 |                 1.02324 |
+-------------------------+--------------------+-------------------------+
```

The speedup is similar to what we see from previous work for LTC's TorchScript backend (we see 1.40 geomean speedup there):
https://docs.google.com/presentation/d/1G09X8v41u_cLKLtSdf7v6R8G19-iZTPcW_VAdOnvYBI/edit#slide=id.g11bf989cb6b_1_5

# Next steps
- Use AOT autograd to enable training
- Share results on XLA devices
- Do more extensive tests on torchbench models

Example command
```
GPU_NUM_DEVICES=1 python benchmarks/dynamo/torchbench.py --randomize-input --performance --use-xla-baseline --only resnet18 --backend=torchxla_trace_once
```

Thanks @JackCaoG from torchxla team to help debugging various perf issues and merging the torchxla PR! That's super critical for us to get the results above. torchxla side PR: pytorch/xla#4119

topic: not user facing

cc @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @jansel

Pull Request resolved: pytorch#87741
Approved by: https://github.com/wconstab
pytorchmergebot pushed a commit that referenced this pull request Nov 22, 2022
In #87741 we added the inference support for dynamo/torchxla integration. Later on in #88449 we attempt to add the training support. That attempt is not smooth because
- we try 2 things together
   1. let dynamo trace the model on xla rather than eager
   2. enable training
- It turns out neither of these two tasks are trivial enough.

Furthermore, item 2 (enable training) depends on item 1 (tracing on xla). We enable training via AOTAutograd. AOTAutograd lift all model parameters/buffers as graph inputs. Without item 1 being done, we would need copy all graph inputs (including model parameters/buffers) from eager device to xla devices. That hurts performance a lot. Have a cache to map eager parameter to XLA parameter does not solve the problem since the update on either will not sync automatically to the other. They will easily go out of sync.

This PR let dynamo trace the model on XLA rather than eager. This is a preparation step to enabling training.

Also, tracing on XLA makes the data movement more efficient. We see 1.5x geomean speedup compared to previous 1.38x.
```
+-------------------------+--------------------+-------------------------+
| Model                   |   XLA (trace once) |   XLA (trace everytime) |
+=========================+====================+=========================+
| resnet18                |            1.38    |                 1.008   |
+-------------------------+--------------------+-------------------------+
| resnet50                |            1.227   |                 0.998   |
+-------------------------+--------------------+-------------------------+
| resnext50_32x4d         |            1.544   |                 1.008   |
+-------------------------+--------------------+-------------------------+
| alexnet                 |            1.085   |                 1.045   |
+-------------------------+--------------------+-------------------------+
| mobilenet_v2            |            2.028   |                 1.013   |
+-------------------------+--------------------+-------------------------+
| mnasnet1_0              |            1.516   |                 0.995   |
+-------------------------+--------------------+-------------------------+
| squeezenet1_1           |            0.868   |                 1.01    |
+-------------------------+--------------------+-------------------------+
| vgg16                   |            1.099   |                 1.008   |
+-------------------------+--------------------+-------------------------+
| BERT_pytorch            |            3.26    |                 1.027   |
+-------------------------+--------------------+-------------------------+
| timm_vision_transformer |            2.182   |                 1.015   |
+-------------------------+--------------------+-------------------------+
| geomean                 |            1.50389 |                 1.01261 |
+-------------------------+--------------------+-------------------------+
```

Example command
```
GPU_NUM_DEVICES=1 python benchmarks/dynamo/torchbench.py --randomize-input --performance --trace-on-xla --only resnet18 --backend=torchxla_trace_once
```

Pull Request resolved: #88904
Approved by: https://github.com/wconstab, https://github.com/JackCaoG, https://github.com/jansel
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Dec 10, 2022
# Motivation
- torchdynamo and torchxla uses different strategies to be a sound graph capture technique. The former relies on guards; the latter relies on retracing
- guard system is quite low overhead but torchxla tracing overhead is quite high

The main idea is to leverage guard system in torchdynamo to avoid retracing in torchxla so that
- we can integration torchdynamo with XLA
- we reduce or even completely avoid tracing overhead of torchxla

# Technique details
## XLA baseline
We found that different frameworks do not generate numerically identical results for the SAME model with the SAME input. By default, torchdynamo uses eager as baseline so the model will run with PyTorch. It would be tricky to compare a model running on XLA with this baseline: it's hard to check correctness. To make the comparison easier, we add a flag `--use-xla-baseline`. When it's enabled, the baseline will be run on XLA.

## New dynamo backends added
We add 2 new dynamo backends torchxla_trivial and trochxla_trace_once to control the optimization targets.

torchxla_trivial simply moves inputs/model parameters to XLA and run the model on XLA. There is tracing overhead for each run. We should expect that result to be mostly neutral compared to the XLA baseline.

torchxla_trace_once only traces once during AOT compiling time. Here are the steps:
1. dynamo capture guards and the subgraph
2. torchxla_trace_once backend trace the graph with torchxla, lowering the graph and record a hash of the graph for later lookup
3. at inference time, the hash is used directly to lookup the optimized graph and run it.

# Limitations
We can not handle LTC/torchxla fall back right now. If a op misses LTC kernel, we raise and exception and that will results in dynamo fallback (or try another compiler). People have brainstormed the idea of graph breaking and stitching the subgraphs together. But maybe it's easier to add those missing LTC kernels for those models.

# Results
The models we tested are those not causing LTC fallback. We run the tests on **GPU**. We see **1.38x** geomean speedup for trochxla_trace_once  and torchxla_trivial is mostly neutral as expected.
```
| Model                   |   XLA (trace once) |   XLA (trace everytime) |
+=========================+====================+=========================+
| resnet18                |            1.346   |                 1.045   |
+-------------------------+--------------------+-------------------------+
| resnet50                |            1.153   |                 1.007   |
+-------------------------+--------------------+-------------------------+
| resnext50_32x4d         |            1.381   |                 1.039   |
+-------------------------+--------------------+-------------------------+
| alexnet                 |            1.045   |                 1.018   |
+-------------------------+--------------------+-------------------------+
| mobilenet_v2            |            1.562   |                 1.021   |
+-------------------------+--------------------+-------------------------+
| mnasnet1_0              |            1.303   |                 1.069   |
+-------------------------+--------------------+-------------------------+
| squeezenet1_1           |            1.278   |                 1.025   |
+-------------------------+--------------------+-------------------------+
| vgg16                   |            1.076   |                 1.008   |
+-------------------------+--------------------+-------------------------+
| BERT_pytorch            |            2.224   |                 0.978   |
+-------------------------+--------------------+-------------------------+
| timm_vision_transformer |            1.81    |                 1.025   |
+-------------------------+--------------------+-------------------------+
| geomean                 |            1.38101 |                 1.02324 |
+-------------------------+--------------------+-------------------------+
```

The speedup is similar to what we see from previous work for LTC's TorchScript backend (we see 1.40 geomean speedup there):
https://docs.google.com/presentation/d/1G09X8v41u_cLKLtSdf7v6R8G19-iZTPcW_VAdOnvYBI/edit#slide=id.g11bf989cb6b_1_5

# Next steps
- Use AOT autograd to enable training
- Share results on XLA devices
- Do more extensive tests on torchbench models

Example command
```
GPU_NUM_DEVICES=1 python benchmarks/dynamo/torchbench.py --randomize-input --performance --use-xla-baseline --only resnet18 --backend=torchxla_trace_once
```

Thanks @JackCaoG from torchxla team to help debugging various perf issues and merging the torchxla PR! That's super critical for us to get the results above. torchxla side PR: pytorch/xla#4119

topic: not user facing

cc @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @jansel

Pull Request resolved: pytorch#87741
Approved by: https://github.com/wconstab
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Dec 10, 2022
…88904)

In pytorch#87741 we added the inference support for dynamo/torchxla integration. Later on in pytorch#88449 we attempt to add the training support. That attempt is not smooth because
- we try 2 things together
   1. let dynamo trace the model on xla rather than eager
   2. enable training
- It turns out neither of these two tasks are trivial enough.

Furthermore, item 2 (enable training) depends on item 1 (tracing on xla). We enable training via AOTAutograd. AOTAutograd lift all model parameters/buffers as graph inputs. Without item 1 being done, we would need copy all graph inputs (including model parameters/buffers) from eager device to xla devices. That hurts performance a lot. Have a cache to map eager parameter to XLA parameter does not solve the problem since the update on either will not sync automatically to the other. They will easily go out of sync.

This PR let dynamo trace the model on XLA rather than eager. This is a preparation step to enabling training.

Also, tracing on XLA makes the data movement more efficient. We see 1.5x geomean speedup compared to previous 1.38x.
```
+-------------------------+--------------------+-------------------------+
| Model                   |   XLA (trace once) |   XLA (trace everytime) |
+=========================+====================+=========================+
| resnet18                |            1.38    |                 1.008   |
+-------------------------+--------------------+-------------------------+
| resnet50                |            1.227   |                 0.998   |
+-------------------------+--------------------+-------------------------+
| resnext50_32x4d         |            1.544   |                 1.008   |
+-------------------------+--------------------+-------------------------+
| alexnet                 |            1.085   |                 1.045   |
+-------------------------+--------------------+-------------------------+
| mobilenet_v2            |            2.028   |                 1.013   |
+-------------------------+--------------------+-------------------------+
| mnasnet1_0              |            1.516   |                 0.995   |
+-------------------------+--------------------+-------------------------+
| squeezenet1_1           |            0.868   |                 1.01    |
+-------------------------+--------------------+-------------------------+
| vgg16                   |            1.099   |                 1.008   |
+-------------------------+--------------------+-------------------------+
| BERT_pytorch            |            3.26    |                 1.027   |
+-------------------------+--------------------+-------------------------+
| timm_vision_transformer |            2.182   |                 1.015   |
+-------------------------+--------------------+-------------------------+
| geomean                 |            1.50389 |                 1.01261 |
+-------------------------+--------------------+-------------------------+
```

Example command
```
GPU_NUM_DEVICES=1 python benchmarks/dynamo/torchbench.py --randomize-input --performance --trace-on-xla --only resnet18 --backend=torchxla_trace_once
```

Pull Request resolved: pytorch#88904
Approved by: https://github.com/wconstab, https://github.com/JackCaoG, https://github.com/jansel
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 25, 2026
# Motivation
- torchdynamo and torchxla uses different strategies to be a sound graph capture technique. The former relies on guards; the latter relies on retracing
- guard system is quite low overhead but torchxla tracing overhead is quite high

The main idea is to leverage guard system in torchdynamo to avoid retracing in torchxla so that
- we can integration torchdynamo with XLA
- we reduce or even completely avoid tracing overhead of torchxla

# Technique details
## XLA baseline
We found that different frameworks do not generate numerically identical results for the SAME model with the SAME input. By default, torchdynamo uses eager as baseline so the model will run with PyTorch. It would be tricky to compare a model running on XLA with this baseline: it's hard to check correctness. To make the comparison easier, we add a flag `--use-xla-baseline`. When it's enabled, the baseline will be run on XLA.

## New dynamo backends added
We add 2 new dynamo backends torchxla_trivial and trochxla_trace_once to control the optimization targets.

torchxla_trivial simply moves inputs/model parameters to XLA and run the model on XLA. There is tracing overhead for each run. We should expect that result to be mostly neutral compared to the XLA baseline.

torchxla_trace_once only traces once during AOT compiling time. Here are the steps:
1. dynamo capture guards and the subgraph
2. torchxla_trace_once backend trace the graph with torchxla, lowering the graph and record a hash of the graph for later lookup
3. at inference time, the hash is used directly to lookup the optimized graph and run it.

# Limitations
We can not handle LTC/torchxla fall back right now. If a op misses LTC kernel, we raise and exception and that will results in dynamo fallback (or try another compiler). People have brainstormed the idea of graph breaking and stitching the subgraphs together. But maybe it's easier to add those missing LTC kernels for those models.

# Results
The models we tested are those not causing LTC fallback. We run the tests on **GPU**. We see **1.38x** geomean speedup for trochxla_trace_once  and torchxla_trivial is mostly neutral as expected.
```
| Model                   |   XLA (trace once) |   XLA (trace everytime) |
+=========================+====================+=========================+
| resnet18                |            1.346   |                 1.045   |
+-------------------------+--------------------+-------------------------+
| resnet50                |            1.153   |                 1.007   |
+-------------------------+--------------------+-------------------------+
| resnext50_32x4d         |            1.381   |                 1.039   |
+-------------------------+--------------------+-------------------------+
| alexnet                 |            1.045   |                 1.018   |
+-------------------------+--------------------+-------------------------+
| mobilenet_v2            |            1.562   |                 1.021   |
+-------------------------+--------------------+-------------------------+
| mnasnet1_0              |            1.303   |                 1.069   |
+-------------------------+--------------------+-------------------------+
| squeezenet1_1           |            1.278   |                 1.025   |
+-------------------------+--------------------+-------------------------+
| vgg16                   |            1.076   |                 1.008   |
+-------------------------+--------------------+-------------------------+
| BERT_pytorch            |            2.224   |                 0.978   |
+-------------------------+--------------------+-------------------------+
| timm_vision_transformer |            1.81    |                 1.025   |
+-------------------------+--------------------+-------------------------+
| geomean                 |            1.38101 |                 1.02324 |
+-------------------------+--------------------+-------------------------+
```

The speedup is similar to what we see from previous work for LTC's TorchScript backend (we see 1.40 geomean speedup there):
https://docs.google.com/presentation/d/1G09X8v41u_cLKLtSdf7v6R8G19-iZTPcW_VAdOnvYBI/edit#slide=id.g11bf989cb6b_1_5

# Next steps
- Use AOT autograd to enable training
- Share results on XLA devices
- Do more extensive tests on torchbench models

Example command
```
GPU_NUM_DEVICES=1 python benchmarks/dynamo/torchbench.py --randomize-input --performance --use-xla-baseline --only resnet18 --backend=torchxla_trace_once
```

Thanks @JackCaoG from torchxla team to help debugging various perf issues and merging the torchxla PR! That's super critical for us to get the results above. torchxla side PR: pytorch/xla#4119

topic: not user facing

cc @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @jansel

Pull Request resolved: pytorch#87741
Approved by: https://github.com/wconstab
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 25, 2026
…88904)

In pytorch#87741 we added the inference support for dynamo/torchxla integration. Later on in pytorch#88449 we attempt to add the training support. That attempt is not smooth because
- we try 2 things together
   1. let dynamo trace the model on xla rather than eager
   2. enable training
- It turns out neither of these two tasks are trivial enough.

Furthermore, item 2 (enable training) depends on item 1 (tracing on xla). We enable training via AOTAutograd. AOTAutograd lift all model parameters/buffers as graph inputs. Without item 1 being done, we would need copy all graph inputs (including model parameters/buffers) from eager device to xla devices. That hurts performance a lot. Have a cache to map eager parameter to XLA parameter does not solve the problem since the update on either will not sync automatically to the other. They will easily go out of sync.

This PR let dynamo trace the model on XLA rather than eager. This is a preparation step to enabling training.

Also, tracing on XLA makes the data movement more efficient. We see 1.5x geomean speedup compared to previous 1.38x.
```
+-------------------------+--------------------+-------------------------+
| Model                   |   XLA (trace once) |   XLA (trace everytime) |
+=========================+====================+=========================+
| resnet18                |            1.38    |                 1.008   |
+-------------------------+--------------------+-------------------------+
| resnet50                |            1.227   |                 0.998   |
+-------------------------+--------------------+-------------------------+
| resnext50_32x4d         |            1.544   |                 1.008   |
+-------------------------+--------------------+-------------------------+
| alexnet                 |            1.085   |                 1.045   |
+-------------------------+--------------------+-------------------------+
| mobilenet_v2            |            2.028   |                 1.013   |
+-------------------------+--------------------+-------------------------+
| mnasnet1_0              |            1.516   |                 0.995   |
+-------------------------+--------------------+-------------------------+
| squeezenet1_1           |            0.868   |                 1.01    |
+-------------------------+--------------------+-------------------------+
| vgg16                   |            1.099   |                 1.008   |
+-------------------------+--------------------+-------------------------+
| BERT_pytorch            |            3.26    |                 1.027   |
+-------------------------+--------------------+-------------------------+
| timm_vision_transformer |            2.182   |                 1.015   |
+-------------------------+--------------------+-------------------------+
| geomean                 |            1.50389 |                 1.01261 |
+-------------------------+--------------------+-------------------------+
```

Example command
```
GPU_NUM_DEVICES=1 python benchmarks/dynamo/torchbench.py --randomize-input --performance --trace-on-xla --only resnet18 --backend=torchxla_trace_once
```

Pull Request resolved: pytorch#88904
Approved by: https://github.com/wconstab, https://github.com/JackCaoG, https://github.com/jansel
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants