Skip to content

[pt2 bug bash] Fix nn.functional.pad compile crash with deterministic mode + replication padding#177166

Closed
ydwu4 wants to merge 4 commits intogh/ydwu4/389/basefrom
gh/ydwu4/389/head
Closed

[pt2 bug bash] Fix nn.functional.pad compile crash with deterministic mode + replication padding#177166
ydwu4 wants to merge 4 commits intogh/ydwu4/389/basefrom
gh/ydwu4/389/head

Conversation

@ydwu4
Copy link
Copy Markdown
Contributor

@ydwu4 ydwu4 commented Mar 11, 2026

Stack from ghstack (oldest at bottom):

Fixes #170079

Context

torch.compile(ReplicationPad1d(...), fullgraph=True) crashes when
torch.use_deterministic_algorithms(True) is set on CUDA. The error: Dynamo can't trace
through importlib.import_module.

The deterministic code path exists because the native replication_pad1d_backward CUDA
kernel uses atomicAdd (non-deterministic). functional.py calls _replication_pad — a
Python decomposition using _unsafe_index, whose backward uses index_put (deterministic).

Dynamo limitations encountered

Three separate Dynamo tracing barriers prevented calling _replication_pad directly:

1. importlib.import_module is marked as skipped

@torch.compile(fullgraph=True)
def fn(x):
    import importlib
    return importlib.import_module("torch").sin(x)
fn(torch.randn(3))  # Unsupported: function marked as skipped

2. elementwise_dtypes returns non-Tensor (from @pw_cast_for_opmath)

@torch.compile(fullgraph=True)
def fn(x):
    from torch._prims_common import elementwise_dtypes, ELEMENTWISE_TYPE_PROMOTION_KIND
    dt, _ = elementwise_dtypes(x, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
    return x.to(dt)
fn(torch.randn(3))  # Unsupported: torch.* op returned non-Tensor

3. torch._check with closure lambda

@torch.compile(fullgraph=True)
def fn(x):
    dim = x.dim()
    torch._check(dim in (2, 3), lambda: f"expected 2D or 3D, got {dim}D")
    return x + 1
fn(torch.randn(3, 3))  # Unsupported: Can't extract message from torch._check()

Iteration log

# Approach Who Tests Reviewer pushback Why it failed
1 Replace importlib with from...import Claude bilinear/trilinear pass, replicate fails "why do we need bilinear/trilinear tests?" — scoped fix to reported bug only Hit limitation #2: @pw_cast_for_opmath
2 Skip decomposition under compile via is_compiling(), rely on AOTAutograd's @register_decomposition Claude forward-only backend="eager" passes "can you verify at inductor level this is actually deterministic?" — inspect AOT graph No backward decomposition registered; backward still uses native replication_pad1d_backward (non-deterministic)
3 Unwrap @pw_cast_for_opmath via __wrapped__ Claude N/A — fails immediately N/A Hit limitation #3: torch._check() closure
4 @nonstrict_trace — Dynamo skips body, AOTAutograd traces through Reviewer suggestion backend="aot_eager", forward + backward under DeterministicGuard(True) N/A — fix is correct N/A

Key insight

The fix isn't about making Dynamo trace the decomposition or skipping it entirely — it's
about putting the boundary in the right place. Dynamo doesn't need to see inside; AOTAutograd
does. @nonstrict_trace is exactly this boundary.

Each "obvious" fix had passing tests that weren't testing the right thing. Only when the
reviewer pushed for backward determinism verification and AOT graph inspection did the
weaknesses surface. The backward completing without error under DeterministicGuard(True)
proves determinism — PyTorch explicitly raises RuntimeError if any non-deterministic CUDA
kernel executes under this mode.

Authored with Claude.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @kadeng @chauhang @amjames @Lucaskabela @jataylo

… mode + replication padding

Fixes #170079

## Context

`torch.compile(ReplicationPad1d(...), fullgraph=True)` crashes when
`torch.use_deterministic_algorithms(True)` is set on CUDA. The error: Dynamo can't trace
through `importlib.import_module`.

The deterministic code path exists because the native `replication_pad1d_backward` CUDA
kernel uses `atomicAdd` (non-deterministic). `functional.py` calls `_replication_pad` — a
Python decomposition using `_unsafe_index`, whose backward uses `index_put` (deterministic).

## Dynamo limitations encountered

Three separate Dynamo tracing barriers prevented calling `_replication_pad` directly:

### 1. `importlib.import_module` is marked as skipped

```python
@torch.compile(fullgraph=True)
def fn(x):
    import importlib
    return importlib.import_module("torch").sin(x)
fn(torch.randn(3))  # Unsupported: function marked as skipped
```

### 2. `elementwise_dtypes` returns non-Tensor (from `@pw_cast_for_opmath`)

```python
@torch.compile(fullgraph=True)
def fn(x):
    from torch._prims_common import elementwise_dtypes, ELEMENTWISE_TYPE_PROMOTION_KIND
    dt, _ = elementwise_dtypes(x, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
    return x.to(dt)
fn(torch.randn(3))  # Unsupported: torch.* op returned non-Tensor
```

### 3. `torch._check` with closure lambda

```python
@torch.compile(fullgraph=True)
def fn(x):
    dim = x.dim()
    torch._check(dim in (2, 3), lambda: f"expected 2D or 3D, got {dim}D")
    return x + 1
fn(torch.randn(3, 3))  # Unsupported: Can't extract message from torch._check()
```

## Iteration log

| # | Approach | Who | Tests | Reviewer pushback | Why it failed |
|---|----------|-----|-------|-------------------|---------------|
| 1 | Replace `importlib` with `from...import` | Claude | bilinear/trilinear pass, replicate fails | "why do we need bilinear/trilinear tests?" — scoped fix to reported bug only | Hit limitation #2: `@pw_cast_for_opmath` |
| 2 | Skip decomposition under compile via `is_compiling()`, rely on AOTAutograd's `@register_decomposition` | Claude | forward-only `backend="eager"` passes | "can you verify at inductor level this is actually deterministic?" — inspect AOT graph | No backward decomposition registered; backward still uses native `replication_pad1d_backward` (non-deterministic) |
| 3 | Unwrap `@pw_cast_for_opmath` via `__wrapped__` | Claude | N/A — fails immediately | N/A | Hit limitation #3: `torch._check()` closure |
| 4 | `@nonstrict_trace` — Dynamo skips body, AOTAutograd traces through | Reviewer suggestion | `backend="aot_eager"`, forward + backward under `DeterministicGuard(True)` | N/A — fix is correct | N/A |

## Key insight

The fix isn't about making Dynamo trace the decomposition or skipping it entirely — it's
about putting the boundary in the right place. Dynamo doesn't need to see inside; AOTAutograd
does. `@nonstrict_trace` is exactly this boundary.

Each "obvious" fix had passing tests that weren't testing the right thing. Only when the
reviewer pushed for backward determinism verification and AOT graph inspection did the
weaknesses surface. The backward completing without error under `DeterministicGuard(True)`
proves determinism — PyTorch explicitly raises `RuntimeError` if any non-deterministic CUDA
kernel executes under this mode.

Authored with Claude.

[ghstack-poisoned]
ydwu4 added a commit that referenced this pull request Mar 11, 2026
… mode + replication padding

Fixes #170079

## Context

`torch.compile(ReplicationPad1d(...), fullgraph=True)` crashes when
`torch.use_deterministic_algorithms(True)` is set on CUDA. The error: Dynamo can't trace
through `importlib.import_module`.

The deterministic code path exists because the native `replication_pad1d_backward` CUDA
kernel uses `atomicAdd` (non-deterministic). `functional.py` calls `_replication_pad` — a
Python decomposition using `_unsafe_index`, whose backward uses `index_put` (deterministic).

## Dynamo limitations encountered

Three separate Dynamo tracing barriers prevented calling `_replication_pad` directly:

### 1. `importlib.import_module` is marked as skipped

```python
torch.compile(fullgraph=True)
def fn(x):
    import importlib
    return importlib.import_module("torch").sin(x)
fn(torch.randn(3))  # Unsupported: function marked as skipped
```

### 2. `elementwise_dtypes` returns non-Tensor (from `pw_cast_for_opmath`)

```python
torch.compile(fullgraph=True)
def fn(x):
    from torch._prims_common import elementwise_dtypes, ELEMENTWISE_TYPE_PROMOTION_KIND
    dt, _ = elementwise_dtypes(x, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
    return x.to(dt)
fn(torch.randn(3))  # Unsupported: torch.* op returned non-Tensor
```

### 3. `torch._check` with closure lambda

```python
torch.compile(fullgraph=True)
def fn(x):
    dim = x.dim()
    torch._check(dim in (2, 3), lambda: f"expected 2D or 3D, got {dim}D")
    return x + 1
fn(torch.randn(3, 3))  # Unsupported: Can't extract message from torch._check()
```

## Iteration log

| # | Approach | Who | Tests | Reviewer pushback | Why it failed |
|---|----------|-----|-------|-------------------|---------------|
| 1 | Replace `importlib` with `from...import` | Claude | bilinear/trilinear pass, replicate fails | "why do we need bilinear/trilinear tests?" — scoped fix to reported bug only | Hit limitation #2: `pw_cast_for_opmath` |
| 2 | Skip decomposition under compile via `is_compiling()`, rely on AOTAutograd's `register_decomposition` | Claude | forward-only `backend="eager"` passes | "can you verify at inductor level this is actually deterministic?" — inspect AOT graph | No backward decomposition registered; backward still uses native `replication_pad1d_backward` (non-deterministic) |
| 3 | Unwrap `pw_cast_for_opmath` via `__wrapped__` | Claude | N/A — fails immediately | N/A | Hit limitation #3: `torch._check()` closure |
| 4 | `nonstrict_trace` — Dynamo skips body, AOTAutograd traces through | Reviewer suggestion | `backend="aot_eager"`, forward + backward under `DeterministicGuard(True)` | N/A — fix is correct | N/A |

## Key insight

The fix isn't about making Dynamo trace the decomposition or skipping it entirely — it's
about putting the boundary in the right place. Dynamo doesn't need to see inside; AOTAutograd
does. `nonstrict_trace` is exactly this boundary.

Each "obvious" fix had passing tests that weren't testing the right thing. Only when the
reviewer pushed for backward determinism verification and AOT graph inspection did the
weaknesses surface. The backward completing without error under `DeterministicGuard(True)`
proves determinism — PyTorch explicitly raises `RuntimeError` if any non-deterministic CUDA
kernel executes under this mode.

Authored with Claude.

ghstack-source-id: f488144
Pull Request resolved: #177166
…terministic mode + replication padding"

Fixes #170079

## Context

`torch.compile(ReplicationPad1d(...), fullgraph=True)` crashes when
`torch.use_deterministic_algorithms(True)` is set on CUDA. The error: Dynamo can't trace
through `importlib.import_module`.

The deterministic code path exists because the native `replication_pad1d_backward` CUDA
kernel uses `atomicAdd` (non-deterministic). `functional.py` calls `_replication_pad` — a
Python decomposition using `_unsafe_index`, whose backward uses `index_put` (deterministic).

## Dynamo limitations encountered

Three separate Dynamo tracing barriers prevented calling `_replication_pad` directly:

### 1. `importlib.import_module` is marked as skipped

```python
torch.compile(fullgraph=True)
def fn(x):
    import importlib
    return importlib.import_module("torch").sin(x)
fn(torch.randn(3))  # Unsupported: function marked as skipped
```

### 2. `elementwise_dtypes` returns non-Tensor (from `pw_cast_for_opmath`)

```python
torch.compile(fullgraph=True)
def fn(x):
    from torch._prims_common import elementwise_dtypes, ELEMENTWISE_TYPE_PROMOTION_KIND
    dt, _ = elementwise_dtypes(x, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
    return x.to(dt)
fn(torch.randn(3))  # Unsupported: torch.* op returned non-Tensor
```

### 3. `torch._check` with closure lambda

```python
torch.compile(fullgraph=True)
def fn(x):
    dim = x.dim()
    torch._check(dim in (2, 3), lambda: f"expected 2D or 3D, got {dim}D")
    return x + 1
fn(torch.randn(3, 3))  # Unsupported: Can't extract message from torch._check()
```

## Iteration log

| # | Approach | Who | Tests | Reviewer pushback | Why it failed |
|---|----------|-----|-------|-------------------|---------------|
| 1 | Replace `importlib` with `from...import` | Claude | bilinear/trilinear pass, replicate fails | "why do we need bilinear/trilinear tests?" — scoped fix to reported bug only | Hit limitation #2: `pw_cast_for_opmath` |
| 2 | Skip decomposition under compile via `is_compiling()`, rely on AOTAutograd's `register_decomposition` | Claude | forward-only `backend="eager"` passes | "can you verify at inductor level this is actually deterministic?" — inspect AOT graph | No backward decomposition registered; backward still uses native `replication_pad1d_backward` (non-deterministic) |
| 3 | Unwrap `pw_cast_for_opmath` via `__wrapped__` | Claude | N/A — fails immediately | N/A | Hit limitation #3: `torch._check()` closure |
| 4 | `nonstrict_trace` — Dynamo skips body, AOTAutograd traces through | Reviewer suggestion | `backend="aot_eager"`, forward + backward under `DeterministicGuard(True)` | N/A — fix is correct | N/A |

## Key insight

The fix isn't about making Dynamo trace the decomposition or skipping it entirely — it's
about putting the boundary in the right place. Dynamo doesn't need to see inside; AOTAutograd
does. `nonstrict_trace` is exactly this boundary.

Each "obvious" fix had passing tests that weren't testing the right thing. Only when the
reviewer pushed for backward determinism verification and AOT graph inspection did the
weaknesses surface. The backward completing without error under `DeterministicGuard(True)`
proves determinism — PyTorch explicitly raises `RuntimeError` if any non-deterministic CUDA
kernel executes under this mode.

Authored with Claude.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 11, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/177166

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 3fb1255 with merge base 8e7898a (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

ydwu4 added a commit that referenced this pull request Mar 11, 2026
… mode + replication padding

Fixes #170079

## Context

`torch.compile(ReplicationPad1d(...), fullgraph=True)` crashes when
`torch.use_deterministic_algorithms(True)` is set on CUDA. The error: Dynamo can't trace
through `importlib.import_module`.

The deterministic code path exists because the native `replication_pad1d_backward` CUDA
kernel uses `atomicAdd` (non-deterministic). `functional.py` calls `_replication_pad` — a
Python decomposition using `_unsafe_index`, whose backward uses `index_put` (deterministic).

## Dynamo limitations encountered

Three separate Dynamo tracing barriers prevented calling `_replication_pad` directly:

### 1. `importlib.import_module` is marked as skipped

```python
torch.compile(fullgraph=True)
def fn(x):
    import importlib
    return importlib.import_module("torch").sin(x)
fn(torch.randn(3))  # Unsupported: function marked as skipped
```

### 2. `elementwise_dtypes` returns non-Tensor (from `pw_cast_for_opmath`)

```python
torch.compile(fullgraph=True)
def fn(x):
    from torch._prims_common import elementwise_dtypes, ELEMENTWISE_TYPE_PROMOTION_KIND
    dt, _ = elementwise_dtypes(x, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
    return x.to(dt)
fn(torch.randn(3))  # Unsupported: torch.* op returned non-Tensor
```

### 3. `torch._check` with closure lambda

```python
torch.compile(fullgraph=True)
def fn(x):
    dim = x.dim()
    torch._check(dim in (2, 3), lambda: f"expected 2D or 3D, got {dim}D")
    return x + 1
fn(torch.randn(3, 3))  # Unsupported: Can't extract message from torch._check()
```

## Iteration log

| # | Approach | Who | Tests | Reviewer pushback | Why it failed |
|---|----------|-----|-------|-------------------|---------------|
| 1 | Replace `importlib` with `from...import` | Claude | bilinear/trilinear pass, replicate fails | "why do we need bilinear/trilinear tests?" — scoped fix to reported bug only | Hit limitation #2: `pw_cast_for_opmath` |
| 2 | Skip decomposition under compile via `is_compiling()`, rely on AOTAutograd's `register_decomposition` | Claude | forward-only `backend="eager"` passes | "can you verify at inductor level this is actually deterministic?" — inspect AOT graph | No backward decomposition registered; backward still uses native `replication_pad1d_backward` (non-deterministic) |
| 3 | Unwrap `pw_cast_for_opmath` via `__wrapped__` | Claude | N/A — fails immediately | N/A | Hit limitation #3: `torch._check()` closure |
| 4 | `nonstrict_trace` — Dynamo skips body, AOTAutograd traces through | Reviewer suggestion | `backend="aot_eager"`, forward + backward under `DeterministicGuard(True)` | N/A — fix is correct | N/A |

## Key insight

The fix isn't about making Dynamo trace the decomposition or skipping it entirely — it's
about putting the boundary in the right place. Dynamo doesn't need to see inside; AOTAutograd
does. `nonstrict_trace` is exactly this boundary.

Each "obvious" fix had passing tests that weren't testing the right thing. Only when the
reviewer pushed for backward determinism verification and AOT graph inspection did the
weaknesses surface. The backward completing without error under `DeterministicGuard(True)`
proves determinism — PyTorch explicitly raises `RuntimeError` if any non-deterministic CUDA
kernel executes under this mode.

Authored with Claude.

ghstack-source-id: f640706
Pull Request resolved: #177166
@ydwu4 ydwu4 added the topic: not user facing topic category label Mar 11, 2026
@github-actions github-actions bot deleted a comment from pytorch-bot bot Mar 11, 2026
…terministic mode + replication padding"

Fixes #170079

## Context

`torch.compile(ReplicationPad1d(...), fullgraph=True)` crashes when
`torch.use_deterministic_algorithms(True)` is set on CUDA. The error: Dynamo can't trace
through `importlib.import_module`.

The deterministic code path exists because the native `replication_pad1d_backward` CUDA
kernel uses `atomicAdd` (non-deterministic). `functional.py` calls `_replication_pad` — a
Python decomposition using `_unsafe_index`, whose backward uses `index_put` (deterministic).

## Dynamo limitations encountered

Three separate Dynamo tracing barriers prevented calling `_replication_pad` directly:

### 1. `importlib.import_module` is marked as skipped

```python
torch.compile(fullgraph=True)
def fn(x):
    import importlib
    return importlib.import_module("torch").sin(x)
fn(torch.randn(3))  # Unsupported: function marked as skipped
```

### 2. `elementwise_dtypes` returns non-Tensor (from `pw_cast_for_opmath`)

```python
torch.compile(fullgraph=True)
def fn(x):
    from torch._prims_common import elementwise_dtypes, ELEMENTWISE_TYPE_PROMOTION_KIND
    dt, _ = elementwise_dtypes(x, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
    return x.to(dt)
fn(torch.randn(3))  # Unsupported: torch.* op returned non-Tensor
```

### 3. `torch._check` with closure lambda

```python
torch.compile(fullgraph=True)
def fn(x):
    dim = x.dim()
    torch._check(dim in (2, 3), lambda: f"expected 2D or 3D, got {dim}D")
    return x + 1
fn(torch.randn(3, 3))  # Unsupported: Can't extract message from torch._check()
```

## Iteration log

| # | Approach | Who | Tests | Reviewer pushback | Why it failed |
|---|----------|-----|-------|-------------------|---------------|
| 1 | Replace `importlib` with `from...import` | Claude | bilinear/trilinear pass, replicate fails | "why do we need bilinear/trilinear tests?" — scoped fix to reported bug only | Hit limitation #2: `pw_cast_for_opmath` |
| 2 | Skip decomposition under compile via `is_compiling()`, rely on AOTAutograd's `register_decomposition` | Claude | forward-only `backend="eager"` passes | "can you verify at inductor level this is actually deterministic?" — inspect AOT graph | No backward decomposition registered; backward still uses native `replication_pad1d_backward` (non-deterministic) |
| 3 | Unwrap `pw_cast_for_opmath` via `__wrapped__` | Claude | N/A — fails immediately | N/A | Hit limitation #3: `torch._check()` closure |
| 4 | `nonstrict_trace` — Dynamo skips body, AOTAutograd traces through | Reviewer suggestion | `backend="aot_eager"`, forward + backward under `DeterministicGuard(True)` | N/A — fix is correct | N/A |

## Key insight

The fix isn't about making Dynamo trace the decomposition or skipping it entirely — it's
about putting the boundary in the right place. Dynamo doesn't need to see inside; AOTAutograd
does. `nonstrict_trace` is exactly this boundary.

Each "obvious" fix had passing tests that weren't testing the right thing. Only when the
reviewer pushed for backward determinism verification and AOT graph inspection did the
weaknesses surface. The backward completing without error under `DeterministicGuard(True)`
proves determinism — PyTorch explicitly raises `RuntimeError` if any non-deterministic CUDA
kernel executes under this mode.

Authored with Claude.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
ydwu4 added a commit that referenced this pull request Mar 11, 2026
… mode + replication padding

Fixes #170079

`torch.compile(ReplicationPad1d(...), fullgraph=True)` crashes when
`torch.use_deterministic_algorithms(True)` is set on CUDA. The error: Dynamo can't trace
through `importlib.import_module`.

The deterministic code path exists because the native `replication_pad1d_backward` CUDA
kernel uses `atomicAdd` (non-deterministic). `functional.py` calls `_replication_pad` — a
Python decomposition using `_unsafe_index`, whose backward uses `index_put` (deterministic).

Three separate Dynamo tracing barriers prevented calling `_replication_pad` directly:

```python
torch.compile(fullgraph=True)
def fn(x):
    import importlib
    return importlib.import_module("torch").sin(x)
fn(torch.randn(3))  # Unsupported: function marked as skipped
```

```python
torch.compile(fullgraph=True)
def fn(x):
    from torch._prims_common import elementwise_dtypes, ELEMENTWISE_TYPE_PROMOTION_KIND
    dt, _ = elementwise_dtypes(x, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
    return x.to(dt)
fn(torch.randn(3))  # Unsupported: torch.* op returned non-Tensor
```

```python
torch.compile(fullgraph=True)
def fn(x):
    dim = x.dim()
    torch._check(dim in (2, 3), lambda: f"expected 2D or 3D, got {dim}D")
    return x + 1
fn(torch.randn(3, 3))  # Unsupported: Can't extract message from torch._check()
```

| # | Approach | Who | Tests | Reviewer pushback | Why it failed |
|---|----------|-----|-------|-------------------|---------------|
| 1 | Replace `importlib` with `from...import` | Claude | bilinear/trilinear pass, replicate fails | "why do we need bilinear/trilinear tests?" — scoped fix to reported bug only | Hit limitation #2: `pw_cast_for_opmath` |
| 2 | Skip decomposition under compile via `is_compiling()`, rely on AOTAutograd's `register_decomposition` | Claude | forward-only `backend="eager"` passes | "can you verify at inductor level this is actually deterministic?" — inspect AOT graph | No backward decomposition registered; backward still uses native `replication_pad1d_backward` (non-deterministic) |
| 3 | Unwrap `pw_cast_for_opmath` via `__wrapped__` | Claude | N/A — fails immediately | N/A | Hit limitation #3: `torch._check()` closure |
| 4 | `nonstrict_trace` — Dynamo skips body, AOTAutograd traces through | Reviewer suggestion | `backend="aot_eager"`, forward + backward under `DeterministicGuard(True)` | N/A — fix is correct | N/A |

The fix isn't about making Dynamo trace the decomposition or skipping it entirely — it's
about putting the boundary in the right place. Dynamo doesn't need to see inside; AOTAutograd
does. `nonstrict_trace` is exactly this boundary.

Each "obvious" fix had passing tests that weren't testing the right thing. Only when the
reviewer pushed for backward determinism verification and AOT graph inspection did the
weaknesses surface. The backward completing without error under `DeterministicGuard(True)`
proves determinism — PyTorch explicitly raises `RuntimeError` if any non-deterministic CUDA
kernel executes under this mode.

Authored with Claude.

ghstack-source-id: 9a2f0c8
Pull Request resolved: #177166
# nonstrict_trace makes Dynamo skip the function body
# (which contains Dynamo-untraceable code) while
# AOTAutograd still traces into it for the backward.
return torch._dynamo.decorators.nonstrict_trace(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious what's in the body that makes it non-dynamo traceable but AOT autograd traceable? usually it's the other way around.

Copy link
Copy Markdown
Contributor Author

@ydwu4 ydwu4 Mar 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the elementwise_dtypes returns non-Tensor (from @pw_cast_for_opmath) and torch._check with closure lambda, see the issue description for minimal repro! Potentially, we could make them working but seems need some additional efforts. Can do them as follow ups!

Copy link
Copy Markdown
Contributor

@mlazos mlazos left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems fine, just curious if we can make the code dynamo traceable

@ydwu4 ydwu4 added ciflow/trunk Trigger trunk jobs on your pull request labels Mar 16, 2026
@ydwu4
Copy link
Copy Markdown
Contributor Author

ydwu4 commented Mar 17, 2026

@pytorchbot merge

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge failed

Reason: Command git -C /home/runner/work/pytorch/pytorch cherry-pick -x b4c85e734fe30a6b953e1c14e5816f38d44cd54f returned non-zero exit code 1

Auto-merging test/dynamo/test_repros.py
CONFLICT (content): Merge conflict in test/dynamo/test_repros.py
error: could not apply b4c85e734fe... [pt2 bug bash] Fix nn.functional.pad compile crash with deterministic mode + replication padding
hint: After resolving the conflicts, mark them with
hint: "git add/rm <pathspec>", then run
hint: "git cherry-pick --continue".
hint: You can instead skip this commit with "git cherry-pick --skip".
hint: To abort and get back to the state before "git cherry-pick",
hint: run "git cherry-pick --abort".
hint: Disable this message with "git config set advice.mergeConflict false"
Details for Dev Infra team Raised by workflow job

…terministic mode + replication padding"

Fixes #170079

## Context

`torch.compile(ReplicationPad1d(...), fullgraph=True)` crashes when
`torch.use_deterministic_algorithms(True)` is set on CUDA. The error: Dynamo can't trace
through `importlib.import_module`.

The deterministic code path exists because the native `replication_pad1d_backward` CUDA
kernel uses `atomicAdd` (non-deterministic). `functional.py` calls `_replication_pad` — a
Python decomposition using `_unsafe_index`, whose backward uses `index_put` (deterministic).

## Dynamo limitations encountered

Three separate Dynamo tracing barriers prevented calling `_replication_pad` directly:

### 1. `importlib.import_module` is marked as skipped

```python
torch.compile(fullgraph=True)
def fn(x):
    import importlib
    return importlib.import_module("torch").sin(x)
fn(torch.randn(3))  # Unsupported: function marked as skipped
```

### 2. `elementwise_dtypes` returns non-Tensor (from `pw_cast_for_opmath`)

```python
torch.compile(fullgraph=True)
def fn(x):
    from torch._prims_common import elementwise_dtypes, ELEMENTWISE_TYPE_PROMOTION_KIND
    dt, _ = elementwise_dtypes(x, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
    return x.to(dt)
fn(torch.randn(3))  # Unsupported: torch.* op returned non-Tensor
```

### 3. `torch._check` with closure lambda

```python
torch.compile(fullgraph=True)
def fn(x):
    dim = x.dim()
    torch._check(dim in (2, 3), lambda: f"expected 2D or 3D, got {dim}D")
    return x + 1
fn(torch.randn(3, 3))  # Unsupported: Can't extract message from torch._check()
```

## Iteration log

| # | Approach | Who | Tests | Reviewer pushback | Why it failed |
|---|----------|-----|-------|-------------------|---------------|
| 1 | Replace `importlib` with `from...import` | Claude | bilinear/trilinear pass, replicate fails | "why do we need bilinear/trilinear tests?" — scoped fix to reported bug only | Hit limitation #2: `pw_cast_for_opmath` |
| 2 | Skip decomposition under compile via `is_compiling()`, rely on AOTAutograd's `register_decomposition` | Claude | forward-only `backend="eager"` passes | "can you verify at inductor level this is actually deterministic?" — inspect AOT graph | No backward decomposition registered; backward still uses native `replication_pad1d_backward` (non-deterministic) |
| 3 | Unwrap `pw_cast_for_opmath` via `__wrapped__` | Claude | N/A — fails immediately | N/A | Hit limitation #3: `torch._check()` closure |
| 4 | `nonstrict_trace` — Dynamo skips body, AOTAutograd traces through | Reviewer suggestion | `backend="aot_eager"`, forward + backward under `DeterministicGuard(True)` | N/A — fix is correct | N/A |

## Key insight

The fix isn't about making Dynamo trace the decomposition or skipping it entirely — it's
about putting the boundary in the right place. Dynamo doesn't need to see inside; AOTAutograd
does. `nonstrict_trace` is exactly this boundary.

Each "obvious" fix had passing tests that weren't testing the right thing. Only when the
reviewer pushed for backward determinism verification and AOT graph inspection did the
weaknesses surface. The backward completing without error under `DeterministicGuard(True)`
proves determinism — PyTorch explicitly raises `RuntimeError` if any non-deterministic CUDA
kernel executes under this mode.

Authored with Claude.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
ydwu4 added a commit that referenced this pull request Mar 18, 2026
… mode + replication padding

Fixes #170079

`torch.compile(ReplicationPad1d(...), fullgraph=True)` crashes when
`torch.use_deterministic_algorithms(True)` is set on CUDA. The error: Dynamo can't trace
through `importlib.import_module`.

The deterministic code path exists because the native `replication_pad1d_backward` CUDA
kernel uses `atomicAdd` (non-deterministic). `functional.py` calls `_replication_pad` — a
Python decomposition using `_unsafe_index`, whose backward uses `index_put` (deterministic).

Three separate Dynamo tracing barriers prevented calling `_replication_pad` directly:

```python
torch.compile(fullgraph=True)
def fn(x):
    import importlib
    return importlib.import_module("torch").sin(x)
fn(torch.randn(3))  # Unsupported: function marked as skipped
```

```python
torch.compile(fullgraph=True)
def fn(x):
    from torch._prims_common import elementwise_dtypes, ELEMENTWISE_TYPE_PROMOTION_KIND
    dt, _ = elementwise_dtypes(x, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
    return x.to(dt)
fn(torch.randn(3))  # Unsupported: torch.* op returned non-Tensor
```

```python
torch.compile(fullgraph=True)
def fn(x):
    dim = x.dim()
    torch._check(dim in (2, 3), lambda: f"expected 2D or 3D, got {dim}D")
    return x + 1
fn(torch.randn(3, 3))  # Unsupported: Can't extract message from torch._check()
```

| # | Approach | Who | Tests | Reviewer pushback | Why it failed |
|---|----------|-----|-------|-------------------|---------------|
| 1 | Replace `importlib` with `from...import` | Claude | bilinear/trilinear pass, replicate fails | "why do we need bilinear/trilinear tests?" — scoped fix to reported bug only | Hit limitation #2: `pw_cast_for_opmath` |
| 2 | Skip decomposition under compile via `is_compiling()`, rely on AOTAutograd's `register_decomposition` | Claude | forward-only `backend="eager"` passes | "can you verify at inductor level this is actually deterministic?" — inspect AOT graph | No backward decomposition registered; backward still uses native `replication_pad1d_backward` (non-deterministic) |
| 3 | Unwrap `pw_cast_for_opmath` via `__wrapped__` | Claude | N/A — fails immediately | N/A | Hit limitation #3: `torch._check()` closure |
| 4 | `nonstrict_trace` — Dynamo skips body, AOTAutograd traces through | Reviewer suggestion | `backend="aot_eager"`, forward + backward under `DeterministicGuard(True)` | N/A — fix is correct | N/A |

The fix isn't about making Dynamo trace the decomposition or skipping it entirely — it's
about putting the boundary in the right place. Dynamo doesn't need to see inside; AOTAutograd
does. `nonstrict_trace` is exactly this boundary.

Each "obvious" fix had passing tests that weren't testing the right thing. Only when the
reviewer pushed for backward determinism verification and AOT graph inspection did the
weaknesses surface. The backward completing without error under `DeterministicGuard(True)`
proves determinism — PyTorch explicitly raises `RuntimeError` if any non-deterministic CUDA
kernel executes under this mode.

Authored with Claude.

ghstack-source-id: 0e8021f
Pull Request resolved: #177166
@ydwu4
Copy link
Copy Markdown
Contributor Author

ydwu4 commented Mar 18, 2026

@pytorchbot merge

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

ryanzhang22 pushed a commit to ryanzhang22/pytorch that referenced this pull request Mar 19, 2026
… mode + replication padding (pytorch#177166)

Fixes pytorch#170079

## Context

`torch.compile(ReplicationPad1d(...), fullgraph=True)` crashes when
`torch.use_deterministic_algorithms(True)` is set on CUDA. The error: Dynamo can't trace
through `importlib.import_module`.

The deterministic code path exists because the native `replication_pad1d_backward` CUDA
kernel uses `atomicAdd` (non-deterministic). `functional.py` calls `_replication_pad` — a
Python decomposition using `_unsafe_index`, whose backward uses `index_put` (deterministic).

## Dynamo limitations encountered

Three separate Dynamo tracing barriers prevented calling `_replication_pad` directly:

### 1. `importlib.import_module` is marked as skipped

```python
@torch.compile(fullgraph=True)
def fn(x):
    import importlib
    return importlib.import_module("torch").sin(x)
fn(torch.randn(3))  # Unsupported: function marked as skipped
```

### 2. `elementwise_dtypes` returns non-Tensor (from `@pw_cast_for_opmath`)

```python
@torch.compile(fullgraph=True)
def fn(x):
    from torch._prims_common import elementwise_dtypes, ELEMENTWISE_TYPE_PROMOTION_KIND
    dt, _ = elementwise_dtypes(x, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
    return x.to(dt)
fn(torch.randn(3))  # Unsupported: torch.* op returned non-Tensor
```

### 3. `torch._check` with closure lambda

```python
@torch.compile(fullgraph=True)
def fn(x):
    dim = x.dim()
    torch._check(dim in (2, 3), lambda: f"expected 2D or 3D, got {dim}D")
    return x + 1
fn(torch.randn(3, 3))  # Unsupported: Can't extract message from torch._check()
```

## Iteration log

| # | Approach | Who | Tests | Reviewer pushback | Why it failed |
|---|----------|-----|-------|-------------------|---------------|
| 1 | Replace `importlib` with `from...import` | Claude | bilinear/trilinear pass, replicate fails | "why do we need bilinear/trilinear tests?" — scoped fix to reported bug only | Hit limitation pytorch#2: `@pw_cast_for_opmath` |
| 2 | Skip decomposition under compile via `is_compiling()`, rely on AOTAutograd's `@register_decomposition` | Claude | forward-only `backend="eager"` passes | "can you verify at inductor level this is actually deterministic?" — inspect AOT graph | No backward decomposition registered; backward still uses native `replication_pad1d_backward` (non-deterministic) |
| 3 | Unwrap `@pw_cast_for_opmath` via `__wrapped__` | Claude | N/A — fails immediately | N/A | Hit limitation pytorch#3: `torch._check()` closure |
| 4 | `@nonstrict_trace` — Dynamo skips body, AOTAutograd traces through | Reviewer suggestion | `backend="aot_eager"`, forward + backward under `DeterministicGuard(True)` | N/A — fix is correct | N/A |

## Key insight

The fix isn't about making Dynamo trace the decomposition or skipping it entirely — it's
about putting the boundary in the right place. Dynamo doesn't need to see inside; AOTAutograd
does. `@nonstrict_trace` is exactly this boundary.

Each "obvious" fix had passing tests that weren't testing the right thing. Only when the
reviewer pushed for backward determinism verification and AOT graph inspection did the
weaknesses surface. The backward completing without error under `DeterministicGuard(True)`
proves determinism — PyTorch explicitly raises `RuntimeError` if any non-deterministic CUDA
kernel executes under this mode.

Authored with Claude.

Pull Request resolved: pytorch#177166
Approved by: https://github.com/mlazos, https://github.com/williamwen42
EmanueleCoradin pushed a commit to EmanueleCoradin/pytorch that referenced this pull request Mar 30, 2026
… mode + replication padding (pytorch#177166)

Fixes pytorch#170079

## Context

`torch.compile(ReplicationPad1d(...), fullgraph=True)` crashes when
`torch.use_deterministic_algorithms(True)` is set on CUDA. The error: Dynamo can't trace
through `importlib.import_module`.

The deterministic code path exists because the native `replication_pad1d_backward` CUDA
kernel uses `atomicAdd` (non-deterministic). `functional.py` calls `_replication_pad` — a
Python decomposition using `_unsafe_index`, whose backward uses `index_put` (deterministic).

## Dynamo limitations encountered

Three separate Dynamo tracing barriers prevented calling `_replication_pad` directly:

### 1. `importlib.import_module` is marked as skipped

```python
@torch.compile(fullgraph=True)
def fn(x):
    import importlib
    return importlib.import_module("torch").sin(x)
fn(torch.randn(3))  # Unsupported: function marked as skipped
```

### 2. `elementwise_dtypes` returns non-Tensor (from `@pw_cast_for_opmath`)

```python
@torch.compile(fullgraph=True)
def fn(x):
    from torch._prims_common import elementwise_dtypes, ELEMENTWISE_TYPE_PROMOTION_KIND
    dt, _ = elementwise_dtypes(x, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
    return x.to(dt)
fn(torch.randn(3))  # Unsupported: torch.* op returned non-Tensor
```

### 3. `torch._check` with closure lambda

```python
@torch.compile(fullgraph=True)
def fn(x):
    dim = x.dim()
    torch._check(dim in (2, 3), lambda: f"expected 2D or 3D, got {dim}D")
    return x + 1
fn(torch.randn(3, 3))  # Unsupported: Can't extract message from torch._check()
```

## Iteration log

| # | Approach | Who | Tests | Reviewer pushback | Why it failed |
|---|----------|-----|-------|-------------------|---------------|
| 1 | Replace `importlib` with `from...import` | Claude | bilinear/trilinear pass, replicate fails | "why do we need bilinear/trilinear tests?" — scoped fix to reported bug only | Hit limitation pytorch#2: `@pw_cast_for_opmath` |
| 2 | Skip decomposition under compile via `is_compiling()`, rely on AOTAutograd's `@register_decomposition` | Claude | forward-only `backend="eager"` passes | "can you verify at inductor level this is actually deterministic?" — inspect AOT graph | No backward decomposition registered; backward still uses native `replication_pad1d_backward` (non-deterministic) |
| 3 | Unwrap `@pw_cast_for_opmath` via `__wrapped__` | Claude | N/A — fails immediately | N/A | Hit limitation pytorch#3: `torch._check()` closure |
| 4 | `@nonstrict_trace` — Dynamo skips body, AOTAutograd traces through | Reviewer suggestion | `backend="aot_eager"`, forward + backward under `DeterministicGuard(True)` | N/A — fix is correct | N/A |

## Key insight

The fix isn't about making Dynamo trace the decomposition or skipping it entirely — it's
about putting the boundary in the right place. Dynamo doesn't need to see inside; AOTAutograd
does. `@nonstrict_trace` is exactly this boundary.

Each "obvious" fix had passing tests that weren't testing the right thing. Only when the
reviewer pushed for backward determinism verification and AOT graph inspection did the
weaknesses surface. The backward completing without error under `DeterministicGuard(True)`
proves determinism — PyTorch explicitly raises `RuntimeError` if any non-deterministic CUDA
kernel executes under this mode.

Authored with Claude.

Pull Request resolved: pytorch#177166
Approved by: https://github.com/mlazos, https://github.com/williamwen42
AaronWang04 pushed a commit to AaronWang04/pytorch that referenced this pull request Mar 31, 2026
… mode + replication padding (pytorch#177166)

Fixes pytorch#170079

## Context

`torch.compile(ReplicationPad1d(...), fullgraph=True)` crashes when
`torch.use_deterministic_algorithms(True)` is set on CUDA. The error: Dynamo can't trace
through `importlib.import_module`.

The deterministic code path exists because the native `replication_pad1d_backward` CUDA
kernel uses `atomicAdd` (non-deterministic). `functional.py` calls `_replication_pad` — a
Python decomposition using `_unsafe_index`, whose backward uses `index_put` (deterministic).

## Dynamo limitations encountered

Three separate Dynamo tracing barriers prevented calling `_replication_pad` directly:

### 1. `importlib.import_module` is marked as skipped

```python
@torch.compile(fullgraph=True)
def fn(x):
    import importlib
    return importlib.import_module("torch").sin(x)
fn(torch.randn(3))  # Unsupported: function marked as skipped
```

### 2. `elementwise_dtypes` returns non-Tensor (from `@pw_cast_for_opmath`)

```python
@torch.compile(fullgraph=True)
def fn(x):
    from torch._prims_common import elementwise_dtypes, ELEMENTWISE_TYPE_PROMOTION_KIND
    dt, _ = elementwise_dtypes(x, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
    return x.to(dt)
fn(torch.randn(3))  # Unsupported: torch.* op returned non-Tensor
```

### 3. `torch._check` with closure lambda

```python
@torch.compile(fullgraph=True)
def fn(x):
    dim = x.dim()
    torch._check(dim in (2, 3), lambda: f"expected 2D or 3D, got {dim}D")
    return x + 1
fn(torch.randn(3, 3))  # Unsupported: Can't extract message from torch._check()
```

## Iteration log

| # | Approach | Who | Tests | Reviewer pushback | Why it failed |
|---|----------|-----|-------|-------------------|---------------|
| 1 | Replace `importlib` with `from...import` | Claude | bilinear/trilinear pass, replicate fails | "why do we need bilinear/trilinear tests?" — scoped fix to reported bug only | Hit limitation pytorch#2: `@pw_cast_for_opmath` |
| 2 | Skip decomposition under compile via `is_compiling()`, rely on AOTAutograd's `@register_decomposition` | Claude | forward-only `backend="eager"` passes | "can you verify at inductor level this is actually deterministic?" — inspect AOT graph | No backward decomposition registered; backward still uses native `replication_pad1d_backward` (non-deterministic) |
| 3 | Unwrap `@pw_cast_for_opmath` via `__wrapped__` | Claude | N/A — fails immediately | N/A | Hit limitation pytorch#3: `torch._check()` closure |
| 4 | `@nonstrict_trace` — Dynamo skips body, AOTAutograd traces through | Reviewer suggestion | `backend="aot_eager"`, forward + backward under `DeterministicGuard(True)` | N/A — fix is correct | N/A |

## Key insight

The fix isn't about making Dynamo trace the decomposition or skipping it entirely — it's
about putting the boundary in the right place. Dynamo doesn't need to see inside; AOTAutograd
does. `@nonstrict_trace` is exactly this boundary.

Each "obvious" fix had passing tests that weren't testing the right thing. Only when the
reviewer pushed for backward determinism verification and AOT graph inspection did the
weaknesses surface. The backward completing without error under `DeterministicGuard(True)`
proves determinism — PyTorch explicitly raises `RuntimeError` if any non-deterministic CUDA
kernel executes under this mode.

Authored with Claude.

Pull Request resolved: pytorch#177166
Approved by: https://github.com/mlazos, https://github.com/williamwen42
ydwu4 added a commit that referenced this pull request Apr 2, 2026
…entwise_dtypes during tracing"


Fix issue 2 discovered in #177166.

elementwise_dtypes was registered as TorchInGraphFunctionVariable via
torch._higher_order_ops.out_dtype, causing dynamo to try putting it in the
FX graph. Since it returns (dtype, dtype) rather than tensors, this failed
with "torch.* op returned non-Tensor". The fix adds a handler that evaluates
elementwise_dtypes eagerly on fake tensor metadata during compilation and
returns the result as constants.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo azahed98

[ghstack-poisoned]
ydwu4 added a commit that referenced this pull request Apr 2, 2026
…uring tracing"


Fix issue 2 discovered in #177166.

elementwise_dtypes was registered as TorchInGraphFunctionVariable via
torch._higher_order_ops.out_dtype, causing dynamo to try putting it in the
FX graph. Since it returns (dtype, dtype) rather than tensors, this failed
with "torch.* op returned non-Tensor". The fix adds a handler that evaluates
elementwise_dtypes eagerly on fake tensor metadata during compilation and
returns the result as constants.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo azahed98

[ghstack-poisoned]
ydwu4 added a commit that referenced this pull request Apr 3, 2026
…entwise_dtypes during tracing"


Fix issue 2 discovered in #177166.

elementwise_dtypes was registered as TorchInGraphFunctionVariable via
torch._higher_order_ops.out_dtype, causing dynamo to try putting it in the
FX graph. Since it returns (dtype, dtype) rather than tensors, this failed
with "torch.* op returned non-Tensor". The fix adds a handler that evaluates
elementwise_dtypes eagerly on fake tensor metadata during compilation and
returns the result as constants.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo azahed98

[ghstack-poisoned]
ydwu4 added a commit that referenced this pull request Apr 3, 2026
…uring tracing"


Fix issue 2 discovered in #177166.

elementwise_dtypes was registered as TorchInGraphFunctionVariable via
torch._higher_order_ops.out_dtype, causing dynamo to try putting it in the
FX graph. Since it returns (dtype, dtype) rather than tensors, this failed
with "torch.* op returned non-Tensor". The fix adds a handler that evaluates
elementwise_dtypes eagerly on fake tensor metadata during compilation and
returns the result as constants.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo azahed98

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request Apr 8, 2026
…ng (#177743)

Fix issue 2 discovered in #177166.

elementwise_dtypes was registered as TorchInGraphFunctionVariable via
torch._higher_order_ops.out_dtype, causing dynamo to try putting it in the
FX graph. Since it returns (dtype, dtype) rather than tensors, this failed
with "torch.* op returned non-Tensor". The fix adds a handler that evaluates
elementwise_dtypes eagerly on fake tensor metadata during compilation and
returns the result as constants.

Pull Request resolved: #177743
Approved by: https://github.com/anijain2305
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants