Skip to content

Complex Not Supported in Torchinductor #93424

@acohen13

Description

@acohen13

🐛 Describe the bug

A linear layer with complex valued weights (e.g. nn.Linear(inp_size, out_size, dtype=torch.complex64)) produces the following NotImplementedError with dynamo.optimize()(model)

Error logs

Traceback (most recent call last):
  
  File "*/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 169, in _fn
    return fn(*args, **kwargs)
  File "*/lib/python3.10/site-packages/functorch/_src/aot_autograd.py", line 951, in forward
    return compiled_f(
  File "*/lib/python3.10/site-packages/functorch/_src/aot_autograd.py", line 937, in new_func
    compiled_fn = create_aot_dispatcher_function(
  File "*/lib/python3.10/site-packages/functorch/_src/aot_autograd.py", line 657, in create_aot_dispatcher_function
    aot_dispatch_autograd(flat_fn, fake_flat_tensor_args, aot_config)
  File "*/lib/python3.10/site-packages/functorch/_src/aot_autograd.py", line 501, in aot_dispatch_autograd
    fw_module, bw_module = aot_config.partition_fn(fx_g, joint_inputs)
  File "*/lib/python3.10/site-packages/functorch/_src/partitioners.py", line 420, in min_cut_rematerialization_partition
    weight = get_node_weight(node)
  File "*/lib/python3.10/site-packages/functorch/_src/partitioners.py", line 382, in get_node_weight
    mem_sz = _size_of(node)
  File "*/lib/python3.10/site-packages/functorch/_src/partitioners.py", line 225, in _size_of
    return _tensor_nbytes(to_size_hint(val.numel()), val.dtype)
  File "*/lib/python3.10/site-packages/functorch/_src/partitioners.py", line 206, in _tensor_nbytes
    raise NotImplementedError("Don't know the size of dtype ", dtype)
NotImplementedError: ("Don't know the size of dtype ", torch.complex64)

Minified repro

No response

cc @ezyang @msaroufim @wconstab @bdhirsh @anijain2305 @zou3519 @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler @soumith @ngimel

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: inductoroncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions