Integer, scalar tensors should behave like integers when used as index. Tensors of dtype torch.uint8 deviate from that:
import torch
t_1d_single = torch.empty(1)
t_1d_multi = torch.empty(2)
for dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
print("single", dtype, t_1d_single[torch.tensor(0, dtype=dtype)].shape)
print("multi1", dtype, t_1d_multi[torch.tensor(0, dtype=dtype)].shape)
print("multi2", dtype, t_1d_multi[torch.tensor(1, dtype=dtype)].shape)
print("#" * 50)
single torch.uint8 torch.Size([0, 1])
multi1 torch.uint8 torch.Size([0, 2])
multi2 torch.uint8 torch.Size([1, 2])
##################################################
single torch.int8 torch.Size([])
multi1 torch.int8 torch.Size([])
multi2 torch.int8 torch.Size([])
##################################################
single torch.int16 torch.Size([])
multi1 torch.int16 torch.Size([])
multi2 torch.int16 torch.Size([])
##################################################
single torch.int32 torch.Size([])
multi1 torch.int32 torch.Size([])
multi2 torch.int32 torch.Size([])
##################################################
single torch.int64 torch.Size([])
multi1 torch.int64 torch.Size([])
multi2 torch.int64 torch.Size([])
##################################################
cc @mruberry @rgommers @pmeier @asmeurer @leofang @AnirudhDagar @asi1024 @emcastillo @kmaehashi
Integer, scalar tensors should behave like integers when used as index. Tensors of dtype
torch.uint8deviate from that:cc @mruberry @rgommers @pmeier @asmeurer @leofang @AnirudhDagar @asi1024 @emcastillo @kmaehashi