OPT - Fix Softmax NaN in half precision mode#17437
OPT - Fix Softmax NaN in half precision mode#17437younesbelkada merged 12 commits intohuggingface:mainfrom
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
So when running in Am I missing anything here? Is I have a PR #17306 for related issue. If using |
|
Hi @ydshieh ! |
Let me check - as if this is the case, the PR #17306 needs to find another way out 😢 |
|
Note also that everything works fine when |
|
The root cause is Change it to More detailsWith the above fix, there is still a minor issue. In for batch index 0, we will see an This is because we have In general, we shouldn't have or use |
|
Great! My suggestion is to mix both - we can force the attention mask to use -65504 for fp16 + upcast in fp32 and cast it back to fp16 after softmax for sanity check and avoid possible overflow issues. - Wdyt? |
|
@stephenroller @suchenzang have you seen something similar in your training / inference runs? Also cc @patil-suraj - see issue. Would be nice to hear your opinion here |
|
FYI, it can happen that during training you never use padding tokens. I may be mistaken but I know that for Bloom we do not train on padded batch inputs but on truncated sequences instead. |
|
Upcast to fp32 should never be required if masked tokens are masked with something that's not -inf. Upcast to fp32 is significant performance penalty. Single |
|
Great thank you all for your comments and help! Following your advice I have added the changes proposed by @ydshieh - let me know if this works for you! |
|
This change is also in #17306, but I am fine for a quick fix for I would still like to point out that, although it is not useful for real usage of the model, leaving non-zero large negative values mixed with |
Forgot to say, with current change, it's still possible to get |
|
We perform the upcast in our code, though we do it with softmax(dtype=torch.float32). It's very important. |
|
That's an excellent point, Stephen! Thank you for that crucial reminder. Indeed, for pytorch ops that support accumulate dtype arg this approach makes things much more efficient than manual casting. I remember I discovered that when optimizing the it made a huge difference. |
For learning purpose, could you share why |
|
Because the op kernel does it automatically internally in a single operation by already accumulating in the correct dtype. When you do it in 2 steps: @ngimel, did I explain that correctly? Thank you! and it should be simple to benchmark the 2 cases to see the difference. |
|
@Chillee, would |
|
My original understanding of the process is like: So I think the correct way should be: right? Another question regarding dtypeAfter we get I am talking about the case where a user loads the models in fp16 and specify the inputs in fp16 too:
(I am not sure this is the correct/usual way to do inference in fp16, but this is what I see in the code snippet from the issue reporter) |
|
I think the issue that this PR aims to address is not really about the upcast to float32. (@younesbelkada , right?) It is mentioned in the PR description as a potential solution, but the original issue we want to address here comes from the fact that we get a sequence with all Maybe it it better to move the discussion(s) regarding the upcasting to another issue/PR page. |
- fix bad rebase - add proposed final changes Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
- upcast to fp32 - dowcast to the original dtype - added a slow test Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
- fix quality - fix final nit
be6d4df to
22bcc84
Compare
|
I can confirm the batched generations works fine now, can we merge? |
stas00
left a comment
There was a problem hiding this comment.
Looks good now - thank you
|
Hi, as new code shows in modeling_opt.py at line 219 to 222, I wondered if it's equivalent to change these lines to In this case, we no longer need to upcast float16 to float32 and might speedup the training , and(or at leaest) the inference? |
|
cc @younesbelkada here |
|
hey! |
What does this PR do?
Fix overflow / unstable operation issues when using large OPT models in half precision
-3.24e+38) to mask the padded tokens. EDIT: it seems to we use correct values to mask padded tokenscc @patrickvonplaten @ArthurZucker @ydshieh @stas00