Skip to content

Static size boolean masking #96111

@ezyang

Description

@ezyang

🐛 Describe the bug

A long standing request is #62320 ; Executorch team has agreed to implement it.

Once this is implemented, we can also get static size boolean masking to work too. The easiest way is to convert the boolean mask into an index tensor. You can use the meta implementation for indexing to do this:

@register_meta(aten.index.Tensor)
def meta_index_Tensor(self, indices):
    result: List[Optional[Tensor]] = []
    for i, index in enumerate(indices):
        if index is not None:
            check(
                index.dtype in [torch.long, torch.int, torch.int8, torch.bool],
                lambda: "tensors used as indices must be long, int, byte or bool tensors",
            )
            if index.dtype in [torch.int8, torch.bool]:
                nonzero = index.nonzero()
                k = len(result)
                check(
                    k + index.ndim <= self.ndim,
                    lambda: f"too many indices for tensor of dimension {self.ndim}",
                    IndexError,
                )
                for j in range(index.ndim):
                    check(
                        index.shape[j] == self.shape[k + j],
                        lambda: f"The shape of the mask {index.shape} at index {i} "
                        f"does not match the shape of the indexed tensor {self.shape} at index {k + j}",
                        IndexError,
                    )
                    result.append(nonzero.select(1, j))
            else:
                result.append(index)
        else:
            result.append(index)
    return result

The main annoyance is that if there are not enough elements to fill the nonzero, it will be zero padded. For a boolean mask, a zero pad is inappropriate; instead you want an invalid index, and then to fill the indexing op with some placeholder element like 0.

Versions

master

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: advanced indexingRelated to x[i] = y, index functionstriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions