@@ -65,24 +65,31 @@ def dummy_data(request):
6565 return data , labels
6666
6767
68- def graph_json (run ):
69- print (run ._backend .summary ["graph" ])
70- graph_path = run ._backend .summary ["graph" ]["path" ]
71- return json .load (open (os .path .join (run .dir , graph_path )))
68+ def graph_json (run_dir , summary ):
69+ print (summary ["graph" ])
70+ path = os .path .join (run_dir , summary ["graph" ]["path" ])
71+ with open (path ) as fh :
72+ return json .load (fh )
7273
7374
7475def test_no_init ():
7576 with pytest .raises (wandb .errors .Error ):
7677 WandbCallback ()
7778
7879
79- def test_basic_keras (dummy_model , dummy_data , wandb_init_run ):
80+ def test_basic_keras (
81+ dummy_model , dummy_data , live_mock_server , test_settings , parse_ctx
82+ ):
83+ run = wandb .init (settings = test_settings )
8084 dummy_model .fit (* dummy_data , epochs = 2 , batch_size = 36 , callbacks = [WandbCallback ()])
85+ run_dir = run .dir
86+ wandb .finish ()
87+ ctx_util = parse_ctx (live_mock_server .get_ctx ())
8188 # wandb.run.summary.load()
82- assert wandb . run . _backend .history [0 ]["epoch" ] == 0
89+ assert ctx_util .history [0 ]["epoch" ] == 0
8390 # NOTE: backend mock doesnt copy history into summary (happens in internal process)
8491 # assert wandb.run._backend.summary["loss"] > 0
85- assert len (graph_json (wandb . run )["nodes" ]) == 3
92+ assert len (graph_json (run_dir , ctx_util . summary )["nodes" ]) == 3
8693
8794
8895def test_keras_telemetry (
@@ -139,139 +146,178 @@ def test_keras_image_bad_data(dummy_model, dummy_data, wandb_init_run):
139146 assert error
140147
141148
142- def test_keras_image_binary (dummy_model , dummy_data , wandb_init_run ):
143- dummy_model .fit (
144- * dummy_data ,
145- epochs = 2 ,
146- batch_size = 36 ,
147- validation_data = dummy_data ,
148- callbacks = [WandbCallback (data_type = "image" )]
149- )
150- assert len (wandb .run ._backend .history [0 ]["examples" ]["captions" ]) == 36
149+ def test_keras_image_binary (
150+ dummy_model , dummy_data , test_settings , parse_ctx , live_mock_server
151+ ):
152+ with wandb .init (settings = test_settings ):
153+ dummy_model .fit (
154+ * dummy_data ,
155+ epochs = 2 ,
156+ batch_size = 36 ,
157+ validation_data = dummy_data ,
158+ callbacks = [WandbCallback (data_type = "image" )]
159+ )
151160
161+ ctx_util = parse_ctx (live_mock_server .get_ctx ())
162+ assert len (ctx_util .history [0 ]["examples" ]["captions" ]) == 36
152163
153- def test_keras_image_binary_captions (dummy_model , dummy_data , wandb_init_run ):
154- dummy_model .fit (
155- * dummy_data ,
156- epochs = 2 ,
157- batch_size = 36 ,
158- validation_data = dummy_data ,
159- callbacks = [
160- WandbCallback (data_type = "image" , predictions = 10 , labels = ["Rad" , "Nice" ])
161- ]
162- )
163- assert wandb .run ._backend .history [0 ]["examples" ]["captions" ][0 ] in ["Rad" , "Nice" ]
164+
165+ def test_keras_image_binary_captions (
166+ dummy_model , dummy_data , test_settings , parse_ctx , live_mock_server
167+ ):
168+
169+ with wandb .init (settings = test_settings ):
170+ dummy_model .fit (
171+ * dummy_data ,
172+ epochs = 2 ,
173+ batch_size = 36 ,
174+ validation_data = dummy_data ,
175+ callbacks = [
176+ WandbCallback (data_type = "image" , predictions = 10 , labels = ["Rad" , "Nice" ])
177+ ]
178+ )
179+
180+ ctx_util = parse_ctx (live_mock_server .get_ctx ())
181+ assert ctx_util .history [0 ]["examples" ]["captions" ][0 ] in ["Rad" , "Nice" ]
164182
165183
166184@pytest .mark .multiclass
167- def test_keras_image_multiclass (dummy_model , dummy_data , wandb_init_run ):
168- dummy_model .fit (
169- * dummy_data ,
170- epochs = 2 ,
171- batch_size = 36 ,
172- validation_data = dummy_data ,
173- callbacks = [WandbCallback (data_type = "image" , predictions = 10 )]
174- )
175- assert len (wandb .run ._backend .history [0 ]["examples" ]["captions" ]) == 10
185+ def test_keras_image_multiclass (
186+ dummy_model , dummy_data , test_settings , parse_ctx , live_mock_server
187+ ):
188+ with wandb .init (settings = test_settings ):
189+ dummy_model .fit (
190+ * dummy_data ,
191+ epochs = 2 ,
192+ batch_size = 36 ,
193+ validation_data = dummy_data ,
194+ callbacks = [WandbCallback (data_type = "image" , predictions = 10 )]
195+ )
196+
197+ ctx_util = parse_ctx (live_mock_server .get_ctx ())
198+ assert len (ctx_util .history [0 ]["examples" ]["captions" ]) == 10
176199
177200
178201@pytest .mark .multiclass
179- def test_keras_image_multiclass_captions (dummy_model , dummy_data , wandb_init_run ):
180- dummy_model .fit (
181- * dummy_data ,
182- epochs = 2 ,
183- batch_size = 36 ,
184- validation_data = dummy_data ,
185- callbacks = [
186- WandbCallback (
187- data_type = "image" ,
188- predictions = 10 ,
189- labels = [
190- "Rad" ,
191- "Nice" ,
192- "Fun" ,
193- "Rad" ,
194- "Nice" ,
195- "Fun" ,
196- "Rad" ,
197- "Nice" ,
198- "Fun" ,
199- "Rad" ,
200- ],
201- )
202- ]
203- )
204- assert wandb .run ._backend .history [0 ]["examples" ]["captions" ][0 ] in [
202+ def test_keras_image_multiclass_captions (
203+ dummy_model , dummy_data , test_settings , parse_ctx , live_mock_server
204+ ):
205+ with wandb .init (settings = test_settings ):
206+ dummy_model .fit (
207+ * dummy_data ,
208+ epochs = 2 ,
209+ batch_size = 36 ,
210+ validation_data = dummy_data ,
211+ callbacks = [
212+ WandbCallback (
213+ data_type = "image" ,
214+ predictions = 10 ,
215+ labels = [
216+ "Rad" ,
217+ "Nice" ,
218+ "Fun" ,
219+ "Rad" ,
220+ "Nice" ,
221+ "Fun" ,
222+ "Rad" ,
223+ "Nice" ,
224+ "Fun" ,
225+ "Rad" ,
226+ ],
227+ )
228+ ]
229+ )
230+
231+ ctx_util = parse_ctx (live_mock_server .get_ctx ())
232+ assert ctx_util .history [0 ]["examples" ]["captions" ][0 ] in [
205233 "Rad" ,
206234 "Nice" ,
207235 "Fun" ,
208236 ]
209237
210238
211239@pytest .mark .image_output
212- def test_keras_image_output (dummy_model , dummy_data , wandb_init_run ):
213- dummy_model . fit (
214- * dummy_data ,
215- epochs = 2 ,
216- batch_size = 36 ,
217- validation_data = dummy_data ,
218- callbacks = [ WandbCallback ( data_type = "image" , predictions = 10 )]
219- )
220- print ( wandb . run . _backend . history [ 0 ])
221- assert wandb . run . _backend . history [ 0 ][ "examples" ][ "count" ] == 30
222- assert wandb . run . _backend . history [ 0 ][ "examples" ][ "height" ] == 10
240+ def test_keras_image_output (
241+ dummy_model , dummy_data , test_settings , parse_ctx , live_mock_server
242+ ):
243+ with wandb . init ( settings = test_settings ):
244+ dummy_model . fit (
245+ * dummy_data ,
246+ epochs = 2 ,
247+ batch_size = 36 ,
248+ validation_data = dummy_data ,
249+ callbacks = [ WandbCallback ( data_type = "image" , predictions = 10 )]
250+ )
223251
252+ ctx_util = parse_ctx (live_mock_server .get_ctx ())
253+ print (ctx_util .history [0 ])
254+ assert ctx_util .history [0 ]["examples" ]["count" ] == 30
255+ assert ctx_util .history [0 ]["examples" ]["height" ] == 10
224256
225- def test_dataset_functional (wandb_init_run ):
226- data = tf .data .Dataset .range (5 ).map (lambda x : (x , 1 )).batch (1 )
227- inputs = tf .keras .Input (shape = (1 ,))
228- outputs = tf .keras .layers .Dense (1 )(inputs )
229- wandb_callback = WandbCallback (save_model = False )
230257
231- model = tf .keras .Model (inputs = inputs , outputs = outputs )
232- model .compile (optimizer = tf .keras .optimizers .Adam (), loss = "mse" )
233- model .fit (data , callbacks = [wandb_callback ])
234- assert graph_json (wandb .run )["nodes" ][0 ]["class_name" ] == "InputLayer"
258+ def test_dataset_functional (live_mock_server , test_settings , parse_ctx ):
235259
260+ with wandb .init (settings = test_settings ) as run :
236261
237- def test_keras_log_weights (dummy_model , dummy_data , wandb_init_run ):
238- dummy_model .fit (
239- * dummy_data ,
240- epochs = 2 ,
241- batch_size = 36 ,
242- validation_data = dummy_data ,
243- callbacks = [WandbCallback (data_type = "image" , log_weights = True )]
244- )
262+ data = tf .data .Dataset .range (5 ).map (lambda x : (x , 1 )).batch (1 )
263+ inputs = tf .keras .Input (shape = (1 ,))
264+ outputs = tf .keras .layers .Dense (1 )(inputs )
265+
266+ wandb_callback = WandbCallback (save_model = False )
267+
268+ model = tf .keras .Model (inputs = inputs , outputs = outputs )
269+ model .compile (optimizer = tf .keras .optimizers .Adam (), loss = "mse" )
270+ model .fit (data , callbacks = [wandb_callback ])
271+
272+ run_dir = run .dir
273+
274+ ctx_util = parse_ctx (live_mock_server .get_ctx ())
245275 assert (
246- wandb .run ._backend .history [0 ]["parameters/dense.weights" ]["_type" ]
247- == "histogram"
276+ graph_json (run_dir , ctx_util .summary )["nodes" ][0 ]["class_name" ] == "InputLayer"
248277 )
249278
250279
280+ def test_keras_log_weights (
281+ dummy_model , dummy_data , live_mock_server , test_settings , parse_ctx
282+ ):
283+ with wandb .init (settings = test_settings ):
284+ dummy_model .fit (
285+ * dummy_data ,
286+ epochs = 2 ,
287+ batch_size = 36 ,
288+ validation_data = dummy_data ,
289+ callbacks = [WandbCallback (data_type = "image" , log_weights = True )]
290+ )
291+
292+ ctx_util = parse_ctx (live_mock_server .get_ctx ())
293+ assert ctx_util .history [0 ]["parameters/dense.weights" ]["_type" ] == "histogram"
294+
295+
251296# this is flaky on all platforms
252297@pytest .mark .flaky
253298@pytest .mark .xfail (reason = "flaky test" )
254- def test_keras_log_gradients (dummy_model , dummy_data , wandb_init_run ):
255- dummy_model .fit (
256- * dummy_data ,
257- epochs = 2 ,
258- batch_size = 36 ,
259- validation_data = dummy_data ,
260- callbacks = [
261- WandbCallback (
262- data_type = "image" , log_gradients = True , training_data = dummy_data
263- )
264- ]
265- )
266- print (wandb .run ._backend .history )
267- assert (
268- wandb .run ._backend .history [0 ]["gradients/dense/kernel.gradient" ]["_type" ]
269- == "histogram"
270- )
299+ def test_keras_log_gradients (
300+ dummy_model , dummy_data , test_settings , parse_ctx , live_mock_server
301+ ):
302+ with wandb .init (settings = test_settings ):
303+ dummy_model .fit (
304+ * dummy_data ,
305+ epochs = 2 ,
306+ batch_size = 36 ,
307+ validation_data = dummy_data ,
308+ callbacks = [
309+ WandbCallback (
310+ data_type = "image" , log_gradients = True , training_data = dummy_data
311+ )
312+ ]
313+ )
314+
315+ ctx_util = parse_ctx (live_mock_server .get_ctx ())
316+ print (ctx_util .history )
271317 assert (
272- wandb .run ._backend .history [0 ]["gradients/dense/bias.gradient" ]["_type" ]
273- == "histogram"
318+ ctx_util .history [0 ]["gradients/dense/kernel.gradient" ]["_type" ] == "histogram"
274319 )
320+ assert ctx_util .history [0 ]["gradients/dense/bias.gradient" ]["_type" ] == "histogram"
275321
276322
277323# @pytest.mark.skip(reason="Coverage insanity error: sqlite3.OperationalError: unable to open database file")
0 commit comments