Skip to content

Commit d6db5d5

Browse files
committed
Apply changes from #37846 to test_topk_smallest_unsorted
1 parent 3bb338a commit d6db5d5

1 file changed

Lines changed: 5 additions & 1 deletion

File tree

test/onnx/test_pytorch_onnx_onnxruntime.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1484,7 +1484,11 @@ def forward(self, x):
14841484
def test_topk_smallest_unsorted(self):
14851485
class MyModule(torch.nn.Module):
14861486
def forward(self, x, k):
1487-
return torch.topk(x, k, largest=False, sorted=False)
1487+
# When sorted=False, order of elements in the outout tensors
1488+
# are not expected to match between PyTorch and ORT
1489+
topk_unsorted = torch.topk(x, k, largest=False, sorted=False)
1490+
topk_sorted = torch.topk(x, k, largest=False, sorted=True)
1491+
return topk_sorted, torch.sort(topk_unsorted.values).values
14881492

14891493
x = torch.arange(1., 6., requires_grad=True)
14901494
k = torch.tensor(3)

0 commit comments

Comments
 (0)