Skip to content

Empty batch support for SyncBatchNorm #36530

@ppwwyyxx

Description

@ppwwyyxx

🚀 Feature

Support empty batches in SyncBatchNorm.

Motivation

#36382 has fixed SyncBatchNorm for cases where different workers have different batch sizes. But when some worker or all workers have zero batch size, the behavior is still unexpected.

Similar to how BatchNorm supports empty batch sizes now (in #12013 (comment)), the expected behavior for SyncBatchNorm should be:

  1. forward/backward should work properly when some or all workers have zero batch size. In particular, inputs of no elements should have non-None gradients, parameters should have zero gradients if the total batch size is 0.
  2. when the total batch size is 0, moving_mean/moving_var should not be updated.

However, currently using SyncBatchNorm with empty batch produces this error:

  File "..../torch/nn/modules/_functions.py", line 17, in forward                           
    mean, invstd = torch.batch_norm_stats(input, eps)
RuntimeError: cannot reshape tensor of 0 elements into shape [0, 3, -1] because the unspecified dimension size -1 can be any value and is ambiguous   

Alternatives

We implement it here but it's a python-based inefficient implementation.
cc @jjsjann123

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNot as big of a feature, but technically not a bug. Should be easy to fixmodule: nnRelated to torch.nntriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions