Skip to content

[CLIP] add: sdpa support to clip.#30390

Closed
sayakpaul wants to merge 40 commits intomainfrom
add-clip-sdpa
Closed

[CLIP] add: sdpa support to clip.#30390
sayakpaul wants to merge 40 commits intomainfrom
add-clip-sdpa

Conversation

@sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Apr 22, 2024

What does this PR do?

CLIP is heavily used in the diffusion modeling space for text encoding. So, having native SDPA support for CLIP would be beneficial for diffusion models both for training and inference.

The test failures can be tackled later once we match the logits to non-SDPA. Here's my test script:

from transformers import AutoTokenizer, CLIPTextModel
import torch

model_sdpa = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32", attn_implementation="sdpa").to("cuda")
tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")

inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
inputs = {k: v.to("cuda") for k, v in inputs.items()}

with torch.no_grad():
    outputs = model_sdpa(**inputs)
    last_hidden_state_sdpa = outputs.last_hidden_state

model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32", attn_implementation="eager").to("cuda")

with torch.no_grad():
    outputs = model(**inputs)
    last_hidden_state = outputs.last_hidden_state

print(last_hidden_state_sdpa[0, :3, -1].flatten())
print(last_hidden_state[0, :3, -1].flatten())
print(torch.allclose(last_hidden_state_sdpa, last_hidden_state, rtol=1e-2, atol=1e-2))

These don't pass:

tensor([ 0.1013,  0.0805, -0.8429], device='cuda:0')
tensor([ 0.1013, -0.9801, -1.9444], device='cuda:0')
False

I inspected the key_states, value_states, and the query_states in CLIPAttention and CLIPSdpaAttention, respectively. In both classes, these matrices are of same value. The differences start arising from attn_output.

SDPA attn_output[0, :3, -1]=tensor([0.3397, 0.2862, 0.5571], device='cuda:0')
SDPA attn_output[0, :3, -1]=tensor([0.0143, 0.0158, 0.0971], device='cuda:0')
SDPA attn_output[0, :3, -1]=tensor([0.0082, 0.0130, 0.1147], device='cuda:0')

Non SDPA attn_output[0, :3, -1]=tensor([0.3397, 0.2576, 0.7800], device='cuda:0')
Non SDPA attn_output[0, :3, -1]=tensor([ 0.0143,  0.0069, -0.0258], device='cuda:0')
Non SDPA attn_output[0, :3, -1]=tensor([ 0.0082,  0.0046, -0.0023], device='cuda:0')

I suspect this is happening because of how masking is handled in CLIPAttention. We first apply causal mask:

# apply the causal_attention_mask first

And then we apply attention mask:

if attention_mask is not None:

But I think it deviates a bit from CLIPSdpaAttention. Not very sure, though. Hence I am opening this PR for seeking feedback.

I have added some comments in line to provide further clarification on some points.

def __init__(self, *args, **kwargs):
is_causal = kwargs.pop("is_causal", False)
super().__init__(*args, **kwargs)
self.is_causal = is_causal
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To directly provide is_causal to F.scale_dot_product_attention.

Comment on lines +427 to +430
if config._attn_implementation == "sdpa":
self.self_attn = CLIP_ATTENTION_CLASSES[config._attn_implementation](config, is_causal=is_causal)
else:
self.self_attn = CLIP_ATTENTION_CLASSES[config._attn_implementation](config)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't use causal masking in the vision tower of CLIP hence this conditioning.

text_config = config.text_config
text_config._attn_implementation = config._attn_implementation
vision_config = config.vision_config
vision_config._attn_implementation = config._attn_implementation
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The image classification class has:

self.vision_model = CLIPVisionTransformer(config.vision_config)

If we don't propagate vision_config._attn_implementation = config._attn_implementation the vision config won't have any way to know about the actual _attn_implementation.

However, correct me if I am wrong.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're completely right! However, the way it's propogated in other models is by passing in the model construction e.g.

self.text_model = CLIPTextTransformer.from_config(text_config, attn_implementation=config._attn_implementation)

I'd encourage doing it this way for two reasons:

  • It's the pattern done for other models. If there's an issue we need to update in the code, we're more likely to find it if it matches
  • The way the attention implementation is set on the config is less than ideal and (imo) prone to unexpected behaviour. Annoyingly, there's a bunch of magic which happens in the setter, which can cause it to appear to magically change or revert back. This way I'm sure (atm) works, I'm not sure that setting on the config like this and then passing to the model will always work, depending on the attention implementation the original config e.g. vision_config

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 89dab66. Keeping this comment open because it seems like an important thing for us to consider.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@amyeroberts
Copy link
Contributor

Thanks for working on this @sayakpaul! Let us know when it's ready for review i.e. when all tests are passing and there's no more changes to be made.

@sayakpaul
Copy link
Member Author

Oh I thought the PR description made it clear. Sorry if I didn't.

The test failures can be tackled later once we match the logits to non-SDPA.

Not very sure, though. Hence I am opening this PR for seeking feedback.

So, concretely, it would be great to have some initial reviews first so that we can localize why the basic logit assertion test isn't passing.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO we should avoid the complexe prepare_4d etc, and just use the _update_causal_mask. Now that's not super possible as would require deprecating, but anyways, not against this PR, let's just make sure we have equivalence and support dispatching to the appropriate kernels!

query_states,
key_states,
value_states,
attn_mask=attention_mask,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you seem to only be using the attention mask, vs using the causal mask and the attention mask

@sayakpaul
Copy link
Member Author

IMO we should avoid the complexe prepare_4d etc, and just use the _update_causal_mask. Now that's not super possible as would require deprecating, but anyways, not against this PR,

@ArthurZucker if it’s easier/benefitting for the library, happy to use the methods you are suggesting. But I don’t get why things would need deprecating, etc. If you could provide more reference, that would be helpful.

@sayakpaul
Copy link
Member Author

sayakpaul commented Apr 23, 2024

I disabled is_causal in F.scaled_dot_product_attention() and decided to handle the masking with the factory methods from the library itself. This is because CLIP applies BOTH causal mask and regular attention mask:

https://github.com/huggingface/transformers/blob/092f1fdaa4224fdd88c616dc9678e6fcb37bfffd/src/transformers/models/clip/modeling_clip.py#L287C1-L303C85

However, this leads to:

=========================== short test summary info ============================
FAILED tests/models/clip/test_modeling_clip.py::CLIPVisionModelTest::test_eager_matches_sdpa_inference_1_bfloat16
FAILED tests/models/clip/test_modeling_clip.py::CLIPTextModelTest::test_eager_matches_sdpa_inference_1_bfloat16
FAILED tests/models/clip/test_modeling_clip.py::CLIPTextModelTest::test_eager_matches_sdpa_inference_2_float32
FAILED tests/models/clip/test_modeling_clip.py::CLIPTextModelTest::test_sdpa_can_dispatch_on_flash
FAILED tests/models/clip/test_modeling_clip.py::CLIPModelTest::test_eager_matches_sdpa_inference_0_float16
FAILED tests/models/clip/test_modeling_clip.py::CLIPModelTest::test_eager_matches_sdpa_inference_1_bfloat16
FAILED tests/models/clip/test_modeling_clip.py::CLIPModelTest::test_eager_matches_sdpa_inference_2_float32
FAILED tests/models/clip/test_modeling_clip.py::CLIPModelTest::test_sdpa_can_dispatch_on_flash
FAILED tests/models/clip/test_modeling_clip.py::CLIPForImageClassificationModelTest::test_eager_matches_sdpa_inference_1_bfloat16
========== 9 failed, 208 passed, 159 skipped, 59 warnings in 40.40s ===========

Amongst these, failures for test_sdpa_can_dispatch_on_flash() are of particular interest. Would appreciate some pointers on approaching this. Would be also helpful to validate the treatment I am giving to the masks when using CLIPSdpaAttention.

@sayakpaul
Copy link
Member Author

This passes:

from transformers import AutoTokenizer, CLIPTextModel
import torch

model_sdpa = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32", attn_implementation="sdpa").to("cuda")
model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32", attn_implementation="eager").to("cuda")
tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")

inputs = tokenizer(["a photo of a cat", "a photo of a dog"], return_tensors="pt")
print(inputs["attention_mask"].tolist())
inputs = {k: v.to("cuda") for k, v in inputs.items()}

with torch.no_grad():
    outputs_sdpa = model_sdpa(**inputs)
    last_hidden_state_sdpa = outputs_sdpa.last_hidden_state
    outputs_eager = model(**inputs)
    last_hidden_state = outputs_eager.last_hidden_state


print(last_hidden_state_sdpa[0, :3, -1].flatten())
print(last_hidden_state[0, :3, -1].flatten())
print(torch.allclose(last_hidden_state_sdpa, last_hidden_state, rtol=1e-3, atol=1e-3))

@sayakpaul sayakpaul requested a review from ArthurZucker April 24, 2024 04:05
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR, could you share the expected speed boost benchmark? (and add it to the clip.md) 🤗

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
@sayakpaul sayakpaul marked this pull request as ready for review April 26, 2024 14:38
@sayakpaul sayakpaul changed the title [WIP] [CLIP] add: sdpa support to clip. [CLIP] add: sdpa support to clip. Apr 26, 2024
@sayakpaul
Copy link
Member Author

@ArthurZucker done :)

@sayakpaul
Copy link
Member Author

@amyeroberts am I doing something wrong when running make style && make quality?

Here's what I have done:

  • Created a fresh environment for the formatting related libs.
  • Ran pip install -e ".[quality]" inside the env.
  • Then from the root of transformers, ran make style && make quality.

@amyeroberts
Copy link
Contributor

@sayakpaul Could you try running make fixup?

@sayakpaul
Copy link
Member Author

Leads to:

image

@amyeroberts
Copy link
Contributor

@sayakpaul Huh, weird, I haven't seen that before. I'm going to try and re-trigger a fresh CI run, as it doesn't seem to be anything related to this PR

Copy link
Contributor

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some final comments on the propogation of attn_implementation

@sayakpaul sayakpaul requested a review from amyeroberts May 24, 2024 13:17
Copy link
Contributor

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this!

Sorry for the bad suggestion re from_config - I didn't realise it's from the autoclass. Switching back to _from_config as you had it before should resolve!

@sayakpaul
Copy link
Member Author

@amyeroberts I had to touch a couple of loading utilities to make sure the equivalence tests pass. LMK if they should have been approached differently.

@sayakpaul sayakpaul requested a review from amyeroberts May 25, 2024 03:36
@sayakpaul
Copy link
Member Author

Ah oh, the FLAX tests still fail the equivalence. I tried a bunch of state dict rejigging in order for the PT related changes to propagate in the FLAX model but none of them worked out. Would appreciate some guidance.

@amyeroberts
Copy link
Contributor

Pinging @sanchit-gandhi here - who knows most of the intricacies of our flax models :)

Comment on lines +365 to +371
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be careful about using them (non-)contiguous since torch has a bug in at least version 2.1.2. See the reference given in the llama implementation.

# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and causal_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()

So calling .contiguous() here (with given checks) could be a solution like demonstrated above. You could also move the whole projections into something like

        query_states = self._shape(self.q_proj(hidden_states), -1, bsz)
        key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
        value_states = self._shape(self.v_proj(hidden_states), -1, bsz)

which automatically calls contiguous under the _shape function.

@huggingface huggingface deleted a comment from github-actions bot Jun 24, 2024
@qubvel qubvel mentioned this pull request Jul 12, 2024
4 tasks
@sayakpaul
Copy link
Member Author

Closing in favor of #31940.

@sayakpaul sayakpaul closed this Jul 18, 2024
@sayakpaul sayakpaul deleted the add-clip-sdpa branch July 25, 2024 08:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants