Skip to content

[pt1][quant] Add the serialization support for FP16 LSTM#26378

Closed
jianyuh wants to merge 3 commits intogh/jianyuh/28/basefrom
gh/jianyuh/28/head
Closed

[pt1][quant] Add the serialization support for FP16 LSTM#26378
jianyuh wants to merge 3 commits intogh/jianyuh/28/basefrom
gh/jianyuh/28/head

Conversation

@jianyuh
Copy link
Member

@jianyuh jianyuh commented Sep 17, 2019

Stack from ghstack:

We would like to add the serialization support for FP16 LSTM.

Differential Revision: D17391638

We would like to add the serialization support for FP16 LSTM.

Differential Revision: [D17391638](https://our.internmc.facebook.com/intern/diff/D17391638/)

[ghstack-poisoned]
@jianyuh jianyuh requested a review from apaszke as a code owner September 17, 2019 23:48
@pytorchbot pytorchbot added the module: nn Related to torch.nn label Sep 17, 2019
jianyuh added a commit that referenced this pull request Sep 17, 2019
We would like to add the serialization support for FP16 LSTM.

Differential Revision: [D17391638](https://our.internmc.facebook.com/intern/diff/D17391638/)

ghstack-source-id: 90265148
Pull Request resolved: #26378
Copy link
Collaborator

@jamesr66a jamesr66a left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

# however there is a JIT compilation error without it. This is just used to
# workaround that error.
if dtype == torch.qint8:
self._orig_weight_values = self._all_weight_values
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just assign an annotated empty list here?

self._orig_weight_values = torch.jit.annotate(List[torch.Tensor], [])

Not sure if it works but it's worth a shot. I think the values should be deduplicated in the pickler logic, but it's probably better not to have this unused stuff hanging around in memory as well

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried before but it didn't work. JIT doesn't allow to return different data types for different branches. The error message is

> ...
> Type mismatch: dynamic_vals is set to type List[Tuple[Tensor, Optional[Tensor]]] in the true branch and type List[Tensor] in the false branch:
> at /data/users/jianyuhuang/fbsource/fbcode/buck-out/dev/gen/caffe2/test/quantization#binary,link-tree/torch/nn/quantized/dynamic/modules/rnn.py:183:8
>             self.batch_first,
>             self.dropout,
>             self.bidirectional,
>             self._all_weight_names,
>             self.__overloads__,
>             self.training,
>             self.dtype,
>         )
>
>         if self.dtype == torch.qint8:
>         ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~...  <--- HERE
>
>             dynamic_vals = torch.jit.annotate(List[Tuple[torch.Tensor, Optional[torch.Tensor]]],
>                                               [])
>
> ...

…STM"

We would like to add the serialization support for FP16 LSTM.

Differential Revision: [D17391638](https://our.internmc.facebook.com/intern/diff/D17391638/)

[ghstack-poisoned]
jianyuh added a commit that referenced this pull request Sep 20, 2019
Pull Request resolved: #26378

We would like to add the serialization support for FP16 LSTM.
ghstack-source-id: 90479104

Differential Revision: [D17391638](https://our.internmc.facebook.com/intern/diff/D17391638/)
@jamesr66a
Copy link
Collaborator

@jianyuh any chance you could land this soon?

@jamesr66a
Copy link
Collaborator

jamesr66a commented Sep 21, 2019

Oh I guess it doesn't actually work:

Previous return statement returned a value of type Tuple[Tuple[str, int, int, int, bool, bool, float, bool, List[str], Dict[str, List[str]], bool, int, List[Tensor]], List[Tuple[Tensor, Optional[Tensor]]]] but this return statement returns a value of type Tuple[Tuple[str, int, int, int, bool, bool, float, bool, List[str], Dict[str, List[str]], bool, int, List[Tensor]], List[Tensor]]:

            for i in range(len(self._all_weight_names)):
                dynamic_vals.append(torch.ops.quantized.linear_unpack(self._all_weight_values[i]))

            return vals, dynamic_vals
        else:
            dynamic_vals_fp16 = torch.jit.annotate(List[torch.Tensor], [])
            for i in range(len(self._all_weight_names)):
                dynamic_vals_fp16.append(self._all_weight_values[i])
            return vals, dynamic_vals_fp16
            ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE

@jianyuh
Copy link
Member Author

jianyuh commented Sep 21, 2019

@jianyuh any chance you could land this soon?

Sorry I just flied back to the Bay Area. The serialization support for FP16 LSTM doesn't work yet. The current PR broke the unit test. Need to check the possible solution for the different return value between FP16 and INT8.

@jamesr66a
Copy link
Collaborator

@jianyuh can you just switch back to the scheme you had before (duplicating weight values)?

@jianyuh
Copy link
Member Author

jianyuh commented Sep 23, 2019

@jianyuh can you just switch back to the scheme you had before (duplicating weight values)?

Sure! Will check out that version.

…STM"

We would like to add the serialization support for FP16 LSTM.

Differential Revision: [D17391638](https://our.internmc.facebook.com/intern/diff/D17391638/)

[ghstack-poisoned]
jianyuh added a commit that referenced this pull request Sep 24, 2019
Pull Request resolved: #26378

We would like to add the serialization support for FP16 LSTM.
ghstack-source-id: 90650081

Differential Revision: [D17391638](https://our.internmc.facebook.com/intern/diff/D17391638/)
@jianyuh jianyuh changed the title [WIP][pt1][quant] Add the serialization support for FP16 LSTM [pt1][quant] Add the serialization support for FP16 LSTM Sep 24, 2019
@jianyuh
Copy link
Member Author

jianyuh commented Sep 24, 2019

@jamesr66a : Please check my current code. It now passes the unit test. However, it has both the duplication of orig_weights_values as well as orig_bias_values.

Maybe we should only store the names instead of values for all_weights? That is, we use the same register_buffer and only use get_attr to fetch those buffers (e.g., https://github.com/pytorch/pytorch/blob/master/torch/jit/quantized.py#L326)? In that way, we avoid the additional orig_bias_values for FP16 path.

@gottbrath
Copy link
Contributor

Jianyu, James, Dima -- Please discuss and decide if this is in scope for 1.3 or not.

@gottbrath
Copy link
Contributor

Jianyu, James, Dima -- Please discuss and decide if this is in scope for 1.3 or not.

sounds like the conclusion (in discussion elsewhere) is that this isn't in scope for 1.3.

@pytorchbot
Copy link
Collaborator

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
Stale pull requests will automatically be closed 30 days after being marked Stale

@github-actions github-actions bot closed this May 12, 2022
@facebook-github-bot facebook-github-bot deleted the gh/jianyuh/28/head branch June 11, 2022 14:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants