Skip to content

Support multi-dim reductions for torch.prod, torch.all, torch.any #56586

@vishwakftw

Description

@vishwakftw

🚀 Feature

Currently, only torch.sum and torch.mean have multi-dimension reductions. It would be nice to extend this property to reduction operations such as torch.prod, torch.all and torch.any. This will also ensure that the interface is similar to NumPy, which supports this.

The current alternative to doing multi-dim reductions is to chain the operations like so:

import torch

x = torch.randn(2, 3, 5)
x.prod(-1).prod(-1)  # can be replaced by x.prod(dim=(-1, -2))

y = torch.randint(2, (2, 3, 5), dtype=bool)
y.all(-1).all(-1)  # can be replaced by y.all(dim=(-1, -2))
# analogously for torch.any

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNot as big of a feature, but technically not a bug. Should be easy to fixmodule: reductionstriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions