Skip to content

Tag dynamo backends as debug/experimental#93878

Closed
jansel wants to merge 8 commits intogh/jansel/40/basefrom
gh/jansel/40/head
Closed

Tag dynamo backends as debug/experimental#93878
jansel wants to merge 8 commits intogh/jansel/40/basefrom
gh/jansel/40/head

Conversation

@jansel
Copy link
Copy Markdown
Contributor

@jansel jansel commented Feb 1, 2023

Stack from ghstack (oldest at bottom):

Hides debug/experimental backends by default.

Before:

torch._dynamo.list_backends()
['aot_eager', 'aot_eager_decomp_partition', 'aot_torchxla_trace_once', 'aot_torchxla_trivial', 'aot_ts', 'aot_ts_nvfuser', 'cudagraphs', 'dynamo_accuracy_minifier_backend', 'dynamo_minifier_backend', 'eager', 'inductor', 'ipex', 'nvprims_aten', 'nvprims_nvfuser', 'onnxrt', 'tensorrt', 'torchxla_trace_once', 'torchxla_trivial', 'ts', 'tvm']

After:

torch._dynamo.list_backends()
['aot_ts_nvfuser', 'cudagraphs', 'inductor', 'ipex', 'nvprims_nvfuser', 'onnxrt', 'tensorrt', 'tvm']

Fixes #93733

cc @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @desertfire

Hides debug/experimental backends by default.

Before:
```
torch._dynamo.list_backends()
['aot_eager', 'aot_eager_decomp_partition', 'aot_torchxla_trace_once', 'aot_torchxla_trivial', 'aot_ts', 'aot_ts_nvfuser', 'cudagraphs', 'dynamo_accuracy_minifier_backend', 'dynamo_minifier_backend', 'eager', 'inductor', 'ipex', 'nvprims_aten', 'nvprims_nvfuser', 'onnxrt', 'tensorrt', 'torchxla_trace_once', 'torchxla_trivial', 'ts', 'tvm']
```

After:
```
torch._dynamo.list_backends()
['aot_ts_nvfuser', 'cudagraphs', 'inductor', 'ipex', 'nvprims_nvfuser', 'onnxrt', 'tensorrt', 'tvm']
```

Fixes #93733

[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Feb 1, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/93878

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 9d4d66f:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Hides debug/experimental backends by default.

Before:
```
torch._dynamo.list_backends()
['aot_eager', 'aot_eager_decomp_partition', 'aot_torchxla_trace_once', 'aot_torchxla_trivial', 'aot_ts', 'aot_ts_nvfuser', 'cudagraphs', 'dynamo_accuracy_minifier_backend', 'dynamo_minifier_backend', 'eager', 'inductor', 'ipex', 'nvprims_aten', 'nvprims_nvfuser', 'onnxrt', 'tensorrt', 'torchxla_trace_once', 'torchxla_trivial', 'ts', 'tvm']
```

After:
```
torch._dynamo.list_backends()
['aot_ts_nvfuser', 'cudagraphs', 'inductor', 'ipex', 'nvprims_nvfuser', 'onnxrt', 'tensorrt', 'tvm']
```

Fixes #93733

cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
Hides debug/experimental backends by default.

Before:
```
torch._dynamo.list_backends()
['aot_eager', 'aot_eager_decomp_partition', 'aot_torchxla_trace_once', 'aot_torchxla_trivial', 'aot_ts', 'aot_ts_nvfuser', 'cudagraphs', 'dynamo_accuracy_minifier_backend', 'dynamo_minifier_backend', 'eager', 'inductor', 'ipex', 'nvprims_aten', 'nvprims_nvfuser', 'onnxrt', 'tensorrt', 'torchxla_trace_once', 'torchxla_trivial', 'ts', 'tvm']
```

After:
```
torch._dynamo.list_backends()
['aot_ts_nvfuser', 'cudagraphs', 'inductor', 'ipex', 'nvprims_nvfuser', 'onnxrt', 'tensorrt', 'tvm']
```

Fixes #93733

cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
jansel added a commit that referenced this pull request Feb 2, 2023
Hides debug/experimental backends by default.

Before:
```
torch._dynamo.list_backends()
['aot_eager', 'aot_eager_decomp_partition', 'aot_torchxla_trace_once', 'aot_torchxla_trivial', 'aot_ts', 'aot_ts_nvfuser', 'cudagraphs', 'dynamo_accuracy_minifier_backend', 'dynamo_minifier_backend', 'eager', 'inductor', 'ipex', 'nvprims_aten', 'nvprims_nvfuser', 'onnxrt', 'tensorrt', 'torchxla_trace_once', 'torchxla_trivial', 'ts', 'tvm']
```

After:
```
torch._dynamo.list_backends()
['aot_ts_nvfuser', 'cudagraphs', 'inductor', 'ipex', 'nvprims_nvfuser', 'onnxrt', 'tensorrt', 'tvm']
```

Fixes #93733

ghstack-source-id: 24ebb92
Pull Request resolved: #93878
@jansel jansel requested a review from voznesenskym February 2, 2023 05:44
Hides debug/experimental backends by default.

Before:
```
torch._dynamo.list_backends()
['aot_eager', 'aot_eager_decomp_partition', 'aot_torchxla_trace_once', 'aot_torchxla_trivial', 'aot_ts', 'aot_ts_nvfuser', 'cudagraphs', 'dynamo_accuracy_minifier_backend', 'dynamo_minifier_backend', 'eager', 'inductor', 'ipex', 'nvprims_aten', 'nvprims_nvfuser', 'onnxrt', 'tensorrt', 'torchxla_trace_once', 'torchxla_trivial', 'ts', 'tvm']
```

After:
```
torch._dynamo.list_backends()
['aot_ts_nvfuser', 'cudagraphs', 'inductor', 'ipex', 'nvprims_nvfuser', 'onnxrt', 'tensorrt', 'tvm']
```

Fixes #93733

cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
jansel added a commit that referenced this pull request Feb 2, 2023
Hides debug/experimental backends by default.

Before:
```
torch._dynamo.list_backends()
['aot_eager', 'aot_eager_decomp_partition', 'aot_torchxla_trace_once', 'aot_torchxla_trivial', 'aot_ts', 'aot_ts_nvfuser', 'cudagraphs', 'dynamo_accuracy_minifier_backend', 'dynamo_minifier_backend', 'eager', 'inductor', 'ipex', 'nvprims_aten', 'nvprims_nvfuser', 'onnxrt', 'tensorrt', 'torchxla_trace_once', 'torchxla_trivial', 'ts', 'tvm']
```

After:
```
torch._dynamo.list_backends()
['aot_ts_nvfuser', 'cudagraphs', 'inductor', 'ipex', 'nvprims_nvfuser', 'onnxrt', 'tensorrt', 'tvm']
```

Fixes #93733

ghstack-source-id: 03e9f7c
Pull Request resolved: #93878
Hides debug/experimental backends by default.

Before:
```
torch._dynamo.list_backends()
['aot_eager', 'aot_eager_decomp_partition', 'aot_torchxla_trace_once', 'aot_torchxla_trivial', 'aot_ts', 'aot_ts_nvfuser', 'cudagraphs', 'dynamo_accuracy_minifier_backend', 'dynamo_minifier_backend', 'eager', 'inductor', 'ipex', 'nvprims_aten', 'nvprims_nvfuser', 'onnxrt', 'tensorrt', 'torchxla_trace_once', 'torchxla_trivial', 'ts', 'tvm']
```

After:
```
torch._dynamo.list_backends()
['aot_ts_nvfuser', 'cudagraphs', 'inductor', 'ipex', 'nvprims_nvfuser', 'onnxrt', 'tensorrt', 'tvm']
```

Fixes #93733

cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
Hides debug/experimental backends by default.

Before:
```
torch._dynamo.list_backends()
['aot_eager', 'aot_eager_decomp_partition', 'aot_torchxla_trace_once', 'aot_torchxla_trivial', 'aot_ts', 'aot_ts_nvfuser', 'cudagraphs', 'dynamo_accuracy_minifier_backend', 'dynamo_minifier_backend', 'eager', 'inductor', 'ipex', 'nvprims_aten', 'nvprims_nvfuser', 'onnxrt', 'tensorrt', 'torchxla_trace_once', 'torchxla_trivial', 'ts', 'tvm']
```

After:
```
torch._dynamo.list_backends()
['aot_ts_nvfuser', 'cudagraphs', 'inductor', 'ipex', 'nvprims_nvfuser', 'onnxrt', 'tensorrt', 'tvm']
```

Fixes #93733

cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
Hides debug/experimental backends by default.

Before:
```
torch._dynamo.list_backends()
['aot_eager', 'aot_eager_decomp_partition', 'aot_torchxla_trace_once', 'aot_torchxla_trivial', 'aot_ts', 'aot_ts_nvfuser', 'cudagraphs', 'dynamo_accuracy_minifier_backend', 'dynamo_minifier_backend', 'eager', 'inductor', 'ipex', 'nvprims_aten', 'nvprims_nvfuser', 'onnxrt', 'tensorrt', 'torchxla_trace_once', 'torchxla_trivial', 'ts', 'tvm']
```

After:
```
torch._dynamo.list_backends()
['aot_ts_nvfuser', 'cudagraphs', 'inductor', 'ipex', 'nvprims_nvfuser', 'onnxrt', 'tensorrt', 'tvm']
```

Fixes #93733

cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
jansel added a commit that referenced this pull request Feb 3, 2023
Hides debug/experimental backends by default.

Before:
```
torch._dynamo.list_backends()
['aot_eager', 'aot_eager_decomp_partition', 'aot_torchxla_trace_once', 'aot_torchxla_trivial', 'aot_ts', 'aot_ts_nvfuser', 'cudagraphs', 'dynamo_accuracy_minifier_backend', 'dynamo_minifier_backend', 'eager', 'inductor', 'ipex', 'nvprims_aten', 'nvprims_nvfuser', 'onnxrt', 'tensorrt', 'torchxla_trace_once', 'torchxla_trivial', 'ts', 'tvm']
```

After:
```
torch._dynamo.list_backends()
['aot_ts_nvfuser', 'cudagraphs', 'inductor', 'ipex', 'nvprims_nvfuser', 'onnxrt', 'tensorrt', 'tvm']
```

Fixes #93733

ghstack-source-id: 357c262
Pull Request resolved: #93878
Hides debug/experimental backends by default.

Before:
```
torch._dynamo.list_backends()
['aot_eager', 'aot_eager_decomp_partition', 'aot_torchxla_trace_once', 'aot_torchxla_trivial', 'aot_ts', 'aot_ts_nvfuser', 'cudagraphs', 'dynamo_accuracy_minifier_backend', 'dynamo_minifier_backend', 'eager', 'inductor', 'ipex', 'nvprims_aten', 'nvprims_nvfuser', 'onnxrt', 'tensorrt', 'torchxla_trace_once', 'torchxla_trivial', 'ts', 'tvm']
```

After:
```
torch._dynamo.list_backends()
['aot_ts_nvfuser', 'cudagraphs', 'inductor', 'ipex', 'nvprims_nvfuser', 'onnxrt', 'tensorrt', 'tvm']
```

Fixes #93733

cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
jansel added a commit that referenced this pull request Feb 3, 2023
Hides debug/experimental backends by default.

Before:
```
torch._dynamo.list_backends()
['aot_eager', 'aot_eager_decomp_partition', 'aot_torchxla_trace_once', 'aot_torchxla_trivial', 'aot_ts', 'aot_ts_nvfuser', 'cudagraphs', 'dynamo_accuracy_minifier_backend', 'dynamo_minifier_backend', 'eager', 'inductor', 'ipex', 'nvprims_aten', 'nvprims_nvfuser', 'onnxrt', 'tensorrt', 'torchxla_trace_once', 'torchxla_trivial', 'ts', 'tvm']
```

After:
```
torch._dynamo.list_backends()
['aot_ts_nvfuser', 'cudagraphs', 'inductor', 'ipex', 'nvprims_nvfuser', 'onnxrt', 'tensorrt', 'tvm']
```

Fixes #93733

ghstack-source-id: fb09e1f
Pull Request resolved: #93878
@jansel jansel added the ciflow/trunk Trigger trunk jobs on your pull request label Feb 3, 2023
@jansel
Copy link
Copy Markdown
Contributor Author

jansel commented Feb 3, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@JackCaoG
Copy link
Copy Markdown
Collaborator

Hi @jansel What's standard of being an non-debug backend? Last time I check we can run 70% of the torchbench model without crash and most of them with reasonable->great performance.

@jansel
Copy link
Copy Markdown
Contributor Author

jansel commented Feb 15, 2023

Which backend are you talking about?

Debug backends are intended for debugging (for example "eager" and "aot_eager"), they don't make your program go faster -- but mainly intended to debug issues.

@JackCaoG
Copy link
Copy Markdown
Collaborator

I am talking about xla backends, so torchxla_trace_once and aot_torchxla_trace_once. torchxla_trivial and aot_torchxla_trivial are only for debugging purpose so they should stay in the debug backends.

@jansel
Copy link
Copy Markdown
Contributor Author

jansel commented Feb 15, 2023

@JackCaoG the XLA ones are tagged "experimental" not "debug". Feel free to submit a PR to reclassify them as debug if you think that is better or remove any ones that aren't useful anymore.

For torchxla_trivial:

@register_backend
def torchxla_trivial(gm, fake_tensor_inputs):
return gm

There is a potential "gotcha" where if the user only does:

torch.compile(backend="torchxla_trivial")

they won't actually get XLA, since the backend is identical to the "eager" backend. There are some extra XLA-related incantations the user must do to actually get XLA. I didn't want users seeing that in the list, and naively trying to turn it on.

I'd suggest adding at least something like:

if any(x.device.type != "xla" for x in inputs):
   raise RuntimeError(... link to docs explaining how to use the XLA backend ...)

I mainly marked them experimental to try to avoid that user gotcha, not making some larger statement about the state of XLA+Dynamo.

Feel free to submit PRs to change the XLA backends, remove them, or reclassify them.

cc @alanwaketan

@facebook-github-bot facebook-github-bot deleted the gh/jansel/40/head branch June 8, 2023 17:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants