Conversation
| def __init__(self, *args, **kwargs): | ||
| is_causal = kwargs.pop("is_causal", False) | ||
| super().__init__(*args, **kwargs) | ||
| self.is_causal = is_causal |
There was a problem hiding this comment.
To directly provide is_causal to F.scale_dot_product_attention.
| if config._attn_implementation == "sdpa": | ||
| self.self_attn = CLIP_ATTENTION_CLASSES[config._attn_implementation](config, is_causal=is_causal) | ||
| else: | ||
| self.self_attn = CLIP_ATTENTION_CLASSES[config._attn_implementation](config) |
There was a problem hiding this comment.
We don't use causal masking in the vision tower of CLIP hence this conditioning.
| text_config = config.text_config | ||
| text_config._attn_implementation = config._attn_implementation | ||
| vision_config = config.vision_config | ||
| vision_config._attn_implementation = config._attn_implementation |
There was a problem hiding this comment.
The image classification class has:
If we don't propagate vision_config._attn_implementation = config._attn_implementation the vision config won't have any way to know about the actual _attn_implementation.
However, correct me if I am wrong.
There was a problem hiding this comment.
You're completely right! However, the way it's propogated in other models is by passing in the model construction e.g.
self.text_model = CLIPTextTransformer.from_config(text_config, attn_implementation=config._attn_implementation)I'd encourage doing it this way for two reasons:
- It's the pattern done for other models. If there's an issue we need to update in the code, we're more likely to find it if it matches
- The way the attention implementation is set on the config is less than ideal and (imo) prone to unexpected behaviour. Annoyingly, there's a bunch of magic which happens in the setter, which can cause it to appear to magically change or revert back. This way I'm sure (atm) works, I'm not sure that setting on the config like this and then passing to the model will always work, depending on the attention implementation the original config e.g.
vision_config
There was a problem hiding this comment.
Done in 89dab66. Keeping this comment open because it seems like an important thing for us to consider.
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
Thanks for working on this @sayakpaul! Let us know when it's ready for review i.e. when all tests are passing and there's no more changes to be made. |
|
Oh I thought the PR description made it clear. Sorry if I didn't.
So, concretely, it would be great to have some initial reviews first so that we can localize why the basic logit assertion test isn't passing. |
There was a problem hiding this comment.
IMO we should avoid the complexe prepare_4d etc, and just use the _update_causal_mask. Now that's not super possible as would require deprecating, but anyways, not against this PR, let's just make sure we have equivalence and support dispatching to the appropriate kernels!
| query_states, | ||
| key_states, | ||
| value_states, | ||
| attn_mask=attention_mask, |
There was a problem hiding this comment.
you seem to only be using the attention mask, vs using the causal mask and the attention mask
@ArthurZucker if it’s easier/benefitting for the library, happy to use the methods you are suggesting. But I don’t get why things would need deprecating, etc. If you could provide more reference, that would be helpful. |
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
|
I disabled However, this leads to: Amongst these, failures for |
|
This passes: from transformers import AutoTokenizer, CLIPTextModel
import torch
model_sdpa = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32", attn_implementation="sdpa").to("cuda")
model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32", attn_implementation="eager").to("cuda")
tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
inputs = tokenizer(["a photo of a cat", "a photo of a dog"], return_tensors="pt")
print(inputs["attention_mask"].tolist())
inputs = {k: v.to("cuda") for k, v in inputs.items()}
with torch.no_grad():
outputs_sdpa = model_sdpa(**inputs)
last_hidden_state_sdpa = outputs_sdpa.last_hidden_state
outputs_eager = model(**inputs)
last_hidden_state = outputs_eager.last_hidden_state
print(last_hidden_state_sdpa[0, :3, -1].flatten())
print(last_hidden_state[0, :3, -1].flatten())
print(torch.allclose(last_hidden_state_sdpa, last_hidden_state, rtol=1e-3, atol=1e-3)) |
ArthurZucker
left a comment
There was a problem hiding this comment.
Thanks for the PR, could you share the expected speed boost benchmark? (and add it to the clip.md) 🤗
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
|
@ArthurZucker done :) |
|
@amyeroberts am I doing something wrong when running Here's what I have done:
|
|
@sayakpaul Could you try running |
|
@sayakpaul Huh, weird, I haven't seen that before. I'm going to try and re-trigger a fresh CI run, as it doesn't seem to be anything related to this PR |
amyeroberts
left a comment
There was a problem hiding this comment.
Just some final comments on the propogation of attn_implementation
amyeroberts
left a comment
There was a problem hiding this comment.
Thanks for adding this!
Sorry for the bad suggestion re from_config - I didn't realise it's from the autoclass. Switching back to _from_config as you had it before should resolve!
|
@amyeroberts I had to touch a couple of loading utilities to make sure the equivalence tests pass. LMK if they should have been approached differently. |
|
Ah oh, the FLAX tests still fail the equivalence. I tried a bunch of state dict rejigging in order for the PT related changes to propagate in the FLAX model but none of them worked out. Would appreciate some guidance. |
|
Pinging @sanchit-gandhi here - who knows most of the intricacies of our flax models :) |
| query_states = self.q_proj(hidden_states) | ||
| key_states = self.k_proj(hidden_states) | ||
| value_states = self.v_proj(hidden_states) | ||
|
|
||
| query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) | ||
| key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) | ||
| value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) |
There was a problem hiding this comment.
We should be careful about using them (non-)contiguous since torch has a bug in at least version 2.1.2. See the reference given in the llama implementation.
transformers/src/transformers/models/llama/modeling_llama.py
Lines 637 to 642 in 6bd511a
So calling .contiguous() here (with given checks) could be a solution like demonstrated above. You could also move the whole projections into something like
query_states = self._shape(self.q_proj(hidden_states), -1, bsz)
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)which automatically calls contiguous under the _shape function.
|
Closing in favor of #31940. |

What does this PR do?
CLIP is heavily used in the diffusion modeling space for text encoding. So, having native SDPA support for CLIP would be beneficial for diffusion models both for training and inference.
The test failures can be tackled later once we match the logits to non-SDPA. Here's my test script:
These don't pass:
I inspected the
key_states,value_states, and thequery_statesinCLIPAttentionandCLIPSdpaAttention, respectively. In both classes, these matrices are of same value. The differences start arising fromattn_output.I suspect this is happening because of how masking is handled in
CLIPAttention. We first apply causal mask:transformers/src/transformers/models/clip/modeling_clip.py
Line 285 in 8b02bb6
And then we apply attention mask:
transformers/src/transformers/models/clip/modeling_clip.py
Line 295 in 8b02bb6
But I think it deviates a bit from
CLIPSdpaAttention. Not very sure, though. Hence I am opening this PR for seeking feedback.I have added some comments in line to provide further clarification on some points.