Skip to content

Commit b792f48

Browse files
authored
Update test_pytorch_onnx_onnxruntime.py
1 parent 6e7e7d5 commit b792f48

1 file changed

Lines changed: 6 additions & 0 deletions

File tree

test/onnx/test_pytorch_onnx_onnxruntime.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3001,6 +3001,7 @@ def forward(self, input, target):
30013001
target = torch.empty(N, dtype=torch.long).random_(0, C)
30023002
self.run_test(NLLModel(), (input, target))
30033003

3004+
@unittest.skip("Enable this once ORT version is updated")
30043005
@skipIfUnsupportedMinOpsetVersion(12)
30053006
def test_nllloss_2d_none(self):
30063007
class NLLModel(torch.nn.Module):
@@ -3019,6 +3020,7 @@ def forward(self, input, target):
30193020
target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
30203021
self.run_test(NLLModel(), (input, target))
30213022

3023+
@unittest.skip("Enable this once ORT version is updated")
30223024
@skipIfUnsupportedMinOpsetVersion(12)
30233025
def test_nllloss_2d_mean(self):
30243026
class NLLModel(torch.nn.Module):
@@ -3037,6 +3039,7 @@ def forward(self, input, target):
30373039
target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
30383040
self.run_test(NLLModel(), (input, target))
30393041

3042+
@unittest.skip("Enable this once ORT version is updated")
30403043
@skipIfUnsupportedMinOpsetVersion(12)
30413044
def test_nllloss_2d_sum(self):
30423045
class NLLModel(torch.nn.Module):
@@ -3055,6 +3058,7 @@ def forward(self, input, target):
30553058
target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
30563059
self.run_test(NLLModel(), (input, target))
30573060

3061+
@unittest.skip("Enable this once ORT version is updated")
30583062
@skipIfUnsupportedMinOpsetVersion(12)
30593063
def test_nllloss_2d_mean_weights(self):
30603064
class NLLModel(torch.nn.Module):
@@ -3073,6 +3077,7 @@ def forward(self, input, target):
30733077
target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
30743078
self.run_test(NLLModel(), (input, target))
30753079

3080+
@unittest.skip("Enable this once ORT version is updated")
30763081
@skipIfUnsupportedMinOpsetVersion(12)
30773082
def test_nllloss_2d_mean_ignore_index(self):
30783083
class NLLModel(torch.nn.Module):
@@ -3091,6 +3096,7 @@ def forward(self, input, target):
30913096
target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
30923097
self.run_test(NLLModel(), (input, target))
30933098

3099+
@unittest.skip("Enable this once ORT version is updated")
30943100
@skipIfUnsupportedMinOpsetVersion(12)
30953101
def test_nllloss_2d_mean_ignore_index_weights(self):
30963102
class NLLModel(torch.nn.Module):

0 commit comments

Comments
 (0)