Skip to content

PyTorch PruningCallback not pruning #294

@deepakpokkalla

Description

@deepakpokkalla

Expected behavior

I have implemented a simple example where I want to do hp tuning using optuna and each trial is spawn over two gpus (as my original data is huge). I am not looking for parallelizing trial runs themselves rather doing data parellism across 2 gpus within each trial. I expected that the trial.report() would work from custom PruningCallback I implemented, hwoever it's not pruning. I am trying the replicate the result (same trials should get pruned) when running on a single GPU with everythign else the same.

Environment

  • Optuna version: 2.10.1
  • Python version: 3.10.15
  • OS: Linux
  • torch: 2.2.2+cu121

Error messages, stack traces, or logs

When I run mlp_ddp.py script below none of the trials get pruned as a result of PruningCallback not working properly with optuna

Steps to reproduce

  1. The mlp_ddp.py code contains implementation with ddp within each trial (expect that same trials as "mlp.py" script to be pruned, but don't see any trials getting pruned)
  2. mlp.py code below contains implementation for a single gpu (reference solution for which trials should be pruned)

Reproducible examples (optional)

# my mlp_ddp.py code

import os 

import optuna
from optuna.trial import TrialState
from optuna.samplers import TPESampler
from optuna.pruners import MedianPruner, NopPruner

import torch
import torch.nn as nn
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import TensorDataset, DataLoader, DistributedSampler

optuna.logging.set_verbosity(optuna.logging.DEBUG)

class MLP(torch.nn.Module):
    def __init__(self,n_layers,hidden_dim,in_dim=10,out_dim=3):
        super().__init__()

        layers = [nn.Linear(in_dim,hidden_dim), nn.ReLU()]
        for _ in range(n_layers):
            layers += [nn.Linear(hidden_dim,hidden_dim), nn.ReLU()]
        layers.append(nn.Linear(hidden_dim,out_dim))

        self.model = nn.Sequential(*layers)

    def forward(self,x):
        return self.model(x)

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

class PruningCallback:
    def __init__(self,trial,monitor="accuracy"):
        self.trial = trial
        self.monitor = monitor
    
    def on_epoch_end(self,epoch,metrics):
        value = metrics.get(self.monitor)
        if value is None:
            return
        
        self.trial.report(value,step=epoch)
        if self.trial.should_prune():
            raise optuna.TrialPruned()

def objective(rank, world_size, params, callback, return_dict):

    setup(rank, world_size)
    torch.manual_seed(42)
    device = torch.device(f"cuda:{rank}")

    in_dim = 10
    out_dim = 3
    num_train_samples=500
    num_val_samples=100
    num_epochs = 10
    batch_size = 64

    train_data = torch.rand(num_train_samples,in_dim).to(device)
    val_data = torch.rand(num_val_samples,in_dim).to(device)
    train_targets = torch.randint(0,out_dim,(num_train_samples,)).to(device)
    val_targets = torch.randint(0,out_dim,(num_val_samples,)).to(device)

    model = MLP(params["n_layers"],params["hidden_dim"],in_dim,out_dim).to(device)
    model = DDP(model, device_ids=[rank])
    optimizer = torch.optim.Adam(model.parameters(),lr=1e-3)
    loss_function = torch.nn.CrossEntropyLoss()

    out_dir = "./multirun/mlp-optuna-test"
    os.makedirs(out_dir, exist_ok=True)

    for epoch in range(num_epochs):
        model.train()
        optimizer.zero_grad()
        train_outputs = model(train_data)
        loss = loss_function(train_outputs,train_targets)
        loss.backward()
        optimizer.step()
    
        model.eval()
        with torch.no_grad():
            val_outputs = model(val_data)
            val_predictions = torch.argmax(val_outputs,dim=1)
            val_correct = (val_predictions == val_targets)
            acc = int(val_correct.sum())/len(val_targets)

        acc_tensor = torch.tensor(acc,device=device)
        dist.all_reduce(acc_tensor)
        acc_avg = acc_tensor.item()/world_size
        # print(rank,acc,acc_tensor,acc_avg)

        if rank==0:
            callback.on_epoch_end(epoch,{"accuracy":acc_avg})

    if rank==0:
        return_dict["result"] = acc_avg
    
    cleanup()

def ddp_objective(trial):
    params = {
        "n_layers": trial.suggest_int("n_layers", 1, 5),
        "hidden_dim": trial.suggest_int("hidden_dim", 32, 64),
    }

    world_size = 2  # Number of GPUs
    manager = mp.Manager()
    return_dict = manager.dict()

    callback = PruningCallback(trial,monitor="accuracy")
    mp.spawn(
        objective,
        args=(world_size, params, callback, return_dict),
        nprocs=world_size,
        join=True,
    )

    return return_dict["result"]


if __name__ == "__main__":

    sampler = TPESampler(seed=42)
    pruner = MedianPruner(n_startup_trials=3,n_warmup_steps=1)
    study = optuna.create_study(direction="maximize", sampler=sampler, pruner=pruner)
    print(f"Pruner:{study.pruner}")
    print(f"Sampler:{study.sampler}")

    study.optimize(ddp_objective,n_trials=20)

    pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])
    complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])

    print("Study statistics:")
    print(f"  Number of finished trials: {len(study.trials)}")
    print(f"  Number of pruned trials: {len([t for t in study.trials if t.state == TrialState.PRUNED])}")
    print(f"  Number of complete trials: {len([t for t in study.trials if t.state == TrialState.COMPLETE])}")

    print("  Number of pruned trials ---: ", len(pruned_trials))
    print("  Number of complete trials ---: ", len(complete_trials))

    print("Best trial:")
    trial = study.best_trial

    print(f"  Value: {trial.value}")
    print(f"  Params: ")
    for key, value in trial.params.items():
        print(f"    {key}: {value}")
# my mlp_ddp.py code

import os 

import optuna
from optuna.trial import TrialState
from optuna.samplers import TPESampler
from optuna.pruners import MedianPruner, NopPruner

import torch
import torch.nn as nn
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import TensorDataset, DataLoader, DistributedSampler

optuna.logging.set_verbosity(optuna.logging.DEBUG)

class MLP(torch.nn.Module):
    def __init__(self,n_layers,hidden_dim,in_dim=10,out_dim=3):
        super().__init__()

        layers = [nn.Linear(in_dim,hidden_dim), nn.ReLU()]
        for _ in range(n_layers):
            layers += [nn.Linear(hidden_dim,hidden_dim), nn.ReLU()]
        layers.append(nn.Linear(hidden_dim,out_dim))

        self.model = nn.Sequential(*layers)

    def forward(self,x):
        return self.model(x)

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

class PruningCallback:
    def __init__(self,trial,monitor="accuracy"):
        self.trial = trial
        self.monitor = monitor
    
    def on_epoch_end(self,epoch,metrics):
        value = metrics.get(self.monitor)
        if value is None:
            return
        
        self.trial.report(value,step=epoch)
        if self.trial.should_prune():
            raise optuna.TrialPruned()

def objective(rank, world_size, params, callback, return_dict):

    setup(rank, world_size)
    torch.manual_seed(42)
    device = torch.device(f"cuda:{rank}")

    in_dim = 10
    out_dim = 3
    num_train_samples=500
    num_val_samples=100
    num_epochs = 10
    batch_size = 64

    train_data = torch.rand(num_train_samples,in_dim).to(device)
    val_data = torch.rand(num_val_samples,in_dim).to(device)
    train_targets = torch.randint(0,out_dim,(num_train_samples,)).to(device)
    val_targets = torch.randint(0,out_dim,(num_val_samples,)).to(device)

    model = MLP(params["n_layers"],params["hidden_dim"],in_dim,out_dim).to(device)
    model = DDP(model, device_ids=[rank])
    optimizer = torch.optim.Adam(model.parameters(),lr=1e-3)
    loss_function = torch.nn.CrossEntropyLoss()

    out_dir = "./multirun/mlp-optuna-test"
    os.makedirs(out_dir, exist_ok=True)

    for epoch in range(num_epochs):
        model.train()
        optimizer.zero_grad()
        train_outputs = model(train_data)
        loss = loss_function(train_outputs,train_targets)
        loss.backward()
        optimizer.step()
    
        model.eval()
        with torch.no_grad():
            val_outputs = model(val_data)
            val_predictions = torch.argmax(val_outputs,dim=1)
            val_correct = (val_predictions == val_targets)
            acc = int(val_correct.sum())/len(val_targets)

        acc_tensor = torch.tensor(acc,device=device)
        dist.all_reduce(acc_tensor)
        acc_avg = acc_tensor.item()/world_size
        # print(rank,acc,acc_tensor,acc_avg)

        if rank==0:
            callback.on_epoch_end(epoch,{"accuracy":acc_avg})

    if rank==0:
        return_dict["result"] = acc_avg
    
    cleanup()

def ddp_objective(trial):
    params = {
        "n_layers": trial.suggest_int("n_layers", 1, 5),
        "hidden_dim": trial.suggest_int("hidden_dim", 32, 64),
    }

    world_size = 2  # Number of GPUs
    manager = mp.Manager()
    return_dict = manager.dict()

    callback = PruningCallback(trial,monitor="accuracy")
    mp.spawn(
        objective,
        args=(world_size, params, callback, return_dict),
        nprocs=world_size,
        join=True,
    )

    return return_dict["result"]


if __name__ == "__main__":

    sampler = TPESampler(seed=42)
    pruner = MedianPruner(n_startup_trials=3,n_warmup_steps=1)
    study = optuna.create_study(direction="maximize", sampler=sampler, pruner=pruner)
    print(f"Pruner:{study.pruner}")
    print(f"Sampler:{study.sampler}")

    study.optimize(ddp_objective,n_trials=20)

    pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])
    complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])

    print("Study statistics:")
    print(f"  Number of finished trials: {len(study.trials)}")
    print(f"  Number of pruned trials: {len([t for t in study.trials if t.state == TrialState.PRUNED])}")
    print(f"  Number of complete trials: {len([t for t in study.trials if t.state == TrialState.COMPLETE])}")

    print("  Number of pruned trials ---: ", len(pruned_trials))
    print("  Number of complete trials ---: ", len(complete_trials))

    print("Best trial:")
    trial = study.best_trial

    print(f"  Value: {trial.value}")
    print(f"  Params: ")
    for key, value in trial.params.items():
        print(f"    {key}: {value}")

Additional context (optional)

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugIssue/PR about behavior that is broken. Not for typos/CI but for example itself.staleExempt from stale bot labeling.

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions