Skip to content

validate_input_col is impractical for partial functions (with large arguments) #95066

@BeckerFelix

Description

@BeckerFelix

🐛 Describe the bug

One can map a Callable over datapipes using torchdata.datapipes.iter.Mapper where Callable might be a functools.partial.

The mapper ultimately calls validate_input_col (def) which in turn calls str() on the Callable to get its name.

  • str(Callable) is not the correct way to get the fn_name
  • str() on a functools.partial that has large arguments is very slow
  • str() on a functools.partial includes the complete function arguments in the returned string. These will then be printed in a potential stacktrace...
import functools
from timeit import default_timer as timer

from torch.utils.data.datapipes.utils.common import validate_input_col


def foo(*args):
    pass

d = {i: list(range(i)) for i in range(10_000)}
partial_foo = functools.partial(foo, d)

start = timer()
validate_input_col(fn=partial_foo, input_col=[1, 2])
end = timer()
print(f"elapsed time: {round(end - start, 2)}")

-> elapsed time: 6.21

changing def foo(*args): to def foo(arg1) results in a stacktrace with around 300 million characters, because it includes the whole d.

Versions

PyTorch version: 1.13.1+cu117
Is debug build: False
Python version: 3.9.12 (main, Feb 1 2023, 14:04:09) [GCC 9.4.0] (64-bit runtime)

cc @VitalyFedyunin @ejguan @NivekT @dzhulgakov

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: datatorch.utils.datatriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions