Add out= variants for cuda.comm.broadcast/gather/scatter#39681
Add out= variants for cuda.comm.broadcast/gather/scatter#39681ssnl wants to merge 5 commits intopytorch:masterfrom
Conversation
💊 CI failures summary and remediationsAs of commit 81c20a0 (more details on the Dr. CI page): ✅ None of the CI failures appear to be your fault 💚
🚧 1 fixed upstream failure:These were probably caused by upstream breakages that were already fixed.
Please rebase on the
|
4f3305a to
cc8c402
Compare
5155011 to
659d569
Compare
There was a problem hiding this comment.
moved comm tests to a separate TestCase. Previously test_gather incorrectly included a non-comm gather test.
There was a problem hiding this comment.
Would I be correct if I assume these moved tests stay intact except that they are belong to a different test class?
There was a problem hiding this comment.
Mostly, with some tests on out= and error message added. I'll comment to highlight the additions.
|
Sorry about the delay, I will help review this. |
facebook-github-bot
left a comment
There was a problem hiding this comment.
@mrshenli has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Would I be correct if I assume these moved tests stay intact except that they are belong to a different test class?
| std::vector<Tensor> nccl_list; | ||
| nccl_list.reserve(out_tensors.size() + 1); | ||
| nccl_list.push_back(tensor); | ||
| for (auto& out_tensor : out_tensors) { | ||
| nccl_list.push_back(out_tensor); | ||
| } |
There was a problem hiding this comment.
Will it be better to move these lines into the if branch below? So that when nccl is not available but using USE_NCCL=1, we don't have to create this vector?
There was a problem hiding this comment.
:) But we need to use this vector<Tensor> to test if NCCL can accept them.
| out_tensors[i].sizes() == tensor.sizes(), | ||
| "Expected all output tensors to have same shape as the source tensor ", | ||
| tensor.sizes(), ", but output tensor at index ", i, " has shape ", | ||
| out_tensors[i].sizes()); |
There was a problem hiding this comment.
do we need to check strides?
There was a problem hiding this comment.
dont need to. if they are not all contiguous, the naive copy_ will handle this fine.
| std::vector<Tensor>& broadcast_out(const Tensor& tensor, std::vector<Tensor> &out_tensors) { | ||
| for (size_t i = 0; i < out_tensors.size(); i++) { | ||
| TORCH_CHECK( | ||
| out_tensors[i].is_cuda(), |
There was a problem hiding this comment.
nit: could you please run clang-format on this file? It might ask for 4 spaces here and several places below.
|
|
||
| // no checks | ||
| static inline | ||
| std::vector<Tensor>& _broadcast_out_impl(const Tensor& tensor, std::vector<Tensor> &out_tensors) { |
There was a problem hiding this comment.
curious, since the out_tensors is already in the arg, why do we need to return it again?
There was a problem hiding this comment.
We don't need to! This can have a void return type. I just followed the python out= and inplace functions signatures and I don't think it matters.
| } | ||
| } | ||
| return tensors; | ||
| _broadcast_out_impl(tensor, diff_device_dst_tensors); |
There was a problem hiding this comment.
When using NCCL, this will create two vectors of tensors. I wonder if it would be better if we std::move diff_device_dst_tensors and let _broadcast_out_impl take the ownership?
There was a problem hiding this comment.
_broadcast_out_impl takes a reference though, so I think it would be okay here.
| for (auto device : devices) { | ||
| if (device != tensor.get_device()) { | ||
| dst_tensors.push_back(*it++); | ||
| } else { |
There was a problem hiding this comment.
I might miss sth, but it doesn't seem this else branch will ever be reached? This function does not add the input tensor to diff_device_dst_tensors, and it seems neither does _broadcast_out_impl?
There was a problem hiding this comment.
If the target device is the same as the source device, we don't broadcast for that device (see line 88 above) and just return the source tensor (var tensor) here since there was no need to move.
| } | ||
| } | ||
| TORCH_INTERNAL_ASSERT(it == diff_device_dst_tensors.end()); | ||
| return dst_tensors; |
There was a problem hiding this comment.
Why do we need to create a new dst_tensors instead of returning diff_device_dst_tensors?
There was a problem hiding this comment.
Because devices can contain the source tensor's device and diff_device_dst_tensors don't include those.
| self.assertEqual(t, input) | ||
| if input.is_cuda and input.get_device() == i: # test not copying on same device | ||
| self.assertEqual(t.data_ptr(), input.data_ptr()) | ||
| # test out= |
| for i, t in enumerate(results): | ||
| self.assertEqual(t.get_device(), i) | ||
| self.assertEqual(t, input) | ||
| # test error msg |
| self.assertEqual(r, input[tuple(index)], atol=0, rtol=0) | ||
| chunk_start = chunk_end | ||
|
|
||
| # test error msg |
| index[dim] = slice(x.size(dim), x.size(dim) + y.size(dim)) | ||
| self.assertEqual(result[tuple(index)], y) | ||
|
|
||
| # test error msg |
| expected_device = torch.device('cuda', torch.cuda.current_device()) | ||
| else: | ||
| expected_device = destination | ||
| for use_out in [True, False]: |
| if r.device == input.device: | ||
| self.assertEqual(r.data_ptr(), input.data_ptr()) # for target @ same device, a view should be returned | ||
|
|
||
| # test out |
| const int64_t chunk_size_sum = | ||
| std::accumulate(chunk_sizes->begin(), chunk_sizes->end(), int64_t{0}); | ||
| TORCH_CHECK(!out_tensors.empty(), "Expected at least one output tensor to scatter to"); | ||
| dim = at::maybe_wrap_dim(dim, tensor); |
There was a problem hiding this comment.
what does maybe_wrap_dim do?
There was a problem hiding this comment.
it makes such that negative dims work!
| i, " has device '", out_tensors[i].device(), "'"); | ||
| auto out_sizes = out_tensors[i].sizes().vec(); | ||
| bool same_ndim = out_sizes.size() == tensor.dim(); | ||
| if (same_ndim) { |
There was a problem hiding this comment.
Since we require same_ndim always to be true, shall we do the TORCH_CHECK before this line and drop the branching here?
There was a problem hiding this comment.
The TORCH_CHECK also compares against out_sizes which can be only constructed with same_ndim
| // more copying than `scatter(src)`. | ||
| out_tensors[i].copy_(chunks[i], /*non_blocking=*/true); | ||
| } | ||
| return out_tensors; |
There was a problem hiding this comment.
Same question, is it necessary to return it since it is the same as the reference in the arg list.
mrshenli
left a comment
There was a problem hiding this comment.
LGTM! Except pending for clang-format correction.
| all_channels_last = all_channels_last && | ||
| tensor.suggest_memory_format() == MemoryFormat::ChannelsLast; | ||
| if (memory_format != MemoryFormat::Contiguous && tensor.suggest_memory_format() != memory_format) { | ||
| memory_format = MemoryFormat::Contiguous; |
There was a problem hiding this comment.
This means any disagreement in memory format across all input tensors would fall back to contiguous memory format?
There was a problem hiding this comment.
yeah, I mostly followed what the current logic is, which is a reasonable choice.
| py::arg("destination_index"), | ||
| py::call_guard<py::gil_scoped_release>()) | ||
| .def( | ||
| "_gather_out", |
There was a problem hiding this comment.
This is prior to this PR. Just curious, why we don't support providing streams for gather as well?
There was a problem hiding this comment.
:) I don't know. I assume that scatter was specially handled to speed up DP.
| devices the tensor should be scattered. | ||
| tensor (Tensor): tensor to scatter. Can be on CPU or CUDA. | ||
| devices (Iterable[torch.device, str or int], optional): an iterable of | ||
| CUDA devices, among which to broadcast. |
|
@mrshenli I think this is mergeable now :) |
facebook-github-bot
left a comment
There was a problem hiding this comment.
@mrshenli has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: Partially fixes pytorch#38911 Pull Request resolved: pytorch#39681 Differential Revision: D22161342 Pulled By: mrshenli fbshipit-source-id: 60295077159b02087823e93bb6ebac9d70adea0a
Partially fixes #38911