Skip to content

codegen fixes to fix tracing XLA autograd ops#90226

Closed
bdhirsh wants to merge 3 commits intogh/bdhirsh/353/basefrom
gh/bdhirsh/353/head
Closed

codegen fixes to fix tracing XLA autograd ops#90226
bdhirsh wants to merge 3 commits intogh/bdhirsh/353/basefrom
gh/bdhirsh/353/head

Conversation

@bdhirsh
Copy link
Copy Markdown
Collaborator

@bdhirsh bdhirsh commented Dec 5, 2022

Today when XLA registers an autograd.Function to the AutogradXLA key, the "add to the autograd graph" step and the "run the forward kernel" step happen all in one go. That's wrong, and prevents other dispatcher code from executing in the middle.

When trying to fix this, I noticed a bug in the codegen: we register kernels for both the XLA and AutogradXLA dispatch keys to the same class. This prevents XLA from registering a separate kernel to the XLA and Autograd XLA, which is what this PR attempts to address.

Companion patch to fix XLA's max_pool2d registration here, which was blocking the dynamo integration: pytorch/xla#4276

After this PR, XLA should generate two separate header files: XLANativeFunctions.h, and AutogradXLANativeFunctions.h Before, all of the kernels (including autograd kernels) would be thrown in XLANativeFunctions.h.

cc @JackCaoG

Stack from ghstack (oldest at bottom):

@bdhirsh bdhirsh requested a review from a team as a code owner December 5, 2022 22:25
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Dec 5, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/90226

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 Failures

As of commit 2c2f4a0:

The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@linux-foundation-easycla
Copy link
Copy Markdown

linux-foundation-easycla Bot commented Dec 5, 2022

CLA Missing ID CLA Not Signed

Today when XLA registers an autograd.Function to the `AutogradXLA` key, the "add to the autograd graph" step and the "run the forward kernel" step happen all in one go. That's wrong, and prevents other dispatcher code from executing in the middle.

When trying to fix this, I noticed a bug in the codegen: we register kernels for both the XLA and AutogradXLA dispatch keys to the same class. This prevents XLA from registering a separate kernel to the XLA and Autograd XLA, which is what this PR attempts to address.

Companion patch to fix XLA's max_pool2d registration here, which was blocking the dynamo integration: pytorch/xla#4276

After this PR, XLA should generate two separate header files: `XLANativeFunctions.h`, and `AutogradXLANativeFunctions.h` Before, all of the kernels (including autograd kernels) would be thrown in `XLANativeFunctions.h`.

cc JackCaoG 




[ghstack-poisoned]
bdhirsh pushed a commit that referenced this pull request Dec 5, 2022
ghstack-source-id: 9957d2a
Pull Request resolved: #90226
Today when XLA registers an autograd.Function to the `AutogradXLA` key, the "add to the autograd graph" step and the "run the forward kernel" step happen all in one go. That's wrong, and prevents other dispatcher code from executing in the middle.

When trying to fix this, I noticed a bug in the codegen: we register kernels for both the XLA and AutogradXLA dispatch keys to the same class. This prevents XLA from registering a separate kernel to the XLA and Autograd XLA, which is what this PR attempts to address.

Companion patch to fix XLA's max_pool2d registration here, which was blocking the dynamo integration: pytorch/xla#4276

After this PR, XLA should generate two separate header files: `XLANativeFunctions.h`, and `AutogradXLANativeFunctions.h` Before, all of the kernels (including autograd kernels) would be thrown in `XLANativeFunctions.h`.

cc JackCaoG 




[ghstack-poisoned]
bdhirsh pushed a commit that referenced this pull request Dec 8, 2022
ghstack-source-id: 6dbb425
Pull Request resolved: #90226
@albanD albanD removed their request for review December 28, 2022 11:17
pytorchmergebot pushed a commit that referenced this pull request Jan 5, 2023
We've already shown some promising perf result by integrating dynamo with torchxla for inference. To provide consistent UX for training and for inference, in this PR we try to enable training for dynamo/torchxla.

Training is trickier than inference and we may not expect much perf gains since
1. in training case, torchxla only generate a single combined graph for fwd/bwd/optimizer while in `torchxla_trace_once` bridge we added in dynamo, due to how AOT_Autograd works, we will generate 3 graphs: one for forward, one for backward and one for the optimizer. XLA favors larger graph to do more optimizations.
2. in training case, tracing overhead can be overlapped with computation. Tracing overhead is not as a big deal for training as for inference. After all training cares more about throughput while inference cares more about latency.
3. in training case, people can increase batch size to 'mitigate' the tracing overhead. Increase batch size does not change tracing overhead, thus it shows like the tracing overhead 'per example' reduces.

But we still want to add training support to dynamo/torchxla to make the work complete.

We added '--iterations-per-run' argument to control how may iterations we do per measure/device sync. This is to understand the impact of item 2 above.

Results:

With '--iterations-per-run' equals to 1, here are the perf numbers:
```
+-------------------------+--------------------+-------------------------+
| Model                   |   XLA (trace once) |   XLA (trace everytime) |
+=========================+====================+=========================+
| resnet18                |             0.91   |                0.959    |
+-------------------------+--------------------+-------------------------+
| resnet50                |             0.917  |                0.932    |
+-------------------------+--------------------+-------------------------+
| resnext50_32x4d         |             0.912  |                0.905    |
+-------------------------+--------------------+-------------------------+
| alexnet                 |             1.038  |                0.974    |
+-------------------------+--------------------+-------------------------+
| mobilenet_v2            |             0.881  |                0.835    |
+-------------------------+--------------------+-------------------------+
| mnasnet1_0              |             0.903  |                0.931    |
+-------------------------+--------------------+-------------------------+
| vgg16                   |             0.914  |                0.967    |
+-------------------------+--------------------+-------------------------+
| BERT_pytorch            |             1.359  |                0.84     |
+-------------------------+--------------------+-------------------------+
| timm_vision_transformer |             1.288  |                0.893    |
+-------------------------+--------------------+-------------------------+
| geomean                 |             1.0006 |                0.913794 |
+-------------------------+--------------------+-------------------------+
```

Overall it looks like graph break indeed cause perf loss. But for BERT_pytorch and timm_vision_transformer we still see perf gain. We need do more experiments with larger '--iterations-per-run'

NOTE:
In torchbench.py I added the following code to do a few workaround:
```
from myscripts import workaround # TODO will remove this line before landing
```

Here are the content of workaround.py:
```
import torch
from torch import nn
import os

# override max_pool2d with avg_pool2d
if os.environ.get("REPLACE_MAXPOOL", "0") == "1":
    torch.nn.MaxPool2d = torch.nn.AvgPool2d

```

It work around a few issues we found
1. MaxPool2d does not work for training in dynamo/torchxla: pytorch/torchdynamo#1837 . WIP fix from Brian in #90226 , https://github.com/pytorch/xla/pull/4276/files (WIP)
2. recent change ( this PR #88697 ) in op decomposition cause batch_norm ops to fallback in torchxla. Fix from jack in pytorch/xla#4282 (comment) . (confirmed the fix after adding Deduper to handle duplicated return from fx graph generated by AOTAutograd)
3. we have issue to handle dropout because of random seed out of sync issue. Here is the fix: pytorch/xla#4293 (confirmed the fix)

Example command:
```
REPLACE_MAXPOOL=1 USE_FAKE_TENSOR=0 GPU_NUM_DEVICES=1 python benchmarks/dynamo/torchbench.py --randomize-input --performance --trace-on-xla --training --backend=aot_torchxla_trace_once --only vgg16
```

Pull Request resolved: #88449
Approved by: https://github.com/wconstab, https://github.com/qihqi, https://github.com/malfet
@github-actions
Copy link
Copy Markdown
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions Bot added the Stale label Feb 26, 2023
@github-actions github-actions Bot closed this Mar 28, 2023
@facebook-github-bot facebook-github-bot deleted the gh/bdhirsh/353/head branch June 8, 2023 15:44
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 25, 2026
We've already shown some promising perf result by integrating dynamo with torchxla for inference. To provide consistent UX for training and for inference, in this PR we try to enable training for dynamo/torchxla.

Training is trickier than inference and we may not expect much perf gains since
1. in training case, torchxla only generate a single combined graph for fwd/bwd/optimizer while in `torchxla_trace_once` bridge we added in dynamo, due to how AOT_Autograd works, we will generate 3 graphs: one for forward, one for backward and one for the optimizer. XLA favors larger graph to do more optimizations.
2. in training case, tracing overhead can be overlapped with computation. Tracing overhead is not as a big deal for training as for inference. After all training cares more about throughput while inference cares more about latency.
3. in training case, people can increase batch size to 'mitigate' the tracing overhead. Increase batch size does not change tracing overhead, thus it shows like the tracing overhead 'per example' reduces.

But we still want to add training support to dynamo/torchxla to make the work complete.

We added '--iterations-per-run' argument to control how may iterations we do per measure/device sync. This is to understand the impact of item 2 above.

Results:

With '--iterations-per-run' equals to 1, here are the perf numbers:
```
+-------------------------+--------------------+-------------------------+
| Model                   |   XLA (trace once) |   XLA (trace everytime) |
+=========================+====================+=========================+
| resnet18                |             0.91   |                0.959    |
+-------------------------+--------------------+-------------------------+
| resnet50                |             0.917  |                0.932    |
+-------------------------+--------------------+-------------------------+
| resnext50_32x4d         |             0.912  |                0.905    |
+-------------------------+--------------------+-------------------------+
| alexnet                 |             1.038  |                0.974    |
+-------------------------+--------------------+-------------------------+
| mobilenet_v2            |             0.881  |                0.835    |
+-------------------------+--------------------+-------------------------+
| mnasnet1_0              |             0.903  |                0.931    |
+-------------------------+--------------------+-------------------------+
| vgg16                   |             0.914  |                0.967    |
+-------------------------+--------------------+-------------------------+
| BERT_pytorch            |             1.359  |                0.84     |
+-------------------------+--------------------+-------------------------+
| timm_vision_transformer |             1.288  |                0.893    |
+-------------------------+--------------------+-------------------------+
| geomean                 |             1.0006 |                0.913794 |
+-------------------------+--------------------+-------------------------+
```

Overall it looks like graph break indeed cause perf loss. But for BERT_pytorch and timm_vision_transformer we still see perf gain. We need do more experiments with larger '--iterations-per-run'

NOTE:
In torchbench.py I added the following code to do a few workaround:
```
from myscripts import workaround # TODO will remove this line before landing
```

Here are the content of workaround.py:
```
import torch
from torch import nn
import os

# override max_pool2d with avg_pool2d
if os.environ.get("REPLACE_MAXPOOL", "0") == "1":
    torch.nn.MaxPool2d = torch.nn.AvgPool2d

```

It work around a few issues we found
1. MaxPool2d does not work for training in dynamo/torchxla: pytorch/torchdynamo#1837 . WIP fix from Brian in pytorch#90226 , https://github.com/pytorch/xla/pull/4276/files (WIP)
2. recent change ( this PR pytorch#88697 ) in op decomposition cause batch_norm ops to fallback in torchxla. Fix from jack in pytorch/xla#4282 (comment) . (confirmed the fix after adding Deduper to handle duplicated return from fx graph generated by AOTAutograd)
3. we have issue to handle dropout because of random seed out of sync issue. Here is the fix: pytorch/xla#4293 (confirmed the fix)

Example command:
```
REPLACE_MAXPOOL=1 USE_FAKE_TENSOR=0 GPU_NUM_DEVICES=1 python benchmarks/dynamo/torchbench.py --randomize-input --performance --trace-on-xla --training --backend=aot_torchxla_trace_once --only vgg16
```

Pull Request resolved: pytorch#88449
Approved by: https://github.com/wconstab, https://github.com/qihqi, https://github.com/malfet
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants