[pt1][quant] Add the serialization support for FP16 LSTM#26378
[pt1][quant] Add the serialization support for FP16 LSTM#26378jianyuh wants to merge 3 commits intogh/jianyuh/28/basefrom
Conversation
We would like to add the serialization support for FP16 LSTM. Differential Revision: [D17391638](https://our.internmc.facebook.com/intern/diff/D17391638/) [ghstack-poisoned]
We would like to add the serialization support for FP16 LSTM. Differential Revision: [D17391638](https://our.internmc.facebook.com/intern/diff/D17391638/) ghstack-source-id: 90265148 Pull Request resolved: #26378
| # however there is a JIT compilation error without it. This is just used to | ||
| # workaround that error. | ||
| if dtype == torch.qint8: | ||
| self._orig_weight_values = self._all_weight_values |
There was a problem hiding this comment.
Can we just assign an annotated empty list here?
self._orig_weight_values = torch.jit.annotate(List[torch.Tensor], [])
Not sure if it works but it's worth a shot. I think the values should be deduplicated in the pickler logic, but it's probably better not to have this unused stuff hanging around in memory as well
There was a problem hiding this comment.
I tried before but it didn't work. JIT doesn't allow to return different data types for different branches. The error message is
> ...
> Type mismatch: dynamic_vals is set to type List[Tuple[Tensor, Optional[Tensor]]] in the true branch and type List[Tensor] in the false branch:
> at /data/users/jianyuhuang/fbsource/fbcode/buck-out/dev/gen/caffe2/test/quantization#binary,link-tree/torch/nn/quantized/dynamic/modules/rnn.py:183:8
> self.batch_first,
> self.dropout,
> self.bidirectional,
> self._all_weight_names,
> self.__overloads__,
> self.training,
> self.dtype,
> )
>
> if self.dtype == torch.qint8:
> ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~... <--- HERE
>
> dynamic_vals = torch.jit.annotate(List[Tuple[torch.Tensor, Optional[torch.Tensor]]],
> [])
>
> ...
…STM" We would like to add the serialization support for FP16 LSTM. Differential Revision: [D17391638](https://our.internmc.facebook.com/intern/diff/D17391638/) [ghstack-poisoned]
Pull Request resolved: #26378 We would like to add the serialization support for FP16 LSTM. ghstack-source-id: 90479104 Differential Revision: [D17391638](https://our.internmc.facebook.com/intern/diff/D17391638/)
|
@jianyuh any chance you could land this soon? |
|
Oh I guess it doesn't actually work: |
Sorry I just flied back to the Bay Area. The serialization support for FP16 LSTM doesn't work yet. The current PR broke the unit test. Need to check the possible solution for the different return value between FP16 and INT8. |
|
@jianyuh can you just switch back to the scheme you had before (duplicating weight values)? |
Sure! Will check out that version. |
…STM" We would like to add the serialization support for FP16 LSTM. Differential Revision: [D17391638](https://our.internmc.facebook.com/intern/diff/D17391638/) [ghstack-poisoned]
Pull Request resolved: #26378 We would like to add the serialization support for FP16 LSTM. ghstack-source-id: 90650081 Differential Revision: [D17391638](https://our.internmc.facebook.com/intern/diff/D17391638/)
|
@jamesr66a : Please check my current code. It now passes the unit test. However, it has both the duplication of Maybe we should only store the names instead of values for |
|
Jianyu, James, Dima -- Please discuss and decide if this is in scope for 1.3 or not. |
sounds like the conclusion (in discussion elsewhere) is that this isn't in scope for 1.3. |
|
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
Stack from ghstack:
We would like to add the serialization support for FP16 LSTM.
Differential Revision: D17391638