Skip to content

Commit 989d877

Browse files
malfetfacebook-github-bot
authored andcommitted
[JIT] Do not allow creating generics with None types (#44958)
Summary: Otherwise, invoking something like `python -c "import torch._C;print(torch._C.ListType(None))"` will result in SIGSEGV Discovered while trying to create a torch script for function with the following type annotation `Tuple[int, Ellipsis] -> None` Pull Request resolved: #44958 Reviewed By: suo Differential Revision: D23799906 Pulled By: malfet fbshipit-source-id: 916a243007d13ed3e7a5b282dd712da3d66e3bf7
1 parent 0a9ac98 commit 989d877

3 files changed

Lines changed: 14 additions & 1 deletion

File tree

aten/src/ATen/core/jit_type.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,12 @@ struct SingleElementType : public Type {
263263
}
264264

265265
protected:
266-
SingleElementType(TypePtr elem) : Type(Kind), elem(std::move(elem)) {}
266+
SingleElementType(TypePtr elem) : Type(Kind), elem(std::move(elem)) {
267+
if (!this->elem) {
268+
throw std::runtime_error(c10::str(
269+
"Can not create ", typeKindToString(Kind), " with None type"));
270+
}
271+
}
267272

268273
private:
269274
TypePtr elem;

aten/src/ATen/core/type.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -716,6 +716,9 @@ TupleType::TupleType(
716716
schema_(std::move(schema)) {
717717
has_free_variables_ =
718718
std::any_of(elements_.begin(), elements_.end(), [](TypePtr v) {
719+
if (!v) {
720+
throw std::runtime_error("Can not create tuple with None type");
721+
}
719722
return v->hasFreeVariables();
720723
});
721724
if (schema_) {

test/jit/test_list_dict.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1155,6 +1155,11 @@ def annotated_fn(x: torch.Tensor) -> List:
11551155
with self.assertRaisesRegex(RuntimeError, r"Attempted to use List without a contained type"):
11561156
torch.jit.script(annotated_fn)
11571157

1158+
def test_list_none(self):
1159+
with self.assertRaisesRegex(RuntimeError, "Can not create ListType with None type"):
1160+
x = torch._C.ListType(None)
1161+
1162+
11581163

11591164
class TestDict(JitTestCase):
11601165
def dict(self):

0 commit comments

Comments
 (0)