-
Notifications
You must be signed in to change notification settings - Fork 195
Closed as not planned
Labels
bugIssue/PR about behavior that is broken. Not for typos/CI but for example itself.Issue/PR about behavior that is broken. Not for typos/CI but for example itself.staleExempt from stale bot labeling.Exempt from stale bot labeling.
Description
Expected behavior
I have implemented a simple example where I want to do hp tuning using optuna and each trial is spawn over two gpus (as my original data is huge). I am not looking for parallelizing trial runs themselves rather doing data parellism across 2 gpus within each trial. I expected that the trial.report() would work from custom PruningCallback I implemented, hwoever it's not pruning. I am trying the replicate the result (same trials should get pruned) when running on a single GPU with everythign else the same.
Environment
- Optuna version: 2.10.1
- Python version: 3.10.15
- OS: Linux
- torch: 2.2.2+cu121
Error messages, stack traces, or logs
When I run mlp_ddp.py script below none of the trials get pruned as a result of PruningCallback not working properly with optuna
Steps to reproduce
- The mlp_ddp.py code contains implementation with ddp within each trial (expect that same trials as "mlp.py" script to be pruned, but don't see any trials getting pruned)
- mlp.py code below contains implementation for a single gpu (reference solution for which trials should be pruned)
Reproducible examples (optional)
# my mlp_ddp.py code
import os
import optuna
from optuna.trial import TrialState
from optuna.samplers import TPESampler
from optuna.pruners import MedianPruner, NopPruner
import torch
import torch.nn as nn
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import TensorDataset, DataLoader, DistributedSampler
optuna.logging.set_verbosity(optuna.logging.DEBUG)
class MLP(torch.nn.Module):
def __init__(self,n_layers,hidden_dim,in_dim=10,out_dim=3):
super().__init__()
layers = [nn.Linear(in_dim,hidden_dim), nn.ReLU()]
for _ in range(n_layers):
layers += [nn.Linear(hidden_dim,hidden_dim), nn.ReLU()]
layers.append(nn.Linear(hidden_dim,out_dim))
self.model = nn.Sequential(*layers)
def forward(self,x):
return self.model(x)
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
class PruningCallback:
def __init__(self,trial,monitor="accuracy"):
self.trial = trial
self.monitor = monitor
def on_epoch_end(self,epoch,metrics):
value = metrics.get(self.monitor)
if value is None:
return
self.trial.report(value,step=epoch)
if self.trial.should_prune():
raise optuna.TrialPruned()
def objective(rank, world_size, params, callback, return_dict):
setup(rank, world_size)
torch.manual_seed(42)
device = torch.device(f"cuda:{rank}")
in_dim = 10
out_dim = 3
num_train_samples=500
num_val_samples=100
num_epochs = 10
batch_size = 64
train_data = torch.rand(num_train_samples,in_dim).to(device)
val_data = torch.rand(num_val_samples,in_dim).to(device)
train_targets = torch.randint(0,out_dim,(num_train_samples,)).to(device)
val_targets = torch.randint(0,out_dim,(num_val_samples,)).to(device)
model = MLP(params["n_layers"],params["hidden_dim"],in_dim,out_dim).to(device)
model = DDP(model, device_ids=[rank])
optimizer = torch.optim.Adam(model.parameters(),lr=1e-3)
loss_function = torch.nn.CrossEntropyLoss()
out_dir = "./multirun/mlp-optuna-test"
os.makedirs(out_dir, exist_ok=True)
for epoch in range(num_epochs):
model.train()
optimizer.zero_grad()
train_outputs = model(train_data)
loss = loss_function(train_outputs,train_targets)
loss.backward()
optimizer.step()
model.eval()
with torch.no_grad():
val_outputs = model(val_data)
val_predictions = torch.argmax(val_outputs,dim=1)
val_correct = (val_predictions == val_targets)
acc = int(val_correct.sum())/len(val_targets)
acc_tensor = torch.tensor(acc,device=device)
dist.all_reduce(acc_tensor)
acc_avg = acc_tensor.item()/world_size
# print(rank,acc,acc_tensor,acc_avg)
if rank==0:
callback.on_epoch_end(epoch,{"accuracy":acc_avg})
if rank==0:
return_dict["result"] = acc_avg
cleanup()
def ddp_objective(trial):
params = {
"n_layers": trial.suggest_int("n_layers", 1, 5),
"hidden_dim": trial.suggest_int("hidden_dim", 32, 64),
}
world_size = 2 # Number of GPUs
manager = mp.Manager()
return_dict = manager.dict()
callback = PruningCallback(trial,monitor="accuracy")
mp.spawn(
objective,
args=(world_size, params, callback, return_dict),
nprocs=world_size,
join=True,
)
return return_dict["result"]
if __name__ == "__main__":
sampler = TPESampler(seed=42)
pruner = MedianPruner(n_startup_trials=3,n_warmup_steps=1)
study = optuna.create_study(direction="maximize", sampler=sampler, pruner=pruner)
print(f"Pruner:{study.pruner}")
print(f"Sampler:{study.sampler}")
study.optimize(ddp_objective,n_trials=20)
pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])
complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])
print("Study statistics:")
print(f" Number of finished trials: {len(study.trials)}")
print(f" Number of pruned trials: {len([t for t in study.trials if t.state == TrialState.PRUNED])}")
print(f" Number of complete trials: {len([t for t in study.trials if t.state == TrialState.COMPLETE])}")
print(" Number of pruned trials ---: ", len(pruned_trials))
print(" Number of complete trials ---: ", len(complete_trials))
print("Best trial:")
trial = study.best_trial
print(f" Value: {trial.value}")
print(f" Params: ")
for key, value in trial.params.items():
print(f" {key}: {value}")# my mlp_ddp.py code
import os
import optuna
from optuna.trial import TrialState
from optuna.samplers import TPESampler
from optuna.pruners import MedianPruner, NopPruner
import torch
import torch.nn as nn
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import TensorDataset, DataLoader, DistributedSampler
optuna.logging.set_verbosity(optuna.logging.DEBUG)
class MLP(torch.nn.Module):
def __init__(self,n_layers,hidden_dim,in_dim=10,out_dim=3):
super().__init__()
layers = [nn.Linear(in_dim,hidden_dim), nn.ReLU()]
for _ in range(n_layers):
layers += [nn.Linear(hidden_dim,hidden_dim), nn.ReLU()]
layers.append(nn.Linear(hidden_dim,out_dim))
self.model = nn.Sequential(*layers)
def forward(self,x):
return self.model(x)
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
class PruningCallback:
def __init__(self,trial,monitor="accuracy"):
self.trial = trial
self.monitor = monitor
def on_epoch_end(self,epoch,metrics):
value = metrics.get(self.monitor)
if value is None:
return
self.trial.report(value,step=epoch)
if self.trial.should_prune():
raise optuna.TrialPruned()
def objective(rank, world_size, params, callback, return_dict):
setup(rank, world_size)
torch.manual_seed(42)
device = torch.device(f"cuda:{rank}")
in_dim = 10
out_dim = 3
num_train_samples=500
num_val_samples=100
num_epochs = 10
batch_size = 64
train_data = torch.rand(num_train_samples,in_dim).to(device)
val_data = torch.rand(num_val_samples,in_dim).to(device)
train_targets = torch.randint(0,out_dim,(num_train_samples,)).to(device)
val_targets = torch.randint(0,out_dim,(num_val_samples,)).to(device)
model = MLP(params["n_layers"],params["hidden_dim"],in_dim,out_dim).to(device)
model = DDP(model, device_ids=[rank])
optimizer = torch.optim.Adam(model.parameters(),lr=1e-3)
loss_function = torch.nn.CrossEntropyLoss()
out_dir = "./multirun/mlp-optuna-test"
os.makedirs(out_dir, exist_ok=True)
for epoch in range(num_epochs):
model.train()
optimizer.zero_grad()
train_outputs = model(train_data)
loss = loss_function(train_outputs,train_targets)
loss.backward()
optimizer.step()
model.eval()
with torch.no_grad():
val_outputs = model(val_data)
val_predictions = torch.argmax(val_outputs,dim=1)
val_correct = (val_predictions == val_targets)
acc = int(val_correct.sum())/len(val_targets)
acc_tensor = torch.tensor(acc,device=device)
dist.all_reduce(acc_tensor)
acc_avg = acc_tensor.item()/world_size
# print(rank,acc,acc_tensor,acc_avg)
if rank==0:
callback.on_epoch_end(epoch,{"accuracy":acc_avg})
if rank==0:
return_dict["result"] = acc_avg
cleanup()
def ddp_objective(trial):
params = {
"n_layers": trial.suggest_int("n_layers", 1, 5),
"hidden_dim": trial.suggest_int("hidden_dim", 32, 64),
}
world_size = 2 # Number of GPUs
manager = mp.Manager()
return_dict = manager.dict()
callback = PruningCallback(trial,monitor="accuracy")
mp.spawn(
objective,
args=(world_size, params, callback, return_dict),
nprocs=world_size,
join=True,
)
return return_dict["result"]
if __name__ == "__main__":
sampler = TPESampler(seed=42)
pruner = MedianPruner(n_startup_trials=3,n_warmup_steps=1)
study = optuna.create_study(direction="maximize", sampler=sampler, pruner=pruner)
print(f"Pruner:{study.pruner}")
print(f"Sampler:{study.sampler}")
study.optimize(ddp_objective,n_trials=20)
pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])
complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])
print("Study statistics:")
print(f" Number of finished trials: {len(study.trials)}")
print(f" Number of pruned trials: {len([t for t in study.trials if t.state == TrialState.PRUNED])}")
print(f" Number of complete trials: {len([t for t in study.trials if t.state == TrialState.COMPLETE])}")
print(" Number of pruned trials ---: ", len(pruned_trials))
print(" Number of complete trials ---: ", len(complete_trials))
print("Best trial:")
trial = study.best_trial
print(f" Value: {trial.value}")
print(f" Params: ")
for key, value in trial.params.items():
print(f" {key}: {value}")Additional context (optional)
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugIssue/PR about behavior that is broken. Not for typos/CI but for example itself.Issue/PR about behavior that is broken. Not for typos/CI but for example itself.staleExempt from stale bot labeling.Exempt from stale bot labeling.