Fix notebook compatibility for transformers 4.57.6 and TRL 0.22-0.27#3998
Conversation
Fixes several notebook failures discovered during testing all 125 notebooks with transformers==4.57.6 + tRL 0.22.2 and TRL 0.27.1. Warning suppression (import_fixes.py): - Suppress torch 2.9+ pin_memory/is_pinned device deprecation warnings - Suppress cuda.cudart/cuda.nvrtc module deprecation FutureWarning - Filter vllm "Level is deprecated" stderr noise - Filter PydanticSerializationUnexpectedValue warnings - Filter Triton "df: No such file" stderr noise VLM tokenizer loading (vision.py): - Add _construct_vlm_processor_fallback() for models where AutoProcessor.from_pretrained fails (e.g., ERNIE 4.5 VL, LFM2.5-VL) - Wrap processor loading in try/except with fallback to manual construction from separate image_processor + tokenizer components - Add fallback to AutoTokenizer/PreTrainedTokenizerFast when tokenizer loading or patching fails TRL 0.27.1 trainer compatibility (trainer.py): - Add _resolve_trainer_params() to handle thin wrapper trainers that only have def __init__(self, *args, **kwargs) (e.g., ORPOTrainer in TRL 0.27.1) by walking MRO for real parameter signature VLM _is_vlm detection (rl.py): - Replace blanket _is_vlm=False override with model-architecture-based detection that checks vision_config or ForConditionalGeneration class name, fixing VLM training when bare tokenizer is passed as processing_class ModernBERT SDPA compatibility (loader.py, sentence_transformer.py): - Add "modernbert" to DISABLE_SDPA_MODEL_NAMES to avoid stride alignment issues with torch.compile backward pass - Add DISABLE_SDPA check for sentence transformer models Other fixes (_utils.py): - Suppress false uninitialized weight warnings for VLM multi_modal_projector.layer_norm Tested: 92/125 notebooks pass with TRL 0.22.2, 94/125 with TRL 0.27.1. Remaining failures are infra (missing FFmpeg, network timeouts, GPU arch) not code bugs.
for more information, see https://pre-commit.ci
Summary of ChangesHello @danielhanchen, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses critical compatibility issues arising from updates in Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces several fixes to improve compatibility with newer versions of transformers and trl, particularly for notebook environments. The changes include suppressing new warnings, adding robust fallbacks for VLM tokenizer loading, and improving trainer compatibility. My review focuses on enhancing the maintainability and robustness of these new additions. I've suggested combining some repetitive warning filters and improving error handling in the new fallback mechanisms to make the code cleaner and easier to debug.
| warnings.filterwarnings( | ||
| "ignore", | ||
| message = r"Expected.*but got.*with value.*is not.*subclass", | ||
| ) |
There was a problem hiding this comment.
The regex r"Expected.*but got.*with value.*is not.*subclass" is very broad and could suppress useful warnings about actual type mismatches in other parts of the codebase. It's generally safer to be more specific when suppressing warnings. If this is intended to catch a specific Pydantic warning, consider anchoring the regex to that warning's message more tightly. If not, it might be better to remove this filter to avoid unintentionally hiding bugs.
| warnings.filterwarnings( | ||
| "ignore", | ||
| message = r"The `device` argument is deprecated", | ||
| category = DeprecationWarning, | ||
| ) | ||
| warnings.filterwarnings( | ||
| "ignore", | ||
| message = r".*pin_memory.*device.*deprecated", | ||
| category = DeprecationWarning, | ||
| ) | ||
| warnings.filterwarnings( | ||
| "ignore", | ||
| message = r".*is_pinned.*device.*deprecated", | ||
| category = DeprecationWarning, | ||
| ) |
There was a problem hiding this comment.
The three warnings.filterwarnings calls for related torch deprecation warnings can be combined into a single call with a more general regex. This improves conciseness and maintainability.
Consider replacing this block with:
# torch 2.9+ pin_memory/is_pinned device arg deprecation
warnings.filterwarnings(
"ignore",
message=r".*(\`device\` argument|pin_memory.*device|is_pinned.*device).*deprecated",
category=DeprecationWarning,
)| warnings.filterwarnings( | ||
| "ignore", | ||
| message = r".*cuda\.cudart.*deprecated", | ||
| category = FutureWarning, | ||
| ) | ||
| warnings.filterwarnings( | ||
| "ignore", | ||
| message = r".*cuda\.nvrtc.*deprecated", | ||
| category = FutureWarning, | ||
| ) |
There was a problem hiding this comment.
These two warnings.filterwarnings calls for related CUDA deprecation warnings can be combined into a single call with a more general regex. This improves conciseness and maintainability.
Consider replacing this block with:
# cuda.cudart / cuda.nvrtc module deprecation (FutureWarning)
warnings.filterwarnings(
"ignore",
message=r".*cuda\.(cudart|nvrtc).*deprecated",
category=FutureWarning,
)| tok_config = json.load(f) | ||
| # Set model-specific special tokens and their IDs |
There was a problem hiding this comment.
The broad except Exception: pass can hide important errors during tokenizer config loading. For example, network issues or unexpected file formats would be silently ignored. It would be more robust to catch specific expected exceptions (like EntryNotFoundError from huggingface_hub if the config is optional) and log other exceptions for debugging purposes. This helps in diagnosing issues with new or custom models.
| if not hasattr(tok, id_key): | ||
| setattr(tok, id_key, token_id) |
| model, tokenizer = patch_tokenizer(model, tokenizer) | ||
| except Exception as _patch_err: | ||
| # Some VLM processors (e.g., ERNIE VL) may fail during tokenizer patching. | ||
| # Try loading tokenizer separately via AutoTokenizer as fallback. | ||
| try: | ||
| from transformers import AutoTokenizer as _AutoTokenizer | ||
|
|
||
| _fallback_tok = _AutoTokenizer.from_pretrained( | ||
| tokenizer_name, | ||
| padding_side = "left", | ||
| token = token, | ||
| trust_remote_code = trust_remote_code, | ||
| ) | ||
| model, _fallback_tok = patch_tokenizer(model, _fallback_tok) | ||
| # Re-attach as processor wrapper if original was a processor | ||
| if hasattr(tokenizer, "image_processor"): | ||
| tokenizer.tokenizer = _fallback_tok | ||
| else: | ||
| tokenizer = _fallback_tok | ||
| except Exception: | ||
| # If fallback also fails, raise the original error | ||
| raise _patch_err | ||
| model = post_patch_loss_function(model) | ||
|
|
There was a problem hiding this comment.
This nested try-except block for last-resort tokenizer loading could be refactored into a separate helper function to improve readability and make the main from_pretrained function easier to follow.
For example:
def _load_tokenizer_last_resort(tokenizer_name, token, trust_remote_code):
try:
from transformers import AutoTokenizer
return AutoTokenizer.from_pretrained(
tokenizer_name,
padding_side="left",
token=token,
trust_remote_code=trust_remote_code,
)
except Exception:
try:
from transformers import PreTrainedTokenizerFast
return PreTrainedTokenizerFast.from_pretrained(
tokenizer_name,
padding_side="left",
token=token,
trust_remote_code=trust_remote_code,
)
except Exception:
return None
# In from_pretrained:
if tokenizer is None:
tokenizer = _load_tokenizer_last_resort(tokenizer_name, token, trust_remote_code)
if tokenizer is None:
del model
raise RuntimeError(
"Unsloth: The tokenizer is weirdly not loaded? Please check if there is one."
)- Patch KTO get_batch_logps to auto-align logits and labels when Unsloth model forward truncates input_ids beyond max_seq_length. TRL 0.27.2 changed _process_tokens to only truncate completions (not prompts), so sequences with long prompts exceed max_seq_length and trigger model-side truncation. The original ValueError is replaced with min-length alignment. - Also truncate attention_mask in LlamaModel forward when input_ids are truncated to max_seq_length, preventing shape mismatches in attention. - Widen except clause in rl_replacements.py openenv import from `except ImportError` to `except (ImportError, NameError, Exception)` to handle vllm SamplingParams NameError in TRL 0.27.2.
…up warning filters TRL 0.26+ thin wrapper resolution (rl.py): - Filter _-prefixed private imports when discovering Trainer/Config classes - Look up Config in separate *_config.py module when not found in trainer module - Detect thin wrappers (<1000 chars source) and resolve to experimental parent via MRO walk; use resolved module for imports and create_new_function - Enables all 15 trainers to patch successfully (was 5/15 before) ModernBERT SDPA (loader.py): - Remove "modernbert" from DISABLE_SDPA_MODEL_NAMES - SDPA works correctly for both classification and sentence transformers - Verified: 88.9% accuracy on emotion classification, correct domain-specific embeddings after sentence transformer fine-tuning Warning filter cleanup (import_fixes.py): - Remove cuda.cudart/cuda.nvrtc FutureWarning filters (no such warnings exist in torch 2.9.1+; proactive suppression is unnecessary)
for more information, see https://pre-commit.ci
Additional fixes pushedFix 1: TRL 0.26+ thin wrapper resolution (
|
| TRL | Transformers | ORPO | DPO | KTO |
|---|---|---|---|---|
| 0.22.2 | 4.56.2 | PASS | PASS | PASS |
| 0.25.1 | 4.56.2 | PASS | PASS | PASS |
| 0.26.2 | 4.57.6 | PASS | PASS | PASS |
| 0.27.2 | 4.57.6 | PASS | PASS | PASS |
Full notebook runs also pass: ORPO (Llama-3-8B), DPO (Zephyr-7B), KTO (Qwen2.5-1.5B), ModernBERT classification (88.9% accuracy), ModernBERT sentence transformer (correct embeddings).
The LFM2.5-VL projector LayerNorm is properly initialized by transformers and does not need to be excluded from the uninitialized weight check. The original exclusion was added as a workaround but is no longer needed after the upstream fix.
…rs-4.57-notebook-compat
…on, BatchEncoding guard, try/except for TRL trainer source, push_to_hub_token compiler fix
- llama.py: Add _get_rope_theta() helper handling both config.rope_theta and rope_parameters dict
- llama.py: Handle BatchEncoding in unsloth_fast_generate (transformers 5.0+ returns BatchEncoding from apply_chat_template)
- gemma.py: Detect config passed as dim arg in GemmaFixedRotaryEmbedding
- tokenizer_utils.py: Add try/except for TRL trainer getsource in patch_sft_trainer_tokenizer
- rl_replacements.py: Add compiler fix replacing bare pop("push_to_hub_token") with pop(..., None)
for more information, see https://pre-commit.ci
… thin wrapper detection The <1000 / >1000 char threshold was fragile -- XPOConfig's parent is only 994 chars and would be skipped. All thin wrappers in TRL 0.26+ contain "trl.experimental" in their deprecation warning, while no real trainer or config class does, making it a reliable detection marker.
|
Do you know if the Gemma3_(27B)_A100-Conversational.ipynb notebook works without errors after these fixes? |
…sformer The function-level import was redundant since loader.py is already imported at module level. Move it to the existing loader import line.
|
@kabachuha Yes hopefully - I can re-check |
…nslothai#3998) * Patch before compile? * Fix notebook compatibility for transformers 4.57.6 and TRL 0.22-0.27 Fixes several notebook failures discovered during testing all 125 notebooks with transformers==4.57.6 + tRL 0.22.2 and TRL 0.27.1. Warning suppression (import_fixes.py): - Suppress torch 2.9+ pin_memory/is_pinned device deprecation warnings - Suppress cuda.cudart/cuda.nvrtc module deprecation FutureWarning - Filter vllm "Level is deprecated" stderr noise - Filter PydanticSerializationUnexpectedValue warnings - Filter Triton "df: No such file" stderr noise VLM tokenizer loading (vision.py): - Add _construct_vlm_processor_fallback() for models where AutoProcessor.from_pretrained fails (e.g., ERNIE 4.5 VL, LFM2.5-VL) - Wrap processor loading in try/except with fallback to manual construction from separate image_processor + tokenizer components - Add fallback to AutoTokenizer/PreTrainedTokenizerFast when tokenizer loading or patching fails TRL 0.27.1 trainer compatibility (trainer.py): - Add _resolve_trainer_params() to handle thin wrapper trainers that only have def __init__(self, *args, **kwargs) (e.g., ORPOTrainer in TRL 0.27.1) by walking MRO for real parameter signature VLM _is_vlm detection (rl.py): - Replace blanket _is_vlm=False override with model-architecture-based detection that checks vision_config or ForConditionalGeneration class name, fixing VLM training when bare tokenizer is passed as processing_class ModernBERT SDPA compatibility (loader.py, sentence_transformer.py): - Add "modernbert" to DISABLE_SDPA_MODEL_NAMES to avoid stride alignment issues with torch.compile backward pass - Add DISABLE_SDPA check for sentence transformer models Other fixes (_utils.py): - Suppress false uninitialized weight warnings for VLM multi_modal_projector.layer_norm Tested: 92/125 notebooks pass with TRL 0.22.2, 94/125 with TRL 0.27.1. Remaining failures are infra (missing FFmpeg, network timeouts, GPU arch) not code bugs. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix KTO shape mismatch on TRL 0.27.2+ and truncation alignment - Patch KTO get_batch_logps to auto-align logits and labels when Unsloth model forward truncates input_ids beyond max_seq_length. TRL 0.27.2 changed _process_tokens to only truncate completions (not prompts), so sequences with long prompts exceed max_seq_length and trigger model-side truncation. The original ValueError is replaced with min-length alignment. - Also truncate attention_mask in LlamaModel forward when input_ids are truncated to max_seq_length, preventing shape mismatches in attention. - Widen except clause in rl_replacements.py openenv import from `except ImportError` to `except (ImportError, NameError, Exception)` to handle vllm SamplingParams NameError in TRL 0.27.2. * Fix TRL 0.26+ thin wrapper resolution, enable ModernBERT SDPA, clean up warning filters TRL 0.26+ thin wrapper resolution (rl.py): - Filter _-prefixed private imports when discovering Trainer/Config classes - Look up Config in separate *_config.py module when not found in trainer module - Detect thin wrappers (<1000 chars source) and resolve to experimental parent via MRO walk; use resolved module for imports and create_new_function - Enables all 15 trainers to patch successfully (was 5/15 before) ModernBERT SDPA (loader.py): - Remove "modernbert" from DISABLE_SDPA_MODEL_NAMES - SDPA works correctly for both classification and sentence transformers - Verified: 88.9% accuracy on emotion classification, correct domain-specific embeddings after sentence transformer fine-tuning Warning filter cleanup (import_fixes.py): - Remove cuda.cudart/cuda.nvrtc FutureWarning filters (no such warnings exist in torch 2.9.1+; proactive suppression is unnecessary) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove multi_modal_projector.layer_norm from uninitialized weight guard The LFM2.5-VL projector LayerNorm is properly initialized by transformers and does not need to be excluded from the uninitialized weight check. The original exclusion was added as a workaround but is no longer needed after the upstream fix. * Add transformers 5.0 compat: rope_theta helper, config-as-dim detection, BatchEncoding guard, try/except for TRL trainer source, push_to_hub_token compiler fix - llama.py: Add _get_rope_theta() helper handling both config.rope_theta and rope_parameters dict - llama.py: Handle BatchEncoding in unsloth_fast_generate (transformers 5.0+ returns BatchEncoding from apply_chat_template) - gemma.py: Detect config passed as dim arg in GemmaFixedRotaryEmbedding - tokenizer_utils.py: Add try/except for TRL trainer getsource in patch_sft_trainer_tokenizer - rl_replacements.py: Add compiler fix replacing bare pop("push_to_hub_token") with pop(..., None) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Use trl.experimental string check instead of char-count heuristic for thin wrapper detection The <1000 / >1000 char threshold was fragile -- XPOConfig's parent is only 994 chars and would be skipped. All thin wrappers in TRL 0.26+ contain "trl.experimental" in their deprecation warning, while no real trainer or config class does, making it a reliable detection marker. * Move DISABLE_SDPA_MODEL_NAMES import to module level in sentence_transformer The function-level import was redundant since loader.py is already imported at module level. Move it to the existing loader import line. --------- Co-authored-by: Datta Nimmaturi <venkatadattasainimmaturi@gmail.com> Co-authored-by: Daniel Hanchen <danielhanchen@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Summary
Fixes several notebook failures discovered during testing all 125 notebooks with transformers==4.57.6 + TRL 0.22.2 and TRL 0.27.1 on 8x B200 GPUs.
Before: 86/125 pass (TRL 0.22.2), 76/125 pass (TRL 0.27.1)
After: 92/125 pass (TRL 0.22.2), 94/125 pass (TRL 0.27.1)
Remaining failures are infra issues (missing FFmpeg, HF Hub timeouts, GPU arch, external services) -- not code bugs.
Changes
import_fixes.py): Suppress torch 2.9+ pin_memory deprecation, cuda.cudart/nvrtc FutureWarning, vllm "Level is deprecated" stderr, PydanticSerializationUnexpectedValue, Triton "df: No such file" stderrvision.py): Add_construct_vlm_processor_fallback()for models whereAutoProcessor.from_pretrainedfails (ERNIE 4.5 VL, LFM2.5-VL). Wrap processor loading in try/except with fallback to manual construction from separate image_processor + tokenizer. Add AutoTokenizer/PreTrainedTokenizerFast last-resort fallback.trainer.py): Add_resolve_trainer_params()to handle thin wrapper trainers withdef __init__(self, *args, **kwargs)(ORPOTrainer in TRL 0.27.1) by walking MRO for real parameter signature.rl.py): Replace blanket_is_vlm=Falseoverride with model-architecture-based detection that checksvision_configorForConditionalGenerationclass name, fixing VLM training when bare tokenizer is passed.loader.py,sentence_transformer.py): Add "modernbert" toDISABLE_SDPA_MODEL_NAMESto avoid stride alignment issues with torch.compile backward pass._utils.py): Suppress false uninitialized weight warnings formulti_modal_projector.layer_norm.Companion PR
Test Results
Fixes confirmed working: