Skip to content

Commit 4b56a5e

Browse files
haje01richardliaw
authored andcommitted
[tune] missing torch.load in mnist_pytorch_trainable.py (#5103)
1 parent c5253cc commit 4b56a5e

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

python/ray/tune/examples/mnist_pytorch_trainable.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def _save(self, checkpoint_dir):
167167
return checkpoint_path
168168

169169
def _restore(self, checkpoint_path):
170-
self.model.load_state_dict(checkpoint_path)
170+
self.model.load_state_dict(torch.load(checkpoint_path))
171171

172172

173173
if __name__ == "__main__":

0 commit comments

Comments
 (0)