Skip to content

[torchax] requires_grad in some tensor constructors not propagated #8983

@qihqi

Description

@qihqi

🐛 Bug

To Reproduce

import torch
import torchax
torchax.enable_globally()
A = torch.randn(8, requires_grad=True)
A.sum().backward()

raises error.

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions