Skip to content

[Compilation Cache] Add API to query number of cached computation graphs#8822

Merged
lsy323 merged 3 commits intomasterfrom
lsiyuan/get-num-graph-hash
Mar 12, 2025
Merged

[Compilation Cache] Add API to query number of cached computation graphs#8822
lsy323 merged 3 commits intomasterfrom
lsiyuan/get-num-graph-hash

Conversation

@lsy323
Copy link
Copy Markdown
Collaborator

@lsy323 lsy323 commented Mar 12, 2025

Add an API torch_xla.runtime.get_num_cached_compilation_graph to query number of cached computation graphs.

This is useful to check if recompilation is happening in application code, or for debugging purposes.

Behavior of when persistent cache is enabled:

When persistent cache is used, returns the number of in-memory cached compilation graph hash. When there is a look up in the persistent cache, it will look up the in-memory cache first, and if it is not found, it will look up the on-disk cache. If there is a cache hit in the on-disk cache, the compilation graph will be fetched into the in-memory cache.

Test:
Test {dynamo, non-dynamo} x {in memory cache, persistent cache} cases

@lsy323 lsy323 marked this pull request as ready for review March 12, 2025 04:41
Comment thread torch_xla/csrc/runtime/cache.h Outdated
Copy link
Copy Markdown
Collaborator

@yaochengji yaochengji left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@lsy323
Copy link
Copy Markdown
Collaborator Author

lsy323 commented Mar 12, 2025

Merging it now, the failed gpu ci is not related and will be fixed in #8825.

@lsy323 lsy323 merged commit bdf30f6 into master Mar 12, 2025
@lsy323 lsy323 deleted the lsiyuan/get-num-graph-hash branch March 12, 2025 18:58
Comment thread torch_xla/csrc/xla_graph_executor.h
Comment on lines +42 to +51
input1 = torch.rand(input_shape).to(xla_dev)
xm.mark_step()
model(input1)
xm.mark_step()
xm.wait_device_ops()
graph_cnt = xr.get_num_cached_compilation_graph()
input2 = torch.rand(input_shape).to(xla_dev)
model(input2)
xm.mark_step()
xm.wait_device_ops()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be nice to add a/some comments here to explain what is happening in each mark_step

@tengyifei
Copy link
Copy Markdown
Collaborator

lgtm

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants