Make Transformers more torch-exportable and dynamo-friendly#42317
Make Transformers more torch-exportable and dynamo-friendly#42317ArthurZucker merged 70 commits intomainfrom
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| offsets = image_grid_thw[:, 1] * image_grid_thw[:, 2] # (num_patches_h, num_patches_w) | ||
| pixel_values = torch.cat( | ||
| [pixel_sequence[:offset] for pixel_sequence, offset in zip(pixel_values, offsets)], | ||
| dim=0, | ||
| ) # (num_patches_h * num_patches_w, pixel_values) | ||
| offsets = image_grid_thw[:, 1] * image_grid_thw[:, 2] # (batch_size,) | ||
| arange = torch.arange(pixel_values.shape[1], device=offsets.device) # (max_len,) | ||
| mask = arange.unsqueeze(0) < offsets.unsqueeze(1) # (batch_size, max_len) | ||
| pixel_values = pixel_values[mask] # (total_valid_patches, channels, height, width) |
There was a problem hiding this comment.
avoiding looping over tensor
| for patch in pixel_values | ||
| ] | ||
| return patch_embeddings | ||
| return self.vision_embed_tokens(pixel_values) |
There was a problem hiding this comment.
need opinion about this
There was a problem hiding this comment.
Don't know why I'm seeing this only now 👴 from what I remember pixel_values for that model is a list of Tensors hence the weird list comp, if tests pass however it should be ~ok!
There was a problem hiding this comment.
Pull Request Overview
This PR makes Transformers more export-friendly by introducing torch_check for dynamic assertions and implementing various export-related optimizations.
Key Changes
- Introduces a new
torch_checkutility function that wrapstorch._checkto enable export-friendly error checking - Replaces
raise ValueErrorwithtorch_checkacross numerous models for runtime validation - Implements performance optimizations including vectorizing batch operations, simplifying list comprehensions, and fixing instance variable assignments
- Corrects error messages (e.g., "Videos features and image tokens" → "Video features and video tokens")
- Adds proper training guards for weight clamping operations
Reviewed Changes
Copilot reviewed 87 out of 87 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| src/transformers/utils/import_utils.py | Adds torch_check function wrapper around torch._check |
| src/transformers/utils/init.py | Exports the new torch_check function |
| src/transformers/models//modeling_.py | Replaces ValueError raises with torch_check calls (50+ files) |
| src/transformers/models/idefics3/modeling_idefics3.py | Vectorizes position embedding computation from loop to batched operations |
| src/transformers/models/llava_next_video/modeling_llava_next_video.py | Fixes bug where instance variables were set in forward method |
| src/transformers/models/timesfm/modeling_timesfm.py | Simplifies frequency handling from loop to slice operation |
| src/transformers/models/tapas/modeling_tapas.py | Fixes tensor shape construction bug |
| src/transformers/models/ctrl/modeling_ctrl.py | Converts pos_encoding to registered buffer |
| src/transformers/models/gemma3n/modeling_gemma3n.py | Guards weight clamping with training check |
| src/transformers/models/fuyu/modeling_fuyu.py | Simplifies get_image_features to remove unnecessary list comprehension |
| src/transformers/models/dac/modeling_dac.py | Adds explicit dtype to torch.full call |
| src/transformers/models/colqwen2/modeling_colqwen2.py | Vectorizes pixel value filtering with mask-based indexing |
| src/transformers/models/biogpt/modeling_biogpt.py | Simplifies position_ids computation |
ArthurZucker
left a comment
There was a problem hiding this comment.
My main comment is to use good default when you define the checking function this way most of the cases were use it or gonna be very simple.
Otherwise would be nice to ducment the good practices that you expose here, and potentially add a test in make repo-fix for simple rules.
Great work!
| f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" | ||
| ) | ||
| special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) | ||
| check_with( |
There was a problem hiding this comment.
I think this needs a better name! something that says "torch_compile_check" something explicit for users as to why we use this!
There was a problem hiding this comment.
I will name it torch_compilable_check as it is compilable without being bound to torch.compile, tell me if it works for you
| offsets = image_grid_thw[:, 1] * image_grid_thw[:, 2] # (num_patches_h, num_patches_w) | ||
| pixel_values = torch.cat( | ||
| [pixel_sequence[:offset] for pixel_sequence, offset in zip(pixel_values, offsets)], | ||
| dim=0, | ||
| ) # (num_patches_h * num_patches_w, pixel_values) | ||
| offsets = image_grid_thw[:, 1] * image_grid_thw[:, 2] # (batch_size,) | ||
| arange = torch.arange(pixel_values.shape[1], device=offsets.device) # (max_len,) | ||
| mask = arange.unsqueeze(0) < offsets.unsqueeze(1) # (batch_size, max_len) | ||
| pixel_values = pixel_values[mask] # (total_valid_patches, channels, height, width) |
| lambda: f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {n_image_features}", | ||
| ) |
There was a problem hiding this comment.
Given that you defined the function check with I think we should not have to use lambda here
There was a problem hiding this comment.
yes we can support both str and lambda returning a string (for when we want the message to only be evaluated if cond is false)
| "Make sure to align the spatial shapes with the sequence length of the encoder hidden states" | ||
| ) | ||
| check_with( | ||
| ValueError, |
There was a problem hiding this comment.
We should put the value error as a default because it seems to be used everywhere this way the more common cases were checked with function is used will be simplified
| position_ids = torch.clamp(position_ids, min=0).to(torch.long) | ||
|
|
||
| return attention_mask, position_ids.to(torch.long) | ||
| return attention_mask, position_ids |
There was a problem hiding this comment.
very nice work here!
| if attention_mask is not None: | ||
| hidden_states = hidden_states * attention_mask[:, -hidden_states.shape[-1] :].unsqueeze(1) | ||
| conv_state = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)) | ||
| cache_params.conv_states[self.layer_idx] = conv_state | ||
| hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) | ||
| if attention_mask is not None and not torch.all(attention_mask == 1): | ||
| if attention_mask is not None: |
There was a problem hiding this comment.
this change is weird same for the next one in this file
There was a problem hiding this comment.
the data-dependency on not torch.all(attention_mask == 1) breaks graphs, I can revert the change and try to find better alternatives later (in another PR).
There was a problem hiding this comment.
no I mean look at the two if else
|
|
||
| if isinstance(cond, torch.Tensor): | ||
| cond = cond.item() | ||
| torch._check_with(error_type, cond, msg) |
There was a problem hiding this comment.
yeah... that does sound good actually but only if we can catch to give a good detailed error!
|
LGTM, now just the flagged change that looks a bit weird (check the if else) |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: aria, aya_vision, bart, bigbird_pegasus, biogpt, chameleon, cohere2_vision, colqwen2, ctrl, d_fine, dac, deepseek_vl, deepseek_vl_hybrid, deformable_detr, emu3, ernie4_5_vl_moe |
|
View the CircleCI Test Summary for this PR: https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=42317&sha=99be85 |
…ace#42317) * make vlms export friendly * seq2seq lms * biogpt * more vlms * colqwen2 * vision models * more vlms * more vlms * more vlms * vectorized vision embedding * fixup * more vlms * more vlms * generate_masks_with_special_tokens_and_transfer_map * custom torch_check * use custom torch_check * revert grounding dino changes * fixup * remove file * undo * undo * testing * fixes * standard error message * use torch._check_with to raise value error instead of torch._check's runtime error * fix recurrent gemma * only itemize tensors * use spatial shapes list instead of tensor * fix udop use_cache default value * use tracable condition for seq2seq lms * make smolvlm exportable * fix fastvlm and t5gemma2 * fix qwen2_audio and idefics * remove script * tbc * skip mra model * helper * style and document * fix * set experts impl to batched * make xmod exportable and efficient * make more ssms exportable * fix * revert recurrent gemma * skip models that use chunked attention or rope_index * qwen3_next * assert async * tensorize (mm) grounding dino mask generation * style * fix repo * address comments * fix qwen2 audio and vits checks * skip two models using kernels by default * skip granite moe hybrid using custom kernels * disable mamba kernels * vits splinter and videomae
…ace#42317) * make vlms export friendly * seq2seq lms * biogpt * more vlms * colqwen2 * vision models * more vlms * more vlms * more vlms * vectorized vision embedding * fixup * more vlms * more vlms * generate_masks_with_special_tokens_and_transfer_map * custom torch_check * use custom torch_check * revert grounding dino changes * fixup * remove file * undo * undo * testing * fixes * standard error message * use torch._check_with to raise value error instead of torch._check's runtime error * fix recurrent gemma * only itemize tensors * use spatial shapes list instead of tensor * fix udop use_cache default value * use tracable condition for seq2seq lms * make smolvlm exportable * fix fastvlm and t5gemma2 * fix qwen2_audio and idefics * remove script * tbc * skip mra model * helper * style and document * fix * set experts impl to batched * make xmod exportable and efficient * make more ssms exportable * fix * revert recurrent gemma * skip models that use chunked attention or rope_index * qwen3_next * assert async * tensorize (mm) grounding dino mask generation * style * fix repo * address comments * fix qwen2 audio and vits checks * skip two models using kernels by default * skip granite moe hybrid using custom kernels * disable mamba kernels * vits splinter and videomae
What does this PR do?
First proposals include:
check_with(error_type, cond, lambda: msg)instead ofif cond: raise error_type(msg), which also works with torch.export/torch.compile to hint to the compiler that the condition is expected to be true at export/compile time.Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.