Skip to content

[Bug] Fine-tuning unsloth/Qwen3-4B-Instruct-2507fails when loading with 8bit #3501

@mapa17

Description

@mapa17

Running a notebook in the official docker imgage with the sample code (based on the unsloth notebook Qwen3_(4B)-Instruct) below will fail on my system.

==((====))==  Unsloth 2025.9.11: Fast Qwen3 patching. Transformers: 4.56.2. vLLM: 0.10.2.
   \\   /|    NVIDIA GeForce RTX 3060. Num GPUs = 1. Max memory: 11.631 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.8.0+cu128. CUDA: 8.6. CUDA Toolkit: 12.8. Triton: 3.4.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.33+5146f2a.d20251002. FA2 = True]
 "-____-"     Free license: http://github.com/unslothai/unsloth
from unsloth import FastLanguageModel
from datasets import Dataset, load_dataset
from unsloth.chat_templates import get_chat_template
from unsloth.chat_templates import standardize_data_formats

import torch

fourbit_models = [
    "unsloth/Qwen3-4B-Instruct-2507-unsloth-bnb-4bit", # Qwen 14B 2x faster
    "unsloth/Qwen3-4B-Thinking-2507-unsloth-bnb-4bit",
    "unsloth/Qwen3-8B-unsloth-bnb-4bit",
    "unsloth/Qwen3-14B-unsloth-bnb-4bit",
    "unsloth/Qwen3-32B-unsloth-bnb-4bit",

    # 4bit dynamic quants for superior accuracy and low memory use
    "unsloth/gemma-3-12b-it-unsloth-bnb-4bit",
    "unsloth/Phi-4",
    "unsloth/Llama-3.1-8B",
    "unsloth/Llama-3.2-3B",
    "unsloth/orpheus-3b-0.1-ft-unsloth-bnb-4bit" # [NEW] We support TTS models!
] # More models at https://huggingface.co/unsloth

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Qwen3-4B-Instruct-2507",
    max_seq_length = 4096, # Choose any for long context!
    load_in_4bit = False,  # 4 bit quantization to reduce memory
    load_in_8bit = True, # [NEW!] A bit more accurate, uses 2x memory
    load_in_16bit = False,
    full_finetuning = False, # [NEW!] We have full finetuning now!
    # token = "hf_...", # use one if using gated models
    cache_dir="./hf_cache"
)

model = FastLanguageModel.get_peft_model(
    model,
    r = 32, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 32,
    lora_dropout = 0, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimized
    # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
    random_state = 3407,
    use_rslora = False,  # We support rank stabilized LoRA
    loftq_config = None, # And LoftQ
)


tokenizer = get_chat_template(
    tokenizer,
    chat_template = "qwen3-instruct",
)

dataset = load_dataset("mlabonne/FineTome-100k", split = "train")


dataset = standardize_data_formats(dataset)

def formatting_prompts_func(examples):
   convos = examples["conversations"]
   texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for convo in convos]
   return { "text" : texts, }

dataset = dataset.map(formatting_prompts_func, batched = True)
split_dataset = dataset.train_test_split(test_size = 0.01, shuffle = True, seed = 3407)

print(f"{split_dataset=}")

from trl import SFTTrainer, SFTConfig
trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = split_dataset['train'],
    eval_dataset = split_dataset['test'], # Can set up evaluation!
    args = SFTConfig(
        dataset_text_field = "text",
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 8, # Use GA to mimic batch size!
        warmup_steps = 5,
        num_train_epochs = 1, # Set this for 1 full training run.
        #max_steps = 60,
        learning_rate = 2e-5, # Reduce to 2e-5 for long training runs
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        report_to = "none", # Use this for WandB etc
        fp16_full_eval = True,
        per_device_eval_batch_size = 2,
        eval_accumulation_steps = 8,
        eval_strategy = "steps",
        eval_steps = 20
    ),
)

trainer_stats = trainer.train()

Producing the error trace

The tokenizer has new PAD[/BOS/EOS](http://localhost:8888/BOS/EOS) tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None}.
==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   [/](http://localhost:8888/)|    Num examples = 99,000 | Num Epochs = 1 | Total steps = 6,188
O^O[/](http://localhost:8888/) \_[/](http://localhost:8888/) \    Batch size per device = 2 | Gradient accumulation steps = 8
\        [/](http://localhost:8888/)    Data Parallel GPUs = 1 | Total batch size (2 x 8 x 1) = 16
 "-____-"     Trainable parameters = 66,060,288 of 4,088,528,384 (1.62% trained)
[/opt/conda/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py:181](http://localhost:8888/opt/conda/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py#line=180): UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
[/opt/conda/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py:181](http://localhost:8888/opt/conda/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py#line=180): UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
[/opt/conda/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py:181](http://localhost:8888/opt/conda/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py#line=180): UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
[/opt/conda/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py:181](http://localhost:8888/opt/conda/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py#line=180): UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
[/opt/conda/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py:181](http://localhost:8888/opt/conda/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py#line=180): UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
[/opt/conda/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py:181](http://localhost:8888/opt/conda/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py#line=180): UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
[/opt/conda/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py:181](http://localhost:8888/opt/conda/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py#line=180): UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[4], line 29
      1 from trl import SFTTrainer, SFTConfig
      2 trainer = SFTTrainer(
      3     model = model,
      4     tokenizer = tokenizer,
   (...)     26     ),
     27 )
---> 29 trainer_stats = trainer.train()

File [/workspace/work/unsloth_compiled_cache/UnslothSFTTrainer.py:53](http://localhost:8888/lab/tree/work/work/unsloth_compiled_cache/UnslothSFTTrainer.py#line=52), in prepare_for_training_mode.<locals>.wrapper(self, *args, **kwargs)
     51 if hasattr(self, 'model') and hasattr(self.model, "for_training"):
     52     self.model.for_training()
---> 53 output = f(self, *args, **kwargs)
     54 # Return inference mode
     55 if hasattr(self, 'model') and hasattr(self.model, "for_inference"):

File [/opt/conda/lib/python3.11/site-packages/transformers/trainer.py:2328](http://localhost:8888/opt/conda/lib/python3.11/site-packages/transformers/trainer.py#line=2327), in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   2326         hf_hub_utils.enable_progress_bars()
   2327 else:
-> 2328     return inner_training_loop(
   2329         args=args,
   2330         resume_from_checkpoint=resume_from_checkpoint,
   2331         trial=trial,
   2332         ignore_keys_for_eval=ignore_keys_for_eval,
   2333     )

File <string>:323, in _fast_inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)

File [/workspace/work/unsloth_compiled_cache/UnslothSFTTrainer.py:1040](http://localhost:8888/lab/tree/work/work/unsloth_compiled_cache/UnslothSFTTrainer.py#line=1039), in _UnslothSFTTrainer.training_step(self, *args, **kwargs)
   1038 def training_step(self, *args, **kwargs):
   1039     with self.maybe_activation_offload_context:
-> 1040         return super().training_step(*args, **kwargs)

File <string>:91, in _unsloth_training_step(***failed resolving arguments***)

File [/opt/conda/lib/python3.11/site-packages/accelerate/accelerator.py:2734](http://localhost:8888/opt/conda/lib/python3.11/site-packages/accelerate/accelerator.py#line=2733), in Accelerator.backward(self, loss, **kwargs)
   2732     self.lomo_backward(loss, learning_rate)
   2733 else:
-> 2734     loss.backward(**kwargs)

File [/opt/conda/lib/python3.11/site-packages/torch/_tensor.py:647](http://localhost:8888/opt/conda/lib/python3.11/site-packages/torch/_tensor.py#line=646), in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    637 if has_torch_function_unary(self):
    638     return handle_torch_function(
    639         Tensor.backward,
    640         (self,),
   (...)    645         inputs=inputs,
    646     )
--> 647 torch.autograd.backward(
    648     self, gradient, retain_graph, create_graph, inputs=inputs
    649 )

File [/opt/conda/lib/python3.11/site-packages/torch/autograd/__init__.py:354](http://localhost:8888/opt/conda/lib/python3.11/site-packages/torch/autograd/__init__.py#line=353), in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    349     retain_graph = create_graph
    351 # The reason we repeat the same comment below is that
    352 # some Python versions print out the first line of a multi-line function
    353 # calls in the traceback and some print out the last line
--> 354 _engine_run_backward(
    355     tensors,
    356     grad_tensors_,
    357     retain_graph,
    358     create_graph,
    359     inputs_tuple,
    360     allow_unreachable=True,
    361     accumulate_grad=True,
    362 )

File [/opt/conda/lib/python3.11/site-packages/torch/autograd/graph.py:829](http://localhost:8888/opt/conda/lib/python3.11/site-packages/torch/autograd/graph.py#line=828), in _engine_run_backward(t_outputs, *args, **kwargs)
    827     unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
    828 try:
--> 829     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    830         t_outputs, *args, **kwargs
    831     )  # Calls into the C++ engine to run the backward pass
    832 finally:
    833     if attach_logging_hooks:

File [/opt/conda/lib/python3.11/site-packages/torch/autograd/function.py:311](http://localhost:8888/opt/conda/lib/python3.11/site-packages/torch/autograd/function.py#line=310), in BackwardCFunction.apply(self, *args)
    305     raise RuntimeError(
    306         "Implementing both 'backward' and 'vjp' for a custom "
    307         "Function is not allowed. You should only implement one "
    308         "of them."
    309     )
    310 user_fn = vjp_fn if vjp_fn is not Function.vjp else backward_fn
--> 311 return user_fn(self, *args)

File [/opt/conda/lib/python3.11/site-packages/unsloth_zoo/gradient_checkpointing.py:568](http://localhost:8888/opt/conda/lib/python3.11/site-packages/unsloth_zoo/gradient_checkpointing.py#line=567), in UnslothCheckpointFunction.backward(ctx, *args)
    565     pass
    567     with torch.enable_grad(), device_autocast_ctx, torch.amp.autocast("cpu", **ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
--> 568         outputs = ctx.run_function(*detached_inputs)
    569     pass
    570 pass

File [/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1773](http://localhost:8888/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py#line=1772), in Module._wrapped_call_impl(self, *args, **kwargs)
   1771     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1772 else:
-> 1773     return self._call_impl(*args, **kwargs)

File [/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1784](http://localhost:8888/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py#line=1783), in Module._call_impl(self, *args, **kwargs)
   1779 # If we don't have any hooks, we want to skip the rest of the logic in
   1780 # this function, and just call forward.
   1781 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1782         or _global_backward_pre_hooks or _global_backward_hooks
   1783         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1784     return forward_call(*args, **kwargs)
   1786 result = None
   1787 called_always_called_hooks = set()

File [/opt/conda/lib/python3.11/site-packages/transformers/utils/deprecation.py:172](http://localhost:8888/opt/conda/lib/python3.11/site-packages/transformers/utils/deprecation.py#line=171), in deprecate_kwarg.<locals>.wrapper.<locals>.wrapped_func(*args, **kwargs)
    168 elif minimum_action in (Action.NOTIFY, Action.NOTIFY_ALWAYS) and not is_torchdynamo_compiling():
    169     # DeprecationWarning is ignored by default, so we use FutureWarning instead
    170     warnings.warn(message, FutureWarning, stacklevel=2)
--> 172 return func(*args, **kwargs)

File [/opt/conda/lib/python3.11/site-packages/transformers/models/qwen3/modeling_qwen3.py:275](http://localhost:8888/opt/conda/lib/python3.11/site-packages/transformers/models/qwen3/modeling_qwen3.py#line=274), in Qwen3DecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_values, use_cache, cache_position, position_embeddings, **kwargs)
    273 residual = hidden_states
    274 hidden_states = self.post_attention_layernorm(hidden_states)
--> 275 hidden_states = self.mlp(hidden_states)
    276 hidden_states = residual + hidden_states
    277 return hidden_states

File [/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1773](http://localhost:8888/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py#line=1772), in Module._wrapped_call_impl(self, *args, **kwargs)
   1771     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1772 else:
-> 1773     return self._call_impl(*args, **kwargs)

File [/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1784](http://localhost:8888/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py#line=1783), in Module._call_impl(self, *args, **kwargs)
   1779 # If we don't have any hooks, we want to skip the rest of the logic in
   1780 # this function, and just call forward.
   1781 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1782         or _global_backward_pre_hooks or _global_backward_hooks
   1783         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1784     return forward_call(*args, **kwargs)
   1786 result = None
   1787 called_always_called_hooks = set()

File [/workspace/work/unsloth_compiled_cache/unsloth_compiled_module_qwen3.py:210](http://localhost:8888/lab/tree/work/work/unsloth_compiled_cache/unsloth_compiled_module_qwen3.py#line=209), in Qwen3MLP.forward(self, x)
    209 def forward(self, x):
--> 210     return Qwen3MLP_forward(self, x)

File [/opt/conda/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py:736](http://localhost:8888/opt/conda/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py#line=735), in _TorchDynamoContext.__call__.<locals>.compile_wrapper(*args, **kwargs)
    733 _maybe_set_eval_frame(_callback_from_stance(callback))
    735 try:
--> 736     return fn(*args, **kwargs)
    737 except Unsupported as e:
    738     if config.verbose:

File [/workspace/work/unsloth_compiled_cache/unsloth_compiled_module_qwen3.py:195](http://localhost:8888/lab/tree/work/work/unsloth_compiled_cache/unsloth_compiled_module_qwen3.py#line=194), in Qwen3MLP_forward(self, x)
    193 @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options)
    194 def Qwen3MLP_forward(self, x):
--> 195     down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
    196     return down_proj

File [/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1773](http://localhost:8888/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py#line=1772), in Module._wrapped_call_impl(self, *args, **kwargs)
   1771     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1772 else:
-> 1773     return self._call_impl(*args, **kwargs)

File [/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1784](http://localhost:8888/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py#line=1783), in Module._call_impl(self, *args, **kwargs)
   1779 # If we don't have any hooks, we want to skip the rest of the logic in
   1780 # this function, and just call forward.
   1781 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1782         or _global_backward_pre_hooks or _global_backward_hooks
   1783         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1784     return forward_call(*args, **kwargs)
   1786 result = None
   1787 called_always_called_hooks = set()

File [/workspace/work/unsloth_compiled_cache/Linear8bitLt_peft_forward.py:73](http://localhost:8888/lab/tree/work/work/unsloth_compiled_cache/Linear8bitLt_peft_forward.py#line=72), in unsloth_forward(self, x, *args, **kwargs)
     71     result = self.base_layer(x, *args, **kwargs)
     72 else:
---> 73     result = self.base_layer(x, *args, **kwargs)
     74     for active_adapter in self.active_adapters:
     75         if active_adapter not in self.lora_A.keys():

File [/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1773](http://localhost:8888/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py#line=1772), in Module._wrapped_call_impl(self, *args, **kwargs)
   1771     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1772 else:
-> 1773     return self._call_impl(*args, **kwargs)

File [/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1784](http://localhost:8888/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py#line=1783), in Module._call_impl(self, *args, **kwargs)
   1779 # If we don't have any hooks, we want to skip the rest of the logic in
   1780 # this function, and just call forward.
   1781 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1782         or _global_backward_pre_hooks or _global_backward_hooks
   1783         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1784     return forward_call(*args, **kwargs)
   1786 result = None
   1787 called_always_called_hooks = set()

File [/opt/conda/lib/python3.11/site-packages/bitsandbytes/nn/modules.py:1072](http://localhost:8888/opt/conda/lib/python3.11/site-packages/bitsandbytes/nn/modules.py#line=1071), in Linear8bitLt.forward(self, x)
   1069 if self.bias is not None and self.bias.dtype != x.dtype:
   1070     self.bias.data = self.bias.data.to(x.dtype)
-> 1072 out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
   1074 if not self.state.has_fp16_weights and self.state.CB is not None:
   1075     self.weight.data = self.state.CB

File [/opt/conda/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py:424](http://localhost:8888/opt/conda/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py#line=423), in matmul(A, B, out, state, threshold, bias)
    422     if A.device.type in ("cpu", "xpu"):
    423         return MatMul8bitFp.apply(A, B, out, bias, state)
--> 424 return MatMul8bitLt.apply(A, B, out, bias, state)

File [/opt/conda/lib/python3.11/site-packages/torch/autograd/function.py:576](http://localhost:8888/opt/conda/lib/python3.11/site-packages/torch/autograd/function.py#line=575), in Function.apply(cls, *args, **kwargs)
    573 if not torch._C._are_functorch_transforms_active():
    574     # See NOTE: [functorch vjp and autograd interaction]
    575     args = _functorch.utils.unwrap_dead_wrappers(args)
--> 576     return super().apply(*args, **kwargs)  # type: ignore[misc]
    578 if not is_setup_ctx_defined:
    579     raise RuntimeError(
    580         "In order to use an autograd.Function with functorch transforms "
    581         "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
    582         "staticmethod. For more details, please see "
    583         "https://pytorch.org/docs/main/notes/extending.func.html"
    584     )

File [/opt/conda/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py:154](http://localhost:8888/opt/conda/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py#line=153), in MatMul8bitLt.forward(ctx, A, B, out, bias, state)
    153 class MatMul8bitLt(torch.autograd.Function):
--> 154     @staticmethod
    155     def forward(
    156         ctx: torch.autograd.function.FunctionCtx,
    157         A: torch.Tensor,
    158         B: torch.Tensor,
    159         out: Optional[torch.Tensor] = None,
    160         bias: Optional[torch.Tensor] = None,
    161         state: Optional[MatmulLtState] = None,
    162     ):
    163         state = state or MatmulLtState()
    165         # default of pytorch behavior if inputs are empty

File [/opt/conda/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py:929](http://localhost:8888/opt/conda/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py#line=928), in DisableContext.__call__.<locals>._fn(*args, **kwargs)
    927 _maybe_set_eval_frame(_callback_from_stance(self.callback))
    928 try:
--> 929     return fn(*args, **kwargs)
    930 finally:
    931     set_eval_frame(None)

File [/opt/conda/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py:1241](http://localhost:8888/opt/conda/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py#line=1240), in aot_module_simplified.<locals>.forward(*runtime_args)
   1239 full_args.extend(params_flat)
   1240 full_args.extend(runtime_args)
-> 1241 return compiled_fn(full_args)

File [/opt/conda/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py:384](http://localhost:8888/opt/conda/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py#line=383), in _create_runtime_wrapper.<locals>.runtime_wrapper(args)
    382         torch._C._set_grad_enabled(False)
    383     record_runtime_wrapper_prologue_exit(cm)
--> 384     all_outs = call_func_at_runtime_with_args(
    385         compiled_fn, args, disable_amp=disable_amp, steal_args=True
    386     )
    387 finally:
    388     if grad_enabled:

File [/opt/conda/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/utils.py:126](http://localhost:8888/opt/conda/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/utils.py#line=125), in call_func_at_runtime_with_args(f, args, steal_args, disable_amp)
    124 with context():
    125     if getattr(f, "_boxed_call", False):
--> 126         out = normalize_as_list(f(args))
    127     else:
    128         # TODO: Please remove soon
    129         # https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670
    130         warnings.warn(
    131             "Your compiler for AOTAutograd is returning a function that doesn't take boxed arguments. "
    132             "Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. "
    133             "See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale."
    134         )

File [/opt/conda/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py:750](http://localhost:8888/opt/conda/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py#line=749), in EffectTokensWrapper.post_compile.<locals>.inner_fn(args)
    747     args = [*([None] * num_tokens), *args]
    748     old_args.clear()
--> 750 outs = compiled_fn(args)
    752 # Inductor cache DummyModule can return None
    753 if outs is None:

File [/opt/conda/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py:556](http://localhost:8888/opt/conda/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py#line=555), in FunctionalizedRngRuntimeWrapper.post_compile.<locals>.wrapper(runtime_args)
    549     out = self._functionalized_rng_runtime_epilogue(
    550         runtime_metadata,
    551         out,
    552         # TODO: this won't be right for the backward when we convert the call_compiled_backward to use the wrapper
    553         runtime_metadata.num_forward_returns,
    554     )
    555     return out
--> 556 return compiled_fn(runtime_args)

File [/opt/conda/lib/python3.11/site-packages/torch/_inductor/output_code.py:584](http://localhost:8888/opt/conda/lib/python3.11/site-packages/torch/_inductor/output_code.py#line=583), in CompiledFxGraph.__call__(self, inputs)
    582 assert self.current_callable is not None
    583 try:
--> 584     return self.current_callable(inputs)
    585 finally:
    586     get_runtime_metrics_context().finish()

File [/opt/conda/lib/python3.11/site-packages/torch/_inductor/utils.py:2716](http://localhost:8888/opt/conda/lib/python3.11/site-packages/torch/_inductor/utils.py#line=2715), in align_inputs_from_check_idxs.<locals>.run(new_inputs)
   2712 def run(new_inputs: list[InputType]) -> Any:
   2713     old_tensors, new_tensors = copy_misaligned_inputs(
   2714         new_inputs, inputs_to_check, mutated_input_idxs
   2715     )
-> 2716     out = model(new_inputs)
   2718     # If a mutated tensor was cloned to be aligned, we need to reflect back the mutation to the
   2719     # original tensor.
   2720     if len(old_tensors):

File [/tmp/torchinductor_unsloth/67/c67ekh5dv3tzbugvqdncd2xm7vt3km63f42igb2w4jybfr4uccul.py:123](http://localhost:8888/tmp/torchinductor_unsloth/67/c67ekh5dv3tzbugvqdncd2xm7vt3km63f42igb2w4jybfr4uccul.py#line=122), in call(args)
    121     assert_alignment(buf7, 16, 'torch.ops.bitsandbytes.int8_mixed_scaled_mm.default')
    122     buf8 = buf6[1]
--> 123     assert_size_stride(buf8, (u1, ), (1, ), 'torch.ops.bitsandbytes.int8_mixed_scaled_mm.default')
    124     assert_alignment(buf8, 16, 'torch.ops.bitsandbytes.int8_mixed_scaled_mm.default')
    125 if not (u1 >= 0):

AssertionError: wrong number of dimensions2 for op: torch.ops.bitsandbytes.int8_mixed_scaled_mm.default

🦥 You can also ask via our Reddit page: https://www.reddit.com/r/unsloth/

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions