[Needs someone to complete] Reduce sum on many axes#2116
[Needs someone to complete] Reduce sum on many axes#2116vlasenkov wants to merge 2 commits intopytorch:masterfrom
Conversation
fmassa
left a comment
There was a problem hiding this comment.
Thanks for the PR!
I wonder if it would be better to remove the implementation from cwrap, to avoid conflicts? I think that it's better than implementing these ops in cwrap, because we automatically have support for autograd.
Also, what is the behavior of numpy for operations like median when multiple axis are passed? Does it perform multiple kernel calls, or does it transpose+view+kernel call? For sum it moght not matter, but for other ops that might make a difference.
| from ._utils import _range | ||
| from operator import mul | ||
| from functools import reduce | ||
| import collections |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/functional.py
Outdated
| input = input.sum(ax, keepdims=True) | ||
| else: | ||
| for ax in sorted(axes, reverse=True): | ||
| input = input.sum(ax) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/functional.py
Outdated
| def sum(input, axes, keepdims=False, out=None): | ||
| if isinstance(axes, collections.Iterable) | ||
| if a.dim() > 3: | ||
| if keepdims: |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| # permute | ||
| # reduce single dim | ||
| else: | ||
| return torch._C.sum(input, axes, keepdims, out) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
I am trying to perform std on many axes, pretty similar to what you are doing with sum. Is this problem solved? |
|
I also want to do var over many axes. is this solved? (same question as @bernardohenz except with var). I should note that numpy supports this, and the only way to do this in pytorch currently is to compute the mean, subtract (using expand), square, and then take the mean. Basically manually. |
|
@tstandley we are working on mean, variance and stdv on multiple axes. @colesbury should put up a PR soon for it. |
|
For signposting: |
Resolves #2006
keepdimtokeepdims