Skip to content

W&B logger not working as expected with accumulate_grad_batches>1 #5405

@Tomiinek

Description

@Tomiinek

🐛 Bug

When logging inside training step to wandb logger and using accumulate_grad_batches > 1 the behavior is not as expected. Similar issue as in #4304 for Tensorboard (which was closed and the fix was merged in #4738).

First half with accumulate_grad_batches == 1, second with accumulate_grad_batches == 8:

image

Moreover, the logging steps are accumulate_grad_batches * number_of_backward_passes and so when using LearningRateMonitor, the logger refuses to log anything (similar to #4811) saying:

wandb: WARNING Step must only increase in log calls.  Step 499 < 2497; dropping {'lr-AdamW': 4.9900000000000005e-06}.

To Reproduce

Sorry for not using the BoringModel, updated the example from #4304:

import os

import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split

import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from torchvision.datasets.mnist import MNIST
from torchvision import transforms

class LitClassifier(pl.LightningModule):
    def __init__(self, hidden_dim=128, learning_rate=1e-3):
        super().__init__()
        self.save_hyperparameters()

        self.l1 = torch.nn.Linear(28 * 28, self.hparams.hidden_dim)
        self.l2 = torch.nn.Linear(self.hparams.hidden_dim, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = torch.relu(self.l1(x))
        x = torch.relu(self.l2(x))
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log("train_loss",loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)

def run_test(accumulate_grad_batches, batch_size, num_workers=4):
    dataset = MNIST('', train=True, download=True, transform=transforms.ToTensor())
    mnist_train, mnist_val = random_split(dataset, [55000, 5000])
    train_loader = DataLoader(mnist_train,batch_size)
    val_loader = DataLoader(mnist_val,batch_size)

    model = LitClassifier()

    trainer = pl.Trainer(
        logger=WandbLogger(name="bug", project='.....', save_dir=".", log_model=False),
        accumulate_grad_batches=accumulate_grad_batches,
        max_epochs=2
        )
    trainer.fit(model, train_loader, val_loader)

run_test(1,32)
run_test(8,32)

Expected behavior

Take a mean (or whatever) of the values logged at the same step and not at every forward pass.

Environment

  • pytorch-lightning: 1.1.2
  • PyTorch: 1.7.1
  • OS: Linux
  • How you installed PyTorch: pip
  • Python version: 3.7.0
  • wandb: 0.10.12

Metadata

Metadata

Assignees

Labels

bugSomething isn't workinghelp wantedOpen to be worked onloggerRelated to the Loggerspriority: 0High priority task

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions