Remove all cuDNN specific inputs to RNN functions#10581
Remove all cuDNN specific inputs to RNN functions#10581apaszke wants to merge 3 commits intopytorch:masterfrom
Conversation
There was no way to handle those in the JIT for now, and they turned out to be completely unnecessary. It should make the Python and C++ module code much simpler too, since all the logic is now centralized in the native functions. The downside is that RNN modules no longer own their dropout buffers, which are shared per-device instead (with appropriate locking and synchronization). This might appear as a perf regression at first, but in reality it's highly unlikely that anyone will want to run cuDNN RNNs on the same GPU in parallel.
facebook-github-bot
left a comment
There was a problem hiding this comment.
apaszke has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
apaszke has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
zdevito
left a comment
There was a problem hiding this comment.
This looks good to me, but I don't really know the details of the rnn code, so it is hard for me to know if it broke something.
|
|
||
| // Pointer-based API (for internal use) | ||
| // Note: ATen/Context is preferred to work with streams safely | ||
| AT_API CUDAEventInternals* CUDAEvent_create(unsigned int flags); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
|
||
| #if defined(__CUDACC__) | ||
| #define REGISTER_DISPATCH(name, fn) \ | ||
| #define REGISTER_ARCH_DISPATCH(name, arch, fn) \ |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| return std::make_tuple(hx, cx); | ||
| } | ||
|
|
||
| struct DropoutState { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
Summary: This is still not the final PR, but it removes all blockers for actually using the RNN functions directly in the JIT. Next patch should be final, and will actually remove the symbolic_override code, and change it to proper symbolics for those ATen functions. Turns out the symbolic code can be also cleaned up a bit, and I'll do that too. zdevito ezyang colesbury (for minor DispatchStub.h) changes There was no way to handle those in the JIT for now, and they turned out to be completely unnecessary. It should make the Python and C++ module code much simpler too, since all the logic is now centralized in the native functions. The downside is that RNN modules no longer own their dropout buffers, which are shared per-device instead (with appropriate locking and synchronization). This might appear as a perf regression at first, but in reality it's highly unlikely that anyone will want to run cuDNN RNNs on the same GPU in parallel. Pull Request resolved: pytorch/pytorch#10581 Reviewed By: colesbury Differential Revision: D9365541 Pulled By: apaszke fbshipit-source-id: 3ef8677ee5481bae60c74a9117a2508665b476b5
Summary: After submitting PR #9726, PR #10581 created a different CUDAEvent class. The CUDAEvent proposed in #9726 was similar to the c10d::CUDAEvent class with additional testing and functionality. In particular, it was movable but not copyable. The CUDAEvent created by #10581 is refcounted and copyable. This PR retains the refcounting of the latter PR while fixing several bugs, adding tests, and extending the functionality to support testing and usage like in PR #8354. In particular, this PR: - Adds set_device() to CUDAContext - Adds three CUDAEvent tests to stream_test.cpp - Fixes three bugs: - Refcounting was broken. Destroying an of the RAIIs holding a particular CUDAEvent would destroy the event UNLESS it was the last RAII (the check was backwards). - Moving an event would cause a segfault. - Events were not destroyed on the device they were created on. See PR #9415 (pietern) - Adds the happened() and recordOnce() functions - Changes the record() functions to not be const - Adds additional assertions to verify correctness This PR does not: - Make c10d use the ATen CUDAEvent (this is appropriate for a separate PR) Whether events should be refcounted is an interesting question. It adds some atomic operations and makes event creation eager. Making events movable but not copyable (like the c10d events) avoids these costs and allows events to be lazily constructed. Lazy construction is preferable when working with containers (like std::array or std::vector) and because the event's device can be set automatically to the first stream it's recorded on. With eager construction the user is required to understand that events have a device and acquire the device of the stream the event will be recorded on upfront. This can be seen here: https://github.com/pytorch/pytorch/blob/542aadd9a7609892e207c1e15de08a975b697752/aten/src/ATen/native/cudnn/RNN.cpp#L1130-L1132 and that file is the only one which currently uses the ATen CUDAEvent. Refcounting does allow single writer multi-reader scenarios, although these scenarios can be also be supported by providing indirect access to the underlying CUDAEvent. I believe all current and planned usage scenarios do not require refcounting, and if desired I can update this PR to remove refcounting and make the ATen event movable but not copyable like the c10d event. I think not refcounting is preferable because it can improve performance, ease usability, and simplify the code (as seen with two of the above bugs). I have decided to separate this from PR #8354 since while it's required for PR #8354 the changes are, clearly, of independent interest. PR #8354 has a new dependency on this one, however. I am closing PR #9726 in favor of this PR. apaszke ezyang pietern Pull Request resolved: #11293 Differential Revision: D9665836 Pulled By: soumith fbshipit-source-id: a1513fa4f9761e2f304d126e402f6b6950e1c1d2
Summary: After submitting PR pytorch#9726, PR pytorch#10581 created a different CUDAEvent class. The CUDAEvent proposed in pytorch#9726 was similar to the c10d::CUDAEvent class with additional testing and functionality. In particular, it was movable but not copyable. The CUDAEvent created by pytorch#10581 is refcounted and copyable. This PR retains the refcounting of the latter PR while fixing several bugs, adding tests, and extending the functionality to support testing and usage like in PR pytorch#8354. In particular, this PR: - Adds set_device() to CUDAContext - Adds three CUDAEvent tests to stream_test.cpp - Fixes three bugs: - Refcounting was broken. Destroying an of the RAIIs holding a particular CUDAEvent would destroy the event UNLESS it was the last RAII (the check was backwards). - Moving an event would cause a segfault. - Events were not destroyed on the device they were created on. See PR pytorch#9415 (pietern) - Adds the happened() and recordOnce() functions - Changes the record() functions to not be const - Adds additional assertions to verify correctness This PR does not: - Make c10d use the ATen CUDAEvent (this is appropriate for a separate PR) Whether events should be refcounted is an interesting question. It adds some atomic operations and makes event creation eager. Making events movable but not copyable (like the c10d events) avoids these costs and allows events to be lazily constructed. Lazy construction is preferable when working with containers (like std::array or std::vector) and because the event's device can be set automatically to the first stream it's recorded on. With eager construction the user is required to understand that events have a device and acquire the device of the stream the event will be recorded on upfront. This can be seen here: https://github.com/pytorch/pytorch/blob/542aadd9a7609892e207c1e15de08a975b697752/aten/src/ATen/native/cudnn/RNN.cpp#L1130-L1132 and that file is the only one which currently uses the ATen CUDAEvent. Refcounting does allow single writer multi-reader scenarios, although these scenarios can be also be supported by providing indirect access to the underlying CUDAEvent. I believe all current and planned usage scenarios do not require refcounting, and if desired I can update this PR to remove refcounting and make the ATen event movable but not copyable like the c10d event. I think not refcounting is preferable because it can improve performance, ease usability, and simplify the code (as seen with two of the above bugs). I have decided to separate this from PR pytorch#8354 since while it's required for PR pytorch#8354 the changes are, clearly, of independent interest. PR pytorch#8354 has a new dependency on this one, however. I am closing PR pytorch#9726 in favor of this PR. apaszke ezyang pietern Pull Request resolved: pytorch#11293 Differential Revision: D9665836 Pulled By: soumith fbshipit-source-id: a1513fa4f9761e2f304d126e402f6b6950e1c1d2
This is still not the final PR, but it removes all blockers for actually using the RNN functions directly in the JIT. Next patch should be final, and will actually remove the symbolic_override code, and change it to proper symbolics for those ATen functions. Turns out the symbolic code can be also cleaned up a bit, and I'll do that too.
@zdevito @ezyang
@colesbury (for minor DispatchStub.h) changes
Commit message:
There was no way to handle those in the JIT for now, and they turned
out to be completely unnecessary. It should make the Python and C++
module code much simpler too, since all the logic is now centralized
in the native functions.
The downside is that RNN modules no longer own their dropout buffers,
which are shared per-device instead (with appropriate locking and
synchronization). This might appear as a perf regression at first, but
in reality it's highly unlikely that anyone will want to run cuDNN RNNs
on the same GPU in parallel.