Skip to content

[aoti][mps] Fix update constants buffer#158349

Closed
angelayi wants to merge 9 commits intogh/angelayi/102/basefrom
gh/angelayi/102/head
Closed

[aoti][mps] Fix update constants buffer#158349
angelayi wants to merge 9 commits intogh/angelayi/102/basefrom
gh/angelayi/102/head

Conversation

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Jul 15, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/158349

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit deaf8aa with merge base 4060f30 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@github-actions
Copy link
Contributor

Attention! PyTorch one of the C-stable API file was changed

You MUST NOT change existing function declarations in this, as this header defines a stable C ABI. If you need to change the signature for a function, introduce a new v2 version of the function and modify code generation to target the new version of the function.


Caused by:

@angelayi angelayi marked this pull request as draft July 15, 2025 16:28
@angelayi angelayi added ciflow/mps Run MPS tests (subset of trunk) topic: not user facing topic category and removed release notes: inductor (aoti) labels Jul 15, 2025
[ghstack-poisoned]
angelayi added 3 commits July 15, 2025 13:59
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@angelayi angelayi requested a review from desertfire July 16, 2025 20:00
@angelayi angelayi requested a review from malfet July 16, 2025 20:00
@angelayi angelayi marked this pull request as ready for review July 16, 2025 20:00
->memcpy(internal_constants_ptr, user_constant_ptr, constant_size)
.wait();
#elif USE_MPS
internal_constants_ptr = constants_blob_ptr;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to treat this differently for MPS?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think because metal has unified memory, we can't directly access the contents of the constants ptr through ptr arithmetic. Instead we have to pass the offset around and access it with the metal functions.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other way around: we don't have a unified memory, and storage ptr is not really a pointer but reference to id<MTLBuffer> object

.wait();
#elif USE_MPS
internal_constants_ptr = constants_blob_ptr;
offset = constants_internal_offset_[idx];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The meaning of offset diverges from other paths, where offset really means a storage_offset of a tensor. I think this will cause a problem to aoti_torch_create_tensor_from_blob later.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I also mean storage offset here

[ghstack-poisoned]
constants_internal_offset_[idx],
offset,
constant_size,
user_constant_ptr);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I still don't think this is correct based on this signature, aoti_torch_mps_copy_buffer(void* buffer, size_t constant_offset, size_t bytes_read, size_t data_size, void* constant_buffer). What is the difference between bytes_read and data_size, and why using offset as bytes_read?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah good point. I renamed the params so that it hopefully makes more sense.

auto constant_metal_buffer = (id<MTLBuffer>)constant_buffer;
auto metal_buffer = (id<MTLBuffer>)buffer;

at::mps::MPSStream* stream = at::mps::getCurrentMPSStream();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit

Suggested change
at::mps::MPSStream* stream = at::mps::getCurrentMPSStream();
const auto* stream = at::mps::getCurrentMPSStream();

->memcpy(internal_constants_ptr, user_constant_ptr, constant_size)
.wait();
#elif USE_MPS
internal_constants_ptr = constants_blob_ptr;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other way around: we don't have a unified memory, and storage ptr is not really a pointer but reference to id<MTLBuffer> object

[ghstack-poisoned]
angelayi added 2 commits July 22, 2025 10:08
[ghstack-poisoned]
[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #158351

pytorchmergebot pushed a commit that referenced this pull request Jul 23, 2025
In the case where we have both mps and cpu code which can be inductor compiled, we need to case on the device -- this requires the device field to be correctly passed.

Pull Request resolved: #158350
Approved by: https://github.com/malfet
ghstack dependencies: #158349
pytorchmergebot pushed a commit that referenced this pull request Jul 23, 2025
pytorchmergebot pushed a commit that referenced this pull request Jul 23, 2025
@github-actions github-actions bot deleted the gh/angelayi/102/head branch August 22, 2025 02:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants