Skip to content

revisit guarding in mark dynamic APIs#176341

Open
laithsakka wants to merge 6 commits intogh/laithsakka/417/basefrom
gh/laithsakka/417/head
Open

revisit guarding in mark dynamic APIs#176341
laithsakka wants to merge 6 commits intogh/laithsakka/417/basefrom
gh/laithsakka/417/head

Conversation

@laithsakka
Copy link
Contributor

@laithsakka laithsakka commented Mar 3, 2026

Stack from ghstack (oldest at bottom):

This PR changes how we guard explicit dynamic attributes on tensors.
after thinking and based on the comments i convinced my self again this is the right thing to do OBEY WHAT USER ASK ME TODO !

The high-level idea:

  1. If the user explicitly specifies dynamic properties, we do an exact match on those
  2. If the user does not specify, we will not fail due to guards on dynamic properties installed on compiled frames .

However, for (2) we need to distinguish between:

  • User did not specify (running phase, doesn't care, happy to reuse compiled code)
  • User explicitly wants nothing dynamic/static (compilation phase, actively opting out)

To solve this, we add support for passing empty list [] to dimension marking APIs, which explicitly sets an empty set on the attribute.

Examples:

  1. Compile with mark_dynamic(x, 0), call with mark_dynamic(x, 0) → no recompile (exact match)
  2. Compile with mark_dynamic(x, 0), call with mark_dynamic(x, 1) → recompile (different dims)
  3. Compile with mark_dynamic(x, [0,1]), call with mark_dynamic(x, 1) → recompile (different dims)
  4. Compile with mark_dynamic(x, 0), call with plain tensor → no recompile (unspecified = don't care)
  5. Compile with plain tensor, call with mark_dynamic(x, 0) → recompile (new marking added)
  6. Compile with mark_dynamic(x, 0), call with mark_dynamic(x, []) → recompile (explicit empty != {0})

See [Note: Dimension Marking Guards] in torch/_dynamo/guards.py for implementation details.

One might argue that compiling a specialization like s0, s1, static after s0, s1, s2 has no value, since the more dynamic version should cover it and would always be picked. In practice this is not true. Because the Dynamo cache is LRU-ordered, the most recently compiled graph is checked first, so the s0, s1, static specialization can actually be selected if its guards match. See the example in the comment: compiling s0, s1, static after s0, s1, s2 can make it preferable for inputs where the user did not specify markings (“I don’t care”), since it will be tried first during cache lookup.

The change is justified because LRU only determines the order in which cached graphs are checked, not which graphs are considered valid matches. Correctness is defined by the guards, so if a user explicitly specifies dynamic properties, the system must respect that intent. Using exact-match semantics ensures that explicit markings reliably control specialization and dispatch, while still allowing reuse in the common case where the user does not specify markings. Supporting an explicit empty list [] is also important, as it distinguishes unspecified (“don’t care”) from explicitly opting out, enabling consistent and predictable guard behavior.

for example

import torch
import torch._dynamo

compile_count = 0


def custom_backend(gm, example_inputs):
    global compile_count
    compile_count += 1
    compilation_id = compile_count
    print(f"\n=== Compilation #{compilation_id} ===")
    print(gm.graph)

    compiled_forward = gm.forward

    def wrapper(*args, **kwargs):
        print(f"  [Runtime] Using graph #{compilation_id}")
        return compiled_forward(*args, **kwargs)

    return wrapper


def fn(x):
    return x + 1


compiled_fn = torch.compile(fn, backend=custom_backend)

x = torch.randn(4, 5, 6, 7)
torch._dynamo.mark_dynamic(x, [1, 2, 3])
print(f"Call 1 - dynamic_indices: {x._dynamo_dynamic_indices}")
compiled_fn(x)

x = torch.randn(4, 5, 6, 7)
torch._dynamo.mark_dynamic(x, [1, 2])
torch._dynamo.mark_static(x, [3])
compiled_fn(x)

compiled_fn(x)

x = torch.randn(4, 5, 6, 7)
compiled_fn(x)
torch._dynamo.mark_dynamic(x, [1, 2, 3])
compiled_fn(x)
print(f"\nTotal compilations: {compile_count}")

will give


Call 1 - dynamic_indices: {1, 2, 3}

=== Compilation #1 ===
graph():
    %s27 : torch.SymInt [num_users=0] = placeholder[target=s27]
    %s53 : torch.SymInt [num_users=0] = placeholder[target=s53]
    %s0 : torch.SymInt [num_users=0] = placeholder[target=s0]
    %l_x_ : torch.Tensor [num_users=1] = placeholder[target=L_x_]
    %add : [num_users=1] = call_function[target=operator.add](args = (%l_x_, 1), kwargs = {})
    return (add,)
  [Runtime] Using graph #1

=== Compilation #2 ===
graph():
    %s27 : torch.SymInt [num_users=0] = placeholder[target=s27]
    %s53 : torch.SymInt [num_users=0] = placeholder[target=s53]
    %l_x_ : torch.Tensor [num_users=1] = placeholder[target=L_x_]
    %add : [num_users=1] = call_function[target=operator.add](args = (%l_x_, 1), kwargs = {})
    return (add,)
  [Runtime] Using graph #2
  [Runtime] Using graph #2
  [Runtime] Using graph #2
  [Runtime] Using graph #1

Note what is picked at runtime is now different before this PR always graph 1 is picked!
Previous Behavior:

  • mark_dynamic used issubset semantics: if marked dims were a subset of compiled dims, no recompilation
  • mark_static and mark_unbacked had missing guards for exact-match checking!
  • maybe_mark_dynamic had NO guards at all - changes to weak dynamic indices were ignored!
  • No way to explicitly say "I want NO marks on this tensor's dimensions"!

New Behavior:

  • All dimension marking APIs use exact-match semantics
  • Added missing guards for _dynamo_weak_dynamic_indices (maybe_mark_dynamic)
  • Guards now check for exact equality, not subset relationship
  • Users can pass [] to explicitly opt-out of any marking type with guards on that.

@pytorch-bot
Copy link

pytorch-bot bot commented Mar 3, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/176341

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 6d4ab20 with merge base 07efc60 (image):

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot
Copy link

pytorch-bot bot commented Mar 3, 2026

This PR needs a release notes: label

If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

laithsakka added a commit that referenced this pull request Mar 3, 2026
ghstack-source-id: ab070dd
Pull Request resolved: #176341
@laithsakka laithsakka changed the title revisit explicit dynamsism guarding revisit guarding in mark dynamic APIs Mar 3, 2026

This PR changes how we guard explicit dynamic attributes on tensors. The high-level idea:

1. If the user explicitly specifies dynamic properties, we do an exact match on those
2. If the user does not specify, we will not fail existing guards on dynamic properties

However, for (2) we need to distinguish between:
- User did not specify (running phase, doesn't care, happy to reuse compiled code)
- User explicitly wants nothing dynamic/static (compilation phase, actively opting out)

To solve this, we add support for passing empty list `[]` to dimension marking APIs, which explicitly sets an empty set on the attribute.

Examples:
1. Compile with mark_dynamic(x, 0), call with mark_dynamic(x, 0) → no recompile (exact match)
2. Compile with mark_dynamic(x, 0), call with mark_dynamic(x, 1) → recompile (different dims)
3. **Compile with mark_dynamic(x, 0), call with plain tensor → no recompile (unspecified = don't care)**
4. Compile with plain tensor, call with mark_dynamic(x, 0) → recompile (new marking added)
5. **Compile with mark_dynamic(x, 0), call with mark_dynamic(x, []) → recompile (explicit empty != {0})**

Previous Behavior:
- mark_dynamic used issubset semantics: if marked dims were a subset of compiled dims, no recompilation
- **mark_static and mark_unbacked had missing guards for exact-match checking!**
- maybe_mark_dynamic had NO guards at all - changes to weak dynamic indices were ignored!
- No way to explicitly say "I want NO marks on this tensor's dimensions"!

New Behavior:
- All dimension marking APIs use exact-match semantics
- Added missing guards for _dynamo_weak_dynamic_indices (maybe_mark_dynamic)
- Guards now check for exact equality, not subset relationship
- Users can pass [] to explicitly opt-out of any marking type with guards on that.

See [Note: Dimension Marking Guards] in torch/_dynamo/guards.py for implementation details.

[ghstack-poisoned]
laithsakka added a commit that referenced this pull request Mar 3, 2026
ghstack-source-id: df7cb50
Pull Request resolved: #176341
@laithsakka laithsakka requested review from Lucaskabela, aorenste, ezyang and williamwen42 and removed request for Lucaskabela and aorenste March 4, 2026 00:11

This PR changes how we guard explicit dynamic attributes on tensors. The high-level idea:

1. If the user explicitly specifies dynamic properties, we do an exact match on those
2. If the user does not specify, we will not fail existing guards on dynamic properties

However, for (2) we need to distinguish between:
- User did not specify (running phase, doesn't care, happy to reuse compiled code)
- User explicitly wants nothing dynamic/static (compilation phase, actively opting out)

To solve this, we add support for passing empty list `[]` to dimension marking APIs, which explicitly sets an empty set on the attribute.

Examples:
1. Compile with mark_dynamic(x, 0), call with mark_dynamic(x, 0) → no recompile (exact match)
2. Compile with mark_dynamic(x, 0), call with mark_dynamic(x, 1) → recompile (different dims)
3. **Compile with mark_dynamic(x, 0), call with plain tensor → no recompile (unspecified = don't care)**
4. Compile with plain tensor, call with mark_dynamic(x, 0) → recompile (new marking added)
5. **Compile with mark_dynamic(x, 0), call with mark_dynamic(x, []) → recompile (explicit empty != {0})**

Previous Behavior:
- mark_dynamic used issubset semantics: if marked dims were a subset of compiled dims, no recompilation
- **mark_static and mark_unbacked had missing guards for exact-match checking!**
- maybe_mark_dynamic had NO guards at all - changes to weak dynamic indices were ignored!
- No way to explicitly say "I want NO marks on this tensor's dimensions"!

New Behavior:
- All dimension marking APIs use exact-match semantics
- Added missing guards for _dynamo_weak_dynamic_indices (maybe_mark_dynamic)
- Guards now check for exact equality, not subset relationship
- Users can pass [] to explicitly opt-out of any marking type with guards on that.

See [Note: Dimension Marking Guards] in torch/_dynamo/guards.py for implementation details.

[ghstack-poisoned]
laithsakka added a commit that referenced this pull request Mar 4, 2026
ghstack-source-id: b34a520
Pull Request resolved: #176341

This PR changes how we guard explicit dynamic attributes on tensors. The high-level idea:

1. If the user explicitly specifies dynamic properties, we do an exact match on those
2. If the user does not specify, we will not fail existing guards on dynamic properties

However, for (2) we need to distinguish between:
- User did not specify (running phase, doesn't care, happy to reuse compiled code)
- User explicitly wants nothing dynamic/static (compilation phase, actively opting out)

To solve this, we add support for passing empty list `[]` to dimension marking APIs, which explicitly sets an empty set on the attribute.

Examples:
1. Compile with mark_dynamic(x, 0), call with mark_dynamic(x, 0) → no recompile (exact match)
2. Compile with mark_dynamic(x, 0), call with mark_dynamic(x, 1) → recompile (different dims)
3. **Compile with mark_dynamic(x, 0), call with plain tensor → no recompile (unspecified = don't care)**
4. Compile with plain tensor, call with mark_dynamic(x, 0) → recompile (new marking added)
5. **Compile with mark_dynamic(x, 0), call with mark_dynamic(x, []) → recompile (explicit empty != {0})**

Previous Behavior:
- mark_dynamic used issubset semantics: if marked dims were a subset of compiled dims, no recompilation
- **mark_static and mark_unbacked had missing guards for exact-match checking!**
- maybe_mark_dynamic had NO guards at all - changes to weak dynamic indices were ignored!
- No way to explicitly say "I want NO marks on this tensor's dimensions"!

New Behavior:
- All dimension marking APIs use exact-match semantics
- Added missing guards for _dynamo_weak_dynamic_indices (maybe_mark_dynamic)
- Guards now check for exact equality, not subset relationship
- Users can pass [] to explicitly opt-out of any marking type with guards on that.

See [Note: Dimension Marking Guards] in torch/_dynamo/guards.py for implementation details.

[ghstack-poisoned]
laithsakka added a commit that referenced this pull request Mar 4, 2026
ghstack-source-id: 282feca
Pull Request resolved: #176341
@laithsakka laithsakka marked this pull request as draft March 4, 2026 01:33
@laithsakka laithsakka marked this pull request as ready for review March 4, 2026 04:29
@laithsakka laithsakka requested a review from anijain2305 March 4, 2026 04:37
@laithsakka laithsakka marked this pull request as draft March 4, 2026 07:04
@laithsakka
Copy link
Contributor Author

laithsakka commented Mar 4, 2026

Having second thoughts on this.

On one hand:

mark_static after mark_dynamic with the current Dynamo cache lookup is effectively useless.

mark_dynamic([1,2]) after mark_dynamic([1,2,3]) also seems useless, since the subsequent cache lookup will still pick the [1,2,3] specialization.

But consider something like this:

compiled_fn = torch.compile(fn)

import torch
import torch._dynamo

compile_count = 0


def custom_backend(gm, example_inputs):
    global compile_count
    compile_count += 1
    print(f"\n=== Compilation #{compile_count} ===")
    print(gm.graph)
    return gm.forward


def fn(x):
    return x + 1


compiled_fn = torch.compile(fn, backend=custom_backend)

# --- Call 1: mark dims 1, 2, 3 as dynamic ---
x = torch.randn(4, 5, 6, 7)
torch._dynamo.mark_dynamic(x, [1, 2, 3])
print(f"Call 1 - dynamic_indices: {x._dynamo_dynamic_indices}")
compiled_fn(x)

# --- Call 2: new tensor, mark dims 1, 2 as dynamic, mark dim 3 as static ---
x = torch.randn(4, 5, 6, 7)
torch._dynamo.mark_dynamic(x, [1, 2])
torch._dynamo.mark_static(x, [3])
print(f"\nCall 2 - dynamic_indices: {x._dynamo_dynamic_indices}")
compiled_fn(x)

print(f"\nTotal compilations: {compile_count}")

This would force s0, s1 to be treated as dynamic and s2 as static on the second call.

Now, if later someone passes a tensor with no attributes, the system would pick s0, s1, s2 as dynamic. At that point, the user is essentially indicating they don’t care which specialization gets selected.

So in a way, these APIs can also be seen as controlling dispatch between different compiled variants.
the above program would print

Call 1 - dynamic_indices: {1, 2, 3}

=== Compilation #1 ===
graph():
    %s27 : torch.SymInt [num_users=0] = placeholder[target=s27]
    %s53 : torch.SymInt [num_users=0] = placeholder[target=s53]
    %s0 : torch.SymInt [num_users=0] = placeholder[target=s0]
    %l_x_ : torch.Tensor [num_users=1] = placeholder[target=L_x_]
    %add : [num_users=1] = call_function[target=operator.add](args = (%l_x_, 1), kwargs = {})
    return (add,)

Call 2 - dynamic_indices: {1, 2}

=== Compilation #2 ===
graph():
    %s27 : torch.SymInt [num_users=0] = placeholder[target=s27]
    %s53 : torch.SymInt [num_users=0] = placeholder[target=s53]
    %l_x_ : torch.Tensor [num_users=1] = placeholder[target=L_x_]
    %add : [num_users=1] = call_function[target=operator.add](args = (%l_x_, 1), kwargs = {})
    return (add,)

@laithsakka
Copy link
Contributor Author

since dynamo cache is LRU actually the latest compilation is looked up first so the comment above is not very accurate
we do not pick s0, s1, s2 over s0, s1, static always! if the latest compile was s0, s1, static we pick s0, s1, static over s0, s1, s2 s2.

import torch
import torch._dynamo

compile_count = 0


def custom_backend(gm, example_inputs):
    global compile_count
    compile_count += 1
    compilation_id = compile_count
    print(f"\n=== Compilation #{compilation_id} ===")
    print(gm.graph)

    compiled_forward = gm.forward

    def wrapper(*args, **kwargs):
        print(f"  [Runtime] Using graph #{compilation_id}")
        return compiled_forward(*args, **kwargs)

    return wrapper


def fn(x):
    return x + 1


compiled_fn = torch.compile(fn, backend=custom_backend)

# --- Call 1: mark dims 1, 2, 3 as dynamic ---
x = torch.randn(4, 5, 6, 7)
torch._dynamo.mark_dynamic(x, [1, 2, 3])
print(f"Call 1 - dynamic_indices: {x._dynamo_dynamic_indices}")
compiled_fn(x)

# --- Call 2: new tensor, mark dims 1, 2 as dynamic, mark dim 3 as static ---
x = torch.randn(4, 5, 6, 7)
torch._dynamo.mark_dynamic(x, [1, 2])
torch._dynamo.mark_static(x, [3])
print(f"\nCall 2 - dynamic_indices: {x._dynamo_dynamic_indices}")
compiled_fn(x)

# --- Call 3: same tensor again, should be a cache hit ---
print(f"\nCall 3 - same tensor, expecting cache hit")
compiled_fn(x)

x = torch.randn(4, 5, 6, 7)
compiled_fn(x)

print(f"\nTotal compilations: {compile_count}")

@laithsakka laithsakka marked this pull request as ready for review March 4, 2026 07:40
@laithsakka
Copy link
Contributor Author

after thinking and based on the comments i convinced my self again this is the right thing to do OBEY WHAT USER ASK ME TODO !

Copy link
Contributor

@Lucaskabela Lucaskabela left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need some convincing on this API personally

t (Any): The tensor to mark as having an unbacked dimension.
index (int or list/tuple of int): The dimension(s) to mark as unbacked. Can be a single integer or a list/tuple of integers.
index (int or list/tuple of int): The dimension(s) to mark as unbacked. Can be a single
integer or a list/tuple of integers. Pass an empty list [] to explicitly mark the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not in love with this empty list - I think a boolean kwarg would make these semantics more obvious. Something like

ignore_all_unbacked=True

This just seems more obvious to me than [] signifying no unbacked; or better yet a seperate API entirely like mark_never_unbacked. Because imo, as a user if my tensors are not unbacked, why would I be mucking around with the mark_unbacked api at all to even discover this option?

Copy link
Contributor Author

@laithsakka laithsakka Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imagine collecting the dimensions that should be unbacked into a list and passing that list to mark_unbacked. If that list happens to be empty, what should happen, we have to support that case anyway we probably do not want to error? Reusing a compilation that marked [a, b, c] as unbacked would likely be incorrect, because the user explicitly indicated that "I am controlling unbacked because i called mark_unbacked" but i did set it to empty . (same for dynamic/static).

On the other hand, if mark_unbacked was never called and the tensor has no dynamic attributes, the call is treated as “don’t care,” and any compiled frame that works for the input can be reused. In other words, mark_unbacked(x, []) is not the same as not calling mark_unbacked at all. Once you call mark_unbacked, you are explicitly participating in the guard semantics, and that distinction affects cache matching and dispatch.

I don’t expect this pattern to be used very often, but it is included for completeness. In most cases users only call mark_* once with a small number of dimensions. Adding a separate argument or API for something rarely needed would feel like overkill and not natural also we have to support empty list anyway. mark_unbacked(x, []) reads naturally as “when it comes to unbacked nothing is unbacked,” and it keeps the API consistent with the rest of the mark_* functions: one function, pass the dims you want, done.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this justification makes a bit more sense - that said I still don't love the ergonomics of this API. I would like some other reviewers to weigh in as well (cc @bobrenjc93 for instance)

@laithsakka laithsakka requested a review from Lucaskabela March 4, 2026 21:27

This PR changes how we guard explicit dynamic attributes on tensors. 
after thinking and based on the comments i convinced my self again this is the right thing to do OBEY WHAT USER ASK ME TODO !

The high-level idea:
1. If the user explicitly specifies dynamic properties, we do an exact match on those
2. If the user does not specify, we will not fail due to guards on dynamic properties installed on compiled frames .

However, for (2) we need to distinguish between:
- User did not specify (running phase, doesn't care, happy to reuse compiled code)
- User explicitly wants nothing dynamic/static (compilation phase, actively opting out)

To solve this, we add support for passing empty list `[]` to dimension marking APIs, which explicitly sets an empty set on the attribute.

Examples:
1. Compile with mark_dynamic(x, 0), call with mark_dynamic(x, 0) → no recompile (exact match)
2. Compile with mark_dynamic(x, 0), call with mark_dynamic(x, 1) → recompile (different dims) 
3. **Compile with mark_dynamic(x, [0,1]), call with mark_dynamic(x, 1) → recompile (different dims)**
4. **Compile with mark_dynamic(x, 0), call with plain tensor → no recompile (unspecified = don't care)**
5. Compile with plain tensor, call with mark_dynamic(x, 0) → recompile (new marking added)
6. **Compile with mark_dynamic(x, 0), call with mark_dynamic(x, []) → recompile (explicit empty != {0})**

See [Note: Dimension Marking Guards] in torch/_dynamo/guards.py for implementation details.

One might argue that compiling a specialization like s0, s1, static after s0, s1, s2 has no value, since the more dynamic version should cover it and would always be picked. In practice this is not true. Because the Dynamo cache is LRU-ordered, the most recently compiled graph is checked first, so the s0, s1, static specialization can actually be selected if its guards match. See the example in the comment: compiling s0, s1, static after s0, s1, s2 can make it preferable for inputs where the user did not specify markings (“I don’t care”), since it will be tried first during cache lookup.

The change is justified because LRU only determines the order in which cached graphs are checked, not which graphs are considered valid matches. Correctness is defined by the guards, so if a user explicitly specifies dynamic properties, the system must respect that intent. Using exact-match semantics ensures that explicit markings reliably control specialization and dispatch, while still allowing reuse in the common case where the user does not specify markings. Supporting an explicit empty list [] is also important, as it distinguishes unspecified (“don’t care”) from explicitly opting out, enabling consistent and predictable guard behavior.

for example
```
import torch
import torch._dynamo

compile_count = 0


def custom_backend(gm, example_inputs):
    global compile_count
    compile_count += 1
    compilation_id = compile_count
    print(f"\n=== Compilation #{compilation_id} ===")
    print(gm.graph)

    compiled_forward = gm.forward

    def wrapper(*args, **kwargs):
        print(f"  [Runtime] Using graph #{compilation_id}")
        return compiled_forward(*args, **kwargs)

    return wrapper


def fn(x):
    return x + 1


compiled_fn = torch.compile(fn, backend=custom_backend)

x = torch.randn(4, 5, 6, 7)
torch._dynamo.mark_dynamic(x, [1, 2, 3])
print(f"Call 1 - dynamic_indices: {x._dynamo_dynamic_indices}")
compiled_fn(x)

x = torch.randn(4, 5, 6, 7)
torch._dynamo.mark_dynamic(x, [1, 2])
torch._dynamo.mark_static(x, [3])
compiled_fn(x)

compiled_fn(x)

x = torch.randn(4, 5, 6, 7)
compiled_fn(x)
torch._dynamo.mark_dynamic(x, [1, 2, 3])
compiled_fn(x)
print(f"\nTotal compilations: {compile_count}")

```
will give
```

Call 1 - dynamic_indices: {1, 2, 3}

=== Compilation #1 ===
graph():
    %s27 : torch.SymInt [num_users=0] = placeholder[target=s27]
    %s53 : torch.SymInt [num_users=0] = placeholder[target=s53]
    %s0 : torch.SymInt [num_users=0] = placeholder[target=s0]
    %l_x_ : torch.Tensor [num_users=1] = placeholder[target=L_x_]
    %add : [num_users=1] = call_function[target=operator.add](args = (%l_x_, 1), kwargs = {})
    return (add,)
  [Runtime] Using graph #1

=== Compilation #2 ===
graph():
    %s27 : torch.SymInt [num_users=0] = placeholder[target=s27]
    %s53 : torch.SymInt [num_users=0] = placeholder[target=s53]
    %l_x_ : torch.Tensor [num_users=1] = placeholder[target=L_x_]
    %add : [num_users=1] = call_function[target=operator.add](args = (%l_x_, 1), kwargs = {})
    return (add,)
  [Runtime] Using graph #2
  [Runtime] Using graph #2
  [Runtime] Using graph #2
  [Runtime] Using graph #1

```
Note what is picked at runtime is now different before this PR always graph 1 is picked! 
Previous Behavior:
- mark_dynamic used issubset semantics: if marked dims were a subset of compiled dims, no recompilation
- **mark_static and mark_unbacked had missing guards for exact-match checking!**
- maybe_mark_dynamic had NO guards at all - changes to weak dynamic indices were ignored!
- No way to explicitly say "I want NO marks on this tensor's dimensions"!

New Behavior:
- All dimension marking APIs use exact-match semantics
- Added missing guards for _dynamo_weak_dynamic_indices (maybe_mark_dynamic)
- Guards now check for exact equality, not subset relationship
- Users can pass [] to explicitly opt-out of any marking type with guards on that.


[ghstack-poisoned]
laithsakka added a commit that referenced this pull request Mar 4, 2026
ghstack-source-id: c725f36
Pull Request resolved: #176341
@Lucaskabela Lucaskabela dismissed their stale review March 5, 2026 17:11

Withdrawing hard objection

@Lucaskabela Lucaskabela requested a review from bobrenjc93 March 5, 2026 17:13
Comment on lines +1077 to 1078
hint_override (Optional[int], default=None):
hint_override (Optional[int], default=None): An optional integer to override the size hint for this dimension.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dup docstring?

if len(index) == 0:
# Empty list explicitly means "no dynamic dims"
# This is different from not calling mark_dynamic at all (unspecified)
if not hasattr(t, "_dynamo_dynamic_indices"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if it was already set then [] doesn't clear it?
i.e. what should this do?

mark_dynamic(x, 0)
mark_dynamic(x, [])

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mm it shall append no
when we do
mark_dynamic(x, 0)
mark_dynamic(x, 1)

we append let me add unit test for that.
but it shall work as appending tldr.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants