Skip to content

Make Transformers more torch-exportable and dynamo-friendly#42317

Merged
ArthurZucker merged 70 commits intomainfrom
export-friendly
Jan 22, 2026
Merged

Make Transformers more torch-exportable and dynamo-friendly#42317
ArthurZucker merged 70 commits intomainfrom
export-friendly

Conversation

@IlyasMoutawwakil
Copy link
Member

@IlyasMoutawwakil IlyasMoutawwakil commented Nov 21, 2025

What does this PR do?

First proposals include:

  • check_with(error_type, cond, lambda: msg) instead of if cond: raise error_type(msg), which also works with torch.export/torch.compile to hint to the compiler that the condition is expected to be true at export/compile time.
  • vectorization of some loops / comprehension lists into traceable, optimized and non-blocking versions.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@IlyasMoutawwakil IlyasMoutawwakil marked this pull request as draft November 21, 2025 08:04
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment on lines -152 to +155
offsets = image_grid_thw[:, 1] * image_grid_thw[:, 2] # (num_patches_h, num_patches_w)
pixel_values = torch.cat(
[pixel_sequence[:offset] for pixel_sequence, offset in zip(pixel_values, offsets)],
dim=0,
) # (num_patches_h * num_patches_w, pixel_values)
offsets = image_grid_thw[:, 1] * image_grid_thw[:, 2] # (batch_size,)
arange = torch.arange(pixel_values.shape[1], device=offsets.device) # (max_len,)
mask = arange.unsqueeze(0) < offsets.unsqueeze(1) # (batch_size, max_len)
pixel_values = pixel_values[mask] # (total_valid_patches, channels, height, width)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

avoiding looping over tensor

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very nic eindeed!

for patch in pixel_values
]
return patch_embeddings
return self.vision_embed_tokens(pixel_values)
Copy link
Member Author

@IlyasMoutawwakil IlyasMoutawwakil Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need opinion about this

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @molbap maybe (looks like this was added in #27007)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't know why I'm seeing this only now 👴 from what I remember pixel_values for that model is a list of Tensors hence the weird list comp, if tests pass however it should be ~ok!

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR makes Transformers more export-friendly by introducing torch_check for dynamic assertions and implementing various export-related optimizations.

Key Changes

  • Introduces a new torch_check utility function that wraps torch._check to enable export-friendly error checking
  • Replaces raise ValueError with torch_check across numerous models for runtime validation
  • Implements performance optimizations including vectorizing batch operations, simplifying list comprehensions, and fixing instance variable assignments
  • Corrects error messages (e.g., "Videos features and image tokens" → "Video features and video tokens")
  • Adds proper training guards for weight clamping operations

Reviewed Changes

Copilot reviewed 87 out of 87 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
src/transformers/utils/import_utils.py Adds torch_check function wrapper around torch._check
src/transformers/utils/init.py Exports the new torch_check function
src/transformers/models//modeling_.py Replaces ValueError raises with torch_check calls (50+ files)
src/transformers/models/idefics3/modeling_idefics3.py Vectorizes position embedding computation from loop to batched operations
src/transformers/models/llava_next_video/modeling_llava_next_video.py Fixes bug where instance variables were set in forward method
src/transformers/models/timesfm/modeling_timesfm.py Simplifies frequency handling from loop to slice operation
src/transformers/models/tapas/modeling_tapas.py Fixes tensor shape construction bug
src/transformers/models/ctrl/modeling_ctrl.py Converts pos_encoding to registered buffer
src/transformers/models/gemma3n/modeling_gemma3n.py Guards weight clamping with training check
src/transformers/models/fuyu/modeling_fuyu.py Simplifies get_image_features to remove unnecessary list comprehension
src/transformers/models/dac/modeling_dac.py Adds explicit dtype to torch.full call
src/transformers/models/colqwen2/modeling_colqwen2.py Vectorizes pixel value filtering with mask-based indexing
src/transformers/models/biogpt/modeling_biogpt.py Simplifies position_ids computation

@IlyasMoutawwakil IlyasMoutawwakil changed the title Make Transformers more export-friendly Make Transformers more torch-exportable and dynamo-friendly Jan 8, 2026
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My main comment is to use good default when you define the checking function this way most of the cases were use it or gonna be very simple.

Otherwise would be nice to ducment the good practices that you expose here, and potentially add a test in make repo-fix for simple rules.

Great work!

f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
check_with(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this needs a better name! something that says "torch_compile_check" something explicit for users as to why we use this!

Copy link
Member Author

@IlyasMoutawwakil IlyasMoutawwakil Jan 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will name it torch_compilable_check as it is compilable without being bound to torch.compile, tell me if it works for you

Comment on lines -152 to +155
offsets = image_grid_thw[:, 1] * image_grid_thw[:, 2] # (num_patches_h, num_patches_w)
pixel_values = torch.cat(
[pixel_sequence[:offset] for pixel_sequence, offset in zip(pixel_values, offsets)],
dim=0,
) # (num_patches_h * num_patches_w, pixel_values)
offsets = image_grid_thw[:, 1] * image_grid_thw[:, 2] # (batch_size,)
arange = torch.arange(pixel_values.shape[1], device=offsets.device) # (max_len,)
mask = arange.unsqueeze(0) < offsets.unsqueeze(1) # (batch_size, max_len)
pixel_values = pixel_values[mask] # (total_valid_patches, channels, height, width)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very nic eindeed!

Comment on lines +282 to +283
lambda: f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {n_image_features}",
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that you defined the function check with I think we should not have to use lambda here

Copy link
Member Author

@IlyasMoutawwakil IlyasMoutawwakil Jan 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes we can support both str and lambda returning a string (for when we want the message to only be evaluated if cond is false)

"Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
)
check_with(
ValueError,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should put the value error as a default because it seems to be used everywhere this way the more common cases were checked with function is used will be simplified

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense !

position_ids = torch.clamp(position_ids, min=0).to(torch.long)

return attention_mask, position_ids.to(torch.long)
return attention_mask, position_ids
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very nice work here!

Comment on lines +494 to +499
if attention_mask is not None:
hidden_states = hidden_states * attention_mask[:, -hidden_states.shape[-1] :].unsqueeze(1)
conv_state = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0))
cache_params.conv_states[self.layer_idx] = conv_state
hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len])
if attention_mask is not None and not torch.all(attention_mask == 1):
if attention_mask is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this change is weird same for the next one in this file

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the data-dependency on not torch.all(attention_mask == 1) breaks graphs, I can revert the change and try to find better alternatives later (in another PR).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no I mean look at the two if else


if isinstance(cond, torch.Tensor):
cond = cond.item()
torch._check_with(error_type, cond, msg)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah... that does sound good actually but only if we can catch to give a good detailed error!

@ArthurZucker
Copy link
Collaborator

LGTM, now just the flagged change that looks a bit weird (check the if else)

@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: aria, aya_vision, bart, bigbird_pegasus, biogpt, chameleon, cohere2_vision, colqwen2, ctrl, d_fine, dac, deepseek_vl, deepseek_vl_hybrid, deformable_detr, emu3, ernie4_5_vl_moe

@github-actions
Copy link
Contributor

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=42317&sha=99be85

@ArthurZucker ArthurZucker merged commit eff263c into main Jan 22, 2026
24 of 26 checks passed
@ArthurZucker ArthurZucker deleted the export-friendly branch January 22, 2026 09:07
vaibhav-research pushed a commit to vaibhav-research/transformers that referenced this pull request Jan 22, 2026
…ace#42317)

* make vlms export friendly

* seq2seq lms

* biogpt

* more vlms

* colqwen2

* vision models

* more vlms

* more vlms

* more vlms

* vectorized vision embedding

* fixup

* more vlms

* more vlms

* generate_masks_with_special_tokens_and_transfer_map

* custom torch_check

* use custom torch_check

* revert grounding dino changes

* fixup

* remove file

* undo

* undo

* testing

* fixes

* standard error message

* use torch._check_with to raise value error instead of torch._check's runtime error

* fix recurrent gemma

* only itemize tensors

* use spatial shapes list instead of tensor

* fix udop use_cache default value

* use tracable condition for seq2seq lms

* make smolvlm exportable

* fix fastvlm and t5gemma2

* fix qwen2_audio and idefics

* remove script

* tbc

* skip mra model

* helper

* style and document

* fix

* set experts impl to batched

* make xmod exportable and efficient

* make more ssms exportable

* fix

* revert recurrent gemma

* skip models that use chunked attention or rope_index

* qwen3_next

* assert async

* tensorize (mm) grounding dino mask generation

* style

* fix repo

* address comments

* fix qwen2 audio and vits checks

* skip two models using kernels by default

* skip granite moe hybrid using custom kernels

* disable mamba kernels

* vits splinter and videomae
SangbumChoi pushed a commit to SangbumChoi/transformers that referenced this pull request Jan 23, 2026
…ace#42317)

* make vlms export friendly

* seq2seq lms

* biogpt

* more vlms

* colqwen2

* vision models

* more vlms

* more vlms

* more vlms

* vectorized vision embedding

* fixup

* more vlms

* more vlms

* generate_masks_with_special_tokens_and_transfer_map

* custom torch_check

* use custom torch_check

* revert grounding dino changes

* fixup

* remove file

* undo

* undo

* testing

* fixes

* standard error message

* use torch._check_with to raise value error instead of torch._check's runtime error

* fix recurrent gemma

* only itemize tensors

* use spatial shapes list instead of tensor

* fix udop use_cache default value

* use tracable condition for seq2seq lms

* make smolvlm exportable

* fix fastvlm and t5gemma2

* fix qwen2_audio and idefics

* remove script

* tbc

* skip mra model

* helper

* style and document

* fix

* set experts impl to batched

* make xmod exportable and efficient

* make more ssms exportable

* fix

* revert recurrent gemma

* skip models that use chunked attention or rope_index

* qwen3_next

* assert async

* tensorize (mm) grounding dino mask generation

* style

* fix repo

* address comments

* fix qwen2 audio and vits checks

* skip two models using kernels by default

* skip granite moe hybrid using custom kernels

* disable mamba kernels

* vits splinter and videomae
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants