Created
March 30, 2020 09:44
-
-
Save nimz/7da5db4031c523e61659c4afd443844d to your computer and use it in GitHub Desktop.
Nondeterministic behavior of PyTorch LSTM
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| import random | |
| print('cuDNN version:', torch.backends.cudnn.version()) | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False | |
| torch.manual_seed(0) | |
| np.random.seed(0) | |
| random.seed(0) | |
| # Helper functions to check reproducibility | |
| def hash_tensor(tensor): | |
| return hash(tuple(tensor.cpu().view(-1).tolist())) | |
| def hash_model(model): | |
| return hash(tuple(hash_tensor(p.data) for p in model.parameters())) | |
| # Model | |
| class RNNClassifier(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.lstm = nn.LSTM(3, 64, 4, batch_first=True, dropout=0.2) | |
| self.fc = nn.Linear(64, 2) | |
| def forward(self, x): | |
| self.lstm.flatten_parameters() | |
| out, _ = self.lstm(x) | |
| out = out[:, -1, :] | |
| return self.fc(out) | |
| model = RNNClassifier().cuda() | |
| model.train() | |
| optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) | |
| # Data and loss | |
| data = torch.randn(1024, 20, 3).cuda() | |
| labels = torch.randint(high=2, size=(1024,)).cuda() | |
| criterion = nn.CrossEntropyLoss() | |
| print('Original model hash:', hash_model(model)) # Consistent | |
| print('Input hashes:', hash_tensor(data), hash_tensor(labels)) # Consistent | |
| # Run forward and backward pass | |
| logits = model(data) | |
| loss = criterion(logits, labels) | |
| print('Output hash and loss:', hash_tensor(logits), loss.item()) # Consistent | |
| optimizer.zero_grad() | |
| loss.backward() | |
| print('Gradient hashes:', {n: hash_tensor(p.grad) for n, p in model.named_parameters()}) # Not consistent | |
| optimizer.step() | |
| print('Updated model hash:', hash_model(model)) # Not consistent |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment