Skip to content

Fix notebook compatibility for transformers 4.57.6 and TRL 0.22-0.27#3998

Merged
danielhanchen merged 12 commits into
mainfrom
fix/transformers-4.57-notebook-compat
Feb 9, 2026
Merged

Fix notebook compatibility for transformers 4.57.6 and TRL 0.22-0.27#3998
danielhanchen merged 12 commits into
mainfrom
fix/transformers-4.57-notebook-compat

Conversation

@danielhanchen

Copy link
Copy Markdown
Member

Summary

Fixes several notebook failures discovered during testing all 125 notebooks with transformers==4.57.6 + TRL 0.22.2 and TRL 0.27.1 on 8x B200 GPUs.

Before: 86/125 pass (TRL 0.22.2), 76/125 pass (TRL 0.27.1)
After: 92/125 pass (TRL 0.22.2), 94/125 pass (TRL 0.27.1)
Remaining failures are infra issues (missing FFmpeg, HF Hub timeouts, GPU arch, external services) -- not code bugs.

Changes

  • Warning suppression (import_fixes.py): Suppress torch 2.9+ pin_memory deprecation, cuda.cudart/nvrtc FutureWarning, vllm "Level is deprecated" stderr, PydanticSerializationUnexpectedValue, Triton "df: No such file" stderr
  • VLM tokenizer loading (vision.py): Add _construct_vlm_processor_fallback() for models where AutoProcessor.from_pretrained fails (ERNIE 4.5 VL, LFM2.5-VL). Wrap processor loading in try/except with fallback to manual construction from separate image_processor + tokenizer. Add AutoTokenizer/PreTrainedTokenizerFast last-resort fallback.
  • TRL 0.27.1 trainer compat (trainer.py): Add _resolve_trainer_params() to handle thin wrapper trainers with def __init__(self, *args, **kwargs) (ORPOTrainer in TRL 0.27.1) by walking MRO for real parameter signature.
  • VLM _is_vlm detection (rl.py): Replace blanket _is_vlm=False override with model-architecture-based detection that checks vision_config or ForConditionalGeneration class name, fixing VLM training when bare tokenizer is passed.
  • ModernBERT SDPA compat (loader.py, sentence_transformer.py): Add "modernbert" to DISABLE_SDPA_MODEL_NAMES to avoid stride alignment issues with torch.compile backward pass.
  • VLM weight warnings (_utils.py): Suppress false uninitialized weight warnings for multi_modal_projector.layer_norm.

Companion PR

  • unslothai/unsloth-zoo: tokenizer None guard, ModernBERT attention mask fix, gpt_oss ParamWrapper unwrap

Test Results

Config Pass Fail Notes
transformers 4.57.6 + TRL 0.22.2 92/125 26+4 killed 22 infra, 3 code (pre-existing), 1 notebook
transformers 4.57.6 + TRL 0.27.1 94/125 30+1 killed 16 infra, 8 killed/flaky, 3 notebook, 2 CUDA errors, 1 SIGBUS

Fixes confirmed working:

  • ERNIE 4.5 VL: tokenizer loading fallback (100 steps, loss 0.091)
  • Paddle_OCR_1B_Vision: _is_vlm override (100 steps, loss 0.16)
  • Pixtral_12B_Vision: _is_vlm override (smoke test pass)
  • Llama3_8B_ORPO: thin wrapper MRO walk (100 steps, loss 2.30)
  • ModernBERT: forced eager attention (sentence transformer training pass)
  • gpt_oss 20B variants: ParamWrapper unwrap (GRPO, RL, fine-tuning all pass)

Datta0 and others added 2 commits February 6, 2026 15:13
Fixes several notebook failures discovered during testing all 125
notebooks with transformers==4.57.6 + tRL 0.22.2 and TRL 0.27.1.

Warning suppression (import_fixes.py):
- Suppress torch 2.9+ pin_memory/is_pinned device deprecation warnings
- Suppress cuda.cudart/cuda.nvrtc module deprecation FutureWarning
- Filter vllm "Level is deprecated" stderr noise
- Filter PydanticSerializationUnexpectedValue warnings
- Filter Triton "df: No such file" stderr noise

VLM tokenizer loading (vision.py):
- Add _construct_vlm_processor_fallback() for models where
  AutoProcessor.from_pretrained fails (e.g., ERNIE 4.5 VL, LFM2.5-VL)
- Wrap processor loading in try/except with fallback to manual
  construction from separate image_processor + tokenizer components
- Add fallback to AutoTokenizer/PreTrainedTokenizerFast when tokenizer
  loading or patching fails

TRL 0.27.1 trainer compatibility (trainer.py):
- Add _resolve_trainer_params() to handle thin wrapper trainers that
  only have def __init__(self, *args, **kwargs) (e.g., ORPOTrainer
  in TRL 0.27.1) by walking MRO for real parameter signature

VLM _is_vlm detection (rl.py):
- Replace blanket _is_vlm=False override with model-architecture-based
  detection that checks vision_config or ForConditionalGeneration class
  name, fixing VLM training when bare tokenizer is passed as
  processing_class

ModernBERT SDPA compatibility (loader.py, sentence_transformer.py):
- Add "modernbert" to DISABLE_SDPA_MODEL_NAMES to avoid stride
  alignment issues with torch.compile backward pass
- Add DISABLE_SDPA check for sentence transformer models

Other fixes (_utils.py):
- Suppress false uninitialized weight warnings for VLM
  multi_modal_projector.layer_norm

Tested: 92/125 notebooks pass with TRL 0.22.2, 94/125 with TRL 0.27.1.
Remaining failures are infra (missing FFmpeg, network timeouts, GPU
arch) not code bugs.
@gemini-code-assist

Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @danielhanchen, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses critical compatibility issues arising from updates in transformers and TRL libraries, specifically targeting versions 4.57.6 and 0.22-0.27 respectively. The changes aim to stabilize notebook execution by implementing robust error handling for VLM tokenizer loading, adapting to new TRL trainer initialization patterns, and resolving specific model-related issues like SDPA compatibility for ModernBERT. These improvements collectively enhance the reliability and user experience when working with Unsloth notebooks across different library versions.

Highlights

  • Enhanced Compatibility: Significantly improved notebook pass rates for transformers==4.57.6 and TRL versions 0.22-0.27.
  • Robust VLM Tokenizer Loading: Implemented fallback mechanisms for VLM models where AutoProcessor.from_pretrained fails, ensuring successful tokenizer and processor initialization.
  • TRL Trainer Adaptability: Introduced logic to correctly resolve __init__ parameters for TRL trainers, including thin wrapper classes, improving compatibility with newer TRL versions.
  • Refined VLM Detection: Updated _is_vlm detection in rl.py to be architecture-based, preventing incorrect VLM identification when bare tokenizers are used.
  • SDPA Compatibility for ModernBERT: Added ModernBERT to the list of models disabling SDPA to resolve stride alignment issues with torch.compile backward pass.
  • Warning and Error Suppression: Suppressed various non-critical warnings and stderr messages from PyTorch, vLLM, Pydantic, and Triton to clean up console output.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • unsloth/import_fixes.py
    • Added warning filters for PyTorch pin_memory/is_pinned device argument deprecation.
    • Added warning filters for PyTorch cuda.cudart/nvrtc module deprecation.
    • Added sys.stderr filters to suppress "Level is deprecated" from vLLM and "df: No such file" from Triton.
    • Added warning filters for PydanticSerializationUnexpectedValue.
  • unsloth/models/_utils.py
    • Extended the list of ignored uninitialized weights to include multi_modal_projector.layer_norm to prevent false warnings.
  • unsloth/models/loader.py
    • Included "modernbert" in DISABLE_SDPA_MODEL_NAMES to mitigate stride alignment issues with torch.compile backward pass.
  • unsloth/models/rl.py
    • Replaced the hardcoded _is_vlm=False override with a dynamic, model-architecture-based detection mechanism for VLM models, especially when a bare tokenizer is provided.
  • unsloth/models/sentence_transformer.py
    • Integrated DISABLE_SDPA_MODEL_NAMES check to force attn_implementation="eager" for specific models (like ModernBERT) to avoid SDPA-related issues.
  • unsloth/models/vision.py
    • Introduced _construct_vlm_processor_fallback to manually build VLM processors from AutoImageProcessor and PreTrainedTokenizerFast when AutoProcessor.from_pretrained fails.
    • Implemented try-except blocks around auto_processor.from_pretrained calls, utilizing the new fallback function.
    • Added a final fallback to AutoTokenizer or PreTrainedTokenizerFast if all other processor loading attempts fail.
  • unsloth/trainer.py
    • Created _resolve_trainer_params to intelligently determine the actual __init__ parameters of TRL trainer classes, traversing the Method Resolution Order (MRO) for thin wrapper trainers.
Activity
  • Initial testing revealed 86/125 notebooks passing with TRL 0.22.2 and 76/125 with TRL 0.27.1.
  • After applying the fixes, the pass rate improved to 92/125 for TRL 0.22.2 and 94/125 for TRL 0.27.1.
  • Remaining failures are attributed to infrastructure issues (e.g., missing FFmpeg, HF Hub timeouts, GPU architecture, external services) rather than code bugs.
  • Confirmed fixes for specific models include ERNIE 4.5 VL (tokenizer loading), Paddle_OCR_1B_Vision and Pixtral_12B_Vision (_is_vlm override), Llama3_8B_ORPO (trainer compatibility), ModernBERT (SDPA), and gpt_oss 20B variants (ParamWrapper unwrap).
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces several fixes to improve compatibility with newer versions of transformers and trl, particularly for notebook environments. The changes include suppressing new warnings, adding robust fallbacks for VLM tokenizer loading, and improving trainer compatibility. My review focuses on enhancing the maintainability and robustness of these new additions. I've suggested combining some repetitive warning filters and improving error handling in the new fallback mechanisms to make the code cleaner and easier to debug.

Comment thread unsloth/import_fixes.py
Comment on lines +204 to +207
warnings.filterwarnings(
"ignore",
message = r"Expected.*but got.*with value.*is not.*subclass",
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The regex r"Expected.*but got.*with value.*is not.*subclass" is very broad and could suppress useful warnings about actual type mismatches in other parts of the codebase. It's generally safer to be more specific when suppressing warnings. If this is intended to catch a specific Pydantic warning, consider anchoring the regex to that warning's message more tightly. If not, it might be better to remove this filter to avoid unintentionally hiding bugs.

Comment thread unsloth/import_fixes.py
Comment on lines +168 to +182
warnings.filterwarnings(
"ignore",
message = r"The `device` argument is deprecated",
category = DeprecationWarning,
)
warnings.filterwarnings(
"ignore",
message = r".*pin_memory.*device.*deprecated",
category = DeprecationWarning,
)
warnings.filterwarnings(
"ignore",
message = r".*is_pinned.*device.*deprecated",
category = DeprecationWarning,
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The three warnings.filterwarnings calls for related torch deprecation warnings can be combined into a single call with a more general regex. This improves conciseness and maintainability.

Consider replacing this block with:

# torch 2.9+ pin_memory/is_pinned device arg deprecation
warnings.filterwarnings(
    "ignore",
    message=r".*(\`device\` argument|pin_memory.*device|is_pinned.*device).*deprecated",
    category=DeprecationWarning,
)

Comment thread unsloth/import_fixes.py Outdated
Comment on lines +185 to +194
warnings.filterwarnings(
"ignore",
message = r".*cuda\.cudart.*deprecated",
category = FutureWarning,
)
warnings.filterwarnings(
"ignore",
message = r".*cuda\.nvrtc.*deprecated",
category = FutureWarning,
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

These two warnings.filterwarnings calls for related CUDA deprecation warnings can be combined into a single call with a more general regex. This improves conciseness and maintainability.

Consider replacing this block with:

# cuda.cudart / cuda.nvrtc module deprecation (FutureWarning)
warnings.filterwarnings(
    "ignore",
    message=r".*cuda\.(cudart|nvrtc).*deprecated",
    category=FutureWarning,
)

Comment thread unsloth/models/vision.py
Comment on lines +355 to +356
tok_config = json.load(f)
# Set model-specific special tokens and their IDs

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The broad except Exception: pass can hide important errors during tokenizer config loading. For example, network issues or unexpected file formats would be silently ignored. It would be more robust to catch specific expected exceptions (like EntryNotFoundError from huggingface_hub if the config is optional) and log other exceptions for debugging purposes. This helps in diagnosing issues with new or custom models.

Comment thread unsloth/models/vision.py
Comment on lines +368 to +369
if not hasattr(tok, id_key):
setattr(tok, id_key, token_id)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the previous try-except block, this broad except Exception: pass can hide underlying issues when trying to load the model's config. It's better to catch specific exceptions and log unexpected ones to aid in debugging.

Comment thread unsloth/models/vision.py
Comment on lines +981 to 1004
model, tokenizer = patch_tokenizer(model, tokenizer)
except Exception as _patch_err:
# Some VLM processors (e.g., ERNIE VL) may fail during tokenizer patching.
# Try loading tokenizer separately via AutoTokenizer as fallback.
try:
from transformers import AutoTokenizer as _AutoTokenizer

_fallback_tok = _AutoTokenizer.from_pretrained(
tokenizer_name,
padding_side = "left",
token = token,
trust_remote_code = trust_remote_code,
)
model, _fallback_tok = patch_tokenizer(model, _fallback_tok)
# Re-attach as processor wrapper if original was a processor
if hasattr(tokenizer, "image_processor"):
tokenizer.tokenizer = _fallback_tok
else:
tokenizer = _fallback_tok
except Exception:
# If fallback also fails, raise the original error
raise _patch_err
model = post_patch_loss_function(model)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This nested try-except block for last-resort tokenizer loading could be refactored into a separate helper function to improve readability and make the main from_pretrained function easier to follow.

For example:

def _load_tokenizer_last_resort(tokenizer_name, token, trust_remote_code):
    try:
        from transformers import AutoTokenizer
        return AutoTokenizer.from_pretrained(
            tokenizer_name,
            padding_side="left",
            token=token,
            trust_remote_code=trust_remote_code,
        )
    except Exception:
        try:
            from transformers import PreTrainedTokenizerFast
            return PreTrainedTokenizerFast.from_pretrained(
                tokenizer_name,
                padding_side="left",
                token=token,
                trust_remote_code=trust_remote_code,
            )
        except Exception:
            return None

# In from_pretrained:
if tokenizer is None:
    tokenizer = _load_tokenizer_last_resort(tokenizer_name, token, trust_remote_code)
    if tokenizer is None:
        del model
        raise RuntimeError(
            "Unsloth: The tokenizer is weirdly not loaded? Please check if there is one."
        )

danielhanchen and others added 3 commits February 8, 2026 09:39
- Patch KTO get_batch_logps to auto-align logits and labels when Unsloth
  model forward truncates input_ids beyond max_seq_length. TRL 0.27.2
  changed _process_tokens to only truncate completions (not prompts), so
  sequences with long prompts exceed max_seq_length and trigger model-side
  truncation. The original ValueError is replaced with min-length alignment.

- Also truncate attention_mask in LlamaModel forward when input_ids are
  truncated to max_seq_length, preventing shape mismatches in attention.

- Widen except clause in rl_replacements.py openenv import from
  `except ImportError` to `except (ImportError, NameError, Exception)` to
  handle vllm SamplingParams NameError in TRL 0.27.2.
…up warning filters

TRL 0.26+ thin wrapper resolution (rl.py):
- Filter _-prefixed private imports when discovering Trainer/Config classes
- Look up Config in separate *_config.py module when not found in trainer module
- Detect thin wrappers (<1000 chars source) and resolve to experimental parent
  via MRO walk; use resolved module for imports and create_new_function
- Enables all 15 trainers to patch successfully (was 5/15 before)

ModernBERT SDPA (loader.py):
- Remove "modernbert" from DISABLE_SDPA_MODEL_NAMES
- SDPA works correctly for both classification and sentence transformers
- Verified: 88.9% accuracy on emotion classification, correct domain-specific
  embeddings after sentence transformer fine-tuning

Warning filter cleanup (import_fixes.py):
- Remove cuda.cudart/cuda.nvrtc FutureWarning filters (no such warnings
  exist in torch 2.9.1+; proactive suppression is unnecessary)
@danielhanchen

Copy link
Copy Markdown
Member Author

Additional fixes pushed

Fix 1: TRL 0.26+ thin wrapper resolution (rl.py)

  • Filter _-prefixed private imports when discovering Trainer/Config classes
  • Look up Config in separate *_config.py module when not found in trainer module
  • Detect thin wrappers (<1000 chars source) and resolve to experimental parent via MRO walk
  • Result: All 15 trainers patch successfully (was 5/15 before this fix)

Fix 2: Enable ModernBERT SDPA (loader.py)

  • Remove "modernbert" from DISABLE_SDPA_MODEL_NAMES
  • Verified: 88.9% accuracy on emotion classification, correct domain-specific embeddings

Fix 3: Remove unnecessary warning filters (import_fixes.py)

  • Remove cuda.cudart/cuda.nvrtc FutureWarning filters (no such warnings exist in torch 2.9.1+)

Fix 4: KTO shape mismatch on TRL 0.27.2+ (rl_replacements.py, llama.py)

  • TRL 0.27.2 changed _process_tokens to only truncate completions, not prompts. Long prompts exceed max_seq_length and trigger Unsloth's model-side input_ids truncation, but labels are left at original length, causing shape mismatch in get_batch_logps.
  • Patch get_batch_logps to auto-align logits and labels to shorter sequence length instead of raising ValueError.
  • Also truncate attention_mask in LlamaModel_fast_forward when input_ids are truncated.

Fix 5: Widen openenv import except clause (rl_replacements.py)

  • trl.experimental.openenv.utils references vllm's SamplingParams at module level, causing NameError when vllm is not installed. Changed except ImportError to except (ImportError, NameError, Exception).

Verification: TRL version matrix -- 12/12 PASS

TRL Transformers ORPO DPO KTO
0.22.2 4.56.2 PASS PASS PASS
0.25.1 4.56.2 PASS PASS PASS
0.26.2 4.57.6 PASS PASS PASS
0.27.2 4.57.6 PASS PASS PASS

Full notebook runs also pass: ORPO (Llama-3-8B), DPO (Zephyr-7B), KTO (Qwen2.5-1.5B), ModernBERT classification (88.9% accuracy), ModernBERT sentence transformer (correct embeddings).

danielhanchen and others added 5 commits February 8, 2026 12:06
The LFM2.5-VL projector LayerNorm is properly initialized by
transformers and does not need to be excluded from the uninitialized
weight check. The original exclusion was added as a workaround but is
no longer needed after the upstream fix.
…on, BatchEncoding guard, try/except for TRL trainer source, push_to_hub_token compiler fix

- llama.py: Add _get_rope_theta() helper handling both config.rope_theta and rope_parameters dict
- llama.py: Handle BatchEncoding in unsloth_fast_generate (transformers 5.0+ returns BatchEncoding from apply_chat_template)
- gemma.py: Detect config passed as dim arg in GemmaFixedRotaryEmbedding
- tokenizer_utils.py: Add try/except for TRL trainer getsource in patch_sft_trainer_tokenizer
- rl_replacements.py: Add compiler fix replacing bare pop("push_to_hub_token") with pop(..., None)
… thin wrapper detection

The <1000 / >1000 char threshold was fragile -- XPOConfig's parent is only
994 chars and would be skipped. All thin wrappers in TRL 0.26+ contain
"trl.experimental" in their deprecation warning, while no real trainer or
config class does, making it a reliable detection marker.
@kabachuha

Copy link
Copy Markdown

Do you know if the Gemma3_(27B)_A100-Conversational.ipynb notebook works without errors after these fixes?

…sformer

The function-level import was redundant since loader.py is already imported
at module level. Move it to the existing loader import line.
@danielhanchen

Copy link
Copy Markdown
Member Author

@kabachuha Yes hopefully - I can re-check

@danielhanchen danielhanchen merged commit e1c0eda into main Feb 9, 2026
4 checks passed
@danielhanchen danielhanchen deleted the fix/transformers-4.57-notebook-compat branch February 9, 2026 13:11
abiswas-realadvice pushed a commit to abiswas-realadvice/unsloth that referenced this pull request May 14, 2026
…nslothai#3998)

* Patch before compile?

* Fix notebook compatibility for transformers 4.57.6 and TRL 0.22-0.27

Fixes several notebook failures discovered during testing all 125
notebooks with transformers==4.57.6 + tRL 0.22.2 and TRL 0.27.1.

Warning suppression (import_fixes.py):
- Suppress torch 2.9+ pin_memory/is_pinned device deprecation warnings
- Suppress cuda.cudart/cuda.nvrtc module deprecation FutureWarning
- Filter vllm "Level is deprecated" stderr noise
- Filter PydanticSerializationUnexpectedValue warnings
- Filter Triton "df: No such file" stderr noise

VLM tokenizer loading (vision.py):
- Add _construct_vlm_processor_fallback() for models where
  AutoProcessor.from_pretrained fails (e.g., ERNIE 4.5 VL, LFM2.5-VL)
- Wrap processor loading in try/except with fallback to manual
  construction from separate image_processor + tokenizer components
- Add fallback to AutoTokenizer/PreTrainedTokenizerFast when tokenizer
  loading or patching fails

TRL 0.27.1 trainer compatibility (trainer.py):
- Add _resolve_trainer_params() to handle thin wrapper trainers that
  only have def __init__(self, *args, **kwargs) (e.g., ORPOTrainer
  in TRL 0.27.1) by walking MRO for real parameter signature

VLM _is_vlm detection (rl.py):
- Replace blanket _is_vlm=False override with model-architecture-based
  detection that checks vision_config or ForConditionalGeneration class
  name, fixing VLM training when bare tokenizer is passed as
  processing_class

ModernBERT SDPA compatibility (loader.py, sentence_transformer.py):
- Add "modernbert" to DISABLE_SDPA_MODEL_NAMES to avoid stride
  alignment issues with torch.compile backward pass
- Add DISABLE_SDPA check for sentence transformer models

Other fixes (_utils.py):
- Suppress false uninitialized weight warnings for VLM
  multi_modal_projector.layer_norm

Tested: 92/125 notebooks pass with TRL 0.22.2, 94/125 with TRL 0.27.1.
Remaining failures are infra (missing FFmpeg, network timeouts, GPU
arch) not code bugs.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix KTO shape mismatch on TRL 0.27.2+ and truncation alignment

- Patch KTO get_batch_logps to auto-align logits and labels when Unsloth
  model forward truncates input_ids beyond max_seq_length. TRL 0.27.2
  changed _process_tokens to only truncate completions (not prompts), so
  sequences with long prompts exceed max_seq_length and trigger model-side
  truncation. The original ValueError is replaced with min-length alignment.

- Also truncate attention_mask in LlamaModel forward when input_ids are
  truncated to max_seq_length, preventing shape mismatches in attention.

- Widen except clause in rl_replacements.py openenv import from
  `except ImportError` to `except (ImportError, NameError, Exception)` to
  handle vllm SamplingParams NameError in TRL 0.27.2.

* Fix TRL 0.26+ thin wrapper resolution, enable ModernBERT SDPA, clean up warning filters

TRL 0.26+ thin wrapper resolution (rl.py):
- Filter _-prefixed private imports when discovering Trainer/Config classes
- Look up Config in separate *_config.py module when not found in trainer module
- Detect thin wrappers (<1000 chars source) and resolve to experimental parent
  via MRO walk; use resolved module for imports and create_new_function
- Enables all 15 trainers to patch successfully (was 5/15 before)

ModernBERT SDPA (loader.py):
- Remove "modernbert" from DISABLE_SDPA_MODEL_NAMES
- SDPA works correctly for both classification and sentence transformers
- Verified: 88.9% accuracy on emotion classification, correct domain-specific
  embeddings after sentence transformer fine-tuning

Warning filter cleanup (import_fixes.py):
- Remove cuda.cudart/cuda.nvrtc FutureWarning filters (no such warnings
  exist in torch 2.9.1+; proactive suppression is unnecessary)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Remove multi_modal_projector.layer_norm from uninitialized weight guard

The LFM2.5-VL projector LayerNorm is properly initialized by
transformers and does not need to be excluded from the uninitialized
weight check. The original exclusion was added as a workaround but is
no longer needed after the upstream fix.

* Add transformers 5.0 compat: rope_theta helper, config-as-dim detection, BatchEncoding guard, try/except for TRL trainer source, push_to_hub_token compiler fix

- llama.py: Add _get_rope_theta() helper handling both config.rope_theta and rope_parameters dict
- llama.py: Handle BatchEncoding in unsloth_fast_generate (transformers 5.0+ returns BatchEncoding from apply_chat_template)
- gemma.py: Detect config passed as dim arg in GemmaFixedRotaryEmbedding
- tokenizer_utils.py: Add try/except for TRL trainer getsource in patch_sft_trainer_tokenizer
- rl_replacements.py: Add compiler fix replacing bare pop("push_to_hub_token") with pop(..., None)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Use trl.experimental string check instead of char-count heuristic for thin wrapper detection

The <1000 / >1000 char threshold was fragile -- XPOConfig's parent is only
994 chars and would be skipped. All thin wrappers in TRL 0.26+ contain
"trl.experimental" in their deprecation warning, while no real trainer or
config class does, making it a reliable detection marker.

* Move DISABLE_SDPA_MODEL_NAMES import to module level in sentence_transformer

The function-level import was redundant since loader.py is already imported
at module level. Move it to the existing loader import line.

---------

Co-authored-by: Datta Nimmaturi <venkatadattasainimmaturi@gmail.com>
Co-authored-by: Daniel Hanchen <danielhanchen@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants