Skip to content

Need to support kwargs in nvfuser python API integration #78923

@IvanYashchuk

Description

@IvanYashchuk

🚀 The feature, motivation and pitch

primTorch's executor prototype currently uses torch.fx.experimental.proxy_tensor.make_fx function for creating an FX graph of operations. Currently make_fx supports only positional arguments and we need to expand the support to default positional and keyword-only arguments.

In [1]: import torch

In [2]: from torch.fx.experimental.proxy_tensor import make_fx

In [3]: def foo(a, b, alpha=1.0):
   ...:     return torch.add(a, b, alpha=alpha)
   ...:

In [4]: alpha = 0.5

In [5]: a = torch.randn(2, 2)

In [6]: b = torch.randn(2, 2)

Here are a few cases that should work:

# case 1
make_fx(foo)(a, b) # raises RuntimeError: Tracing expected 3 arguments but got 2 concrete arguments

# case 2
make_fx(foo)(a, b, alpha=alpha) # TypeError: wrapped() got an unexpected keyword argument 'alpha'

# case 3
make_fx(foo)(a, b, alpha)(a, b, alpha=alpha) # TypeError: forward() got an unexpected keyword argument 'alpha'

# case 4
make_fx(foo)(a, b, alpha)(a, b) # works

Alternatives

No response

Additional context

No response

cc @ezyang @gchanan @zou3519 @mruberry @ngimel

Metadata

Metadata

Assignees

No one assigned

    Labels

    has workaroundmodule: fxmodule: primTorchtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions