Skip to content

Function request: np.isin #3025

@nrbrd

Description

@nrbrd

I believe that should be useful to have a function similar to https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.in1d.html, that compares a tensor element-wise with a list of possible values. (I'm using it to filter labels/classes in some classifiers).

Expected behavior:

>>> a = torch.LongTensor([[1,2,3],[1,1,2],[3,5,1]])
>>> a
 1  2  3
 1  1  2
 3  5  1
[torch.LongTensor of size 3x3]
>>> a.in(torch.LongTensor([1, 2, 5]))
 1 1 0
 1 1 1
 0 1 1

Now it's possible to implement it by iterating over the filter, and storing the results in a tensor with the OR operator. But a faster implementation is possible with TH/THC

My current implementation:

@utils.tensorfy(0, 1, tensor_klass=torch.LongTensor)
def filter_labels(y, labels):
    """Utility used to create a mask to filter values in a tensor.

    Args:
        y (list, torch.Tensor): tensor where each element is a numeric integer
            representing a label.
        labels (list, torch.Tensor): filter used to generate the mask. For each
            value in ``y`` its mask will be "1" if its value is in ``labels``,
            "0" otherwise".

    Shape:
        y: can have any shape. Usually will be :math:`(N, S)` or :math:`(S)`,
            containing `batch X samples` or just a list of `samples`.
        labels: a flatten list, or a 1D LongTensor.

    Returns:
        mask (torch.ByteTensor): a binary mask, with "1" with the respective value from ``y`` is
        in the ``labels`` filter.

    Example::

        >>> a = torch.LongTensor([[1,2,3],[1,1,2],[3,5,1]])
        >>> a
         1  2  3
         1  1  2
         3  5  1
        [torch.LongTensor of size 3x3]
        >>> classification.filter_labels(a, [1, 2, 5])
         1  1  0
         1  1  1
         0  1  1
        [torch.ByteTensor of size 3x3]
        >>> classification.filter_labels(a, torch.LongTensor([1]))
         1  0  0
         1  1  0
         0  0  1
        [torch.ByteTensor of size 3x3]
    """
    mapping = torch.zeros(y.size()).byte()

    for label in labels:
        mapping = mapping | y.eq(label)

    return mapping

cc @mruberry @rgommers @heitorschueroff

Metadata

Metadata

Assignees

Labels

function requestA request for a new function or the addition of new arguments/modes to an existing function.module: bootcampWe plan to do a full writeup on the issue, and then get someone to do it for onboardingmodule: numpyRelated to numpy support, and also numpy compatibility of our operatorstriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions