Add Loads from fixed inputs#162031
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/162031
Note: Links to docs will display an error until the docs builds have been completed. ⏳ 1 Pending, 1 Unrelated FailureAs of commit 954c5f1 with merge base 086dec3 ( BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: trunk / linux-jammy-cuda12.8-py3.10-gcc11 / test (default, 1, 5, lf.linux.g6.4xlarge.experimental.nvidia.gpu) Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
## TODO
Check on multi indices
```Python
@cute.jit
def score_mod(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers):
in_ptr4 = buffers[0]
tmp0 = tSrS_ssa
tmp1 = b_idx
tmp2 = h_idx
tmp3 = cute.make_fragment(1, cutlass.Int32)
tmp4 = tmp3.store(32*tmp1 + tmp2)
tmp5 = cute.make_fragment(1, cutlass.BFloat16)
tmp6 = tmp3[0]
tmp7 = tmp5[0] = (in_ptr4[tmp6])
tmp8 = (tmp5.load()).to(cutlass.Float32)
tmp9 = (tmp0 + tmp8)
tSrS_ssa = tmp9
return tSrS_ssa
```
I dont think that
```
tmp4 = tmp3.store(32*tmp1 + tmp2)
tmp5 = cute.make_fragment(1, cutlass.BFloat16)
tmp6 = tmp3[0]
tmp7 = tmp5[0] = (in_ptr4[tmp6]
```
is right since this tmp6 value will be larger than the actual index dim int his case its B -> see if its possible to 1d index
Pull Request resolved: pytorch#162031
Approved by: https://github.com/v0i0
ghstack dependencies: pytorch#161118
Stack from ghstack (oldest at bottom):
TODO
Check on multi indices
I dont think that
is right since this tmp6 value will be larger than the actual index dim int his case its B -> see if its possible to 1d index
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben