Skip to content

[jit] torch.isfinite is broken #29340

@driazati

Description

@driazati
def fn(x):
    print(torch.isfinite(x))


s = torch.jit.script(fn)
fn(torch.randn(2, 2))
s(torch.randn(2, 2))

results in

tensor([[True, True],
        [True, True]])
graph(%x.1 : Tensor):
  %4 : None = prim::Constant() # ../test.py:14:0
  %2 : float = prim::ImplicitTensorToNum(%x.1) # ../test.py:15:10
  %3 : bool = aten::isfinite(%2) # ../test.py:15:10
   = prim::Print(%3) # ../test.py:15:4
  return (%4)                                                                                                                                   
Traceback (most recent call last):
  File "../test.py", line 21, in <module>
    s(torch.randn(2, 2))
RuntimeError: Cannot input a tensor of dimension other than 0 as a scalar argument
The above operation failed in interpreter, with the following stack trace:
at ../test.py:15:10
def fn(x):
    print(torch.isfinite(x))
          ~~~~~~~~~~~~~~ <--- HERE

It crashes at runtime and the signature is wrong since the actual op is not bound in, it's instead done here

DEFINE_UNARY_FLOAT_OP(aten::isfinite, std::isfinite(a), bool),



cc @suo

Metadata

Metadata

Assignees

No one assigned

    Labels

    oncall: jitAdd this issue/PR to JIT oncall triage queue

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions