Skip to content

Batch Sampler and Dataset indexing too restricted #8652

@flauted

Description

@flauted

I was trying to use a Dataloader with a sampler of strings to index my dataset. The default BatchSampler threw me an error that essentially str cannot be casted to int. The cast in question happens on line 139 (currently):

batch.append(int(idx))

I copied the source for the Batch Loader from master, removed the cast, and everything worked as expected.

torch.utils.data DataLoader doesn't suggest that samplers should return integers, and the various Samplers just say list of indices. An index set doesn't have to be integers. Now Dataset does say __getitem__ should support integer indexing, but why?

Really, the choice of sampler (including choosing by default the default sampler) drives whether a Dataset must be indexable by integers. On the other hand, sometimes it's more sensible to "key" the dataset. For example, the sampler could "request" the a filename be loaded. Or, the actual data could be organized in a dict (like I have in my examples below).

Along the way, I noticed that BatchLoader is now asserting that sampler is a Sampler. (This isn't the case in 0.4.0, just master). I don't really understand why. Why not duck type that sampler is iterable? Just a try-except around the for-loop that catches TypeError and returns whatever type of error. It breaks the example if I'm not mistaken, and excludes a very useful way to unittest.

MWE

from string import ascii_lowercase
import torch
from torch._six import int_classes as _int_classes
from torch.utils.data import sampler, DataLoader


# CHANGES to BatchSampler in master: Removing type cast and removing sampler
# explicit type-check, adding iterable duck-type
class BatchSampler(sampler.Sampler):
    r"""Wraps another sampler to yield a mini-batch of indices.
    Args:
        sampler (Sampler): Base sampler.
        batch_size (int): Size of mini-batch.
        drop_last (bool): If ``True``, the sampler will drop the last batch if
            its size would be less than ``batch_size``
    Example:
        >>> list(BatchSampler(range(10), batch_size=3, drop_last=False))
        [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
        >>> list(BatchSampler(range(10), batch_size=3, drop_last=True))
        [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
    """
    def __init__(self, sampler, batch_size, drop_last):
        if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \
                batch_size <= 0:
            raise ValueError("batch_size should be a positive integeral value, "
                             "but got batch_size={}".format(batch_size))
        if not isinstance(drop_last, bool):
            raise ValueError("drop_last should be a boolean value, but got "
                             "drop_last={}".format(drop_last))
        self.sampler = sampler
        self.batch_size = batch_size
        self.drop_last = drop_last

    def __iter__(self):
        batch = []
        try:
            for idx in self.sampler:
                batch.append(idx)
                if len(batch) == self.batch_size:
                    yield batch
                    batch = []
            if len(batch) > 0 and not self.drop_last:
                yield batch
        except TypeError:
            raise ValueError("sampler should be an iterable, "
                                       "but got sampler={}".format(sampler))

    def __len__(self):
        if self.drop_last:
            return len(self.sampler) // self.batch_size
        else:
            return (len(self.sampler) + self.batch_size - 1) // self.batch_size


dataset_stub = {ascii_lowercase[i]: i for i in range(20)}
sampler = sampler.SubsetRandomSampler(list(dataset_stub.keys()))
loader = DataLoader(dataset_stub, batch_sampler=BatchSampler(sampler, 1, True))

for elem in loader:
    print(elem)

M.should-be-W.E.:

from string import ascii_lowercase
import torch
from torch.utils.data import sampler, DataLoader

dataset_stub = {ascii_lowercase[i]: i for i in range(20)}
sampler = sampler.SubsetRandomSampler(list(dataset_stub.keys()))
loader = DataLoader(dataset_stub, sampler=sampler)

for elem in loader:
    print(elem)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions