-
Notifications
You must be signed in to change notification settings - Fork 27.4k
Select coo tensor with zero NNZ does not preserve dtype (for integer types?) #82150
Copy link
Copy link
Closed
Labels
module: advanced indexingRelated to x[i] = y, index functionsRelated to x[i] = y, index functionsmodule: sparseRelated to torch.sparseRelated to torch.sparsemodule: type promotionRelated to semantics of type promotionRelated to semantics of type promotiontriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
Index select on sparse COO tensors with zero nnz values and integral dtypes results in the reduced dimension result having dtype int64 always.
🐛 Describe the bug
import torch
dtypes = [torch.bool, torch.int8, torch.int32, torch.bfloat16, torch.float32, torch.float64]
for dtype in dtypes:
a = torch.ones((6, 1), dtype=dtype)
a_s = a.to_sparse_coo()
if a_s[0].dtype != dtype:
print(f'2d->1d index select (with nnz != 0) result dtype incorrect, for dtype {dtype}, got {a_s[0].dtype}')
if a_s[0,0].dtype != dtype:
print(f'2d->scalar index select (with nnz != 0) result dtype incorrect, for dtype {dtype}, got {a_s[0,0].dtype}')
b_s = (a * 0).to_sparse_coo()
if b_s[0].dtype != dtype:
print(f'2d->1d index select (with nnz == 0) result dtype incorrect, for dtype {dtype}, got {b_s[0].dtype}')
if b_s[0,0].dtype != dtype:
print(f'2d->scalar index select (with nnz == 0) result dtype incorrect, for dtype {dtype}, got {b_s[0,0].dtype}')Output
2d->1d index select (with nnz == 0) result dtype incorrect, for dtype torch.bool, got torch.int64
2d->scalar index select (with nnz == 0) result dtype incorrect, for dtype torch.bool, got torch.int64
2d->scalar index select (with nnz == 0) result dtype incorrect, for dtype torch.int8, got torch.int64
2d->scalar index select (with nnz == 0) result dtype incorrect, for dtype torch.int32, got torch.int64
Versions
Repro generated with current master
cc @nikitaved @pearu @cpuhrsch @amjames @bhosmer @nairbv @mruberry
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
module: advanced indexingRelated to x[i] = y, index functionsRelated to x[i] = y, index functionsmodule: sparseRelated to torch.sparseRelated to torch.sparsemodule: type promotionRelated to semantics of type promotionRelated to semantics of type promotiontriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module