[aoti][mps] Fix update constants buffer#158349
[aoti][mps] Fix update constants buffer#158349angelayi wants to merge 9 commits intogh/angelayi/102/basefrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/158349
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit deaf8aa with merge base 4060f30 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Attention! PyTorch one of the C-stable API file was changedYou MUST NOT change existing function declarations in this, as this header defines a stable C ABI. If you need to change the signature for a function, introduce a new v2 version of the function and modify code generation to target the new version of the function. Caused by: |
| ->memcpy(internal_constants_ptr, user_constant_ptr, constant_size) | ||
| .wait(); | ||
| #elif USE_MPS | ||
| internal_constants_ptr = constants_blob_ptr; |
There was a problem hiding this comment.
Why do we need to treat this differently for MPS?
There was a problem hiding this comment.
I think because metal has unified memory, we can't directly access the contents of the constants ptr through ptr arithmetic. Instead we have to pass the offset around and access it with the metal functions.
There was a problem hiding this comment.
Other way around: we don't have a unified memory, and storage ptr is not really a pointer but reference to id<MTLBuffer> object
| .wait(); | ||
| #elif USE_MPS | ||
| internal_constants_ptr = constants_blob_ptr; | ||
| offset = constants_internal_offset_[idx]; |
There was a problem hiding this comment.
The meaning of offset diverges from other paths, where offset really means a storage_offset of a tensor. I think this will cause a problem to aoti_torch_create_tensor_from_blob later.
There was a problem hiding this comment.
Oh I also mean storage offset here
| constants_internal_offset_[idx], | ||
| offset, | ||
| constant_size, | ||
| user_constant_ptr); |
There was a problem hiding this comment.
Hmm, I still don't think this is correct based on this signature, aoti_torch_mps_copy_buffer(void* buffer, size_t constant_offset, size_t bytes_read, size_t data_size, void* constant_buffer). What is the difference between bytes_read and data_size, and why using offset as bytes_read?
There was a problem hiding this comment.
ah good point. I renamed the params so that it hopefully makes more sense.
| auto constant_metal_buffer = (id<MTLBuffer>)constant_buffer; | ||
| auto metal_buffer = (id<MTLBuffer>)buffer; | ||
|
|
||
| at::mps::MPSStream* stream = at::mps::getCurrentMPSStream(); |
There was a problem hiding this comment.
Nit
| at::mps::MPSStream* stream = at::mps::getCurrentMPSStream(); | |
| const auto* stream = at::mps::getCurrentMPSStream(); |
| ->memcpy(internal_constants_ptr, user_constant_ptr, constant_size) | ||
| .wait(); | ||
| #elif USE_MPS | ||
| internal_constants_ptr = constants_blob_ptr; |
There was a problem hiding this comment.
Other way around: we don't have a unified memory, and storage ptr is not really a pointer but reference to id<MTLBuffer> object
|
Starting merge as part of PR stack under #158351 |
In the case where we have both mps and cpu code which can be inductor compiled, we need to case on the device -- this requires the device field to be correctly passed. Pull Request resolved: #158350 Approved by: https://github.com/malfet ghstack dependencies: #158349
Pull Request resolved: #158351 Approved by: https://github.com/desertfire, https://github.com/malfet ghstack dependencies: #158349, #158350
Pull Request resolved: #158703 Approved by: https://github.com/malfet, https://github.com/desertfire ghstack dependencies: #158349, #158350, #158351
Stack from ghstack (oldest at bottom):
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben