Skip to content

batching rule for aten::scatter_add_ #148307

@ZhongkuiMa

Description

@ZhongkuiMa

🚀 The feature, motivation and pitch

Hi Guys,

I'm a PhD student and working on a Pytorch project. Currently, I encountered the following warning.

UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::scatter_add_. Please file us an issue on GitHub so that we can prioritize its implementation. (Triggered internally at ../aten/src/ATen/functorch/BatchedFallback.cpp:81.)

It happens when I implement a vmap on a function containing scatter_add_.
In fact, I need operations on a very large tensor (maybe 10~40GB). So I have to use vmap to save memory but remain efficient by tensor operations.
This is a very common feature and similar operations to scatter operations may have existed.
All in all, I hope this feature can be implemented with a priority.

Alternatives

Currently, I just ignore the warning.

Additional context

Thanks for the PyTorch team's hard work.

cc @zou3519 @Chillee @samdow @kshitij12345

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNot as big of a feature, but technically not a bug. Should be easy to fixmodule: functorchPertaining to torch.func or pytorch/functorchmodule: vmaptriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions