Skip to content

Select coo tensor with zero NNZ does not preserve dtype (for integer types?) #82150

@amjames

Description

@amjames

Index select on sparse COO tensors with zero nnz values and integral dtypes results in the reduced dimension result having dtype int64 always.

🐛 Describe the bug

import torch
dtypes = [torch.bool, torch.int8, torch.int32, torch.bfloat16, torch.float32, torch.float64]
for dtype in dtypes:
    a = torch.ones((6, 1), dtype=dtype)
    a_s = a.to_sparse_coo()
    if a_s[0].dtype != dtype:
        print(f'2d->1d index select (with nnz != 0) result dtype incorrect, for dtype {dtype}, got {a_s[0].dtype}')
    if a_s[0,0].dtype != dtype:
        print(f'2d->scalar index select (with nnz != 0) result dtype incorrect, for dtype {dtype}, got {a_s[0,0].dtype}')

    b_s = (a * 0).to_sparse_coo()
    if b_s[0].dtype != dtype:
        print(f'2d->1d index select (with nnz == 0) result dtype incorrect, for dtype {dtype}, got {b_s[0].dtype}')
    if b_s[0,0].dtype != dtype:
        print(f'2d->scalar index select (with nnz == 0) result dtype incorrect, for dtype {dtype}, got {b_s[0,0].dtype}')

Output

2d->1d index select (with nnz == 0) result dtype incorrect, for dtype torch.bool, got torch.int64
2d->scalar index select (with nnz == 0) result dtype incorrect, for dtype torch.bool, got torch.int64
2d->scalar index select (with nnz == 0) result dtype incorrect, for dtype torch.int8, got torch.int64
2d->scalar index select (with nnz == 0) result dtype incorrect, for dtype torch.int32, got torch.int64

Versions

Repro generated with current master

cc @nikitaved @pearu @cpuhrsch @amjames @bhosmer @nairbv @mruberry

Metadata

Metadata

Assignees

Labels

module: advanced indexingRelated to x[i] = y, index functionsmodule: sparseRelated to torch.sparsemodule: type promotionRelated to semantics of type promotiontriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions