feat: Validate PEFT target modules#1747
Conversation
|
Maybe this can be coordinated with #1799 so that Canonical LoRA also supports targeting at MLA linear layers. |
7f2fe2a to
5ec088e
Compare
|
/ok to test 5ec088e |
📝 WalkthroughWalkthroughThis pull request introduces a validation mechanism for PEFT target module matching. The system now tracks which target modules successfully match during model traversal and validates that all requested targets are satisfied, raising an error if any fail to match. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/megatron/bridge/peft/base.py (1)
99-120:⚠️ Potential issue | 🟠 MajorPreflight target validation before mutating the model.
Running
_validate_target_matches()only afterself._walk_model(model, self.transform)makes the error path unsafe: a config like["linear_qkv", "typo"]will already have frozen/wrapped the valid targets before theValueErroris raised. It also breaks reapplying matcher-based PEFT configs, becauseCanonicalLoRA.transform()andDoRA.transform()both return early for already wrapped modules before callingmatch(), so the second pass records no hits and fails validation. A read-onlymatchwalk beforefreeze_model()/self.transformavoids both behaviors.💡 Suggested approach
def __call__(self, model: ModelType, training: bool = True) -> ModelType: if isinstance(self, ModuleMatcher): self._reset_target_match_state() + + def _validate_only( + module: nn.Module, name: Optional[str] = None, prefix: Optional[str] = None + ) -> nn.Module: + self.match(module, name, prefix) + return module + + self._walk_model(model, _validate_only) + self._validate_target_matches() + self._reset_target_match_state() self.freeze_model(model, training=training) self._walk_model(model, self.transform) @@ - if isinstance(self, ModuleMatcher): - self._validate_target_matches() - return model🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/megatron/bridge/peft/base.py` around lines 99 - 120, Run a read-only preflight match pass for ModuleMatcher before mutating the model: if isinstance(self, ModuleMatcher) call a non-mutating walk that invokes the matcher (i.e., calls match() via the same walk logic) and then call self._validate_target_matches() immediately after that preflight; only after validation proceed to call self.freeze_model(model, training=training), self._walk_model(model, self.transform) and the rest (including maybe_enable_recompute_inputs_grad and setting training modes). This ensures CanonicalLoRA.transform and DoRA.transform won’t short-circuit earlier matches and prevents partial mutation before validation.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/megatron/bridge/peft/canonical_lora.py`:
- Around line 273-293: The code is doing substring checks and replace() across
the entire target string which can rewrite earlier path segments; instead split
the target into its leaf token (e.g., leaf = target.rsplit(".", 1)[-1] or use
the appropriate separator), perform the canonicalization on that leaf only to
set canonical_component and canonical_leaf (e.g., replace "linear_q" ->
"linear_qkv" on the leaf), then reconstruct canonical_target by joining the
original prefix (if any) with the canonical_leaf; finally call
register_target_alias(original_target, canonical_target) and update
self.canonical_mapping[canonical_target].add(canonical_component) so wildcard
patterns and parent path segments are preserved.
---
Outside diff comments:
In `@src/megatron/bridge/peft/base.py`:
- Around line 99-120: Run a read-only preflight match pass for ModuleMatcher
before mutating the model: if isinstance(self, ModuleMatcher) call a
non-mutating walk that invokes the matcher (i.e., calls match() via the same
walk logic) and then call self._validate_target_matches() immediately after that
preflight; only after validation proceed to call self.freeze_model(model,
training=training), self._walk_model(model, self.transform) and the rest
(including maybe_enable_recompute_inputs_grad and setting training modes). This
ensures CanonicalLoRA.transform and DoRA.transform won’t short-circuit earlier
matches and prevents partial mutation before validation.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: b11d6812-d386-4063-a823-e2132ce8a3f6
📒 Files selected for processing (5)
src/megatron/bridge/peft/base.pysrc/megatron/bridge/peft/canonical_lora.pysrc/megatron/bridge/peft/dora.pysrc/megatron/bridge/peft/module_matcher.pytests/unit_tests/peft/test_lora.py
Move target-module validation to a read-only preflight pass that runs before freeze_model/transform. This fixes two issues: 1. Idempotent re-application: transform() returns early for already-wrapped modules without calling match(), so post-transform validation saw zero matches and raised ValueError on second pass. 2. Partial mutation on error: the previous flow froze and wrapped valid targets before discovering unmatched ones, leaving the model in a half-transformed state. The preflight walk calls match() on every module (name-based, type- independent) so it correctly records hits even for wrapped modules. Also fixes test_canonical_lora_training_vs_inference_mode which used default CanonicalLoRA targets on SimpleModel (individual layers, not fused linear_qkv/linear_fc1). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
|
/claude review |
Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com> Signed-off-by: Yu Yao <54727607+yaoyu-33@users.noreply.github.com>
Add regression test ensuring CanonicalLoRA raises ValueError when a target module is not found in the model, covering the alias→pattern→match chain end-to-end (complements the existing LoRA equivalent). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
|
/ok to test d666375 |
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> Signed-off-by: Yu Yao <54727607+yaoyu-33@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com> Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>
…A-NeMo#3583) Signed-off-by: Chen Cui <chcui@nvidia.com> Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>
Summary
Summary by CodeRabbit
Bug Fixes
Tests