Skip to content

Support DDP in PyTorch-Lightning#4384

Merged
HideakiImamura merged 21 commits intooptuna:masterfrom
Alnusjaponica:fix-checks-integration-pytorch-lightning-followup
Mar 13, 2023
Merged

Support DDP in PyTorch-Lightning#4384
HideakiImamura merged 21 commits intooptuna:masterfrom
Alnusjaponica:fix-checks-integration-pytorch-lightning-followup

Conversation

@Alnusjaponica
Copy link
Copy Markdown
Contributor

Motivation

Follow up #4322.
Temporary, DDP is not supported in PyTorchLightningPruningCallback because of the problem described in #4322.
This PR make PyTorchLightningPruningCallback support DDP again.

Description of the changes

This PR

  • Require users to call callback.check_pruned() in objective functions when they use DDP
  • Activates test_pytorch_lightning_pruning_callback_ddp_monitor and test_pytorch_lightning_pruning_callback_ddp_unsupported_storage
  • Store intermediate values, pruning state and message directly in the storage

@github-actions github-actions bot added the optuna.integration Related to the `optuna.integration` submodule. This is automatically labeled by github-actions. label Feb 2, 2023
@Alnusjaponica
Copy link
Copy Markdown
Contributor Author

I made this PR review ready. Note that this PR is based on #4322 and should be merged after it.

@Alnusjaponica Alnusjaponica marked this pull request as ready for review February 2, 2023 04:21
@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented Feb 2, 2023

Codecov Report

Merging #4384 (d14363c) into master (fa54271) will decrease coverage by 0.24%.
The diff coverage is 85.71%.

📣 This organization is not using Codecov’s GitHub App Integration. We recommend you install it so Codecov can continue to function properly for your repositories. Learn more

@@            Coverage Diff             @@
##           master    #4384      +/-   ##
==========================================
- Coverage   90.40%   90.16%   -0.24%     
==========================================
  Files         172      181       +9     
  Lines       13682    14037     +355     
==========================================
+ Hits        12369    12657     +288     
- Misses       1313     1380      +67     
Impacted Files Coverage Δ
optuna/integration/pytorch_lightning.py 92.40% <85.71%> (+92.40%) ⬆️

... and 16 files with indirect coverage changes

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

@toshihikoyanase
Copy link
Copy Markdown
Member

@HideakiImamura This is a follow-up for #4322. Could you join the review, please?

@toshihikoyanase
Copy link
Copy Markdown
Member

@Alnusjaponica #4322 was merged into master. Could you rebase master, please?

@toshihikoyanase toshihikoyanase added the feature Change that does not break compatibility, but affects the public interfaces. label Feb 9, 2023
@github-actions
Copy link
Copy Markdown
Contributor

This pull request has not seen any recent activity.

@github-actions github-actions bot added the stale Exempt from stale bot labeling. label Feb 16, 2023
@github-actions github-actions bot removed the stale Exempt from stale bot labeling. label Feb 19, 2023
Copy link
Copy Markdown
Member

@toshihikoyanase toshihikoyanase left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me share my early comments.
I guess the change uses system attrs to store intermediate values even if users execute only a single process. I'm thinking of keeping them to simplify the logic. What do you think of it?

Co-authored-by: Toshihiko Yanase <toshihiko.yanase@gmail.com>
@Alnusjaponica
Copy link
Copy Markdown
Contributor Author

@toshihikoyanase Thank you for your indication. I fixed the comment and am going to make change to use system_attr only under distributed situation.

@toshihikoyanase
Copy link
Copy Markdown
Member

@Alnusjaponica Thank you for your update. As we pair-programmed the code, we may simplify the code in terms of the following points:

  • The empty intermediate values in system_attrs are only required for the DDP. We can skip it by moving the logic to on_fit_start
  • The logic in on_validation_end is a bit complicated. We may be separate the logic for single-process optimization and DDP optimization
  • check_pruned is only used for DDP optimization. It should be documented.
  • check_pruned is assumed to be used with _cachedStorage. We may remove #type ignored if we assert it.
  • The message can be removed from the system_attrs since the pruned epoch can generate it.

Copy link
Copy Markdown
Member

@toshihikoyanase toshihikoyanase left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your update. Let me add some small comments.

Alnusjaponica and others added 2 commits March 3, 2023 11:34
Co-authored-by: Toshihiko Yanase <toshihiko.yanase@gmail.com>
Co-authored-by: Toshihiko Yanase <toshihiko.yanase@gmail.com>
Alnusjaponica and others added 2 commits March 3, 2023 11:35
Co-authored-by: Toshihiko Yanase <toshihiko.yanase@gmail.com>
Co-authored-by: Toshihiko Yanase <toshihiko.yanase@gmail.com>
Co-authored-by: Toshihiko Yanase <toshihiko.yanase@gmail.com>
Alnusjaponica and others added 3 commits March 3, 2023 11:53
Co-authored-by: Toshihiko Yanase <toshihiko.yanase@gmail.com>
Copy link
Copy Markdown
Member

@toshihikoyanase toshihikoyanase left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I confirmed that the updated callback worked with the PyTorch Lightning DDP example in optuna-examples. I have a small comment, but the change almost looks good to me.

Diff of pytorch_lightning_ddp.py ```diff diff --git a/pytorch/pytorch_lightning_ddp.py b/pytorch/pytorch_lightning_ddp.py index 030834d..6d609cb 100644 --- a/pytorch/pytorch_lightning_ddp.py +++ b/pytorch/pytorch_lightning_ddp.py @@ -78,6 +78,9 @@ class LightningNet(pl.LightningModule): self.log("val_acc", accuracy, sync_dist=True) self.log("hp_metric", accuracy, on_step=False, on_epoch=True, sync_dist=True)
  • def validation_epoch_end(self, output) -> None:
  •    return
    
  • def configure_optimizers(self) -> optim.Optimizer:
    return optim.Adam(self.model.parameters())

@@ -124,19 +127,22 @@ def objective(trial: optuna.trial.Trial) -> float:
model = LightningNet(dropout, output_dims)
datamodule = FashionMNISTDataModule(data_dir=DIR, batch_size=BATCHSIZE)

  • callback = PyTorchLightningPruningCallback(trial, monitor="val_acc")
    trainer = pl.Trainer(
    logger=True,
    limit_val_batches=PERCENT_VALID_EXAMPLES,
    enable_checkpointing=False,
    max_epochs=EPOCHS,
    gpus=-1 if torch.cuda.is_available() else None,
  •    accelerator="ddp_cpu" if not torch.cuda.is_available() else None,
    
  •    accelerator="cpu" if not torch.cuda.is_available() else None,
    
  •    strategy="ddp_spawn",
       num_processes=os.cpu_count() if not torch.cuda.is_available() else None,
    
  •    callbacks=[PyTorchLightningPruningCallback(trial, monitor="val_acc")],
    
  •    callbacks=[callback],
    

    )
    hyperparameters = dict(n_layers=n_layers, dropout=dropout, output_dims=output_dims)
    trainer.logger.log_hyperparams(hyperparameters)
    trainer.fit(model, datamodule=datamodule)

  • callback.check_pruned()

    return trainer.callback_metrics["val_acc"].item()

</details>

Copy link
Copy Markdown
Member

@toshihikoyanase toshihikoyanase left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have some follow-up tasks as described in TODO comments, but I think we can work on them in a new PR.

LGTM. Thank you!

@toshihikoyanase toshihikoyanase removed their assignment Mar 3, 2023
Copy link
Copy Markdown
Member

@HideakiImamura HideakiImamura left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR and sorry for the late reply. I checked the overall codes and basically looks good to me. Could you check my several comment?

self._trial.storage.set_trial_system_attr(self._trial._trial_id, _PRUNED_KEY, True)
self._trial.storage.set_trial_system_attr(self._trial._trial_id, _EPOCH_KEY, epoch)

def check_pruned(self) -> None:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is intended to be called by users by hand after the Trainer.fit. How about adding the concrete instruction on the document about where this function should be called?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added some explanation and example codes in docstrings.

Alnusjaponica added a commit to Alnusjaponica/optuna-examples that referenced this pull request Mar 10, 2023
Copy link
Copy Markdown
Member

@HideakiImamura HideakiImamura left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update. LGTM.

@HideakiImamura HideakiImamura merged commit 6bf2e2e into optuna:master Mar 13, 2023
@HideakiImamura HideakiImamura added this to the v3.2.0 milestone Mar 13, 2023
@Alnusjaponica Alnusjaponica deleted the fix-checks-integration-pytorch-lightning-followup branch March 13, 2023 10:00
nzw0301 pushed a commit to nzw0301/optuna-examples that referenced this pull request May 4, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

feature Change that does not break compatibility, but affects the public interfaces. optuna.integration Related to the `optuna.integration` submodule. This is automatically labeled by github-actions.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants