Skip to content

[Pallas] Remove torch.empty in tracing#6897

Merged
alanwaketan merged 4 commits intomasterfrom
alanwaketan/pallas_no_empty
Apr 8, 2024
Merged

[Pallas] Remove torch.empty in tracing#6897
alanwaketan merged 4 commits intomasterfrom
alanwaketan/pallas_no_empty

Conversation

@alanwaketan
Copy link
Copy Markdown
Collaborator

Summary:
Previously we rely on torch.empty to create some empty tensors as the outputs from the Pallas and then make Pallas as in-place ops. However, it turns out that torch.empty will actually allocate memory and therefore it's expansive to use. In this change, I switched to simply pass the shapes and dtypes to construct the graph node.

Test Plan:
PJRT_DEVICE=TPU python test/test_pallas.py

Performance benchmarks can be found: http://shortn/_wdmom7I6q7

@alanwaketan alanwaketan self-assigned this Apr 6, 2024
@alanwaketan alanwaketan requested review from JackCaoG and lsy323 April 6, 2024 02:57
@alanwaketan
Copy link
Copy Markdown
Collaborator Author

This PR is ready for review.

@JackCaoG
Copy link
Copy Markdown
Collaborator

JackCaoG commented Apr 8, 2024

I will take a look

@JackCaoG
Copy link
Copy Markdown
Collaborator

JackCaoG commented Apr 8, 2024

Do we need this in 2.3 release? I don't consider this as the critical fix, wdyt?

@alanwaketan
Copy link
Copy Markdown
Collaborator Author

Do we need this in 2.3 release? I don't consider this as the critical fix, wdyt?

I think it is, but I will hold it until I finished some e2e modeling tests. Thanks for the review.

@alanwaketan alanwaketan merged commit 66ed39b into master Apr 8, 2024
@alanwaketan alanwaketan deleted the alanwaketan/pallas_no_empty branch April 8, 2024 22:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants