Skip to content

Commit 05b2cd8

Browse files
authored
[WB-4336] Annotate data_types.py (Part 1: 2286 lines converted, 1275 to go ) (#1764)
1 parent 47d30f9 commit 05b2cd8

14 files changed

Lines changed: 4837 additions & 2116 deletions

File tree

mypy.ini

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,17 @@ check_untyped_defs = True
8383
disallow_untyped_decorators = True
8484
strict_equality = True
8585

86+
[mypy-wandb.sdk.data_types]
87+
disallow_incomplete_defs = True
88+
disallow_untyped_defs = True
89+
warn_unused_ignores = True
90+
warn_return_any = True
91+
warn_unreachable = True
92+
check_untyped_defs = True
93+
# disallow_untyped_calls = True
94+
disallow_untyped_decorators = True
95+
strict_equality = True
96+
8697
[mypy-wandb.sdk.lib.telemetry]
8798
disallow_untyped_defs = True
8899
disallow_untyped_calls = True

tests/test_data_types.py

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,42 @@ def api(runner):
2626
return Api()
2727

2828

29+
def test_wb_value(live_mock_server, test_settings):
30+
run = wandb.init(settings=test_settings)
31+
local_art = wandb.Artifact("N", "T")
32+
public_art = run.use_artifact("N:latest")
33+
34+
wbvalue = data_types.WBValue()
35+
with pytest.raises(NotImplementedError):
36+
wbvalue.to_json(local_art)
37+
38+
with pytest.raises(NotImplementedError):
39+
data_types.WBValue.from_json({}, public_art)
40+
41+
assert data_types.WBValue.with_suffix("item") == "item.json"
42+
43+
table = data_types.WBValue.init_from_json(
44+
{
45+
"_type": "table",
46+
"data": [[]],
47+
"columns": [],
48+
"column_types": wandb.data_types._dtypes.DictType({}).to_json(),
49+
},
50+
public_art,
51+
)
52+
assert isinstance(table, data_types.WBValue) and isinstance(
53+
table, wandb.data_types.Table
54+
)
55+
56+
type_mapping = data_types.WBValue.type_mapping()
57+
assert all(
58+
[issubclass(type_mapping[key], data_types.WBValue) for key in type_mapping]
59+
)
60+
61+
assert wbvalue == wbvalue
62+
assert wbvalue != data_types.WBValue()
63+
64+
2965
def test_raw_data():
3066
wbhist = wandb.Histogram(data)
3167
assert len(wbhist.histogram) == 64
@@ -161,7 +197,7 @@ def test_max_images(caplog, mocked_run):
161197
large_list = [wandb.Image(large_image)] * 200
162198
large_list[0].bind_to_run(mocked_run, "test2", 0, 0)
163199
meta = wandb.Image.seq_to_json(
164-
data_types.prune_max_seq(large_list), mocked_run, "test2", 0
200+
wandb.wandb_sdk.data_types._prune_max_seq(large_list), mocked_run, "test2", 0
165201
)
166202
expected = {
167203
"_type": "images/separated",
@@ -407,6 +443,17 @@ def test_molecule(runner, mocked_run):
407443
assert os.path.exists(mol._path)
408444

409445

446+
def test_molecule_file(runner, mocked_run):
447+
with runner.isolated_filesystem():
448+
with open("test.pdb", "w") as f:
449+
f.write("00000")
450+
mol = wandb.Molecule(open("test.pdb", "r"))
451+
mol.bind_to_run(mocked_run, "rad", "summary")
452+
wandb.Molecule.seq_to_json([mol], mocked_run, "rad", "summary")
453+
454+
assert os.path.exists(mol._path)
455+
456+
410457
def test_html_str(mocked_run):
411458
html = wandb.Html("<html><body><h1>Hello</h1></body></html>")
412459
html.bind_to_run(mocked_run, "rad", "summary")
@@ -446,6 +493,16 @@ def test_html_file(mocked_run):
446493
assert os.path.exists(html._path)
447494

448495

496+
def test_html_file_path(mocked_run):
497+
with open("test.html", "w") as f:
498+
f.write("<html><body><h1>Hello</h1></body></html>")
499+
html = wandb.Html("test.html")
500+
html.bind_to_run(mocked_run, "rad", "summary")
501+
wandb.Html.seq_to_json([html, html], mocked_run, "rad", "summary")
502+
503+
assert os.path.exists(html._path)
504+
505+
449506
def test_table_default():
450507
table = wandb.Table()
451508
table.add_data("Some awesome text", "Positive", "Negative")
@@ -494,6 +551,22 @@ def test_object3d_numpy(mocked_run):
494551
assert obj3.to_json(mocked_run)["_type"] == "object3D-file"
495552

496553

554+
def test_object3d_dict(mocked_run):
555+
obj = wandb.Object3D({"type": "lidar/beta",})
556+
obj.bind_to_run(mocked_run, "object3D", 0)
557+
assert obj.to_json(mocked_run)["_type"] == "object3D-file"
558+
559+
560+
def test_object3d_dict_invalid(mocked_run):
561+
with pytest.raises(ValueError):
562+
obj = wandb.Object3D({"type": "INVALID",})
563+
564+
565+
def test_object3d_dict_invalid_string(mocked_run):
566+
with pytest.raises(ValueError):
567+
obj = wandb.Object3D("INVALID")
568+
569+
497570
def test_object3d_obj(mocked_run):
498571
obj = wandb.Object3D(utils.fixture_open("cube.obj"))
499572
obj.bind_to_run(mocked_run, "object3D", 0)
@@ -651,7 +724,7 @@ def test_graph():
651724

652725

653726
def test_numpy_arrays_to_list():
654-
conv = data_types.numpy_arrays_to_lists
727+
conv = data_types._numpy_arrays_to_lists
655728
assert conv(np.array((1, 2,))) == [1, 2]
656729
assert conv([np.array((1, 2,))]) == [[1, 2]]
657730
assert conv(np.array(({"a": [np.array((1, 2,))]}, 3))) == [{"a": [[1, 2]]}, 3]

tests/test_dtypes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ def test_classes_type():
329329
]
330330
)
331331

332-
wb_class_type = data_types._ClassesIdType.from_obj(wb_classes)
332+
wb_class_type = wandb.wandb_sdk.data_types._ClassesIdType.from_obj(wb_classes)
333333
assert wb_class_type.assign(1) == wb_class_type
334334
assert wb_class_type.assign(0) == InvalidType()
335335

wandb/_globals.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# WARNING: This is an anti-pattern file and we should avoid
2+
# adding to it and remove entries whenever possible. This file
3+
# contains global objects which need to be referenced by multiple
4+
# submodules. If you need a global object, seriously reconsider. This
5+
# file is intended to be a stop gap to help during code migrations (eg.
6+
# when moving to typing a module) to avoid circular references. Anything
7+
# added here is pure tech debt. Use with care. - Tim
8+
9+
_glob_datatypes_callback = None
10+
11+
12+
def _datatypes_set_callback(cb):
13+
global _glob_datatypes_callback
14+
_glob_datatypes_callback = cb
15+
16+
17+
def _datatypes_callback(fname):
18+
if _glob_datatypes_callback:
19+
_glob_datatypes_callback(fname)

wandb/apis/public.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2793,7 +2793,7 @@ def get(self, name):
27932793
with open(item_path, "r") as file:
27942794
json_obj = json.load(file)
27952795
result = wb_class.from_json(json_obj, self)
2796-
result.artifact_source = {"artifact": self, "name": name}
2796+
result.set_artifact_source(self, name)
27972797
return result
27982798

27992799
def download(self, root=None, recursive=False):

0 commit comments

Comments
 (0)