Skip to content

OPT - Fix Softmax NaN in half precision mode#17437

Merged
younesbelkada merged 12 commits intohuggingface:mainfrom
younesbelkada:opt-fix-softmax
Jun 29, 2022
Merged

OPT - Fix Softmax NaN in half precision mode#17437
younesbelkada merged 12 commits intohuggingface:mainfrom
younesbelkada:opt-fix-softmax

Conversation

@younesbelkada
Copy link
Contributor

@younesbelkada younesbelkada commented May 26, 2022

What does this PR do?

Fix overflow / unstable operation issues when using large OPT models in half precision

  • As it is done in Megatron-DeepSpeed, for large models it appears that you will have to first upcast the input to float32 before applying the Softmax function to avoid unexpected NaNs. This is because we use very large values (eg -3.24e+38) to mask the padded tokens. EDIT: it seems to we use correct values to mask padded tokens
  • Linked issue: OPT produce NaN during batched generation #17433
  • We'll probably need to re-compute the logits for slow tests but I am not sure

cc @patrickvonplaten @ArthurZucker @ydshieh @stas00

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented May 26, 2022

The documentation is not available anymore as the PR was closed or merged.

@ydshieh
Copy link
Collaborator

ydshieh commented May 26, 2022

Hi @younesbelkada

expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])

So when running in half precision, _expand_mask will use torch.finfo(dtype) with dtype = inputs_embeds.dtype = fp16 and the min is -65536.

Am I missing anything here?

Is fp32.min used unexpectedly instead of fp16.min in this particular issue?

I have a PR #17306 for related issue. If using -65536 has issue, then I need to hold on that PR to investigate.

@younesbelkada
Copy link
Contributor Author

Hi @ydshieh !
I think that you are right, when running in half precision I have -65530 and not -3.24e+38 in the attention mask as I said. But even with this mask I get NaNs on the padded hidden states for opt-1.3b, and upcasting the input to fp32 and casting back to fp16 seems to solve the issue for now

@ydshieh
Copy link
Collaborator

ydshieh commented May 26, 2022

Hi @ydshieh ! I think that you are right, when running in half precision I have -65530 and not -3.24e+38 in the attention mask as I said. But even with this mask I get NaNs on the padded hidden states for opt-1.3b, and upcasting the input to fp32 and casting back to fp16 seems to solve the issue for now

Let me check - as if this is the case, the PR #17306 needs to find another way out 😢

@ydshieh
Copy link
Collaborator

ydshieh commented May 26, 2022

I get NaNs on the padded hidden states for opt-1.3b,

@younesbelkada

  • Could you point me which line in OPTModel you got NaN for padded hidden states?
  • Did you use the generation script in the linked issue, or you just run the model with some input ids? If it is the later case, could you provide the code snippet 🙏 please?

@younesbelkada
Copy link
Contributor Author

younesbelkada commented May 26, 2022

  • I got NaNs exactly here:
    attn_weights = nn.functional.softmax(attn_weights, dim=-1)
    - to fix it you can just do attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len).float() + attention_mask here and then attn_weights = nn.functional.softmax(attn_weights, dim=-1).half() here
  • Yes use the generation script provided in the issue, ie:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# I have tested and the error happens to opt-1.3b, opt-2.7b, opt-6.7b, and opt-13b.
# opt-125m and opt-350m seems to work fine.
# I haven't tested opt-30b.
model_name = "facebook/opt-1.3b"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
tokenizer.padding_side = "left"
# It works when torch_dtype=torch.float32
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, use_cache=True)
model = model.eval().to("cuda")

batch = tokenizer(
    ["Who are you?", "Joe Biden is the president of"],
    padding=True, return_tensors="pt"
)

# It produces NaN in the early layers for the first sequence.
# I check the pattern, and NaN first appears in the padded token position.
greedy_output = model.generate(
    input_ids=batch["input_ids"].to("cuda"),
    attention_mask=batch["attention_mask"].to("cuda"),
    do_sample=False, top_k=0
)

Note also that everything works fine when torch_dtype is set to torch.float32 or torch.bfloat16

@ydshieh
Copy link
Collaborator

ydshieh commented May 26, 2022

@younesbelkada

The root cause is -inf is used here

mask = torch.full((tgt_len, tgt_len), float("-inf"))

Change it to mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min) should be fine.
(+/- inf * 0.0 will result NaN ).

More details

With the above fix, there is still a minor issue. In

return combined_attention_mask

for batch index 0, we will see an -inf

tensor([[[[-65504.,    -inf, -65504., -65504., -65504., -65504., -65504.],
          [-65504., -65504., -65504., -65504., -65504., -65504., -65504.],
          [-65504., -65504.,      0., -65504., -65504., -65504., -65504.],
          [-65504., -65504.,      0.,      0., -65504., -65504., -65504.],
          [-65504., -65504.,      0.,      0.,      0., -65504., -65504.],
          [-65504., -65504.,      0.,      0.,      0.,      0., -65504.],
          [-65504., -65504.,      0.,      0.,      0.,      0.,      0.]]],

This is because we have -65504 for causal mask + -65504 due to (left) padding.
Regarding this part, we need to discuss with the team.

In general, we shouldn't have or use -inf (the only safe place to use it is immediately before the softmax).

@ydshieh ydshieh mentioned this pull request May 26, 2022
@younesbelkada
Copy link
Contributor Author

younesbelkada commented May 26, 2022

Great! My suggestion is to mix both - we can force the attention mask to use -65504 for fp16 + upcast in fp32 and cast it back to fp16 after softmax for sanity check and avoid possible overflow issues. - Wdyt?

@patrickvonplaten
Copy link
Contributor

@stephenroller @suchenzang have you seen something similar in your training / inference runs?

Also cc @patil-suraj - see issue. Would be nice to hear your opinion here

@younesbelkada
Copy link
Contributor Author

FYI, it can happen that during training you never use padding tokens. I may be mistaken but I know that for Bloom we do not train on padded batch inputs but on truncated sequences instead.
Usually these issues can happen at inference time only!

@ngimel
Copy link

ngimel commented May 26, 2022

Upcast to fp32 should never be required if masked tokens are masked with something that's not -inf. Upcast to fp32 is significant performance penalty. Single -inf value shouldn't be a problem as long as there are some non-zero values in the row, it would change output a little bit but that output is meaningless anyway, the whole row is masked out.

@younesbelkada
Copy link
Contributor Author

Great thank you all for your comments and help! Following your advice I have added the changes proposed by @ydshieh - let me know if this works for you!

@ydshieh
Copy link
Collaborator

ydshieh commented May 27, 2022

This change is also in #17306, but I am fine for a quick fix for OPTModel.

I would still like to point out that, although it is not useful for real usage of the model, leaving non-zero large negative values mixed with -inf to mask a whole sequence is not good for testing/debugging purpose -> but this could be addressed in another PR.

Copy link
Collaborator

@ydshieh ydshieh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@ydshieh
Copy link
Collaborator

ydshieh commented May 27, 2022

This change is also in #17306, but I am fine for a quick fix for OPTModel.

I would still like to point out that, although it is not useful for real usage of the model, leaving non-zero large negative values mixed with -inf to mask a whole sequence is not good for testing/debugging purpose -> but this could be addressed in another PR.

Forgot to say, with current change, it's still possible to get [-inf, -inf, dtype.min, dtype.min ...] or [-inf, -inf, -inf] etc. after summing with the attn_weights (as mentioned, this depends the values in attn_weights). I will try to implement some processing in #17306 today.

@stephenroller
Copy link

We perform the upcast in our code, though we do it with softmax(dtype=torch.float32). It's very important.

@stas00
Copy link
Contributor

stas00 commented May 27, 2022

That's an excellent point, Stephen! Thank you for that crucial reminder.

Indeed, for pytorch ops that support accumulate dtype arg this approach makes things much more efficient than manual casting.

I remember I discovered that when optimizing the LabelSmoother:

smoothed_loss = log_probs.sum(dim=-1, keepdim=True, dtype=torch.float32)

it made a huge difference.

@stas00 stas00 mentioned this pull request May 27, 2022
3 tasks
@ydshieh
Copy link
Collaborator

ydshieh commented May 27, 2022

accumulate dtype arg this approach makes things much

For learning purpose, could you share why using softmax(dtype=torch.float32) is more efficient than explicit upcasting?

@stas00
Copy link
Contributor

stas00 commented May 27, 2022

Because the op kernel does it automatically internally in a single operation by already accumulating in the correct dtype.

When you do it in 2 steps: op(...).to(dtype=...), 2 additional memory copying operations have to happen to perform the casting.

@ngimel, did I explain that correctly? Thank you!

and it should be simple to benchmark the 2 cases to see the difference.

@stas00
Copy link
Contributor

stas00 commented May 27, 2022

@Chillee, would nvfuser fuse explicit casting into the op's accumulate dtype automatically?

@ydshieh
Copy link
Collaborator

ydshieh commented May 27, 2022

My original understanding of the process is like:

attn_scores =  attn_scores.to(torch.float32)
attn_prob =  nn.functional.softmax(attn_scores)

So I think the correct way should be:

attn_prob =  nn.functional.softmax(attn_scores, dtype=torch.float32)

right?

Another question regarding dtype

After we get attn_prob in float32, should we cast it back to the target precision for the subsequential ops, like

attn_output = torch.bmm(attn_probs, value_states)

I am talking about the case where a user loads the models in fp16 and specify the inputs in fp16 too:

model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
  • If we don't cast attn_probs back to the target type (here fp16)
    • it will fail (if value_states is fp16) for some op like torch.bmm
    • or will propagate the type fp32 for some simple ops (like +)

(I am not sure this is the correct/usual way to do inference in fp16, but this is what I see in the code snippet from the issue reporter)

@ydshieh
Copy link
Collaborator

ydshieh commented May 27, 2022

I think the issue that this PR aims to address is not really about the upcast to float32. (@younesbelkada , right?)

It is mentioned in the PR description as a potential solution, but the original issue we want to address here comes from the fact that we get a sequence with all -inf as attention scores before softmax.

Maybe it it better to move the discussion(s) regarding the upcasting to another issue/PR page.

ydshieh and others added 6 commits June 29, 2022 10:55
- fix bad rebase
- add proposed final changes

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
- upcast to fp32
- dowcast to the original dtype
- added a slow test

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
- fix quality
- fix final nit
@younesbelkada
Copy link
Contributor Author

I can confirm the batched generations works fine now, can we merge?

Copy link
Contributor

@stas00 stas00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good now - thank you

@younesbelkada younesbelkada merged commit d444edb into huggingface:main Jun 29, 2022
@patrickvonplaten patrickvonplaten deleted the opt-fix-softmax branch June 29, 2022 22:49
viclzhu pushed a commit to viclzhu/transformers that referenced this pull request Jul 18, 2022
@oscmansan oscmansan mentioned this pull request Jul 18, 2022
4 tasks
@itsucks
Copy link

itsucks commented Aug 29, 2022

Hi, as new code shows

if dtype_attn_weights == torch.float16:
    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(dtype_attn_weights)
else:
    attn_weights = nn.functional.softmax(attn_weights, dim=-1)

in modeling_opt.py at line 219 to 222, I wondered if it's equivalent to change these lines to

attn_weights = attn_weights-attn_weights.max(-1, keepdim=True).values
attn_weights = nn.functional.softmax(attn_weights, dim=-1)

In this case, we no longer need to upcast float16 to float32 and might speedup the training , and(or at leaest) the inference?

@patrickvonplaten
Copy link
Contributor

cc @younesbelkada here

@younesbelkada
Copy link
Contributor Author

younesbelkada commented Sep 29, 2022

hey!
Sorry for responding here late, indeed it seems to be equivalent, ie the slow test that verifiies this specific issue successfully pass with your proposed change.
Recall that the most crucial part to solve the issue is this line: attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) (at least as in today) and we decided to keep the nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32) to make the implementation consistent with original OPT as mentioned above.
However, your proposed changes break 2 torch.fx tests and I did not dig further into that, maybe if you open a PR we could discuss that into more details
Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants