Skip to content

[JIT] Partially support ForwardRef type annotations for NamedTuple attributes#96933

Closed
davidberard98 wants to merge 3 commits intopytorch:masterfrom
davidberard98:namedtuple-forwardref-v1
Closed

[JIT] Partially support ForwardRef type annotations for NamedTuple attributes#96933
davidberard98 wants to merge 3 commits intopytorch:masterfrom
davidberard98:namedtuple-forwardref-v1

Conversation

@davidberard98
Copy link
Contributor

@davidberard98 davidberard98 commented Mar 16, 2023

Summary NamedTuple attributes can be annotated to declare their type:

class MyNamedTuple(NamedTuple):
    x: int
    y: torch.Tensor
    z: MyOtherType

Normally in python you can also declare your types as strings, x: 'int'. But NamedTuples previously didn't support this, because their annotation evaluation process was slightly different. This PR updates the NamedTuple attribute type annotation evaluation method to support ForwardRef declarations (i.e. declaring as strings).

Details

Below I repeat the comment I left in _jit_internal.py:

NamedTuple types are slightly different from normal types.

Normally, annotations are evaluted like this (during jit.script):

  1. Load strings of python code into c++ and parse.
  2. Get annotations as strings
  3. Use the PythonResolver's resolution callback (rcb) to convert the string into a python object
  4. We call into annotations.py:ann_to_type to convert python obj from step 3 into a type that torchscript understands.

NamedTuples are more complicated, because they have sub-types. Normally, once we have the NamedTuple type object from #3, we can just look at the annotation literal values and use ann_to_type directly on them.

But sometimes, users will annotate with string literals, e.g.

   x: 'int'

This also happens with PEP563 (from forward import annotations)

These annotations appear in the annotation dict as ForwardRef('int').

Then, we need to convert the string into a python object. This requires having local context for custom objects or imported types. rcb() is what gives us this. So, we plumb rcb through the stack so it can be used in this context for the if block below.

FAQ:

  • Why do we need this special handling for NamedTuple but string annotations work fine for normal types? Normally, we parse the string directly and then call rcb() directly from C++.
  • Why not use ForwardRef._evaluate? For that, we need globals() and locals() for the local context where the NamedTuple was defined. rcb is what lets us look up into these. So, basically rcb does the hard work for us.
  • What is rcb? rcb is a ResolutionCallback - python callable that takes a string and returns a type. It's generated by createResolutionCallback.* in _jit_internal.py.

Why is this only partial support:

This only plumbs the rcb through some paths. In particular, the toSugaredValue path uses a fake rcb.

Alternatives:

We could also treat this the way we treat non-nn.Module classes: we evaluate them separately, ahead of time. That solution is probably better, but probably requires a more risky refactor for the way NamedTuples are handled.

Fixes #95858

@pytorch-bot
Copy link

pytorch-bot bot commented Mar 16, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/96933

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit bec629a:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: jit release notes category label Mar 16, 2023
@davidberard98 davidberard98 changed the title [JIT] Support ForwardRef type annotations for NamedTuple attributes [JIT] Partially support ForwardRef type annotations for NamedTuple attributes Mar 17, 2023
@davidberard98 davidberard98 force-pushed the namedtuple-forwardref-v1 branch from 1540193 to f7d80b3 Compare March 17, 2023 04:17
@facebook-github-bot
Copy link
Contributor

@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

**Summary** NamedTuple attributes can be annotated to declare their type:
```python
class MyNamedTuple(NamedTuple):
    x: int
    y: torch.Tensor
    z: MyOtherType
```
Normally in python you can also declare your types as strings, `x: 'int'`. But NamedTuples previously didn't support this, because their annotation evaluation process was slightly different. This PR updates the NamedTuple attribute type annotation evaluation method to support ForwardRef declarations (i.e. declaring as strings).

**Details**

Below I repeat the comment I left in _jit_internal.py:

NamedTuple types are slightly different from normal types.

Normally, annotations are evaluted like this (during jit.script):
1. Load strings of python code into c++ and parse.
2. Get annotations as strings
3. Use the PythonResolver's resolution callback (rcb) to convert
   the string into a python object
4. We call into annotations.py:ann_to_type to convert python obj
   from step 3 into a type that torchscript understands.

NamedTuples are more complicated, because it has sub-types.
Normally, once we have the NamedTuple type object from pytorch#3,
we can just look at the annotation literal values and use
ann_to_type directly on them.

But sometimes, users will annotate with string literals, e.g.
```
   x: 'int'
```
This also happens with PEP563 (from __forward__ import annotations)

These annotations appear in the annotation dict as ForwardRef('int').

Then, we need to convert the string into a python object. This
requires having local context for custom objects or imported types.
rcb() is what gives us this. So, we plumb rcb through the stack so
it can be used in this context for the if block below.

FAQ:
- Why do we need this special handling for NamedTuple but string
  annotations work fine for normal types? Normally, we parse the
  string directly and then call rcb() directly from C++.
- Why not use ForwardRef._evaluate? For that, we need globals()
  and locals() for the local context where the NamedTuple was defined.
  rcb is what lets us look up into these. So, basically rcb does the
  hard work for us.

**Alternatives**:

We could also treat this the way we treat non-nn.Module classes: we evaluate them separately, ahead of time. That solution is probably better, but probably requires a more risky refactor for the way NamedTuples are handled.
@facebook-github-bot
Copy link
Contributor

@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@davidberard98 davidberard98 requested review from eellison and qihqi March 21, 2023 22:54
@davidberard98 davidberard98 marked this pull request as ready for review March 21, 2023 22:55
@davidberard98 davidberard98 added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 22, 2023
@davidberard98
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Mar 23, 2023
…tributes (#96933)

**Summary** NamedTuple attributes can be annotated to declare their type:
```python
class MyNamedTuple(NamedTuple):
    x: int
    y: torch.Tensor
    z: MyOtherType
```
Normally in python you can also declare your types as strings, `x: 'int'`. But NamedTuples previously didn't support this, because their annotation evaluation process was slightly different. This PR updates the NamedTuple attribute type annotation evaluation method to support ForwardRef declarations (i.e. declaring as strings).

**Details**

Below I repeat the comment I left in _jit_internal.py:

NamedTuple types are slightly different from normal types.

Normally, annotations are evaluted like this (during jit.script):
1. Load strings of python code into c++ and parse.
2. Get annotations as strings
3. Use the PythonResolver's resolution callback (rcb) to convert the string into a python object
4. We call into annotations.py:ann_to_type to convert python obj from step 3 into a type that torchscript understands.

NamedTuples are more complicated, because they have sub-types. Normally, once we have the NamedTuple type object from #3, we can just look at the annotation literal values and use ann_to_type directly on them.

But sometimes, users will annotate with string literals, e.g.
```
   x: 'int'
```
This also happens with PEP563 (from __forward__ import annotations)

These annotations appear in the annotation dict as ForwardRef('int').

Then, we need to convert the string into a python object. This requires having local context for custom objects or imported types. rcb() is what gives us this. So, we plumb rcb through the stack so it can be used in this context for the if block below.

FAQ:
- Why do we need this special handling for NamedTuple but string annotations work fine for normal types? Normally, we parse the string directly and then call rcb() directly from C++.
- Why not use ForwardRef._evaluate? For that, we need globals() and locals() for the local context where the NamedTuple was defined. rcb is what lets us look up into these. So, basically rcb does the hard work for us.
- What is rcb? rcb is a ResolutionCallback - python callable that takes a string and returns a type. It's generated by `createResolutionCallback.*` in _jit_internal.py.

**Why is this only partial support**:

This only plumbs the rcb through some paths. In particular, the `toSugaredValue` path uses a fake rcb.

**Alternatives**:

We could also treat this the way we treat non-nn.Module classes: we evaluate them separately, ahead of time. That solution is probably better, but probably requires a more risky refactor for the way NamedTuples are handled.

Fixes #95858

Pull Request resolved: pytorch/pytorch#96933
Approved by: https://github.com/qihqi
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Mar 27, 2023
…tributes (#96933)

**Summary** NamedTuple attributes can be annotated to declare their type:
```python
class MyNamedTuple(NamedTuple):
    x: int
    y: torch.Tensor
    z: MyOtherType
```
Normally in python you can also declare your types as strings, `x: 'int'`. But NamedTuples previously didn't support this, because their annotation evaluation process was slightly different. This PR updates the NamedTuple attribute type annotation evaluation method to support ForwardRef declarations (i.e. declaring as strings).

**Details**

Below I repeat the comment I left in _jit_internal.py:

NamedTuple types are slightly different from normal types.

Normally, annotations are evaluted like this (during jit.script):
1. Load strings of python code into c++ and parse.
2. Get annotations as strings
3. Use the PythonResolver's resolution callback (rcb) to convert the string into a python object
4. We call into annotations.py:ann_to_type to convert python obj from step 3 into a type that torchscript understands.

NamedTuples are more complicated, because they have sub-types. Normally, once we have the NamedTuple type object from #3, we can just look at the annotation literal values and use ann_to_type directly on them.

But sometimes, users will annotate with string literals, e.g.
```
   x: 'int'
```
This also happens with PEP563 (from __forward__ import annotations)

These annotations appear in the annotation dict as ForwardRef('int').

Then, we need to convert the string into a python object. This requires having local context for custom objects or imported types. rcb() is what gives us this. So, we plumb rcb through the stack so it can be used in this context for the if block below.

FAQ:
- Why do we need this special handling for NamedTuple but string annotations work fine for normal types? Normally, we parse the string directly and then call rcb() directly from C++.
- Why not use ForwardRef._evaluate? For that, we need globals() and locals() for the local context where the NamedTuple was defined. rcb is what lets us look up into these. So, basically rcb does the hard work for us.
- What is rcb? rcb is a ResolutionCallback - python callable that takes a string and returns a type. It's generated by `createResolutionCallback.*` in _jit_internal.py.

**Why is this only partial support**:

This only plumbs the rcb through some paths. In particular, the `toSugaredValue` path uses a fake rcb.

**Alternatives**:

We could also treat this the way we treat non-nn.Module classes: we evaluate them separately, ahead of time. That solution is probably better, but probably requires a more risky refactor for the way NamedTuples are handled.

Fixes #95858

Pull Request resolved: pytorch/pytorch#96933
Approved by: https://github.com/qihqi
davidberard98 added a commit that referenced this pull request Mar 28, 2023
Follow-up to #96933. This test was intended to have quotes around the
type annotations for the attributes of the NamedTuple. This PR adds this
missing quotes.

[ghstack-poisoned]
davidberard98 added a commit that referenced this pull request Mar 28, 2023
Follow-up to #96933. This test was intended to have quotes around the
type annotations for the attributes of the NamedTuple. This PR adds this
missing quotes.

ghstack-source-id: e6f42e4
Pull Request resolved: #97736
pytorchmergebot pushed a commit that referenced this pull request Mar 28, 2023
Follow-up to #96933. This test was intended to have quotes around the
type annotations for the attributes of the NamedTuple. This PR adds this
missing quotes.
Pull Request resolved: #97736
Approved by: https://github.com/eellison
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: jit release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[JIT] Support string type annotations in NamedTuples

4 participants