Skip to content

[dynamo] handle unbacked SymInt data in Tensor.new_tensor#176390

Closed
vvvdwbvvv wants to merge 9 commits intopytorch:mainfrom
vvvdwbvvv:fix-unbacked-SymInt
Closed

[dynamo] handle unbacked SymInt data in Tensor.new_tensor#176390
vvvdwbvvv wants to merge 9 commits intopytorch:mainfrom
vvvdwbvvv:fix-unbacked-SymInt

Conversation

@vvvdwbvvv
Copy link
Copy Markdown
Contributor

@vvvdwbvvv vvvdwbvvv commented Mar 4, 2026

PR Summary

Fixes #176067 a torch.compile(..., fullgraph=True) failure for Tensor.new_tensor(...)

Repro

import torch

def f1(a):
    num = a.nonzero().squeeze(-1).numel()
    return torch.tensor([num])  # works

def f2(a):
    num = a.nonzero().squeeze(-1).numel()
    return a.new_tensor([num])  # fails before this change

a = torch.tensor([1, 0])
torch.compile(f1, fullgraph=True)(a)
torch.compile(f2, fullgraph=True)(a)

Prior to this change, f2 could fail with:

torch._dynamo.exc.UserError: Could not extract specialized integer from
data-dependent expression u0 (unhinted: u0). (Size-like symbols: u0)

related issue: #176067

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @kadeng @chauhang @amjames @Lucaskabela @jataylo @mlazos

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 4, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/176390

Note: Links to docs will display an error until the docs builds have been completed.

❌ 4 New Failures, 2 Unrelated Failures

As of commit e1a4bb2 with merge base 7643509 (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 4, 2026

This PR needs a release notes: label

If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Copy link
Copy Markdown
Contributor

@Lucaskabela Lucaskabela left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before we are ready to review we need to add a unit test (the repro is a good starting point; given this is quite a hot path we may also want to think about a few other paths that might cause this graph break and test those as well)

@Lucaskabela
Copy link
Copy Markdown
Contributor

Please re request review once ready

@vvvdwbvvv
Copy link
Copy Markdown
Contributor Author

vvvdwbvvv commented Mar 4, 2026

@Lucaskabela Thanks for replying. Yes I am writing test for it. I am also surveying whether there are better way to fix it. If you have better implementations for the fix, please let me know.

@vvvdwbvvv
Copy link
Copy Markdown
Contributor Author

vvvdwbvvv commented Mar 9, 2026

Now repro with

"""
Probes whether Dynamo has a dedicated handler for new_tensor vs torch.tensor
when given a data-dependent SymInt.

Expected finding:
  - torch.tensor([data_dep_symint])  → handled → no break
  - a.new_tensor([data_dep_symint])  → no handler → break
"""
import torch
import torch._dynamo as dynamo


def probe(name, fn, x):
    torch._dynamo.reset()

    # checks Dynamo record a break reason
    exp = dynamo.explain(fn)(x)
    has_break = bool(exp.break_reasons)

    # fullgraph compile
    torch._dynamo.reset()
    compile_breaks = False
    try:
        torch.compile(fn, fullgraph=True)(x)
    except Exception:
        compile_breaks = True

    # 3. inspect handler registry
    from torch._dynamo.trace_rules import lookup
    import inspect
    try:
        handler = lookup(fn)
        handler_info = str(handler)
        print(f"handler         : {handler_info}")
    except Exception:
        handler_info = "lookup failed"

    print(f"{'─' * 60}")
    print(f"case            : {name}")
    print(f"explain() break : {has_break}")
    print(f"fullgraph break : {compile_breaks}")
    if has_break:
        for r in exp.break_reasons:
            print(f"  reason        : {r.reason}")
    print()


a = torch.tensor([1, 0, 1])

# Case 1: torch.tensor with data-dependent SymInt
def fn_torch_tensor(a):
    num = a.nonzero().squeeze(-1).numel()
    return torch.tensor([num])

# Case 2: new_tensor with data-dependent SymInt
def fn_new_tensor(a):
    num = a.nonzero().squeeze(-1).numel()
    return a.new_tensor([num])

# Case 3: new_tensor with static scalar (control — should not break)
def fn_new_tensor_static(a):
    return a.new_tensor([42.0])

# Case 4: torch.zeros_like as baseline functional op (should not break)
def fn_zeros_like(a):
    num = a.nonzero().squeeze(-1).numel()
    return torch.zeros(num)

print(f"torch version: {torch.__version__}\n")
probe("torch.tensor(data_dep_numel)",    fn_torch_tensor,    a)
probe("a.new_tensor(data_dep_numel)",    fn_new_tensor,      a)
probe("a.new_tensor(static)",            fn_new_tensor_static, a)
probe("torch.zeros(data_dep_numel)",     fn_zeros_like,      a)

Use TORCH_LOGS="+dynamo" to check the full log, here 's some important log

Selected log
    V0309 15:12:02.513000 82094 .local/pythons/3.11/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py:4275] [0/0] [__graph_breaks] Graph break in user code at /Users/eddietsai/repro.py:57
  V0309 15:12:02.513000 82094 .local/pythons/3.11/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py:4275] [0/0] [__graph_breaks] Graph Break Reason: Encountered graph break when attempting to trace CALL: a function call, e.g. f(x, y):
  
  V0309 15:12:02.513000 82094 .local/pythons/3.11/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py:4275] [0/0] [__graph_breaks] Dynamic shape operator
  V0309 15:12:02.513000 82094 .local/pythons/3.11/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py:4275] [0/0] [__graph_breaks]   Explanation: Operator `aten.nonzero.default`'s output shape depends on input Tensor data.
  V0309 15:12:02.513000 82094 .local/pythons/3.11/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py:4275] [0/0] [__graph_breaks]   Hint: Enable tracing of dynamic shape operators with `torch._dynamo.config.capture_dynamic_output_shape_ops = True`
  
  
  V0309 15:12:02.513000 82094 .local/pythons/3.11/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py:4275] [0/0] [__graph_breaks]   Developer debug context: aten.nonzero.default
  V0309 15:12:02.513000 82094 .local/pythons/3.11/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py:4275] [0/0] [__graph_breaks] 
  V0309 15:12:02.513000 82094 .local/pythons/3.11/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py:4275] [0/0] [__graph_breaks]  For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0036.html
  
  
  I0309 15:12:02.515000 82094 .local/pythons/3.11/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py:4447] [0/0_1] Step 1: torchdynamo start tracing fn_new_tensor /Users/eddietsai/repro.py:56
  
  V0309 15:12:02.518000 82094 .local/pythons/3.11/lib/python3.11/site-packages/torch/_dynamo/output_graph.py:1467] [0/0_1] COMPILING GRAPH due to GraphCompileReason(reason="Dynamic shape operator\n  Explanation: Operator `aten.nonzero.default`'s output shape depends on input Tensor data.\n  Hint: Enable tracing of dynamic shape operators with `torch._dynamo.config.capture_dynamic_output_shape_ops = True`\n\n  Developer debug context: aten.nonzero.default\n\n For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0036.html", user_stack=[<FrameSummary file /Users/eddietsai/repro.py, line 57 in fn_new_tensor>], graph_break=True)
  
  
  V0309 15:12:02.540000 82094 .local/pythons/3.11/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py:1715] [1/0] torchdynamo start compiling torch_dynamo_resume_in_fn_new_tensor_at_57 /Users/eddietsai/repro.py:57, stack (elided 5 frames):
  
  
  V0309 15:12:02.541000 82094 .local/pythons/3.11/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py:1289] [1/0] [__trace_source]         num = a.nonzero().squeeze(-1).numel()
  
  V0309 15:12:02.548000 82094 .local/pythons/3.11/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py:1315] [1/0] [__trace_bytecode] TRACE STORE_FAST num [ConstantVariable(int: 2)]
  
  
  V0309 15:12:02.554000 82094 .local/pythons/3.11/lib/python3.11/site-packages/torch/_dynamo/output_graph.py:2184] [1/0] [__graph_code]         new_tensor: "i64[1][1]cpu" = l_a_.new_tensor([2]);  l_a_ = None
  
  
  V0309 15:12:02.577000 82094 .local/pythons/3.11/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py:1715] [0/0] torchdynamo start compiling fn_new_tensor /Users/eddietsai/repro.py:56, stack (elided 4 frames):
  V0309 15:12:02.577000 82094 .local/pythons/3.11/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py:1715] [0/0]   File "/Users/eddietsai/repro.py", line 71, in <module>
  V0309 15:12:02.577000 82094 .local/pythons/3.11/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py:1715] [0/0]     probe("a.new_tensor(data_dep_numel)",    fn_new_tensor,      a)
  V0309 15:12:02.577000 82094 .local/pythons/3.11/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py:1715] [0/0]   File "/Users/eddietsai/repro.py", line 24, in probe
  V0309 15:12:02.577000 82094 .local/pythons/3.11/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py:1715] [0/0]     torch.compile(fn, fullgraph=True)(x)
  
  
  I0309 15:12:02.581000 82094 .local/pythons/3.11/lib/python3.11/site-packages/torch/fx/experimental/symbolic_shapes.py:4885] [0/0] create_unbacked_symint u0 [-int_oo, int_oo] num = a.nonzero().squeeze(-1).numel()  # repro.py:57 in fn_new_tensor (_subclasses/fake_impls.py:720 in nonzero)
  
  V0309 15:12:02.579000 82094 .local/pythons/3.11/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py:1315] [0/0] [__trace_bytecode] TRACE PRECALL 0 [NullVariable, GetAttrVariable(TensorVariable(), nonzero)]
  V0309 15:12:02.580000 82094 .local/pythons/3.11/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py:1315] [0/0] [__trace_bytecode] TRACE CALL 0 [NullVariable, GetAttrVariable(TensorVariable(), nonzero)]
  V0309 15:12:02.580000 82094 .local/pythons/3.11/lib/python3.11/site-packages/torch/_dynamo/output_graph.py:3145] [0/0] [__trace_call] TRACE FX call nonzero from /Users/eddietsai/repro.py:57 in fn_new_tensor (fn_new_tensor)
  V0309 15:12:02.580000 82094 .local/pythons/3.11/lib/python3.11/site-packages/torch/_dynamo/output_graph.py:3145] [0/0] [__trace_call]     num = a.nonzero().squeeze(-1).numel()
  V0309 15:12:02.580000 82094 .local/pythons/3.11/lib/python3.11/site-packages/torch/_dynamo/output_graph.py:3145] [0/0] [__trace_call]           ~~~~~~~~~^^
  I0309 15:12:02.581000 82094 .local/pythons/3.11/lib/python3.11/site-packages/torch/fx/experimental/symbolic_shapes.py:4885] [0/0] create_unbacked_symint u0 [-int_oo, int_oo] num = a.nonzero().squeeze(-1).numel()  # repro.py:57 in fn_new_tensor (_subclasses/fake_impls.py:720 in nonzero)
  
  
  V0309 15:12:02.616000 82094 .local/pythons/3.11/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py:1315] [0/0] [__trace_bytecode] TRACE BUILD_LIST 1 [NullVariable, GetAttrVariable(TensorVariable(), new_tensor), SymNodeVariable()]
  V0309 15:12:02.616000 82094 .local/pythons/3.11/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py:1315] [0/0] [__trace_bytecode] TRACE PRECALL 1 [NullVariable, GetAttrVariable(TensorVariable(), new_tensor), ListVariable(length=1)]
  V0309 15:12:02.617000 82094 .local/pythons/3.11/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py:1315] [0/0] [__trace_bytecode] TRACE CALL 1 [NullVariable, GetAttrVariable(TensorVariable(), new_tensor), ListVariable(length=1)]
  V0309 15:12:02.617000 82094 .local/pythons/3.11/lib/python3.11/site-packages/torch/_dynamo/output_graph.py:3145] [0/0] [__trace_call] TRACE FX call new_tensor from /Users/eddietsai/repro.py:58 in fn_new_tensor (fn_new_tensor)
  V0309 15:12:02.617000 82094 .local/pythons/3.11/lib/python3.11/site-packages/torch/_dynamo/output_graph.py:3145] [0/0] [__trace_call]     return a.new_tensor([num])
  V0309 15:12:02.617000 82094 .local/pythons/3.11/lib/python3.11/site-packages/torch/_dynamo/output_graph.py:3145] [0/0] [__trace_call]            ~~~~~~~~~~~~^^^^^^^
  V0309 15:12:02.621000 82094 .local/pythons/3.11/lib/python3.11/site-packages/torch/fx/experimental/symbolic_shapes.py:6649] [0/0] Data dependent variable 'u0' allocated at:

Finding

explain() path:
nonzero() → graph break → restart → resume with num concretized to constant (e.g. 2) → new_tensor([2]) → succeeds, no break recorded
fullgraph=True path:
nonzero() → creates unbacked SymInt u0 → continues tracing → new_tensor([u0]) → no handler for SymInt in list → compile fails
Conclusion:
torch.tensor has dedicated support for data-dependent SymInts in shape/list → no break
Tensor.new_tensor lacks the handler → fullgraph compile breaks on unbacked SymInt

This might be related to https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0036.html

cc @Lucaskabela What do you suggest for the workaround, can I directly handling new_tensor to torch.tensor path?, for my current implementation it's just hotfix. While I would like to do something more to fix this problem :)

@Lucaskabela
Copy link
Copy Markdown
Contributor

Why aren't we just making the repro:

def f1(a):
    num = a.nonzero().squeeze(-1).numel()
    return torch.tensor([num])  # works

def f2(a):
    num = a.nonzero().squeeze(-1).numel()
    return a.new_tensor([num])  # fails before this change

a = torch.tensor([1, 0])
torch.compile(f1, fullgraph=True)(a)
torch.compile(f2, fullgraph=True)(a)

into a unit test case? This should be done and added to our unit test (i.e checked into the code) first as a prerequisite before we go any further.

Please make this a formal pytorch test case and add it to one of our unit test files

# are common keys.
all_tensor_attrs = torch._C.TensorBase.__dict__ | torch.Tensor.__dict__

def _contains_unspec_tensor_data(x):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Needs typehints like:

_contains_unspec_tensor_data(x: VariableTracker) -> bool:

return self.call_method(tx, "new_empty", args, kwargs)
return None

def method_new_tensor(self, *args, **kwargs):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: needs typehints (see above for example)


from ..symbolic_convert import InstructionTranslator

tx = InstructionTranslator.current_tx()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why aren't we just passing tx in as an arg?

@vvvdwbvvv
Copy link
Copy Markdown
Contributor Author

Understood. I’ll first turn this into a formal PyTorch unit test case, add it to the appropriate test file, and make sure it is checked in before proceeding further.

@vvvdwbvvv vvvdwbvvv marked this pull request as ready for review March 10, 2026 09:10
@vvvdwbvvv
Copy link
Copy Markdown
Contributor Author

Added a unitest test_new_tensor_break and handler in method_new_tensor

@vvvdwbvvv vvvdwbvvv requested a review from Lucaskabela March 10, 2026 09:14
@Lucaskabela Lucaskabela added the ciflow/dynamo Trigger jobs ran periodically on main for dynamo tests label Mar 10, 2026
res = opt_f(x)
self.assertEqual(ref, res)

@torch._dynamo.config.patch(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also add a multi tensor test case, like:

return a.new_tensor([n, n + 1, n * 2])  

I believe this particular logic will fail for that example

all_tensor_attrs = torch._C.TensorBase.__dict__ | torch.Tensor.__dict__


def _contains_unspec_tensor_data(x: VariableTracker) -> bool:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is only used in method_new_tensor let's just check inside method_new_tensor (no need for this helper function up here so far from the logic it is used

new_kwargs["device"] = kwargs.get("device", self.var_getattr(tx, "device"))

# Build scalar_tensor
scalar = variables.TorchInGraphFunctionVariable(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This below logic needs refactoring to support the more general case - we don't really need unsqueeze (can use stack instead)

@pytorch-bot pytorch-bot bot removed the ciflow/dynamo Trigger jobs ran periodically on main for dynamo tests label Mar 12, 2026
@vvvdwbvvv vvvdwbvvv requested a review from Lucaskabela March 12, 2026 02:07
@Lucaskabela Lucaskabela added the ciflow/dynamo Trigger jobs ran periodically on main for dynamo tests label Mar 12, 2026
Copy link
Copy Markdown
Contributor

@Lucaskabela Lucaskabela left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating on this - I think this looks ready to me!

@Lucaskabela
Copy link
Copy Markdown
Contributor

@pytorchmergebot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 12, 2026
@vvvdwbvvv
Copy link
Copy Markdown
Contributor Author

Thank you!🙏

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / macos-py3-arm64 / test (openreg, 1, 1, macos-m1-stable)

Details for Dev Infra team Raised by workflow job

@Lucaskabela
Copy link
Copy Markdown
Contributor

@pytorchmergebot merge -i

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged while ignoring the following 1 checks: trunk / macos-py3-arm64 / test (openreg, 1, 1, macos-m1-stable)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / macos-py3-arm64 / test (mps, 1, 1, macos-m1-14)

Details for Dev Infra team Raised by workflow job

@Lucaskabela
Copy link
Copy Markdown
Contributor

@claude explain this test failure Build left local git repository checkout dirty

@claude
Copy link
Copy Markdown

claude bot commented Mar 12, 2026

Claude finished @Lucaskabela's task in 1m 41s —— View job


Analysis of "Build left local git repository checkout dirty" failure

  • Check CI status and identify failing jobs
  • Download and analyze job logs
  • Explain root cause and suggest resolution

This failure is unrelated to the PR changes. Both failing jobs show the same issue:

+ assert_git_not_dirty
++ git status --porcelain
++ grep -v '?? third_party'
+ git_status=' M third_party/kineto'
Build left local git repository checkout dirty
git status --porcelain:
 M third_party/kineto
+ exit 1

What happened

After all tests complete, PyTorch CI runs an assert_git_not_dirty check to verify the build/test process didn't modify any tracked files. In both failing jobs (macos-py3-arm64 / test (mps) and macos-py3-arm64 / test (openreg)), the third_party/kineto submodule is showing as modified.

This is a known macOS CI infrastructure issue — the third_party/kineto submodule gets marked dirty during the build or test process on macOS ARM64 runners. It's not caused by any code change in this PR (which only touches test/dynamo/test_misc.py and torch/_dynamo/variables/tensor.py).

The actual tests themselves passed successfully (the logs show OK right before the dirty-check step).

Resolution

This is already being handled correctly: @Lucaskabela used @pytorchmergebot merge -i to ignore these known-flaky infra failures. The second merge attempt failed because only the openreg job was ignored, not the mps job. Both jobs need to be ignored:

@pytorchmergebot merge -i

Since the pytorchmergebot should pick up both failures this time around, the next merge attempt should succeed (assuming no new failures appear).


@vvvdwbvvv
Copy link
Copy Markdown
Contributor Author

Is there anything that I can help?

@Lucaskabela
Copy link
Copy Markdown
Contributor

@pytorchmergebot merge -i

@Lucaskabela
Copy link
Copy Markdown
Contributor

Is there anything that I can help?

I will ping if so - pretty sure these are just flaky test signals though

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

EmanueleCoradin pushed a commit to EmanueleCoradin/pytorch that referenced this pull request Mar 30, 2026
…6390)

## PR Summary

Fixes pytorch#176067  a `torch.compile(..., fullgraph=True)` failure for `Tensor.new_tensor(...)`

## Repro

```python
import torch

def f1(a):
    num = a.nonzero().squeeze(-1).numel()
    return torch.tensor([num])  # works

def f2(a):
    num = a.nonzero().squeeze(-1).numel()
    return a.new_tensor([num])  # fails before this change

a = torch.tensor([1, 0])
torch.compile(f1, fullgraph=True)(a)
torch.compile(f2, fullgraph=True)(a)
```

Prior to this change, `f2` could fail with:

```text
torch._dynamo.exc.UserError: Could not extract specialized integer from
data-dependent expression u0 (unhinted: u0). (Size-like symbols: u0)
```

related issue: pytorch#176067

Pull Request resolved: pytorch#176390
Approved by: https://github.com/Lucaskabela
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/dynamo Trigger jobs ran periodically on main for dynamo tests ciflow/trunk Trigger trunk jobs on your pull request Merged module: dynamo open source release notes: dynamo

Projects

None yet

Development

Successfully merging this pull request may close these issues.

a.new_tensor([num]) should not introduce a graph break

4 participants