Skip to content

[Bugfix][Dynamo] Fix Sparse tensors by graph break in Dynamo#164873

Closed
Lucaskabela wants to merge 17 commits intomainfrom
lucaskabela/fix_164823
Closed

[Bugfix][Dynamo] Fix Sparse tensors by graph break in Dynamo#164873
Lucaskabela wants to merge 17 commits intomainfrom
lucaskabela/fix_164823

Conversation

@Lucaskabela
Copy link
Contributor

@Lucaskabela Lucaskabela commented Oct 7, 2025

Fixes #164823 by making lack of support for sparse tensors very explicit (in fake tensor, inductor, and lowering code)

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 7, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/164873

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit d9b2166 with merge base 14af1dc (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@Lucaskabela
Copy link
Contributor Author

@pytorchbot label "topic: not user facing"

@Lucaskabela
Copy link
Contributor Author

There are perhaps two ways of going about this;

  1. We can graph break on any sparse tensor in graph capture (approach taken here)
  2. We can error inside inductor itself when we see sparse tensors

I am partial to 1, since this is more defensive and clear for users; however we then preclude folks using dynamo/aot_eager for sparse tensor capture. However, I think this is fine since we already have a handful of graph breaks for sparse tensors and operations today (see

if isinstance(obj, TensorVariable):
)

@Lucaskabela Lucaskabela marked this pull request as ready for review October 9, 2025 17:28
@Lucaskabela Lucaskabela changed the title [Draft][Bugfix] Fix by graph break on sparse output ops [Bugfix][Dynamo] Fix Sparse tensors by graph break in Dynamo Oct 9, 2025
not tx.export or not config.capture_sparse_compute
):
unimplemented_v2(
gb_type="Attempted to wrap sparse Tensor",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
gb_type="Attempted to wrap sparse Tensor",
gb_type="Attempted to wrap sparse Tensor with VariableTracker",

):
unimplemented_v2(
gb_type="Attempted to wrap sparse Tensor",
context="",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
context="",
context=str(example_value),

@StrongerXi
Copy link
Contributor

If it works with aot_eager, I'd expect there's a way for inductor to skip the sparse ops during codegen and run them via aten calls... Did you check with inductor folks?

@Lucaskabela
Copy link
Contributor Author

Sure let me check - cc @laithsakka @eellison

Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be in FakeTensor instead ? E.g if an operator decomposed to use sparse tensors, or had sparsity in the backward. See

"torch.compile does not support strided NestedTensor"

@eellison
Copy link
Contributor

@StrongerXi yes, there is this check to fallback for unsupported tensor

if t.is_complex():
# Complex views are supported with IR ComplexView
_warn_complex_not_supported()
return True

@Lucaskabela
Copy link
Contributor Author

@StrongerXi yes, there is this check to fallback for unsupported tensor

if t.is_complex():
# Complex views are supported with IR ComplexView
_warn_complex_not_supported()
return True

Okay - let me see if we can move here as opposed to another graph break :)

@Lucaskabela
Copy link
Contributor Author

@pytorchmergebot merge -i "Test failure from cuda error not related to changes"

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 13, 2025

❌ 🤖 pytorchbot command failed:

@pytorchbot: error: unrecognized arguments: Test failure from cuda error not related to changes

usage: @pytorchbot [-h] {merge,revert,rebase,label,drci,cherry-pick} ...

Try @pytorchbot --help for more info.

@Lucaskabela
Copy link
Contributor Author

@pytorchmergebot merge -i

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 13, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 1 checks: inductor / inductor-test / test (inductor_torchbench, 1, 2, linux.g5.4xlarge.nvidia.gpu)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / linux-jammy-cuda12.8-py3.10-gcc11 / test (default, 3, 5, lf.linux.g6.4xlarge.experimental.nvidia.gpu)

Details for Dev Infra team Raised by workflow job


for _ in range(3):
self.assertEqual(foo_opt(view_6, buf31), foo(view_6, buf31))
with self.assertRaises(torch._dynamo.exc.BackendCompilerFailed):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was previously not supported but just silently didn't cudagraph (see

if isinstance(t, torch.Tensor) and t.is_sparse:
)

Letting users know loudly seems preferable

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased lucaskabela/fix_164823 onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout lucaskabela/fix_164823 && git pull --rebase)

@Lucaskabela
Copy link
Contributor Author

@pytorchmergebot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
…#164873)

Fixes pytorch#164823 by making lack of support for sparse tensors very explicit (in fake tensor, inductor, and lowering code)

Pull Request resolved: pytorch#164873
Approved by: https://github.com/williamwen42, https://github.com/eellison, https://github.com/mlazos
zhudada0120 pushed a commit to zhudada0120/pytorch that referenced this pull request Oct 22, 2025
…#164873)

Fixes pytorch#164823 by making lack of support for sparse tensors very explicit (in fake tensor, inductor, and lowering code)

Pull Request resolved: pytorch#164873
Approved by: https://github.com/williamwen42, https://github.com/eellison, https://github.com/mlazos
@github-actions github-actions bot deleted the lucaskabela/fix_164823 branch November 16, 2025 02:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

torch.compile with Inductor fails with NotImplementedError for models using to_sparse()

6 participants