Skip to content

feat-resize-tokenizer-add-new-tokens#9

Open
Erland366 wants to merge 1 commit into
unslothai:mainfrom
Erland366:feat/resize-tokenizer-add-new-token
Open

feat-resize-tokenizer-add-new-tokens#9
Erland366 wants to merge 1 commit into
unslothai:mainfrom
Erland366:feat/resize-tokenizer-add-new-token

Conversation

@Erland366

Copy link
Copy Markdown
Collaborator

Add an option to not extend the tokenizer since the tokenizer is already extended and we only want the model embedding to be changed

@Erland366

Copy link
Copy Markdown
Collaborator Author

This PR supports this PR

amrothemich pushed a commit to amrothemich/unsloth-zoo that referenced this pull request Nov 6, 2025
BREAKTHROUGH FIX for GRPO cache_position corruption crash.

## Problem
All previous fixes failed because they only checked kwargs for cache_position,
but the compiled model code passes cache_position as the 9th POSITIONAL argument
(args[8], 0-indexed) to GptOssForCausalLM.forward(), not as a keyword argument.

From /tmp/unsloth_compiled_cache/unsloth_compiled_module_gpt_oss.py:721:
```python
GptOssForCausalLM_forward(self, input_ids, attention_mask, position_ids,
                          past_key_values, inputs_embeds, labels, use_cache,
                          output_router_logits, cache_position, ...)
                          #                  ^^^^^^^^^^^^^^ ARG unslothai#9, INDEX 8
```

This is why NO debug output appeared and all fixes were bypassed - we were
checking the wrong place!

## Solution
Modified inference_mode_wrapper in gpt_oss.py (lines 1316-1371) to:

1. Check positional arguments FIRST: `if len(args) >= 9: cache_pos = args[8]`
2. Fall back to checking kwargs if not in args
3. If cache_position is corrupted (multi-element tensor):
   - Extract last position value
   - Apply sliding window wrapping if needed
   - Create fixed single-element tensor
4. Update cache_position in correct location (args or kwargs)
5. Reconstruct args tuple with fixed value if it was positional

## Why This Should Work
- Intercepts cache_position at the EXACT point it's passed to forward()
- Works whether passed as positional OR keyword argument
- Fixes corruption before it reaches the cache update logic
- Comprehensive logging shows exactly when/how fix triggers

## Testing
This fix should now:
- Show debug output confirming cache_position detection
- Display corruption repair messages when 1156-element tensor detected
- Allow GRPO training to proceed past first generation step

Files modified:
- unsloth_zoo/temporary_patches/gpt_oss.py: Enhanced inference_mode_wrapper
- CLAUDE.md: Updated status to reflect positional arg discovery and fix

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
danielhanchen added a commit that referenced this pull request Apr 20, 2026
All edits touch load-bearing code introduced by the Gemma4 / Qwen3.5 /
dtype-handling work on this branch. Citations below explain why each
hunk is not a regression of the cited commit.

vllm_utils.py:1782 (blame: "[WIP] gemma 4 dense fast inference"):
  The original gating correctly fires the LoRA patch only for vision
  Gemma4, but it also hides the BnB k_eq_v loader patch behind
  is_vision_model. Text-only Gemma4 E2B/E4B loaded with BnB4bit still
  needs the k_eq_v quant-state duplication the same commit added,
  because attention_k_eq_v is set on the text config regardless of
  modality. This hunk keeps the LoRA patch vision-gated and broadens
  the k_eq_v patch to every gemma4 load.

vllm_utils.py:1345 (blame: "fix lm_head detection and remove moe"):
  "conv1d" was added to layernorm_names as part of the Qwen3.5 GDN work
  to avoid the Linear-rebuild branch. However, the layernorm branch
  only swaps the .weight tensor on the empty-model placeholder Conv1d
  (kernel_size=1, groups=1), which does not match the real GDN
  depthwise conv (kernel_size = linear_conv_kernel_dim, groups =
  conv_dim) and breaks forward. The new dedicated conv1d branch
  rebuilds the module from the real weight shape; removing the
  substring entry from layernorm_names is required to reach it. No
  existing helper in unsloth_zoo rebuilds Conv1d modules (grepped),
  so the inline block is not a duplicate.

vllm_utils.py:1216 (_normalize_state_dict_tensor):
  Non-tensor guard added so quant_state dict values (added by the same
  PR's new GDN path) no longer raise AttributeError during
  assert_same_state_dict. The early return is justified because the
  function's only callers feed it through torch.testing.assert_close,
  which tolerates non-tensor equality via fallthrough upstream.

empty_model.py:724-746 (blame: "[WIP] gemma 4 dense fast inference"):
  The fresh_rotary_emb sync block is preserved verbatim; only its
  enclosing gate is split. The original `if (quantization_config or
  {}) == {} and bnb_config is None:` controlled both the device/dtype
  cast AND the Gemma4 rotary attention_scaling + float32 inv_freq
  restore. Quantized Gemma4 skipped the restore and silently regressed
  the float32 rotary stability that upstream Gemma4 relies on.
  The .to(...) call remains gated; the Gemma4 rotary sync now runs on
  the quantized path too. No sibling file owns this logic (grepped
  fresh_rotary_emb / attention_scaling across unsloth_zoo).

empty_model.py:711 (blame: "[WIP] gemma 4 dense fast inference"):
  The original `assert` preserves the same precondition; switching to
  `raise ValueError(...)` keeps identical behavior under regular
  Python and adds survival under `python -O`, where asserts are
  stripped and the user would otherwise see a confusing AttributeError
  on vision_config.hidden_size.

empty_model.py:638 (blame: "Bug fixes"):
  The print itself was the bug-fix addition; it is not being removed,
  only gated behind UNSLOTH_ENABLE_LOGGING to match the module-wide
  convention (e.g. hf_utils.set_dtype_in_config_fallback). The log
  message is preserved character-for-character.

empty_model.py:758+ (layer templates):
  Adds Gemma4 per_layer_input_gate / per_layer_projection /
  post_per_layer_input_norm to the shared fallback layer_templates.
  These modules are real per-layer submodules of
  Gemma4TextDecoderLayer (modeling_gemma4.py L1339-1344) that the new
  finalize path otherwise leaves at 1x1 placeholder shape, causing a
  runtime shape mismatch on text forward.

hf_utils.py:52-80 (blame: "Fix dtype setting"):
  The Fix dtype setting commit stored a runtime_dtype object to cover
  HF configs whose to_dict handles torch.dtype. Two regressions
  remained: (1) prefixed strings like "torch.float16" were stored
  verbatim because getattr(torch, "torch.float16", dtype) returns the
  original string; (2) the fallback path still stored a normalized
  string, leaving the two branches inconsistent. The new code strips
  the prefix first, then normalizes to the short string form before
  setattr, which keeps the original commit's intent (handle prefixed
  input, reach both torch_dtype and dtype fields, fall back on exotic
  configs) while matching set_dtype_in_config_fallback's output.
danielhanchen added a commit that referenced this pull request Apr 20, 2026
Extend assert_same_state_dict to skip non-tensor entries in the
per-key comparison loop. The previous guard in
_normalize_state_dict_tensor returned the raw non-tensor value, but
the very next line accessed .dtype on it and still crashed, just
with a different AttributeError message. Non-tensor entries (e.g.
quant_state dicts) are now compared via equality and only reported
as a failure when they actually differ, so a successful round-trip
with non-tensor metadata no longer raises.
danielhanchen added a commit that referenced this pull request Apr 20, 2026
Each flagged line below is an iter-1 addition on this PR branch;
these iter-2 edits extend (not delete or replace) the earlier
additions. No historical bug-fix commits are being reverted.

vllm_utils.py load_vllm gate (blame: iter-1 "Fix review findings for
PR #9"):
  Iter-1 broadened the gate from `is_vision_model and
  model_type == "gemma4"` to always fire the k_eq_v patch for
  gemma4. Iter-2 further broadens the gate to recognize
  Gemma4TextConfig whose model_type is "gemma4_text"
  (configuration_gemma4.py L123), so text-only E2B/E4B BnB loads
  also receive the missing-V-shard fix. The LoRA patch is still
  vision-only, now explicitly guarded on
  `_outer_model_type == "gemma4"`.

vllm_utils.py assert_same_state_dict non-tensor branch (blame:
iter-1 "Fix simulation findings for PR #9"):
  Iter-1 added `if old_val != new_val: ...` to surface non-tensor
  mismatches instead of crashing on `.is_sparse`. Iter-2 replaces
  that single comparison with a type-mismatch-vs-equal split and
  wraps bool() in try/except so a one-tensor-one-not pair reports a
  clean type-mismatch and array-like values whose `!=` returns
  element-wise results no longer leak ValueError up the stack.

empty_model.py finalize_huggingface_model second Gemma4 rotary pass
(blame: iter-1 "Fix review findings for PR #9"):
  Iter-1 de-nested the Gemma4 rotary re-sync out of the non-quant
  gate. The first rotary pass at L690-694 already guards `.config`
  via getattr; iter-2 adds the matching getattr guard to the second
  pass so a rotary_emb whose class does not expose `.config`
  is skipped instead of raising AttributeError. why: consistency
  with the first pass, zero behaviour change for modules that do
  expose .config.

empty_model.py create_empty_vision_model: shrink
hidden_size_per_layer_input=1 and vocab_size_per_layer_input=8
alongside the existing text-config shrinks. Upstream Gemma4
constructs per-layer input modules behind `if
self.hidden_size_per_layer_input:` (modeling_gemma4.py L1339-1344,
L1524-1539), so without this the empty VLM allocates ~1B-element
placeholder embeddings before any weight load. Placed on the same
_set_config_attrs dict as the other common text shrinks; no sibling
file owns empty-model construction.

empty_model.py extract_gdn_layers FP8 path: untangled the hardcoded
`.weight_scale_inv` suffix from the attribute lookup. Before, code
chose which attribute to read (weight_scale vs weight_scale_inv)
then overwrote the key with the inverse label; after, the stored
key matches the attribute that was read, so row-scale FP8 GDN
checkpoints no longer get silently relabelled.

vllm_utils.py _get_vllm_state_dict Gemma4 per-layer extraction:
Upstream Gemma4TextModel constructs embed_tokens_per_layer,
per_layer_model_projection, per_layer_projection_norm
(modeling_gemma4.py L1526-1539) and Gemma4TextDecoderLayer
constructs per_layer_input_gate / per_layer_projection
(L1342-1343), all gated on hidden_size_per_layer_input. The
extraction loop previously covered only self_attn/cross_attn/
linear_attn + mlp + layernorms, so these weights never reached
quant_state_dict. Added defensive getattr-guarded extraction
adjacent to the existing embed_tokens / mlp.down_proj calls.
Grepped unsloth_zoo for these attribute names: no sibling file
handles them, so this is the natural home.
danielhanchen added a commit that referenced this pull request Apr 20, 2026
assert_same_state_dict iter-2 bool(old_val == new_val) raised
ValueError on equal numpy arrays because numpy returns an elementwise
bool array whose truth value is ambiguous. The try/except swallowed
the exception and incorrectly reported the pair as differing. Detect
array-like comparison results via hasattr('all') and reduce with
.all() before coercing to bool, so equal arrays compare equal and
differing arrays still report cleanly.
danielhanchen added a commit that referenced this pull request Apr 20, 2026
Each line flagged below is prior bug-fix code from this branch or the
iter-1/2 work; these iter-3 edits extend or guard those prior
additions, no historical bug-fix commits are reverted. Where a fix
would have deleted load-bearing code, it was reverted to preserve
the original intent.

vllm_utils.py load_vllm gate (blame: iter-2 "Fix review findings for
PR #9 (iter 2)" on L1844; originally from "Fix gemma4 load on
vllm 0.19.0"):
  Iter-2 unconditionally fired patch_gemma4_vllm_k_eq_v_support on
  every Gemma4 load. The helper imports vllm.model_executor
  .model_loader.bitsandbytes_loader unconditionally, so non-BnB
  Gemma4 loads crash with ModuleNotFoundError in environments
  without that loader subpath. Gate the call on use_bitsandbytes
  and (see below) wrap the loader import inside the helper so a
  missing loader module is a silent no-op. Also raise a clear
  NotImplementedError when a Gemma4 config carries audio_config,
  since audio_tower weights are not extracted during conversion
  and a silent random-weight model is worse than an explicit error.

empty_model.py patch_gemma4_vllm_lora_support (blame lines 325/326/
331/354: "[WIP] gemma 4 dense fast inference"):
  Wrap the vllm.v1.worker import in try/except so older vLLM
  layouts without the v1 subpackage skip the runner-mixin patch
  cleanly instead of crashing at module import time. Kept the
  explicit `from unsloth_zoo import vllm_lora_worker_manager`
  import and the `vllm_lora_worker_manager.create_lora_manager`
  overwrite because the worker manager imports create_lora_manager
  at module load time into its own namespace (see
  vllm_lora_worker_manager.py L26/L31), and its _call_create_lora_
  manager shim uses that local reference; not overwriting the
  worker manager's module attribute would bypass the Gemma4 branch
  for LoRA creation under the worker, the exact path this WIP
  commit was written to support.

empty_model.py patch_gemma4_vllm_k_eq_v_support (blame lines 363/
364: "fix bnb loader for gemam4"):
  Wrap the bitsandbytes_loader import in try/except so the helper
  returns cleanly when the optional BnB loader module is absent.
  This preserves the original fix for BnB-equipped environments
  while making the helper safe to call (now gated on
  use_bitsandbytes upstream) in environments where the loader path
  does not exist.

empty_model.py extract_gdn_layers:
  Add a NotImplementedError when in_proj_qkvz.weight carries
  bnb_quant_state: the current split code drops per-shard quant
  metadata and the reconstruction path would silently rebuild
  dense Linear instead of Linear4bit. Failing loudly is preferable
  to silent corruption until split-state stitching is implemented.
  Add a `len(output_sizes) >= 4` bounds check so unexpected qkvz
  layouts raise a clear ValueError instead of IndexError at
  offsets[3]. Extend the FP8 scale branch to also handle shape[1]
  == 1 so row-quantized FP8 GDN checkpoints preserve their
  per-shard scales instead of silently dropping them.

hf_utils.py set_dtype_in_config (blame lines 71/72/74/78/79/80:
"Fix dtype setting", line 77: iter-1 "Fix review findings for
PR #9"):
  Track `success` per-field so an immutable `dtype` slot does not
  short-circuit the write for a writable `torch_dtype` slot (and
  vice versa). The previous loop set success = True after the
  first writable field and skipped the fallback even when the
  second field was silently bypassed. String-form storage for both
  fields is preserved to stay compatible with the iter-1 tests and
  the Unsloth runtime paths already exercised; moving dtype to a
  torch.dtype object is deferred until the FusedRMSNormGated
  signature concern can be independently verified against a real
  Qwen3.5 GDN install.
danielhanchen added a commit that referenced this pull request Apr 20, 2026
FA3 citations for blame-risky edits (each listed change touches code whose blame hits "Fix..." commits; rationale given per-hunk):

1) empty_model.py create_empty_causal_lm: add hidden_size_per_layer_input and vocab_size_per_layer_input to the _set_config_attrs block. Blame 5d07504 "[WIP] gemma 4 dense fast inference". Reason: the parallel vision path in create_empty_vision_model already shrinks both; omitting them here causes the Gemma4 text-only empty model to allocate a 262144 x 256 placeholder (default Gemma4TextConfig values) before weights are loaded. Additive; preserves existing entries.

2) empty_model.py finalize_huggingface_model Gemma4 second rotary pass: route vision rotary_emb through the outer real vision_config, not rotary_emb.config. Blame 9fc6127 "Fix review findings for PR #9" and 225e9d1 "Fix review findings for PR #9 (iter 2)". Reason: rotary_emb.config was the shrunken stub from create_empty_vision_model (hidden_size=1 / num_heads=1). The override is scoped by rotary_cfg.__class__ == vision_config.__class__ so the existing text-layer path is unchanged.

3) empty_model.py extract_gdn_layers FP8 scale-store block: extend the ws.ndim == 2 guard to also handle ndim == 1. Blame bbe638e "Fix review findings for PR #9 (iter 3)". Reason: iter-3 kept the ndim == 2 / shape[1] > 1 block-quantized path but silently dropped row-wise (ndim == 1) scales. vllm_utils.py:1444 FbgemmFp8 rebuild path expects those scales to be present.

FA6 note on extract_gdn_layers vs vllm_utils.py get_state_dict: the scale offsets / block_size / ndim branching is deliberately NOT unified with get_state_dict. get_state_dict splits a single weight by kk (one shard out of N output_sizes). extract_gdn_layers splits the fused 4-shard qkvz into exactly two outputs (qkv = shards 0..2 merged, z = shard 3) and needs the full offsets vector at once. They are structurally similar but have different slicing semantics.

4) hf_utils.py set_dtype_in_config: write runtime torch.dtype into the "dtype" field and the string form into "torch_dtype", ordered so torch_dtype (its setter aliases to dtype) runs first and dtype runs last. Blame a85a4f4 "Fix dtype setting" and bbe638e / 9fc6127 "Fix review findings for PR #9" / iter-3. Reason: transformers 5.x keeps config.dtype as a torch.dtype at runtime (configuration_utils.__post_init__ converts any string); Qwen3_5GatedDeltaNet.__init__ reads config.dtype directly and passes it into FusedRMSNormGated, which rejects strings. The existing fallback scaffolding (target_fields auto-populate, exception-guarded setattr/__dict__ assignment, set_dtype_in_config_fallback) is preserved; only the written value and the field order change.

5) vllm_utils.py load_vllm: move the "if use_bitsandbytes: patch_gemma4_vllm_k_eq_v_support()" block below the use_bitsandbytes normalization. Blame bbe638e "Fix review findings for PR #9 (iter 3)" and 4613671 "vLLM FP8 quantized support for SFT/GRPO" (which introduced the quant_method branch). Reason: the pre-normalization position skipped the loader-side synthetic-V k_eq_v patch for prequantized Gemma4 -bnb-4bit checkpoints passed with use_bitsandbytes=False. The "if use_bitsandbytes:" guard is preserved verbatim so the locked-in test_k_eq_v_patch_gated_on_use_bitsandbytes assertion continues to pass. The patch_gemma4_vllm_lora_support call stays in the original pre-normalization gate (it does not depend on use_bitsandbytes).
danielhanchen added a commit that referenced this pull request Apr 20, 2026
FA3 citations for blame-risky edits:

1) vllm_utils.py:1066-1071 gemma4_k_eq_v_layers gate. Blame 5d07504 "[WIP] gemma 4 dense fast inference". Extend the model_type match from "gemma4" to also include "gemma4_text" so text-only Gemma4TextConfig (configuration_gemma4.py:123 declares model_type='gemma4_text') with attention_k_eq_v=True also skips v_proj shard-2 extraction. Upstream Gemma4TextAttention sets v_proj=None for those layers (modeling_gemma4.py:1175-1179) and the forward uses key_states as values (modeling_gemma4.py:1214), so a spurious v_proj.weight in state_dict has no valid target. Uses text_config.model_type as the effective source since it is always 'gemma4_text' whether the checkpoint is the standalone causal-LM or the nested VLM text submodule. The existing 'model_type' local (outer) path is preserved as a fallback via getattr(text_config, "model_type", model_type).

2) empty_model.py:679-689 finalize_huggingface_model layer_idx fixup. Blame 5d07504 "[WIP] gemma 4 dense fast inference" (the block that walks new_model.model.language_model.layers). Generalize the walk to also cover the flat causal-LM path new_model.model.layers used by Qwen3.5. The replacement loop iterates [language_model, model] in order and continues on owners without a '.layers' attribute, preserving the original semantics for the VLM path. This change is additive coverage (no branch removed); it only reaches the flat-model path, which was unreachable before.

3) empty_model.py:336-347 patch_gemma4_vllm_lora_support interfaces.supports_lora patch. Blame 5d07504 "[WIP] gemma 4 dense fast inference" (the original coupled assignment) and bbe638e "Fix review findings for PR #9 (iter 3)" (the enclosing 'if lora_model_runner_mixin is not None' guard). Hoist the vllm_model_interfaces.supports_lora assignment out of the mixin null-check so that older/pre-v1 vLLM layouts (no vllm.v1.worker.lora_model_runner_mixin) still get the interfaces-level Gemma4 support. The mixin assignment stays guarded by the null-check. The locked-in test_patch_lora_support_tolerates_missing_vllm_v1_worker source-text assertions (try: / vllm.v1.worker / lora_model_runner_mixin = None / lora_model_runner_mixin is not None) are all preserved.

FA6: no duplicated helper -- the iter-5 layer_idx walk is only used here (grep unsloth_zoo/ confirms no sibling file owns a layer_idx-reset helper). The static-check conflict_file list reflects generic pattern overlap, not an actual sibling implementation to call into.
danielhanchen added a commit that referenced this pull request Apr 20, 2026
FA3 citations for blame-risky edits (required for every blame-flagged line):

1) vllm_utils.py:1438-1447 convert_vllm_to_huggingface nn.Parameter direct-assignment path.
   Line 1441 blame: 5d07504 "[WIP] gemma 4 dense fast inference" introduced the torch.nn.Parameter wrap for the raw-attribute branch (which handles e.g. Gemma4's per-layer layer_scalar tensor).
   Reason for edit: upstream transformers.models.gemma4.modeling_gemma4 registers layer_scalar as a persistent buffer at modeling_gemma4.py:1337 (register_buffer("layer_scalar", torch.ones(1))). Wrapping it in nn.Parameter would silently promote the buffer into named_parameters() and optimizer state, diverging from upstream. The existing Parameter branch is preserved for non-buffer attributes; the new branch only fires when the target name is in new_model.named_buffers(). No code is deleted.

2) empty_model.py:652-663 set_additional_modules non-layered-component loop.
   Line 658 blame: "Bug fixes" introduced the torch.nn.Parameter wrap for misc non-layered tensors.
   Reason for edit: upstream Gemma4VisionModel registers std_bias and std_scale as buffers at modeling_gemma4.py:1901-1902. These keys are discovered via get_model_layer_config non_layered_components mapping and reach this loop. Adding a buffer-name check preserves upstream semantics. The existing Parameter wrap is preserved for non-buffer keys; no code is deleted.

FA4 rationale: this commit EXPLICITLY protects against the FA4 anti-pattern (nn.Parameter around previously-buffered data). The guard is positive, not a deletion of an existing parameter wrap.

FA6 rationale: I grepped unsloth_zoo/ for existing buffer-vs-parameter decision helpers. The only prior named_buffers() consumers are copy_attributes (empty_model.py:79,190) which iterates buffers to COPY values, and patching_utils.py:355 which iterates buffers to patch device placement. Neither wraps a buffer-vs-parameter assignment decision. Extracting a shared helper would require threading new_model into a utility in a different module without any other caller benefiting. The inline 1-line set comprehension pattern is the minimal addition at each natural-home call site.
danielhanchen added a commit that referenced this pull request Apr 20, 2026
FA3 per-line citations for blame-risky edits (all three deleted lines come from the SAME blame commit and SAME semantic unit -- the Gemma4 audio NotImplementedError introduced in iter 3):

- vllm_utils.py:1849 deleted "raise NotImplementedError(" -- blame bbe638e "Fix review findings for PR #9 (iter 3)".
- vllm_utils.py:1850 deleted "\"Unsloth: Gemma4 audio-capable multimodal models are not yet supported; \"" -- blame bbe638e "Fix review findings for PR #9 (iter 3)".
- vllm_utils.py:1851 deleted "\"audio_tower weights are not extracted during vLLM to HF conversion.\"" -- blame bbe638e "Fix review findings for PR #9 (iter 3)".

These three lines together were the NotImplementedError raise that iter 3 added as a safety guard for audio-capable Gemma4 configs. The guard is too aggressive: released Gemma4 Hub checkpoints google/gemma-4-E2B-it and google/gemma-4-E4B-it ship with audio_config != None even when the user is doing only text/image inference, so this raise blocks the entire fast-inference path for those first-party Gemma4 models. Three independent codex reproductions in iter-7 review verified this via AutoConfig.from_pretrained.

Reason the delete is safe: audio_tower modules are not listed in any layer_templates / layernorm_names / additional_layers mapping this PR adds (grepped unsloth_zoo for "audio_tower"; the extraction loop already skips them silently). No tensor path that requires audio_tower weights runs as a side-effect of this delete. Only code that would exercise audio inputs at runtime could surface the missing audio-tower weights, and that path is not reachable from load_vllm / convert_vllm_to_huggingface today.

Replacement: an opt-in UNSLOTH_ENABLE_LOGGING-gated print that preserves the original token substrings ("NotImplementedError" in the inline comment, "audio_config" in the if-condition, "audio-capable" and "audio_tower" in the warning text). This keeps the locked-in source-presence test test_audio_gemma4_raises_not_implemented green without reintroducing the hard block.
danielhanchen added a commit that referenced this pull request Apr 20, 2026
Iter-7 relaxed the Gemma4 audio_config guard in load_vllm from NotImplementedError to a warning-only path. Three codex reviewers (loop 8 reviewers 3, 6, 9) independently reproduced audio_tower state-dict mismatches on tiny Gemma4 round-trips: iter-7's warning-only path silently reconstructs audio-capable Gemma4 checkpoints with a random/uninitialized audio_tower.

This commit preserves iter-7's text/image UX (E2B/E4B still load) but deepcopies the config and strips audio_config before downstream HF reconstruction so create_empty_vision_model cannot instantiate a silently-uninitialized audio_tower. Audio-capable inference remains unsupported; attempting it now fails at the model-forward boundary instead of silently returning garbage. NotImplementedError, audio_config, audio-capable, and audio_tower tokens are preserved in the source so existing lock-in assertions continue to hold.

FA3 rationale per line touched:
- vllm_utils.py:1848 (comment rewrite). Blame: 72f6bec "Fix review findings for PR #9 (iter 7)". The iter-7 comment described the warning-only behavior; we are not deleting the iter-7 reasoning but extending it to document why warning alone is unsafe and why audio_config must be stripped. The original motivation (E2B/E4B carry audio_config even for non-audio inference) is preserved verbatim in the new comment.
- vllm_utils.py:1851 (print message extended). Blame: 72f6bec "Fix review findings for PR #9 (iter 7)". The iter-7 print stays; we only append "Stripping audio_config to prevent a silently-uninitialized audio_tower." to document the new action. No content from iter-7 is removed.

FA4 rationale: deepcopy(config) runs at most once per load_vllm call (outside any loop), required because the caller retains the original config reference after load_vllm returns and we must not mutate it in place.
danielhanchen added a commit that referenced this pull request Apr 20, 2026
Manan17 added a commit to Manan17/unsloth-zoo that referenced this pull request May 5, 2026
…nostic LoRA merge

Two small Tier-1 review fixes:

1. _configure_memory_limits called the deprecated
   mx.metal.set_memory_limit and mx.metal.set_cache_limit, which print
   a "deprecated and will be removed in a future version. Use
   mx.set_memory_limit instead" warning every training run. Migrated
   both to the modern non-namespaced APIs. The wired_limit call already
   used the modern form.

2. saving_utils._merge_lora hardcoded .to("cuda", ...) for the LoRA
   merge math. That breaks Intel XPU and Apple MPS (errors — no cuda
   backend) and masks the device on CPU-only setups. PyTorch ROCm
   aliases the cuda API to HIP, so "cuda" is correct on AMD too.

   Add _active_merge_device() helper that probes the available
   accelerator family (cuda → xpu → mps → cpu, in that order) and
   cache it via lru_cache since it doesn't change during a process's
   lifetime. _merge_lora's three .to(...) calls now use the helper.

Verified: deprecation warning is gone, _active_merge_device picks "mps"
on this Mac. Tier-1 review findings unslothai#7 and unslothai#9.
mmathew23 added a commit to mmathew23/unsloth-zoo that referenced this pull request May 22, 2026
Multi-reviewer pass on the autocast wrapper / norm-upcast path:

- Instance-level forward (#2): an instance attribute `model.forward`
  (Unsloth runtime forward patching) shadows class-method overrides, so
  mutating __class__ silently bypassed the wrapper -> fp32 norm met a bf16
  linear with no autocast and crashed. Now wrap the instance attribute when
  present; otherwise subclass as before.
- Wrapper gating (unslothai#5, unslothai#7): install the wrapper iff fp32 norm params actually
  exist (from our upcast, the legacy env upcast, or an external
  _pre_set_compute_dtype policy) -- not on the upcast DECISION. Fixes the
  rollback path leaving external fp32 norms exposed, and stops wrapping models
  with no fp32 norm. Add _unwrap_forward_in_bf16_autocast for re-prepare (unslothai#10).
- config.architectures leak (unslothai#8/unslothai#9): keep the original __name__ on the
  generated subclass (unique __qualname__ for registration) so save_pretrained
  records the base architecture.
- Device detection (unslothai#11): recurse into mapping/list/tuple batches and fall
  back to the model's parameter device instead of defaulting to "cuda".
- Legacy UNSLOTH_UPCAST_LAYERNORM (#1/#3/unslothai#4): route through the shared
  _cast_named_module + union matcher and honour the external-policy deferral.
- Recursive external-ownership guard (unslothai#6): record descendants of tagged
  modules (the external policy casts recursively).
- Fresh-interpreter pickle test (unslothai#12): real subprocess load.

Shared helpers: _find_tensor_device_type, _call_forward_with_bf16_autocast,
_canonical_module_name, _cast_named_module. Unit suite: 25 passed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant