-
Notifications
You must be signed in to change notification settings - Fork 7.2k
add support for multiple dtypes in prototype ToDtype #6697
Copy link
Copy link
Closed
Description
Right now, the transform takes a single dtype and one or multiple types to apply the dtype to:
vision/torchvision/prototype/transforms/_misc.py
Lines 143 to 144 in b482d89
| class ToDtype(Lambda): | |
| def __init__(self, dtype: torch.dtype, *types: Type) -> None: |
In case one needs multiple different dtypes, you need to use multiple transforms, e.g.
transform = transforms.Compose(
[
transforms.ToDtype(torch.uint8, features.Image),
transforms.ToDtype(torch.float32, features.BoundingBox),
transforms.ToDtype(torch.int64, features.Label),
]
)Not only has this runtime implications since we need to recurse three times through the same sample, it would also be better UI, if ToDtype accepted a mapping specifying the different dtypes
transform = transforms.ToDtype(
{
features.Image: torch.uint8,
features.BoundingBox: torch.float32,
features.Label: torch.int64,
}
)This would also align this parameter with what we did to fill.
Reactions are currently unavailable