properly compute batch_element_count in warp_softmax#82927
properly compute batch_element_count in warp_softmax#82927
Conversation
🔗 Helpful links
✅ No Failures (0 Pending)As of commit 34f64a7 (more details on the Dr. CI page): Expand to see more💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
|
Thank you! Local test passes on my side, too. Running the cherry-pick through CI now. |
|
@pytorchbot merge -g |
|
@pytorchbot successfully started a merge job. Check the current status here |
|
Hey @ngimel. |
Summary: Turns out sometimes local_batches can be completely bogus (I thought for masked softmax they are guaranteed to be equal to WARP_BATCH), so to compute real number of elements it needs to be taken into account. cc ptrblck Pull Request resolved: #82927 Approved by: https://github.com/erichan1 Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/88e43ca409dbd3e3da41faefe82821f38b90299a Reviewed By: seemethere Differential Revision: D38542295 Pulled By: ngimel fbshipit-source-id: 2c7b6984bd1f275931cfaeee1b7e390d2a367543
Turns out sometimes local_batches can be completely bogus (I thought for masked softmax they are guaranteed to be equal to WARP_BATCH), so to compute real number of elements it needs to be taken into account.
cc @ptrblck