Skip to content

Stable torch.sort and torch.argsort #38681

@agadetsky

Description

@agadetsky

🚀 Feature

Add stable version of torch.sort and torch.argsort. Stable sort algorithms sort repeated elements in the same order that they appear in the input.

Motivation

In some applications we need to preserve the order of equal elements while sorting.

Alternatives

Now we can do:

def stable_argsort(arr, dim=-1, descending=False):
    arr_np = arr.detach().cpu().numpy()
    if descending:
        indices = np.argsort(-arr_np, axis=dim, kind='stable')
    else:
        indices = np.argsort(arr_np, axis=dim, kind='stable')
    return torch.from_numpy(indices).long().to(arr.device)

and

def stable_sort(arr, dim=-1, descending=False):
    arr_np = arr.detach().cpu().numpy()
    if descending:
        indices = np.sort(-arr_np, axis=dim, kind='stable')
    else:
        indices = np.sort(arr_np, axis=dim, kind='stable')
    return torch.from_numpy(indices).as_type(arr)

But I think that it will be cool to have this functionality in torch.sort and torch.argsort.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNot as big of a feature, but technically not a bug. Should be easy to fixtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions