|
1 | | -import functools |
2 | | -from typing import Any |
3 | | -from typing import Callable |
4 | | -from typing import Dict |
5 | | -from typing import Optional |
6 | | -from typing import Sequence |
7 | | -from typing import Union |
| 1 | +from optuna_integration.wandb import WeightsAndBiasesCallback |
8 | 2 |
|
9 | | -import optuna |
10 | | -from optuna._experimental import experimental_class |
11 | | -from optuna._experimental import experimental_func |
12 | | -from optuna._imports import try_import |
13 | | -from optuna.study.study import ObjectiveFuncType |
14 | 3 |
|
15 | | - |
16 | | -with try_import() as _imports: |
17 | | - import wandb |
18 | | - |
19 | | - |
20 | | -@experimental_class("2.9.0") |
21 | | -class WeightsAndBiasesCallback: |
22 | | - """Callback to track Optuna trials with Weights & Biases. |
23 | | -
|
24 | | - This callback enables tracking of Optuna study in |
25 | | - Weights & Biases. The study is tracked as a single experiment |
26 | | - run, where all suggested hyperparameters and optimized metrics |
27 | | - are logged and plotted as a function of optimizer steps. |
28 | | -
|
29 | | - .. note:: |
30 | | - User needs to be logged in to Weights & Biases before |
31 | | - using this callback in online mode. For more information, please |
32 | | - refer to `wandb setup <https://docs.wandb.ai/quickstart#1-set-up-wandb>`_. |
33 | | -
|
34 | | - .. note:: |
35 | | - Users who want to run multiple Optuna studies within the same process |
36 | | - should call ``wandb.finish()`` between subsequent calls to |
37 | | - ``study.optimize()``. Calling ``wandb.finish()`` is not necessary |
38 | | - if you are running one Optuna study per process. |
39 | | -
|
40 | | - .. note:: |
41 | | - To ensure correct trial order in Weights & Biases, this callback |
42 | | - should only be used with ``study.optimize(n_jobs=1)``. |
43 | | -
|
44 | | -
|
45 | | - Example: |
46 | | -
|
47 | | - Add Weights & Biases callback to Optuna optimization. |
48 | | -
|
49 | | - .. code:: |
50 | | -
|
51 | | - import optuna |
52 | | - from optuna.integration.wandb import WeightsAndBiasesCallback |
53 | | -
|
54 | | -
|
55 | | - def objective(trial): |
56 | | - x = trial.suggest_float("x", -10, 10) |
57 | | - return (x - 2) ** 2 |
58 | | -
|
59 | | -
|
60 | | - study = optuna.create_study() |
61 | | -
|
62 | | - wandb_kwargs = {"project": "my-project"} |
63 | | - wandbc = WeightsAndBiasesCallback(wandb_kwargs=wandb_kwargs) |
64 | | -
|
65 | | - study.optimize(objective, n_trials=10, callbacks=[wandbc]) |
66 | | -
|
67 | | -
|
68 | | -
|
69 | | - Weights & Biases logging in multirun mode. |
70 | | -
|
71 | | - .. code:: |
72 | | -
|
73 | | - import optuna |
74 | | - from optuna.integration.wandb import WeightsAndBiasesCallback |
75 | | -
|
76 | | - wandb_kwargs = {"project": "my-project"} |
77 | | - wandbc = WeightsAndBiasesCallback(wandb_kwargs=wandb_kwargs, as_multirun=True) |
78 | | -
|
79 | | -
|
80 | | - @wandbc.track_in_wandb() |
81 | | - def objective(trial): |
82 | | - x = trial.suggest_float("x", -10, 10) |
83 | | - return (x - 2) ** 2 |
84 | | -
|
85 | | -
|
86 | | - study = optuna.create_study() |
87 | | - study.optimize(objective, n_trials=10, callbacks=[wandbc]) |
88 | | -
|
89 | | -
|
90 | | - Args: |
91 | | - metric_name: |
92 | | - Name assigned to optimized metric. In case of multi-objective optimization, |
93 | | - list of names can be passed. Those names will be assigned |
94 | | - to metrics in the order returned by objective function. |
95 | | - If single name is provided, or this argument is left to default value, |
96 | | - it will be broadcasted to each objective with a number suffix in order |
97 | | - returned by objective function e.g. two objectives and default metric name |
98 | | - will be logged as ``value_0`` and ``value_1``. The number of metrics must be |
99 | | - the same as the number of values objective function returns. |
100 | | - wandb_kwargs: |
101 | | - Set of arguments passed when initializing Weights & Biases run. |
102 | | - Please refer to `Weights & Biases API documentation |
103 | | - <https://docs.wandb.ai/ref/python/init>`_ for more details. |
104 | | - as_multirun: |
105 | | - Creates new runs for each trial. Useful for generating W&B Sweeps like |
106 | | - panels (for ex., parameter importance, parallel coordinates, etc). |
107 | | -
|
108 | | - """ |
109 | | - |
110 | | - def __init__( |
111 | | - self, |
112 | | - metric_name: Union[str, Sequence[str]] = "value", |
113 | | - wandb_kwargs: Optional[Dict[str, Any]] = None, |
114 | | - as_multirun: bool = False, |
115 | | - ) -> None: |
116 | | - _imports.check() |
117 | | - |
118 | | - if not isinstance(metric_name, Sequence): |
119 | | - raise TypeError( |
120 | | - "Expected metric_name to be string or sequence of strings, got {}.".format( |
121 | | - type(metric_name) |
122 | | - ) |
123 | | - ) |
124 | | - |
125 | | - self._metric_name = metric_name |
126 | | - self._wandb_kwargs = wandb_kwargs or {} |
127 | | - self._as_multirun = as_multirun |
128 | | - |
129 | | - if not self._as_multirun: |
130 | | - self._initialize_run() |
131 | | - |
132 | | - def __call__(self, study: optuna.study.Study, trial: optuna.trial.FrozenTrial) -> None: |
133 | | - if isinstance(self._metric_name, str): |
134 | | - if len(trial.values) > 1: |
135 | | - # Broadcast default name for multi-objective optimization. |
136 | | - names = ["{}_{}".format(self._metric_name, i) for i in range(len(trial.values))] |
137 | | - |
138 | | - else: |
139 | | - names = [self._metric_name] |
140 | | - |
141 | | - else: |
142 | | - if len(self._metric_name) != len(trial.values): |
143 | | - raise ValueError( |
144 | | - "Running multi-objective optimization " |
145 | | - "with {} objective values, but {} names specified. " |
146 | | - "Match objective values and names, or use default broadcasting.".format( |
147 | | - len(trial.values), len(self._metric_name) |
148 | | - ) |
149 | | - ) |
150 | | - |
151 | | - else: |
152 | | - names = [*self._metric_name] |
153 | | - |
154 | | - metrics = {name: value for name, value in zip(names, trial.values)} |
155 | | - |
156 | | - if self._as_multirun: |
157 | | - metrics["trial_number"] = trial.number |
158 | | - |
159 | | - attributes = {"direction": [d.name for d in study.directions]} |
160 | | - |
161 | | - step = trial.number if wandb.run else None |
162 | | - run = wandb.run |
163 | | - |
164 | | - # Might create extra runs if a user logs in wandb but doesn't use the decorator. |
165 | | - |
166 | | - if not run: |
167 | | - run = self._initialize_run() |
168 | | - run.name = f"trial/{trial.number}/{run.name}" |
169 | | - |
170 | | - run.log({**trial.params, **metrics}, step=step) |
171 | | - |
172 | | - if self._as_multirun: |
173 | | - run.config.update({**attributes, **trial.params}) |
174 | | - run.tags = tuple(self._wandb_kwargs.get("tags", ())) + (study.study_name,) |
175 | | - run.finish() |
176 | | - else: |
177 | | - run.config.update(attributes) |
178 | | - |
179 | | - @experimental_func("3.0.0") |
180 | | - def track_in_wandb(self) -> Callable: |
181 | | - """Decorator for using W&B for logging inside the objective function. |
182 | | -
|
183 | | - The run is initialized with the same ``wandb_kwargs`` that are passed to the callback. |
184 | | - All the metrics from inside the objective function will be logged into the same run |
185 | | - which stores the parameters for a given trial. |
186 | | -
|
187 | | - Example: |
188 | | -
|
189 | | - Add additional logging to Weights & Biases. |
190 | | -
|
191 | | - .. code:: |
192 | | -
|
193 | | - import optuna |
194 | | - from optuna.integration.wandb import WeightsAndBiasesCallback |
195 | | - import wandb |
196 | | -
|
197 | | - wandb_kwargs = {"project": "my-project"} |
198 | | - wandbc = WeightsAndBiasesCallback(wandb_kwargs=wandb_kwargs, as_multirun=True) |
199 | | -
|
200 | | -
|
201 | | - @wandbc.track_in_wandb() |
202 | | - def objective(trial): |
203 | | - x = trial.suggest_float("x", -10, 10) |
204 | | - wandb.log({"power": 2, "base of metric": x - 2}) |
205 | | -
|
206 | | - return (x - 2) ** 2 |
207 | | -
|
208 | | -
|
209 | | - study = optuna.create_study() |
210 | | - study.optimize(objective, n_trials=10, callbacks=[wandbc]) |
211 | | -
|
212 | | -
|
213 | | - Returns: |
214 | | - Objective function with W&B tracking enabled. |
215 | | - """ |
216 | | - |
217 | | - def decorator(func: ObjectiveFuncType) -> ObjectiveFuncType: |
218 | | - @functools.wraps(func) |
219 | | - def wrapper(trial: optuna.trial.Trial) -> Union[float, Sequence[float]]: |
220 | | - run = wandb.run # Uses global run when `as_multirun` is set to False. |
221 | | - if not run: |
222 | | - run = self._initialize_run() |
223 | | - run.name = f"trial/{trial.number}/{run.name}" |
224 | | - |
225 | | - return func(trial) |
226 | | - |
227 | | - return wrapper |
228 | | - |
229 | | - return decorator |
230 | | - |
231 | | - def _initialize_run(self) -> "wandb.sdk.wandb_run.Run": |
232 | | - """Initializes Weights & Biases run.""" |
233 | | - run = wandb.init(**self._wandb_kwargs) |
234 | | - if not isinstance(run, wandb.sdk.wandb_run.Run): |
235 | | - raise RuntimeError( |
236 | | - "Cannot create a Run. " |
237 | | - "Expected wandb.sdk.wandb_run.Run as a return. " |
238 | | - f"Got: {type(run)}." |
239 | | - ) |
240 | | - return run |
| 4 | +__all__ = ["WeightsAndBiasesCallback"] |
0 commit comments