Skip to content

Commit f4cbcff

Browse files
jananisrirampytorchmergebot
authored andcommitted
[TorchScript] Expand TorchScript __init__ annotation warning (#127045)
Summary: Expand TorchScript `__init__` annotation warning to `list` and `dict` with reference to GSD task T187638414 and annotation warning reproduction D56834720. Currently, the TorchScript compiler ignores and throws `UserWarning`s for the following annotation types for empty values within the `__init__` function: `List`, `Dict`, `Optional`. However, the compiler should additionally cover warnings for `list` and `dict`. This diff adds support for `list` and `dict`. Test Plan: Added 4 new unit tests: `test_annotated_empty_list_lowercase` and `test_annotated_empty_dict_lowercase` verify that TorchScript throws UserWarnings for the list and dict type annotations on empty values. ``` (base) [jananisriram@devvm2248.cco0 /data/users/jananisriram/fbsource/fbcode (e4ce427eb)]$ buck2 test @mode/{opt,inplace} //caffe2/test:jit -- --regex test_annotated_empty_list_lowercase ... Tests finished: Pass 2. Fail 0. Fatal 0. Skip 0. Build failure 0 ``` ``` (base) [jananisriram@devvm2248.cco0 /data/users/jananisriram/fbsource/fbcode (e4ce427eb)]$ buck2 test @mode/{opt,inplace} //caffe2/test:jit -- --regex test_annotated_empty_dict_lowercase ... Tests finished: Pass 2. Fail 0. Fatal 0. Skip 0. Build failure 0 ``` `test_annotated_with_jit_empty_list_lowercase` and `test_annotated_with_jit_empty_dict_lowercase` verify that TorchScript throws UserWarnings for the list and dict type annotations on empty values with the jit annotation. ``` (base) [jananisriram@devvm2248.cco0 /data/users/jananisriram/fbsource/fbcode (e4ce427eb)]$ buck2 test @mode/{opt,inplace} //caffe2/test:jit -- --regex test_annotated_with_jit_empty_list_lowercase ... Tests finished: Pass 2. Fail 0. Fatal 0. Skip 0. Build failure 0 ``` ``` (base) [jananisriram@devvm2248.cco0 /data/users/jananisriram/fbsource/fbcode (e4ce427eb)]$ buck2 test @mode/{opt,inplace} //caffe2/test:jit -- --regex test_annotated_with_jit_empty_dict_lowercase ... Tests finished: Pass 2. Fail 0. Fatal 0. Skip 0. Build failure 0 ``` Differential Revision: D57752002 Pull Request resolved: #127045 Approved by: https://github.com/davidberard98
1 parent 1be7e40 commit f4cbcff

2 files changed

Lines changed: 98 additions & 1 deletion

File tree

test/jit/test_scriptmod_ann.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import os
44
import sys
5+
import unittest
56
import warnings
67
from typing import Dict, List, Optional
78

@@ -150,6 +151,30 @@ def forward(self, x: List[int]):
150151
):
151152
torch.jit.script(M())
152153

154+
@unittest.skipIf(
155+
sys.version_info[:2] < (3, 9), "Requires lowercase static typing (Python 3.9+)"
156+
)
157+
def test_annotated_empty_list_lowercase(self):
158+
class M(torch.nn.Module):
159+
def __init__(self):
160+
super().__init__()
161+
self.x: list[int] = []
162+
163+
def forward(self, x: list[int]):
164+
self.x = x
165+
return 1
166+
167+
with self.assertRaisesRegexWithHighlight(
168+
RuntimeError, "Tried to set nonexistent attribute", "self.x = x"
169+
):
170+
with self.assertWarnsRegex(
171+
UserWarning,
172+
"doesn't support "
173+
"instance-level annotations on "
174+
"empty non-base types",
175+
):
176+
torch.jit.script(M())
177+
153178
def test_annotated_empty_dict(self):
154179
class M(torch.nn.Module):
155180
def __init__(self):
@@ -171,6 +196,30 @@ def forward(self, x: Dict[str, int]):
171196
):
172197
torch.jit.script(M())
173198

199+
@unittest.skipIf(
200+
sys.version_info[:2] < (3, 9), "Requires lowercase static typing (Python 3.9+)"
201+
)
202+
def test_annotated_empty_dict_lowercase(self):
203+
class M(torch.nn.Module):
204+
def __init__(self):
205+
super().__init__()
206+
self.x: dict[str, int] = {}
207+
208+
def forward(self, x: dict[str, int]):
209+
self.x = x
210+
return 1
211+
212+
with self.assertRaisesRegexWithHighlight(
213+
RuntimeError, "Tried to set nonexistent attribute", "self.x = x"
214+
):
215+
with self.assertWarnsRegex(
216+
UserWarning,
217+
"doesn't support "
218+
"instance-level annotations on "
219+
"empty non-base types",
220+
):
221+
torch.jit.script(M())
222+
174223
def test_annotated_empty_optional(self):
175224
class M(torch.nn.Module):
176225
def __init__(self):
@@ -213,6 +262,30 @@ def forward(self, x: List[int]):
213262
):
214263
torch.jit.script(M())
215264

265+
@unittest.skipIf(
266+
sys.version_info[:2] < (3, 9), "Requires lowercase static typing (Python 3.9+)"
267+
)
268+
def test_annotated_with_jit_empty_list_lowercase(self):
269+
class M(torch.nn.Module):
270+
def __init__(self):
271+
super().__init__()
272+
self.x = torch.jit.annotate(list[int], [])
273+
274+
def forward(self, x: list[int]):
275+
self.x = x
276+
return 1
277+
278+
with self.assertRaisesRegexWithHighlight(
279+
RuntimeError, "Tried to set nonexistent attribute", "self.x = x"
280+
):
281+
with self.assertWarnsRegex(
282+
UserWarning,
283+
"doesn't support "
284+
"instance-level annotations on "
285+
"empty non-base types",
286+
):
287+
torch.jit.script(M())
288+
216289
def test_annotated_with_jit_empty_dict(self):
217290
class M(torch.nn.Module):
218291
def __init__(self):
@@ -234,6 +307,30 @@ def forward(self, x: Dict[str, int]):
234307
):
235308
torch.jit.script(M())
236309

310+
@unittest.skipIf(
311+
sys.version_info[:2] < (3, 9), "Requires lowercase static typing (Python 3.9+)"
312+
)
313+
def test_annotated_with_jit_empty_dict_lowercase(self):
314+
class M(torch.nn.Module):
315+
def __init__(self):
316+
super().__init__()
317+
self.x = torch.jit.annotate(dict[str, int], {})
318+
319+
def forward(self, x: dict[str, int]):
320+
self.x = x
321+
return 1
322+
323+
with self.assertRaisesRegexWithHighlight(
324+
RuntimeError, "Tried to set nonexistent attribute", "self.x = x"
325+
):
326+
with self.assertWarnsRegex(
327+
UserWarning,
328+
"doesn't support "
329+
"instance-level annotations on "
330+
"empty non-base types",
331+
):
332+
torch.jit.script(M())
333+
237334
def test_annotated_with_jit_empty_optional(self):
238335
class M(torch.nn.Module):
239336
def __init__(self):

torch/jit/_check.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def visit_AnnAssign(self, node):
156156
# cannot be reassigned later to a non-empty tuple. Same
157157
# deal with `NamedTuple`
158158

159-
containers = {"List", "Dict", "Optional"}
159+
containers = {"List", "list", "Dict", "dict", "Optional"}
160160

161161
# If we're not evaluating one of the specified problem types
162162
try:

0 commit comments

Comments
 (0)