Fix SyncBatchNorm for empty inputs#74944
Fix SyncBatchNorm for empty inputs#74944mrshenli wants to merge 5 commits intogh/mrshenli/341/basefrom
Conversation
TODO: 1. avoid copying count_all to CPU if possible 2. it's not crashed any more, but the output is nan Next step will try to move the fix to the CUDA kernel of `batch_norm_gather_stats_with_counts` accordingly [ghstack-poisoned]
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit d5f20a8 (more details on the Dr. CI page):
🕵️ 1 new failure recognized by patternsThe following CI failures do not appear to be due to upstream breakages
|
[ghstack-poisoned]
|
@mrshenli has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
datumbox
left a comment
There was a problem hiding this comment.
Thanks for the change @mrshenli.
Overall the approach looks good to me. I've added minor comments for nits. I'm currently testing this patch on a cluster using real data and it seems that the problem is resolved. If something breaks, I'll let you know.
| combined = torch.cat([mean, invstd, count], dim=0) | ||
| else: | ||
| # for empty input, directly set all stats to 0 | ||
| combined = torch.zeros( |
There was a problem hiding this comment.
Wouldn't something like: torch.zeros(dtype=input.dtype, device=input.device).expand(2 * num_channels + 1) also work and reduce the bandwidth that is wasted?
Not sure how the rpc is handling non-contiguous Tensors.
There was a problem hiding this comment.
torch.zeros(dtype=input.dtype, device=input.device).expand(2 * num_channels + 1)
Curious, what bandwidth does the above code save? And why RPC is relevant here?
There was a problem hiding this comment.
This "combined" Tensor is shared with all other nodes during the all reduce below right?
And while the Tensor in the code today has 2 * num_channels + 1 elements (that need to go through the wire), the expanded version has 1 element. So if it is sent over the wire effectively, you save a lot of bandwidth.
There was a problem hiding this comment.
Oh I see. Not sure if this gonna work. Collectives use ProcessGroup and will call NCCL APIs under the hood. IIRC, NCCL expects contiguous tensors and will directly read numel() elements from the memory pointer. Let me double check on that
There was a problem hiding this comment.
File "/raid/shenli/pytorch/torch/distributed/distributed_c10d.py", line 2130, in _all_gather_base
work = group._allgather_base(output_tensor, input_tensor)
RuntimeError: Tensors must be contiguous
Exception raised from check_gpu_single_tensor at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1227 (most recent call first):
Hit the above error, caused by the following line.
pytorch/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
Lines 1226 to 1228 in 835cc66
There was a problem hiding this comment.
Ok then.
As a side note, I think you should look into that as it is potentially a major bandwidth gain (and if I understand correctly, this is an expensive commodity).
| num_channels = saved_input.shape[1] | ||
| if self.needs_input_grad[0]: | ||
| # launch all_reduce to unblock other peer processes | ||
| combined = torch.zeros( |
There was a problem hiding this comment.
Same question about expanded Tensor to reduce bandwidth use
fixes #36530 Prior to this commit, SyncBatchNorm crashes with the following error message. ``` File "..../torch/nn/modules/_functions.py", line 17, in forward mean, invstd = torch.batch_norm_stats(input, eps) RuntimeError: cannot reshape tensor of 0 elements into shape [0, 3, -1] because the unspecified dimension size -1 can be any value and is ambiguous ``` This PR adds a dedicated branch to handle empty inputs. When a process recieves empty inputs, it will set its local `mean`, `invstd`, and `count` to zero, and participate in the `all_gather` collective communications in the forward pass. Then `mean` and `invstd` with zero count will be filtered out before computing global mean and invstd. In the backward pass, it also participate in the `all_reduce` communication with zero tensors to unblock its peers. Differential Revision: [D35273409](https://our.internmc.facebook.com/intern/diff/D35273409) [ghstack-poisoned]
|
|
|
|
||
| # input does not requires grad | ||
| x.requires_grad = False | ||
| self._test_not_nan(model, x) |
There was a problem hiding this comment.
There was a problem hiding this comment.
I think I agree. It's not going to be the same gradient because the minibatch statistics will be different in the two cases.
fixes #36530 Prior to this commit, SyncBatchNorm crashes with the following error message. ``` File "..../torch/nn/modules/_functions.py", line 17, in forward mean, invstd = torch.batch_norm_stats(input, eps) RuntimeError: cannot reshape tensor of 0 elements into shape [0, 3, -1] because the unspecified dimension size -1 can be any value and is ambiguous ``` This PR adds a dedicated branch to handle empty inputs. When a process recieves empty inputs, it will set its local `mean`, `invstd`, and `count` to zero, and participate in the `all_gather` collective communications in the forward pass. Then `mean` and `invstd` with zero count will be filtered out before computing global mean and invstd. In the backward pass, it also participate in the `all_reduce` communication with zero tensors to unblock its peers. Differential Revision: [D35273409](https://our.internmc.facebook.com/intern/diff/D35273409) [ghstack-poisoned]
fixes #36530 Prior to this commit, SyncBatchNorm crashes with the following error message. ``` File "..../torch/nn/modules/_functions.py", line 17, in forward mean, invstd = torch.batch_norm_stats(input, eps) RuntimeError: cannot reshape tensor of 0 elements into shape [0, 3, -1] because the unspecified dimension size -1 can be any value and is ambiguous ``` This PR adds a dedicated branch to handle empty inputs. When a process recieves empty inputs, it will set its local `mean`, `invstd`, and `count` to zero, and participate in the `all_gather` collective communications in the forward pass. Then `mean` and `invstd` with zero count will be filtered out before computing global mean and invstd. In the backward pass, it also participate in the `all_reduce` communication with zero tensors to unblock its peers. ghstack-source-id: b060e51 Pull Request resolved: #74944
|
@mrshenli has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
fixes #36530 Prior to this commit, SyncBatchNorm crashes with the following error message. ``` File "..../torch/nn/modules/_functions.py", line 17, in forward mean, invstd = torch.batch_norm_stats(input, eps) RuntimeError: cannot reshape tensor of 0 elements into shape [0, 3, -1] because the unspecified dimension size -1 can be any value and is ambiguous ``` This PR adds a dedicated branch to handle empty inputs. When a process recieves empty inputs, it will set its local `mean`, `invstd`, and `count` to zero, and participate in the `all_gather` collective communications in the forward pass. Then `mean` and `invstd` with zero count will be filtered out before computing global mean and invstd. In the backward pass, it also participate in the `all_reduce` communication with zero tensors to unblock its peers. Differential Revision: [D35273409](https://our.internmc.facebook.com/intern/diff/D35273409) [ghstack-poisoned]
fixes #36530 Prior to this commit, SyncBatchNorm crashes with the following error message. ``` File "..../torch/nn/modules/_functions.py", line 17, in forward mean, invstd = torch.batch_norm_stats(input, eps) RuntimeError: cannot reshape tensor of 0 elements into shape [0, 3, -1] because the unspecified dimension size -1 can be any value and is ambiguous ``` This PR adds a dedicated branch to handle empty inputs. When a process recieves empty inputs, it will set its local `mean`, `invstd`, and `count` to zero, and participate in the `all_gather` collective communications in the forward pass. Then `mean` and `invstd` with zero count will be filtered out before computing global mean and invstd. In the backward pass, it also participate in the `all_reduce` communication with zero tensors to unblock its peers. ghstack-source-id: d59971b Pull Request resolved: #74944
|
@mrshenli has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
datumbox
left a comment
There was a problem hiding this comment.
LGTM from my side. My tests on real-data show that the issue is fixed.
Summary: Pull Request resolved: #74944 fixes #36530 Prior to this commit, SyncBatchNorm crashes with the following error message. ``` File "..../torch/nn/modules/_functions.py", line 17, in forward mean, invstd = torch.batch_norm_stats(input, eps) RuntimeError: cannot reshape tensor of 0 elements into shape [0, 3, -1] because the unspecified dimension size -1 can be any value and is ambiguous ``` This PR adds a dedicated branch to handle empty inputs. When a process recieves empty inputs, it will set its local `mean`, `invstd`, and `count` to zero, and participate in the `all_gather` collective communications in the forward pass. Then `mean` and `invstd` with zero count will be filtered out before computing global mean and invstd. In the backward pass, it also participate in the `all_reduce` communication with zero tensors to unblock its peers. Differential Revision: D35273409 D35273409 Test Plan: Imported from OSS Reviewed By: datumbox Pulled By: mrshenli fbshipit-source-id: 1cee51eea866773c329b3fbf5da2be8a5fee6f0f
|
Hey @mrshenli. |
Stack from ghstack:
fixes #36530
Prior to this commit, SyncBatchNorm crashes with the following
error message.
This PR adds a dedicated branch to handle empty inputs. When a process
recieves empty inputs, it will set its local
mean,invstd, andcountto zero, and participate in the
all_gathercollective communications inthe forward pass. Then
meanandinvstdwith zero count will befiltered out before computing global mean and invstd. In the backward
pass, it also participate in the
all_reducecommunication with zerotensors to unblock its peers.
Differential Revision: D35273409