Skip to content

Commit e759efd

Browse files
authored
convert tests to use mock server instead of mock backend (#3286)
1 parent 56cb2ec commit e759efd

4 files changed

Lines changed: 260 additions & 172 deletions

File tree

tests/integrations/test_keras.py

Lines changed: 157 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -65,24 +65,31 @@ def dummy_data(request):
6565
return data, labels
6666

6767

68-
def graph_json(run):
69-
print(run._backend.summary["graph"])
70-
graph_path = run._backend.summary["graph"]["path"]
71-
return json.load(open(os.path.join(run.dir, graph_path)))
68+
def graph_json(run_dir, summary):
69+
print(summary["graph"])
70+
path = os.path.join(run_dir, summary["graph"]["path"])
71+
with open(path) as fh:
72+
return json.load(fh)
7273

7374

7475
def test_no_init():
7576
with pytest.raises(wandb.errors.Error):
7677
WandbCallback()
7778

7879

79-
def test_basic_keras(dummy_model, dummy_data, wandb_init_run):
80+
def test_basic_keras(
81+
dummy_model, dummy_data, live_mock_server, test_settings, parse_ctx
82+
):
83+
run = wandb.init(settings=test_settings)
8084
dummy_model.fit(*dummy_data, epochs=2, batch_size=36, callbacks=[WandbCallback()])
85+
run_dir = run.dir
86+
wandb.finish()
87+
ctx_util = parse_ctx(live_mock_server.get_ctx())
8188
# wandb.run.summary.load()
82-
assert wandb.run._backend.history[0]["epoch"] == 0
89+
assert ctx_util.history[0]["epoch"] == 0
8390
# NOTE: backend mock doesnt copy history into summary (happens in internal process)
8491
# assert wandb.run._backend.summary["loss"] > 0
85-
assert len(graph_json(wandb.run)["nodes"]) == 3
92+
assert len(graph_json(run_dir, ctx_util.summary)["nodes"]) == 3
8693

8794

8895
def test_keras_telemetry(
@@ -139,139 +146,178 @@ def test_keras_image_bad_data(dummy_model, dummy_data, wandb_init_run):
139146
assert error
140147

141148

142-
def test_keras_image_binary(dummy_model, dummy_data, wandb_init_run):
143-
dummy_model.fit(
144-
*dummy_data,
145-
epochs=2,
146-
batch_size=36,
147-
validation_data=dummy_data,
148-
callbacks=[WandbCallback(data_type="image")]
149-
)
150-
assert len(wandb.run._backend.history[0]["examples"]["captions"]) == 36
149+
def test_keras_image_binary(
150+
dummy_model, dummy_data, test_settings, parse_ctx, live_mock_server
151+
):
152+
with wandb.init(settings=test_settings):
153+
dummy_model.fit(
154+
*dummy_data,
155+
epochs=2,
156+
batch_size=36,
157+
validation_data=dummy_data,
158+
callbacks=[WandbCallback(data_type="image")]
159+
)
151160

161+
ctx_util = parse_ctx(live_mock_server.get_ctx())
162+
assert len(ctx_util.history[0]["examples"]["captions"]) == 36
152163

153-
def test_keras_image_binary_captions(dummy_model, dummy_data, wandb_init_run):
154-
dummy_model.fit(
155-
*dummy_data,
156-
epochs=2,
157-
batch_size=36,
158-
validation_data=dummy_data,
159-
callbacks=[
160-
WandbCallback(data_type="image", predictions=10, labels=["Rad", "Nice"])
161-
]
162-
)
163-
assert wandb.run._backend.history[0]["examples"]["captions"][0] in ["Rad", "Nice"]
164+
165+
def test_keras_image_binary_captions(
166+
dummy_model, dummy_data, test_settings, parse_ctx, live_mock_server
167+
):
168+
169+
with wandb.init(settings=test_settings):
170+
dummy_model.fit(
171+
*dummy_data,
172+
epochs=2,
173+
batch_size=36,
174+
validation_data=dummy_data,
175+
callbacks=[
176+
WandbCallback(data_type="image", predictions=10, labels=["Rad", "Nice"])
177+
]
178+
)
179+
180+
ctx_util = parse_ctx(live_mock_server.get_ctx())
181+
assert ctx_util.history[0]["examples"]["captions"][0] in ["Rad", "Nice"]
164182

165183

166184
@pytest.mark.multiclass
167-
def test_keras_image_multiclass(dummy_model, dummy_data, wandb_init_run):
168-
dummy_model.fit(
169-
*dummy_data,
170-
epochs=2,
171-
batch_size=36,
172-
validation_data=dummy_data,
173-
callbacks=[WandbCallback(data_type="image", predictions=10)]
174-
)
175-
assert len(wandb.run._backend.history[0]["examples"]["captions"]) == 10
185+
def test_keras_image_multiclass(
186+
dummy_model, dummy_data, test_settings, parse_ctx, live_mock_server
187+
):
188+
with wandb.init(settings=test_settings):
189+
dummy_model.fit(
190+
*dummy_data,
191+
epochs=2,
192+
batch_size=36,
193+
validation_data=dummy_data,
194+
callbacks=[WandbCallback(data_type="image", predictions=10)]
195+
)
196+
197+
ctx_util = parse_ctx(live_mock_server.get_ctx())
198+
assert len(ctx_util.history[0]["examples"]["captions"]) == 10
176199

177200

178201
@pytest.mark.multiclass
179-
def test_keras_image_multiclass_captions(dummy_model, dummy_data, wandb_init_run):
180-
dummy_model.fit(
181-
*dummy_data,
182-
epochs=2,
183-
batch_size=36,
184-
validation_data=dummy_data,
185-
callbacks=[
186-
WandbCallback(
187-
data_type="image",
188-
predictions=10,
189-
labels=[
190-
"Rad",
191-
"Nice",
192-
"Fun",
193-
"Rad",
194-
"Nice",
195-
"Fun",
196-
"Rad",
197-
"Nice",
198-
"Fun",
199-
"Rad",
200-
],
201-
)
202-
]
203-
)
204-
assert wandb.run._backend.history[0]["examples"]["captions"][0] in [
202+
def test_keras_image_multiclass_captions(
203+
dummy_model, dummy_data, test_settings, parse_ctx, live_mock_server
204+
):
205+
with wandb.init(settings=test_settings):
206+
dummy_model.fit(
207+
*dummy_data,
208+
epochs=2,
209+
batch_size=36,
210+
validation_data=dummy_data,
211+
callbacks=[
212+
WandbCallback(
213+
data_type="image",
214+
predictions=10,
215+
labels=[
216+
"Rad",
217+
"Nice",
218+
"Fun",
219+
"Rad",
220+
"Nice",
221+
"Fun",
222+
"Rad",
223+
"Nice",
224+
"Fun",
225+
"Rad",
226+
],
227+
)
228+
]
229+
)
230+
231+
ctx_util = parse_ctx(live_mock_server.get_ctx())
232+
assert ctx_util.history[0]["examples"]["captions"][0] in [
205233
"Rad",
206234
"Nice",
207235
"Fun",
208236
]
209237

210238

211239
@pytest.mark.image_output
212-
def test_keras_image_output(dummy_model, dummy_data, wandb_init_run):
213-
dummy_model.fit(
214-
*dummy_data,
215-
epochs=2,
216-
batch_size=36,
217-
validation_data=dummy_data,
218-
callbacks=[WandbCallback(data_type="image", predictions=10)]
219-
)
220-
print(wandb.run._backend.history[0])
221-
assert wandb.run._backend.history[0]["examples"]["count"] == 30
222-
assert wandb.run._backend.history[0]["examples"]["height"] == 10
240+
def test_keras_image_output(
241+
dummy_model, dummy_data, test_settings, parse_ctx, live_mock_server
242+
):
243+
with wandb.init(settings=test_settings):
244+
dummy_model.fit(
245+
*dummy_data,
246+
epochs=2,
247+
batch_size=36,
248+
validation_data=dummy_data,
249+
callbacks=[WandbCallback(data_type="image", predictions=10)]
250+
)
223251

252+
ctx_util = parse_ctx(live_mock_server.get_ctx())
253+
print(ctx_util.history[0])
254+
assert ctx_util.history[0]["examples"]["count"] == 30
255+
assert ctx_util.history[0]["examples"]["height"] == 10
224256

225-
def test_dataset_functional(wandb_init_run):
226-
data = tf.data.Dataset.range(5).map(lambda x: (x, 1)).batch(1)
227-
inputs = tf.keras.Input(shape=(1,))
228-
outputs = tf.keras.layers.Dense(1)(inputs)
229-
wandb_callback = WandbCallback(save_model=False)
230257

231-
model = tf.keras.Model(inputs=inputs, outputs=outputs)
232-
model.compile(optimizer=tf.keras.optimizers.Adam(), loss="mse")
233-
model.fit(data, callbacks=[wandb_callback])
234-
assert graph_json(wandb.run)["nodes"][0]["class_name"] == "InputLayer"
258+
def test_dataset_functional(live_mock_server, test_settings, parse_ctx):
235259

260+
with wandb.init(settings=test_settings) as run:
236261

237-
def test_keras_log_weights(dummy_model, dummy_data, wandb_init_run):
238-
dummy_model.fit(
239-
*dummy_data,
240-
epochs=2,
241-
batch_size=36,
242-
validation_data=dummy_data,
243-
callbacks=[WandbCallback(data_type="image", log_weights=True)]
244-
)
262+
data = tf.data.Dataset.range(5).map(lambda x: (x, 1)).batch(1)
263+
inputs = tf.keras.Input(shape=(1,))
264+
outputs = tf.keras.layers.Dense(1)(inputs)
265+
266+
wandb_callback = WandbCallback(save_model=False)
267+
268+
model = tf.keras.Model(inputs=inputs, outputs=outputs)
269+
model.compile(optimizer=tf.keras.optimizers.Adam(), loss="mse")
270+
model.fit(data, callbacks=[wandb_callback])
271+
272+
run_dir = run.dir
273+
274+
ctx_util = parse_ctx(live_mock_server.get_ctx())
245275
assert (
246-
wandb.run._backend.history[0]["parameters/dense.weights"]["_type"]
247-
== "histogram"
276+
graph_json(run_dir, ctx_util.summary)["nodes"][0]["class_name"] == "InputLayer"
248277
)
249278

250279

280+
def test_keras_log_weights(
281+
dummy_model, dummy_data, live_mock_server, test_settings, parse_ctx
282+
):
283+
with wandb.init(settings=test_settings):
284+
dummy_model.fit(
285+
*dummy_data,
286+
epochs=2,
287+
batch_size=36,
288+
validation_data=dummy_data,
289+
callbacks=[WandbCallback(data_type="image", log_weights=True)]
290+
)
291+
292+
ctx_util = parse_ctx(live_mock_server.get_ctx())
293+
assert ctx_util.history[0]["parameters/dense.weights"]["_type"] == "histogram"
294+
295+
251296
# this is flaky on all platforms
252297
@pytest.mark.flaky
253298
@pytest.mark.xfail(reason="flaky test")
254-
def test_keras_log_gradients(dummy_model, dummy_data, wandb_init_run):
255-
dummy_model.fit(
256-
*dummy_data,
257-
epochs=2,
258-
batch_size=36,
259-
validation_data=dummy_data,
260-
callbacks=[
261-
WandbCallback(
262-
data_type="image", log_gradients=True, training_data=dummy_data
263-
)
264-
]
265-
)
266-
print(wandb.run._backend.history)
267-
assert (
268-
wandb.run._backend.history[0]["gradients/dense/kernel.gradient"]["_type"]
269-
== "histogram"
270-
)
299+
def test_keras_log_gradients(
300+
dummy_model, dummy_data, test_settings, parse_ctx, live_mock_server
301+
):
302+
with wandb.init(settings=test_settings):
303+
dummy_model.fit(
304+
*dummy_data,
305+
epochs=2,
306+
batch_size=36,
307+
validation_data=dummy_data,
308+
callbacks=[
309+
WandbCallback(
310+
data_type="image", log_gradients=True, training_data=dummy_data
311+
)
312+
]
313+
)
314+
315+
ctx_util = parse_ctx(live_mock_server.get_ctx())
316+
print(ctx_util.history)
271317
assert (
272-
wandb.run._backend.history[0]["gradients/dense/bias.gradient"]["_type"]
273-
== "histogram"
318+
ctx_util.history[0]["gradients/dense/kernel.gradient"]["_type"] == "histogram"
274319
)
320+
assert ctx_util.history[0]["gradients/dense/bias.gradient"]["_type"] == "histogram"
275321

276322

277323
# @pytest.mark.skip(reason="Coverage insanity error: sqlite3.OperationalError: unable to open database file")

tests/integrations/test_torch.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -157,19 +157,25 @@ def conv3x3(in_channels, out_channels, **kwargs):
157157
return layer
158158

159159

160-
def test_all_logging(wandb_init_run):
160+
def test_all_logging(live_mock_server, test_settings, parse_ctx):
161161
# TODO(jhr): does not work with --flake-finder
162+
run = wandb.init(settings=test_settings)
162163
net = ConvNet()
163164
wandb.watch(net, log="all", log_freq=1)
164165
for i in range(3):
165166
output = net(dummy_torch_tensor((32, 1, 28, 28)))
166167
grads = torch.ones(32, 10)
167168
output.backward(grads)
168-
wandb.log({"a": 2})
169-
assert len(wandb.run._backend.history[0]) == 20
170-
assert len(wandb.run._backend.history[0]["parameters/fc2.bias"]["bins"]) == 65
171-
assert len(wandb.run._backend.history[0]["gradients/fc2.bias"]["bins"]) == 65
172-
assert len(wandb.run._backend.history) == 3
169+
run.log({"a": 2})
170+
run.finish()
171+
172+
ctx_util = parse_ctx(live_mock_server.get_ctx())
173+
assert len(ctx_util.history) == 3
174+
for i in range(3):
175+
ctx_util.history[i]["_step"] == i
176+
assert len(ctx_util.history[i]) == 20
177+
assert len(ctx_util.history[i]["parameters/fc2.bias"]["bins"]) == 65
178+
assert len(ctx_util.history[i]["gradients/fc2.bias"]["bins"]) == 65
173179

174180

175181
def test_double_log(wandb_init_run):
@@ -179,7 +185,8 @@ def test_double_log(wandb_init_run):
179185
wandb.watch(net, log_graph=True)
180186

181187

182-
def test_embedding_dict_watch(wandb_init_run):
188+
def test_embedding_dict_watch(live_mock_server, test_settings, parse_ctx):
189+
run = wandb.init(settings=test_settings)
183190
model = EmbModelWrapper()
184191
wandb.watch(model, log_freq=1, idx=0)
185192
opt = torch.optim.Adam(params=model.parameters())
@@ -189,9 +196,13 @@ def test_embedding_dict_watch(wandb_init_run):
189196
loss = F.mse_loss(out, inp.float())
190197
loss.backward()
191198
opt.step()
192-
wandb.log({"loss": loss})
193-
print(wandb.run._backend.history)
194-
assert len(wandb.run._backend.history[0]["gradients/emb.emb1.weight"]["bins"]) == 65
199+
run.log({"loss": loss})
200+
run.finish()
201+
202+
ctx_util = parse_ctx(live_mock_server.get_ctx())
203+
print(ctx_util.history)
204+
assert len(ctx_util.history[0]["gradients/emb.emb1.weight"]["bins"]) == 65
205+
assert ctx_util.history[0]["gradients/emb.emb1.weight"]["_type"] == "histogram"
195206

196207

197208
@pytest.mark.timeout(120)

0 commit comments

Comments
 (0)