Skip to content

Commit a98a314

Browse files
committed
update unit test
1 parent e19ee00 commit a98a314

5 files changed

Lines changed: 22 additions & 4 deletions

File tree

mmselfsup/models/algorithms/base.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,10 @@ def val_step(self, data, optimizer):
150150
losses = self(**data)
151151
loss, log_vars = self._parse_losses(losses)
152152

153-
outputs = dict(
154-
loss=loss, log_vars=log_vars, num_samples=len(data['img'].data))
153+
if isinstance(data['img'], list):
154+
num_samples = len(data['img'][0].data)
155+
else:
156+
num_samples = len(data['img'].data)
157+
outputs = dict(loss=loss, log_vars=log_vars, num_samples=num_samples)
155158

156159
return outputs

tests/test_models/test_algorithms/test_deepcluster.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def test_deepcluster():
3636

3737
fake_input = torch.randn((16, 3, 224, 224))
3838
fake_labels = torch.ones(16, dtype=torch.long)
39-
fake_out = alg.forward_test(fake_input)
39+
fake_out = alg.forward(fake_input, mode='test')
4040
assert 'head0' in fake_out
4141
assert fake_out['head0'].size() == torch.Size([16, num_classes])
4242

tests/test_models/test_algorithms/test_densecl.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,8 @@ def test_densecl():
7777
assert fake_loss['loss_dense'] > 0
7878
assert alg.queue_ptr.item() == 16
7979
assert alg.queue2_ptr.item() == 16
80+
81+
# test train step with 2 keys in loss
82+
fake_outputs = alg.train_step(dict(img=[fake_input, fake_input]), None)
83+
assert fake_outputs['loss'].item() > -1
84+
assert fake_outputs['num_samples'] == 16

tests/test_models/test_algorithms/test_mocov3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,6 @@ def test_mocov3():
5252
alg.momentum_update()
5353

5454
fake_input = torch.randn((16, 3, 224, 224))
55-
fake_backbone_out = alg.extract_feat(fake_input)
55+
fake_backbone_out = alg.forward(fake_input, mode='extract')
5656
assert fake_backbone_out[0][0].size() == torch.Size([16, 384, 14, 14])
5757
assert fake_backbone_out[0][1].size() == torch.Size([16, 384])

tests/test_models/test_algorithms/test_simsiam.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,13 @@ def test_simsiam():
5151
]
5252
fake_out = alg.forward(fake_input)
5353
assert fake_out['loss'].item() > -1
54+
55+
# test train step
56+
fake_outputs = alg.train_step(dict(img=fake_input), None)
57+
assert fake_outputs['loss'].item() > -1
58+
assert fake_outputs['num_samples'] == 16
59+
60+
# test val step
61+
fake_outputs = alg.val_step(dict(img=fake_input), None)
62+
assert fake_outputs['loss'].item() > -1
63+
assert fake_outputs['num_samples'] == 16

0 commit comments

Comments
 (0)