Skip to content

[inductor] TLParse tensor metadata logging + test#160132

Closed
skarjala wants to merge 13 commits intogh/skarjala/17/basefrom
gh/skarjala/17/head
Closed

[inductor] TLParse tensor metadata logging + test#160132
skarjala wants to merge 13 commits intogh/skarjala/17/basefrom
gh/skarjala/17/head

Conversation

@skarjala
Copy link
Contributor

@skarjala skarjala commented Aug 7, 2025

Summary:

  • Add TLParse artifact logging per op with output tensor shape, stride, and dtype for cross-rank aggregation.

Testing:

  • Add test to verify structure and contents of tlparse artifiact

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @Lucaskabela

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Aug 7, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/160132

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 10e7567 with merge base 01bcf9a (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

skarjala added a commit that referenced this pull request Aug 7, 2025
ghstack-source-id: 86972c4
Pull-Request: #160132
@skarjala skarjala changed the title logging for tensor metadata [inductor] TLParse tensor metadata logging + test Aug 7, 2025
@skarjala skarjala requested review from xmfan and yushangdi August 7, 2025 20:10
[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #160260

1 similar comment
@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #160260

[ghstack-poisoned]
skarjala added a commit that referenced this pull request Aug 11, 2025
ghstack-source-id: 1f9321b
Pull-Request: #160132
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "inductor_tlparse_tensor_meta",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be combined with the existing runtime estimates artifact

for out in op.get("outputs", []):
self.assertIn("shape", out)
self.assertIn("stride", out)
self.assertIn("dtype", out)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please write this test using assertExpectedInline so that we can see the artifact's outputs from the test file

Comment on lines +1351 to +1353
def fn(x):
y = x @ x
return y + 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is insufficient test coverage. at the very minimum, we should consider 1 case of each where shapes/stride/dtype are different between input and outputs

[ghstack-poisoned]
skarjala added a commit that referenced this pull request Aug 12, 2025
ghstack-source-id: 1f2c252
Pull-Request: #160132

fix pr feedback
[ghstack-poisoned]
skarjala added a commit that referenced this pull request Aug 12, 2025
ghstack-source-id: cad2dba
Pull-Request: #160132

fix pr feedback

update
Comment on lines +1396 to +1401
"shape": [
2
],
"stride": [
1
]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should include a test with dynamic shapes to test out the to_size_hints codepath

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

[ghstack-poisoned]
self.assertParses()

@requires_tlparse
@torch._dynamo.config.patch(dynamic_shapes=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this config doesnt do anything

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

w = z.to(torch.float16)
return w

compiled = torch.compile(f, backend="inductor", fullgraph=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

either compile with dynamic=True or use mark_dynamic like in the starter tasks

[ghstack-poisoned]
Comment on lines +1435 to +1456
simplified_ops = []
for op in ops:
outs = [
{
"shape": out.get("shape", []),
"stride": out.get("stride", []),
"dtype": out.get("dtype", None),
}
for out in op.get("outputs", [])
]
if outs:
simplified_ops.append(
{
"type": op.get("type", ""),
"outputs": outs,
}
)

simplified = (
{"ops": simplified_ops[-1:]} if simplified_ops else {"ops": []}
)
actual = json.dumps(simplified, indent=2, sort_keys=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you could just self.assertExpectedInline(ops, ...)

self.assertExpectedInline(
actual,
r"""{
"ops": [
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should really add tests with at least multiple ops

skarjala added a commit that referenced this pull request Aug 15, 2025
ghstack-source-id: 11030a1
Pull-Request: #160132

fix pr feedback

update

add dynamic test

fix dynamic test

add tests w/ more ops

fix cuda

update

change to cuda and triton

fix cuda and triton
[ghstack-poisoned]
@skarjala
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@clee2000
Copy link
Contributor

@pytorchbot revert -m "broke lint GH job link HUD commit link. landrace with another PR that changed some had_cuda related things" -c landrace

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

@skarjala your PR has been successfully reverted.

pytorchmergebot added a commit that referenced this pull request Aug 16, 2025
@pytorchmergebot pytorchmergebot added Reverted ci-no-td Do not run TD on this PR labels Aug 16, 2025
[ghstack-poisoned]
@skarjala
Copy link
Contributor Author

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 1 checks: pull / linux-jammy-py3.9-clang12 / test (crossref, 2, 2, lf.linux.2xlarge)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

can-gaa-hou pushed a commit to can-gaa-hou/pytorch that referenced this pull request Aug 22, 2025
Summary:
- Add TLParse artifact logging per op with output tensor shape, stride, and dtype for cross-rank aggregation.

Testing:
- Add test to verify structure and contents of tlparse artifiact

Pull Request resolved: pytorch#160132
Approved by: https://github.com/xmfan
ghstack dependencies: pytorch#160260
can-gaa-hou pushed a commit to can-gaa-hou/pytorch that referenced this pull request Aug 22, 2025
can-gaa-hou pushed a commit to can-gaa-hou/pytorch that referenced this pull request Aug 22, 2025
Summary:
- Add TLParse artifact logging per op with output tensor shape, stride, and dtype for cross-rank aggregation.

Testing:
- Add test to verify structure and contents of tlparse artifiact

Pull Request resolved: pytorch#160132
Approved by: https://github.com/xmfan
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
Summary:
- Add TLParse artifact logging per op with output tensor shape, stride, and dtype for cross-rank aggregation.

Testing:
- Add test to verify structure and contents of tlparse artifiact

Pull Request resolved: pytorch#160132
Approved by: https://github.com/xmfan
ghstack dependencies: pytorch#160260
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
Summary:
- Add TLParse artifact logging per op with output tensor shape, stride, and dtype for cross-rank aggregation.

Testing:
- Add test to verify structure and contents of tlparse artifiact

Pull Request resolved: pytorch#160132
Approved by: https://github.com/xmfan
@github-actions github-actions bot deleted the gh/skarjala/17/head branch September 17, 2025 02:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants