Skip to content

Issue with torch.max() over dim 2 #1310

@bunelr

Description

@bunelr

When operating over tensor with 4 dimensions, the index returned by torch.max() are wrong and can go beyond the size of the dimension that was reduced over.

Surprinsingly, this doesn't happen for other number of dimensions

Code for demonstrating:

import torch
from torch.autograd import Variable

# This is fine and works as expected
t_2d = torch.randn(5, 25)
max_val_2d, max_idxs_2d = torch.max(t_2d, 0)
# max_val_2d / max_idxs_2d is of size 1 x 25 -> fine
# The value of max_idxs_2d go from 0 to 4    -> fine
assert(max_idxs_2d.max() < 5)

# This is fine and works as expected
t_3d = torch.randn(26, 5, 25)
max_val_3d, max_idxs_3d = torch.max(t_3d, 1)
# max_val_3d / max_idxs_3d is of size 26 x 1 x 25 -> fine
# The value of max_idxs_3d go from 0 to 4         -> fine
assert(max_idxs_3d.max() < 5)

# This is fine and works as expected
t_5d = torch.randn(1, 1, 26, 5, 25)
max_val_5d, max_idxs_5d = torch.max(t_5d, 3)
# max_val_5d / max_idxs_5d is of size 1 x 1 x 26 x 1 x 25 -> fine
# The value of max_idxs_5d go from 0 to 4                 -> fine
assert(max_idxs_5d.max() < 5)

# This is fine and works as expected
t_6d = torch.randn(1, 1, 1, 26, 5, 25)
max_val_6d, max_idxs_6d = torch.max(t_6d, 4)
# max_val_6d / max_idxs_6d is of size 1 x 1 x 1 x 26 x 1 x 25 -> fine
# The value of max_idxs_6d go from 0 to 4                     -> fine
assert(max_idxs_6d.max() < 5)

# This is fine and works as expected
t_7d = torch.randn(1, 1, 1, 1, 26, 5, 25)
max_val_7d, max_idxs_7d = torch.max(t_7d, 5)
# max_val_7d / max_idxs_7d is of size 1 x 1 x 1 x 1 x 26 x 1 x 25 -> fine
# The value of max_idxs_7d go from 0 to 4                         -> fine
assert(max_idxs_7d.max() < 5)



# This is not fine
t_4d = torch.randn(1, 26, 5, 25)
max_val_4d, max_idxs_4d = torch.max(t_4d, 2)
# max_val_4d, max_idxs_4d is of size 1 x 26 x 1 x 25 -> fine
# The value of max_idxs_4d go from 0 to 24           -> ???
print(max_idxs_4d)
assert(max_idxs_4d.max() < 5)

Can you confirm this is a bug and not a misunderstanding on my part of what torch.max() should return?

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions