Skip to content

trace_backward does not work for complex tensors #50381

@anjali411

Description

@anjali411

I ran the code below after locally rebasing on #50380 and adding trace to GRADIENT_IMPLEMENTED_FOR_COMPLEX in gen_variable_type.py.

>>> x=torch.randn(3, 3, dtype=torch.cfloat, requires_grad=True)
>>> y=x.trace()
>>> y.backward()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/chourdiaanjali/pytorch2/torch/tensor.py", line 225, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/home/chourdiaanjali/pytorch2/torch/autograd/__init__.py", line 147, in backward
    allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag
RuntimeError: _th_index_fill_ not supported on CPUType for ComplexFloat

cc @ezyang @anjali411 @dylanbespalko @mruberry

Metadata

Metadata

Assignees

No one assigned

    Labels

    complex_autogradmodule: complexRelated to complex number support in PyTorchtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions