-
Notifications
You must be signed in to change notification settings - Fork 7.4k
[tune] best practices for parallel hyper-param search on a pytorch-lightning module ? #8976
Description
what does the ray community see as good ways to interface with pytorch lightning?
I'm new to Ray, and to some extent to pytorch lightning (less than a month). My use case was as follows.
I developed a model as a LightningModule. l tried a few hyperparameters that seem relevant. I wanted to do a more thorough hyperparameter search parallelized across trials, and I saw the tutorial on different hyperparameter search schedulers, which seem helpful to me so I would like to use tune for that reason as well.
Ideally, I would take my pytorch lightning module and that would be enough for ray.tune to do the search (perhaps with minor modifications to the dataloader methods, to control number of workers), it doesn’t look like there is a tutorial on this at the moment.
My first take was to make a trainable interface that wraps both the LightningModule and a pl.Trainer. The current downside is that pl.Trainer.fit() is not very flexible from what I can tell. I cannot specify for it to run only for one more epoch, instead I need to work around it to try coercing it to do that. Here is what I ended up doing roughly,
https://gist.github.com/orm011/6b8e7465a2a80124f97fd9bb7b8a87ec
I trained a few things with it, but I’m not sure epoch and learning rate are being adjusted properly. The tensorflow logs at least don’t show the learning rate changing.
Do you recommend some other way of achieving the same goal (without the downsides)? There is probably also a question for the pytorch lightning project in here, about how to use their Trainer for a situations like this, but I will take that up separately.