Fix: quantize and target moe layers in transformers v5 for adapters and many misc fixes#3439
Fix: quantize and target moe layers in transformers v5 for adapters and many misc fixes#3439
Conversation
require_grad=true on init
This reverts commit 1d54518.
📝 WalkthroughWalkthroughThis PR introduces MoE expert quantization support, adds GLM-4.7-Flash model fine-tuning configurations, extends FSDP2 with DTensor handling for QLoRA, updates the Cut Cross Entropy integration, adds MOE architecture mappings, and refines checkpoint saving for context-parallel setups. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
Tip Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs). Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 5
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/axolotl/core/trainers/base.py (1)
753-757:⚠️ Potential issue | 🔴 CriticalUnconditional
is_main_processkwarg will break custom model saves.Lines 753-757 and 768-772 unconditionally pass
is_main_processtosave_pretrained(). Custom model overrides that don't accept this parameter (e.g.,src/axolotl/models/mamba/modeling_mamba.py:110-114) will raiseTypeErrorduring checkpoint save. The mamba model'ssave_pretrainedsignature accepts onlysave_directoryandstate_dict, with no**kwargsto absorb additional parameters.This requires conditional passing of the kwarg only when the method signature supports it, such as using
inspect.signature()to check before calling.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/core/trainers/base.py` around lines 753 - 757, The save_pretrained call is passing is_main_process unconditionally which breaks custom implementations (e.g., modeling_mamba's save_pretrained) that don't accept that kwarg; update the code that calls save_pretrained (the call where state_dict and is_main_process are passed) to first inspect the save_pretrained signature (use inspect.signature on the model's save_pretrained method) and only include is_main_process when the parameter is accepted (otherwise call with only save_directory and state_dict); ensure you reference the same save_pretrained method used in the trainer and use self.accelerator.is_main_process for the flag.
🧹 Nitpick comments (2)
examples/glm4.7-flash/qlora.yaml (1)
26-34: Clarify whether this example adapts experts or only attention projections.Right now the active config targets
q_proj/v_proj/k_proj/o_proj, while expert parameter targets are commented out. A brief inline note would prevent confusion for users expecting MoE expert adaptation in this example.📝 Suggested clarification
lora_target_modules: - q_proj - v_proj - k_proj - o_proj +# This example fine-tunes attention projections only. +# To also adapt MoE expert tensors, use lora_target_parameters: # lora_target_parameters: # - mlp.experts.gate_up_proj # - mlp.experts.down_proj🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/glm4.7-flash/qlora.yaml` around lines 26 - 34, The config currently only adapts attention projection modules (lora_target_modules: q_proj, v_proj, k_proj, o_proj) while the MoE expert parameter targets (lora_target_parameters: mlp.experts.gate_up_proj, mlp.experts.down_proj) are commented out; update the YAML to make intent explicit by either uncommenting the lora_target_parameters entries (and enabling lora_mlp_kernel if needed) when you intend to adapt experts, or add a one-line inline comment above lora_target_modules explaining this example only adapts attention projections and does not modify MoE expert parameters.src/axolotl/loaders/patch_manager.py (1)
380-387: Limit PEFT monkeypatch to adapter flows and apply once.Line 386 applies a global PEFT method patch whenever
quantize_moe_expertsis true, even if no adapter path is active. Consider gating this on adapter usage and adding a one-time guard to reduce global side effects.♻️ Proposed fix
patch_moe_quantization_on_load(self.cfg) - patch_peft_target_parameters_matching() + if self.cfg.adapter: + patch_peft_target_parameters_matching()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/loaders/patch_manager.py` around lines 380 - 387, Only apply the PEFT monkeypatch when adapters are actually in use and ensure it's applied once: check the config flag that indicates an adapter path is active (e.g., self.cfg.adapter_path or equivalent adapter-enabled field) before calling patch_peft_target_parameters_matching(), and guard the call with a one-time flag (module-level or attribute like self._peft_patch_applied) so subsequent loads don't reapply the global patch; keep patch_moe_quantization_on_load(self.cfg) behavior unchanged but move or wrap patch_peft_target_parameters_matching() behind the adapter check and one-time guard.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/glm4.7-flash/README.md`:
- Line 63: The README's full-finetune tip is incomplete: when instructing users
to remove adapter: qlora and load_in_4bit: true from the FSDP2 config, also
disable or remove quantize_moe_experts because it currently requires adapter to
be lora or qlora and will fail validation; update the sentence to say to set
quantize_moe_experts: false (or remove that key) in the FSDP2 config when doing
a full finetune so validation passes.
In `@scripts/cutcrossentropy_install.py`:
- Line 32: The pip-install line in scripts/cutcrossentropy_install.py currently
pins the repository with a short commit hash ("a668583") in the string
f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @
git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a668583"'; replace
that abbreviated SHA with the repository's full 40-character commit SHA to
ensure deterministic dependency resolution (look up the full SHA for the
intended commit in the ml-cross-entropy repo and update the string accordingly).
In `@src/axolotl/monkeypatch/fsdp2_qlora.py`:
- Around line 166-193: The runtime monkeypatch re-wraps methods on every call;
modify apply_linear8bitlt_save_patch to be idempotent by detecting and skipping
if already applied: capture and store the original function once (e.g., a
module-level variable or by checking an attribute on
bnb.nn.Linear8bitLt._save_to_state_dict), return immediately if the method is
already the patched wrapper, and when installing the wrapper, mark the patched
function with a sentinel attribute (e.g., _is_axolotl_patched = True) so
repeated calls do nothing; apply the same pattern to
apply_init_dtype_attrs_patch (use unique sentinel names and preserved originals
like original_save / original_init_dtype_attrs) to avoid stacking wrappers.
In `@src/axolotl/monkeypatch/moe_quant.py`:
- Around line 82-107: Compute and write the current runtime quantization state
into _moe_load_state before the early "already patched" return: determine mode
from cfg (same logic using getattr(cfg, "load_in_8bit", False) to select "8bit"
or "4bit"), set _moe_load_state["mode"] = mode and _moe_load_state["count"] = 0
(and if applicable pre-populate
_moe_load_state["quant_type"]/_moe_load_state["compress_statistics"] when mode
== "4bit") before the if _moe_load_state["patched"] check so later loads don't
see stale state from prior loads.
In `@src/axolotl/utils/schemas/config.py`:
- Around line 632-641: Extend the existing Pydantic validation that currently
checks adapter + load_in_4bit/load_in_8bit to also enforce that
quantize_moe_experts can only be true when the runtime backend is CUDA (reject
when backend is ROCm/other); specifically, in the same validator that references
quantize_moe_experts and the load_in_4bit/load_in_8bit flags, add a check that
the configured backend/device backend (or runtime torch backend detection)
indicates CUDA and raise a ValidationError if quantize_moe_experts is true but
the backend is not CUDA so the config fails validation early.
---
Outside diff comments:
In `@src/axolotl/core/trainers/base.py`:
- Around line 753-757: The save_pretrained call is passing is_main_process
unconditionally which breaks custom implementations (e.g., modeling_mamba's
save_pretrained) that don't accept that kwarg; update the code that calls
save_pretrained (the call where state_dict and is_main_process are passed) to
first inspect the save_pretrained signature (use inspect.signature on the
model's save_pretrained method) and only include is_main_process when the
parameter is accepted (otherwise call with only save_directory and state_dict);
ensure you reference the same save_pretrained method used in the trainer and use
self.accelerator.is_main_process for the flag.
---
Nitpick comments:
In `@examples/glm4.7-flash/qlora.yaml`:
- Around line 26-34: The config currently only adapts attention projection
modules (lora_target_modules: q_proj, v_proj, k_proj, o_proj) while the MoE
expert parameter targets (lora_target_parameters: mlp.experts.gate_up_proj,
mlp.experts.down_proj) are commented out; update the YAML to make intent
explicit by either uncommenting the lora_target_parameters entries (and enabling
lora_mlp_kernel if needed) when you intend to adapt experts, or add a one-line
inline comment above lora_target_modules explaining this example only adapts
attention projections and does not modify MoE expert parameters.
In `@src/axolotl/loaders/patch_manager.py`:
- Around line 380-387: Only apply the PEFT monkeypatch when adapters are
actually in use and ensure it's applied once: check the config flag that
indicates an adapter path is active (e.g., self.cfg.adapter_path or equivalent
adapter-enabled field) before calling patch_peft_target_parameters_matching(),
and guard the call with a one-time flag (module-level or attribute like
self._peft_patch_applied) so subsequent loads don't reapply the global patch;
keep patch_moe_quantization_on_load(self.cfg) behavior unchanged but move or
wrap patch_peft_target_parameters_matching() behind the adapter check and
one-time guard.
ℹ️ Review info
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (21)
examples/colab-notebooks/colab-axolotl-example.ipynbexamples/glm4.7-flash/README.mdexamples/glm4.7-flash/lora.yamlexamples/glm4.7-flash/lora_fsdp.yamlexamples/glm4.7-flash/qlora.yamlexamples/glm4.7-flash/qlora_fsdp.yamlexamples/trinity/README.mdexamples/trinity/trinity-nano-preview-qlora.yamlscripts/cutcrossentropy_install.pysrc/axolotl/common/architectures.pysrc/axolotl/core/trainers/base.pysrc/axolotl/integrations/cut_cross_entropy/README.mdsrc/axolotl/integrations/cut_cross_entropy/__init__.pysrc/axolotl/integrations/kernels/README.mdsrc/axolotl/loaders/model.pysrc/axolotl/loaders/patch_manager.pysrc/axolotl/monkeypatch/accelerate/fsdp2.pysrc/axolotl/monkeypatch/fsdp2_qlora.pysrc/axolotl/monkeypatch/moe_quant.pysrc/axolotl/utils/schemas/config.pysrc/axolotl/utils/schemas/peft.py
💤 Files with no reviewable changes (1)
- examples/trinity/trinity-nano-preview-qlora.yaml
| - `temperature: 1.0` | ||
| - `top_p: 0.95` | ||
| - `max_new_tokens: 131072` | ||
| - You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the FSDP2 config. This is heavy, so we have not tested this. |
There was a problem hiding this comment.
Update the full-finetune tip to disable quantize_moe_experts.
Line 63 currently tells users to remove adapter: qlora and load_in_4bit: true, but the config still contains quantize_moe_experts: true, which now requires adapter to be lora or qlora and will fail validation.
📝 Suggested doc fix
-- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the FSDP2 config. This is heavy, so we have not tested this.
+- You can run a full finetuning by removing `adapter: qlora`, removing `load_in_4bit: true`, and setting `quantize_moe_experts: false` in the FSDP2 config. This is heavy, so we have not tested this.📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| - You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the FSDP2 config. This is heavy, so we have not tested this. | |
| - You can run a full finetuning by removing `adapter: qlora`, removing `load_in_4bit: true`, and setting `quantize_moe_experts: false` in the FSDP2 config. This is heavy, so we have not tested this. |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/glm4.7-flash/README.md` at line 63, The README's full-finetune tip
is incomplete: when instructing users to remove adapter: qlora and load_in_4bit:
true from the FSDP2 config, also disable or remove quantize_moe_experts because
it currently requires adapter to be lora or qlora and will fail validation;
update the sentence to say to set quantize_moe_experts: false (or remove that
key) in the FSDP2 config when doing a full finetune so validation passes.
| def apply_linear8bitlt_save_patch(): | ||
| """Patch Linear8bitLt._save_to_state_dict to handle DTensor-wrapped Int8Params. | ||
|
|
||
| After FSDP2 sharding, Linear8bitLt.weight is a DTensor wrapping Int8Params. | ||
| BnB's _save_to_state_dict accesses self.weight.SCB directly, but DTensor | ||
| doesn't proxy custom attribute access to its _local_tensor. This patch | ||
| temporarily unwraps the DTensor during saving so BnB can find the SCB attribute. | ||
| """ | ||
| import bitsandbytes as bnb | ||
| from torch.distributed.tensor import DTensor | ||
|
|
||
| original_save = bnb.nn.Linear8bitLt._save_to_state_dict | ||
|
|
||
| def _patched_save_to_state_dict(self, destination, prefix, keep_vars): | ||
| # Use _parameters dict directly to bypass nn.Module.__setattr__ type check. | ||
| weight = self._parameters["weight"] | ||
| unwrapped = False | ||
| if isinstance(weight, DTensor) and hasattr(weight, "_local_tensor"): | ||
| self._parameters["weight"] = weight._local_tensor | ||
| unwrapped = True | ||
| try: | ||
| original_save(self, destination, prefix, keep_vars) | ||
| finally: | ||
| if unwrapped: | ||
| self._parameters["weight"] = weight | ||
|
|
||
| bnb.nn.Linear8bitLt._save_to_state_dict = _patched_save_to_state_dict | ||
| LOG.info("Patched Linear8bitLt._save_to_state_dict for DTensor compatibility") |
There was a problem hiding this comment.
Add idempotency guards for runtime monkeypatches.
apply_linear8bitlt_save_patch() and apply_init_dtype_attrs_patch() re-wrap methods on every call. In repeated model-load flows this stacks wrappers and can cause hard-to-debug behavior/perf regressions.
♻️ Proposed fix
def apply_linear8bitlt_save_patch():
@@
import bitsandbytes as bnb
from torch.distributed.tensor import DTensor
+ if getattr(bnb.nn.Linear8bitLt, "_axolotl_save_patch_applied", False):
+ return
+
original_save = bnb.nn.Linear8bitLt._save_to_state_dict
@@
bnb.nn.Linear8bitLt._save_to_state_dict = _patched_save_to_state_dict
+ bnb.nn.Linear8bitLt._axolotl_save_patch_applied = True
LOG.info("Patched Linear8bitLt._save_to_state_dict for DTensor compatibility")
@@
def apply_init_dtype_attrs_patch():
@@
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
+ if getattr(FSDPParam, "_axolotl_init_dtype_attrs_patch_applied", False):
+ return
+
original_init_dtype_attrs = FSDPParam.init_dtype_attrs
@@
FSDPParam.init_dtype_attrs = patched_init_dtype_attrs
+ FSDPParam._axolotl_init_dtype_attrs_patch_applied = True
LOG.info("Patched FSDPParam.init_dtype_attrs for non-float quantized params")Also applies to: 196-224
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/axolotl/monkeypatch/fsdp2_qlora.py` around lines 166 - 193, The runtime
monkeypatch re-wraps methods on every call; modify apply_linear8bitlt_save_patch
to be idempotent by detecting and skipping if already applied: capture and store
the original function once (e.g., a module-level variable or by checking an
attribute on bnb.nn.Linear8bitLt._save_to_state_dict), return immediately if the
method is already the patched wrapper, and when installing the wrapper, mark the
patched function with a sentinel attribute (e.g., _is_axolotl_patched = True) so
repeated calls do nothing; apply the same pattern to
apply_init_dtype_attrs_patch (use unique sentinel names and preserved originals
like original_save / original_init_dtype_attrs) to avoid stacking wrappers.
| if _moe_load_state["patched"]: | ||
| LOG.debug("MoE loading-time quantization patch already active") | ||
| return | ||
|
|
||
| import transformers.core_model_loading | ||
| import transformers.modeling_utils | ||
|
|
||
| if getattr(cfg, "load_in_8bit", False): | ||
| mode = "8bit" | ||
| else: | ||
| mode = "4bit" | ||
|
|
||
| _moe_load_state["mode"] = mode | ||
| _moe_load_state["count"] = 0 | ||
|
|
||
| if mode == "4bit": | ||
| from bitsandbytes.nn.parametrize import replace_parameter_4bit | ||
|
|
||
| quant_type = getattr(cfg, "bnb_4bit_quant_type", None) or "nf4" | ||
| compress_statistics = getattr(cfg, "bnb_4bit_use_double_quant", None) | ||
| if compress_statistics is None: | ||
| compress_statistics = True | ||
|
|
||
| _moe_load_state["quant_type"] = quant_type | ||
| _moe_load_state["compress_statistics"] = compress_statistics | ||
|
|
There was a problem hiding this comment.
Reset runtime quantization state before the “already patched” early return.
At Line 82, returning immediately skips updating _moe_load_state["mode"] and _moe_load_state["count"]. In multi-load processes, later loads can run with stale mode/count from a prior model load.
🐛 Proposed fix
def patch_moe_quantization_on_load(cfg):
@@
- if _moe_load_state["patched"]:
- LOG.debug("MoE loading-time quantization patch already active")
- return
-
import transformers.core_model_loading
import transformers.modeling_utils
if getattr(cfg, "load_in_8bit", False):
mode = "8bit"
@@
_moe_load_state["mode"] = mode
_moe_load_state["count"] = 0
if mode == "4bit":
from bitsandbytes.nn.parametrize import replace_parameter_4bit
@@
_moe_load_state["quant_type"] = quant_type
_moe_load_state["compress_statistics"] = compress_statistics
+
+ if _moe_load_state["patched"]:
+ LOG.debug(
+ "MoE loading-time quantization patch already active; refreshed runtime state"
+ )
+ return📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if _moe_load_state["patched"]: | |
| LOG.debug("MoE loading-time quantization patch already active") | |
| return | |
| import transformers.core_model_loading | |
| import transformers.modeling_utils | |
| if getattr(cfg, "load_in_8bit", False): | |
| mode = "8bit" | |
| else: | |
| mode = "4bit" | |
| _moe_load_state["mode"] = mode | |
| _moe_load_state["count"] = 0 | |
| if mode == "4bit": | |
| from bitsandbytes.nn.parametrize import replace_parameter_4bit | |
| quant_type = getattr(cfg, "bnb_4bit_quant_type", None) or "nf4" | |
| compress_statistics = getattr(cfg, "bnb_4bit_use_double_quant", None) | |
| if compress_statistics is None: | |
| compress_statistics = True | |
| _moe_load_state["quant_type"] = quant_type | |
| _moe_load_state["compress_statistics"] = compress_statistics | |
| import transformers.core_model_loading | |
| import transformers.modeling_utils | |
| if getattr(cfg, "load_in_8bit", False): | |
| mode = "8bit" | |
| else: | |
| mode = "4bit" | |
| _moe_load_state["mode"] = mode | |
| _moe_load_state["count"] = 0 | |
| if mode == "4bit": | |
| from bitsandbytes.nn.parametrize import replace_parameter_4bit | |
| quant_type = getattr(cfg, "bnb_4bit_quant_type", None) or "nf4" | |
| compress_statistics = getattr(cfg, "bnb_4bit_use_double_quant", None) | |
| if compress_statistics is None: | |
| compress_statistics = True | |
| _moe_load_state["quant_type"] = quant_type | |
| _moe_load_state["compress_statistics"] = compress_statistics | |
| if _moe_load_state["patched"]: | |
| LOG.debug( | |
| "MoE loading-time quantization patch already active; refreshed runtime state" | |
| ) | |
| return |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/axolotl/monkeypatch/moe_quant.py` around lines 82 - 107, Compute and
write the current runtime quantization state into _moe_load_state before the
early "already patched" return: determine mode from cfg (same logic using
getattr(cfg, "load_in_8bit", False) to select "8bit" or "4bit"), set
_moe_load_state["mode"] = mode and _moe_load_state["count"] = 0 (and if
applicable pre-populate
_moe_load_state["quant_type"]/_moe_load_state["compress_statistics"] when mode
== "4bit") before the if _moe_load_state["patched"] check so later loads don't
see stale state from prior loads.
| quantize_moe_experts: bool = Field( | ||
| default=False, | ||
| json_schema_extra={ | ||
| "description": "Quantize MoE expert weights on load to reduce VRAM. " | ||
| "Requires adapter (lora/qlora) with load_in_4bit or load_in_8bit. " | ||
| "Requires CUDA (not compatible with ROCm or other backends). " | ||
| "Note: total parameter count may be reported incorrectly when enabled " | ||
| "(trainable param count is correct)." | ||
| }, | ||
| ) |
There was a problem hiding this comment.
quantize_moe_experts is documented as CUDA-only but backend is not validated.
Line 637 states this is not compatible with ROCm/other backends, but Lines 1305-1313 only validate adapter + bit-loading flags. Unsupported backends can pass config validation and fail later at runtime.
✅ Suggested validator extension
`@model_validator`(mode="before")
`@classmethod`
def check_quantize_moe_experts(cls, data):
if data.get("quantize_moe_experts"):
+ capabilities = data.get("capabilities") or {}
+ compute_capability = str(capabilities.get("compute_capability", ""))
+ if not compute_capability.startswith("sm_"):
+ raise ValueError(
+ "quantize_moe_experts requires CUDA/NVIDIA (compute capability sm_*)."
+ )
if data.get("adapter") not in ("lora", "qlora"):
raise ValueError("quantize_moe_experts requires adapter: lora or qlora")
if not (data.get("load_in_4bit") or data.get("load_in_8bit")):
raise ValueError(
"quantize_moe_experts requires load_in_4bit or load_in_8bit"
)
return dataAlso applies to: 1305-1313
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/axolotl/utils/schemas/config.py` around lines 632 - 641, Extend the
existing Pydantic validation that currently checks adapter +
load_in_4bit/load_in_8bit to also enforce that quantize_moe_experts can only be
true when the runtime backend is CUDA (reject when backend is ROCm/other);
specifically, in the same validator that references quantize_moe_experts and the
load_in_4bit/load_in_8bit flags, add a check that the configured backend/device
backend (or runtime torch backend detection) indicates CUDA and raise a
ValidationError if quantize_moe_experts is true but the backend is not CUDA so
the config fails validation early.
|
📖 Documentation Preview: https://69a573539f289059635108bd--resonant-treacle-0fd729.netlify.app Deployed on Netlify from commit 3aafa4a |
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
|
|
||
| - **FSDP VRAM**: FSDP2 may use more VRAM per GPU than single GPU training. We suspect not all layers are properly sharded across ranks. | ||
| - **FSDP initial spike**: FSDP LoRA (8-bit) may have a large initial VRAM spike at the first 1-2 steps that then drops. FSDP QLoRA (4-bit) does not exhibit this. | ||
| - **cpu_ram_efficient_loading**: Must be set to `false` with FSDP2 — causes `AttributeError: e_score_correction_bias is not an nn.Parameter` due to modeling source. |
There was a problem hiding this comment.
btw, can we log the stack trace for this and open an issue for fixing in axolotl?
| return state_dict | ||
|
|
||
|
|
||
| def patch_peft_param_wrapper_for_fsdp2(): |
There was a problem hiding this comment.
I looked at this. This patch is created to resolve an issue that occurred from our quantize patch, so maybe doesn't make sense to be upstreamed.
Description
Closes #3374 #3370
Problem: Experts aren't properly targeted with the normal
lora_target_modules, we need to uselora_target_parameters. Additionally, expert layers aren't being quantized by bnb.Based on #3395 work by ved (without the lora kernels changes) and contains additional fixes:
How to use
(see included configs)
Results
From a QLoRA training using 127GiB peak memory, we managed to reduce till 23GiB.
Loss line differs as we swapped optim (
adamw_bnb_8bit->adamw_torch_8bit)+ nodes while working on this. Verified that without our fix, the prior line is consistent.We also incorporated these changes onto LoRA and FSDP2 LoRA/QLoRA trainings.
Limitation
AttributeError:e_score_correction_biasis not an nn.Parameterdue to how it's instantiatedbf16lora not testedMisc:
How to use
See included example yamls and README
Motivation and Context
How has this been tested?
Training loss consistent with our on-load quantize as the previous post-quantize change. Tested across single GPU for LoRA and QLoRA respectively. Also ensured FSDP/DDP loss were within reasonable expectations.
TODO (ongoing):
AI Usage Disclaimer
Claude heavily while checking result via manual runs
Screenshots (if appropriate)
Types of changes
Social Handles (Optional)
Thanks to ved for original PR to base off and help test throughout.
Summary by CodeRabbit
Release Notes
New Features
Bug Fixes
Documentation