Skip to content

add support for multiple dtypes in prototype ToDtype  #6697

@pmeier

Description

@pmeier

Right now, the transform takes a single dtype and one or multiple types to apply the dtype to:

class ToDtype(Lambda):
def __init__(self, dtype: torch.dtype, *types: Type) -> None:

In case one needs multiple different dtypes, you need to use multiple transforms, e.g.

transform = transforms.Compose(
    [
        transforms.ToDtype(torch.uint8, features.Image),
        transforms.ToDtype(torch.float32, features.BoundingBox),
        transforms.ToDtype(torch.int64, features.Label),
    ]
)

Not only has this runtime implications since we need to recurse three times through the same sample, it would also be better UI, if ToDtype accepted a mapping specifying the different dtypes

transform = transforms.ToDtype(
    {
        features.Image: torch.uint8,
        features.BoundingBox: torch.float32,
        features.Label: torch.int64,
    }
)

This would also align this parameter with what we did to fill.

cc @vfdev-5 @datumbox @bjuncek

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions