Skip to content

Fix auto padding free logic to respect user passed False#4128

Merged
danielhanchen merged 2 commits into
unslothai:mainfrom
mmathew23:fix/padding-free-false
Mar 2, 2026
Merged

Fix auto padding free logic to respect user passed False#4128
danielhanchen merged 2 commits into
unslothai:mainfrom
mmathew23:fix/padding-free-false

Conversation

@mmathew23

Copy link
Copy Markdown
Contributor

Qwen3-14B notebook currently OOM's on a T4 due to increase VRAM when padding free is turned on. SFTConfig currently defaults padding_free=False so the current logic can't differentiate between when it should autopad or the user specifically requested to turn off.

This PR patches SFTConfig to default padding_free to None. If it's None padding free will be auto-enabled (the default). If it's True it's enabled, and if it's False it's turned off.

Notebook before fix with padding_free=False
https://colab.research.google.com/drive/1u51CbHLntgBLUrWe4B4lZFRNtPGi1faG?usp=sharing

Working notebook after fix with padding_free=False
https://colab.research.google.com/drive/1GYBXNlm9yP8zP0XAT6LD0kmmL5NficK-?usp=sharing

@gemini-code-assist

Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses an Out-Of-Memory issue encountered when padding_free was explicitly set to False in the SFTConfig. The core problem stemmed from the inability to distinguish between a user intentionally disabling padding-free and the system's default behavior. By changing the default padding_free value to None, the system can now correctly interpret user intent, allowing auto-padding to be enabled only when not explicitly disabled, thereby preventing unexpected memory consumption.

Highlights

  • Padding-free logic fix: Resolved an Out-Of-Memory (OOM) issue occurring when padding_free was explicitly set to False in SFTConfig, which previously led to unintended auto-padding.
  • SFTConfig default change: Modified the default value of padding_free in SFTConfig from False to None to clearly differentiate between a user explicitly disabling it and the system's default behavior.
  • Auto-padding detection: Updated the internal logic to correctly interpret padding_free=None as a signal to auto-enable padding-free, while respecting True for explicit enablement and False for explicit disablement.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • unsloth/models/rl.py
    • Changed the default value of padding_free to None in the _patch_trl_rl_trainers function.
  • unsloth/trainer.py
    • Updated the _should_auto_padding_free function to check if padding_free is None for auto-enabling.
    • Added a comment in new_init to clarify the resolution logic for padding_free.
Activity
  • No specific activity (comments, reviews, etc.) has been recorded for this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly addresses an issue where explicitly setting padding_free=False was not respected, by changing the default value to None and updating the detection logic. The changes look good. I've added one suggestion to improve the robustness of the new check in _should_auto_padding_free.

Comment thread unsloth/trainer.py Outdated

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: e9981c3b5b

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread unsloth/trainer.py Outdated
):
return False
return not getattr(config, "padding_free", False)
return getattr(config, "padding_free", False) is None

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Treat missing padding_free as unset for auto mode

The new return getattr(config, "padding_free", False) is None check disables auto padding-free for configs that do not define padding_free at all, because the fallback now becomes False instead of “unset”. This regresses the backwards-compatibility path where callers pass plain transformers.TrainingArguments (handled in _backwards_compatible_trainer), since those configs typically lack this field and will no longer get the previous auto-enable behavior unless users manually inject padding_free.

Useful? React with 👍 / 👎.

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

@danielhanchen danielhanchen left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tested and verified. The fix correctly distinguishes between "user didn't set padding_free" (None, auto-enable) and "user explicitly set padding_free=False" (respect it).

Unit tests (7/7 passed): All 4 padding_free states (None, True, False, missing attribute) plus edge cases.

GPU benchmarks (Llama-3.2-1B-Instruct, 61 steps):

Scenario Peak Mem train_loss Stdout message
Baseline (main) 1.53 GB 1.3789 "Padding-free auto-enabled"
PR: default (None) 1.53 GB 1.3789 "Padding-free auto-enabled"
PR: explicit False 1.62 GB 1.3748 No padding-free message
PR: explicit True 1.53 GB 1.3789 "Padding-free enabled"
  • Baseline vs PR default: losses/grad-norms identical -- zero regression
  • Baseline vs PR True: losses/grad-norms identical
  • PR explicit False: correctly disables padding-free (slightly higher memory, slightly different losses from different batching)

No additional changes needed.

@danielhanchen danielhanchen merged commit 9b56d63 into unslothai:main Mar 2, 2026
1 check passed
abiswas-realadvice pushed a commit to abiswas-realadvice/unsloth that referenced this pull request May 14, 2026
)

* Fix auto padding free logic to respect user passed

* Update unsloth/trainer.py

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

---------

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants