Running a notebook in the official docker imgage with the sample code (based on the unsloth notebook Qwen3_(4B)-Instruct) below will fail on my system.
from unsloth import FastLanguageModel
from datasets import Dataset, load_dataset
from unsloth.chat_templates import get_chat_template
from unsloth.chat_templates import standardize_data_formats
import torch
fourbit_models = [
"unsloth/Qwen3-4B-Instruct-2507-unsloth-bnb-4bit", # Qwen 14B 2x faster
"unsloth/Qwen3-4B-Thinking-2507-unsloth-bnb-4bit",
"unsloth/Qwen3-8B-unsloth-bnb-4bit",
"unsloth/Qwen3-14B-unsloth-bnb-4bit",
"unsloth/Qwen3-32B-unsloth-bnb-4bit",
# 4bit dynamic quants for superior accuracy and low memory use
"unsloth/gemma-3-12b-it-unsloth-bnb-4bit",
"unsloth/Phi-4",
"unsloth/Llama-3.1-8B",
"unsloth/Llama-3.2-3B",
"unsloth/orpheus-3b-0.1-ft-unsloth-bnb-4bit" # [NEW] We support TTS models!
] # More models at https://huggingface.co/unsloth
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "unsloth/Qwen3-4B-Instruct-2507",
max_seq_length = 4096, # Choose any for long context!
load_in_4bit = False, # 4 bit quantization to reduce memory
load_in_8bit = True, # [NEW!] A bit more accurate, uses 2x memory
load_in_16bit = False,
full_finetuning = False, # [NEW!] We have full finetuning now!
# token = "hf_...", # use one if using gated models
cache_dir="./hf_cache"
)
model = FastLanguageModel.get_peft_model(
model,
r = 32, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",],
lora_alpha = 32,
lora_dropout = 0, # Supports any, but = 0 is optimized
bias = "none", # Supports any, but = "none" is optimized
# [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
random_state = 3407,
use_rslora = False, # We support rank stabilized LoRA
loftq_config = None, # And LoftQ
)
tokenizer = get_chat_template(
tokenizer,
chat_template = "qwen3-instruct",
)
dataset = load_dataset("mlabonne/FineTome-100k", split = "train")
dataset = standardize_data_formats(dataset)
def formatting_prompts_func(examples):
convos = examples["conversations"]
texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for convo in convos]
return { "text" : texts, }
dataset = dataset.map(formatting_prompts_func, batched = True)
split_dataset = dataset.train_test_split(test_size = 0.01, shuffle = True, seed = 3407)
print(f"{split_dataset=}")
from trl import SFTTrainer, SFTConfig
trainer = SFTTrainer(
model = model,
tokenizer = tokenizer,
train_dataset = split_dataset['train'],
eval_dataset = split_dataset['test'], # Can set up evaluation!
args = SFTConfig(
dataset_text_field = "text",
per_device_train_batch_size = 2,
gradient_accumulation_steps = 8, # Use GA to mimic batch size!
warmup_steps = 5,
num_train_epochs = 1, # Set this for 1 full training run.
#max_steps = 60,
learning_rate = 2e-5, # Reduce to 2e-5 for long training runs
logging_steps = 1,
optim = "adamw_8bit",
weight_decay = 0.01,
lr_scheduler_type = "linear",
seed = 3407,
report_to = "none", # Use this for WandB etc
fp16_full_eval = True,
per_device_eval_batch_size = 2,
eval_accumulation_steps = 8,
eval_strategy = "steps",
eval_steps = 20
),
)
trainer_stats = trainer.train()
The tokenizer has new PAD[/BOS/EOS](http://localhost:8888/BOS/EOS) tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None}.
==((====))== Unsloth - 2x faster free finetuning | Num GPUs used = 1
\\ [/](http://localhost:8888/)| Num examples = 99,000 | Num Epochs = 1 | Total steps = 6,188
O^O[/](http://localhost:8888/) \_[/](http://localhost:8888/) \ Batch size per device = 2 | Gradient accumulation steps = 8
\ [/](http://localhost:8888/) Data Parallel GPUs = 1 | Total batch size (2 x 8 x 1) = 16
"-____-" Trainable parameters = 66,060,288 of 4,088,528,384 (1.62% trained)
[/opt/conda/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py:181](http://localhost:8888/opt/conda/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py#line=180): UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
[/opt/conda/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py:181](http://localhost:8888/opt/conda/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py#line=180): UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
[/opt/conda/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py:181](http://localhost:8888/opt/conda/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py#line=180): UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
[/opt/conda/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py:181](http://localhost:8888/opt/conda/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py#line=180): UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
[/opt/conda/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py:181](http://localhost:8888/opt/conda/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py#line=180): UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
[/opt/conda/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py:181](http://localhost:8888/opt/conda/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py#line=180): UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
[/opt/conda/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py:181](http://localhost:8888/opt/conda/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py#line=180): UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
Cell In[4], line 29
1 from trl import SFTTrainer, SFTConfig
2 trainer = SFTTrainer(
3 model = model,
4 tokenizer = tokenizer,
(...) 26 ),
27 )
---> 29 trainer_stats = trainer.train()
File [/workspace/work/unsloth_compiled_cache/UnslothSFTTrainer.py:53](http://localhost:8888/lab/tree/work/work/unsloth_compiled_cache/UnslothSFTTrainer.py#line=52), in prepare_for_training_mode.<locals>.wrapper(self, *args, **kwargs)
51 if hasattr(self, 'model') and hasattr(self.model, "for_training"):
52 self.model.for_training()
---> 53 output = f(self, *args, **kwargs)
54 # Return inference mode
55 if hasattr(self, 'model') and hasattr(self.model, "for_inference"):
File [/opt/conda/lib/python3.11/site-packages/transformers/trainer.py:2328](http://localhost:8888/opt/conda/lib/python3.11/site-packages/transformers/trainer.py#line=2327), in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
2326 hf_hub_utils.enable_progress_bars()
2327 else:
-> 2328 return inner_training_loop(
2329 args=args,
2330 resume_from_checkpoint=resume_from_checkpoint,
2331 trial=trial,
2332 ignore_keys_for_eval=ignore_keys_for_eval,
2333 )
File <string>:323, in _fast_inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
File [/workspace/work/unsloth_compiled_cache/UnslothSFTTrainer.py:1040](http://localhost:8888/lab/tree/work/work/unsloth_compiled_cache/UnslothSFTTrainer.py#line=1039), in _UnslothSFTTrainer.training_step(self, *args, **kwargs)
1038 def training_step(self, *args, **kwargs):
1039 with self.maybe_activation_offload_context:
-> 1040 return super().training_step(*args, **kwargs)
File <string>:91, in _unsloth_training_step(***failed resolving arguments***)
File [/opt/conda/lib/python3.11/site-packages/accelerate/accelerator.py:2734](http://localhost:8888/opt/conda/lib/python3.11/site-packages/accelerate/accelerator.py#line=2733), in Accelerator.backward(self, loss, **kwargs)
2732 self.lomo_backward(loss, learning_rate)
2733 else:
-> 2734 loss.backward(**kwargs)
File [/opt/conda/lib/python3.11/site-packages/torch/_tensor.py:647](http://localhost:8888/opt/conda/lib/python3.11/site-packages/torch/_tensor.py#line=646), in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
637 if has_torch_function_unary(self):
638 return handle_torch_function(
639 Tensor.backward,
640 (self,),
(...) 645 inputs=inputs,
646 )
--> 647 torch.autograd.backward(
648 self, gradient, retain_graph, create_graph, inputs=inputs
649 )
File [/opt/conda/lib/python3.11/site-packages/torch/autograd/__init__.py:354](http://localhost:8888/opt/conda/lib/python3.11/site-packages/torch/autograd/__init__.py#line=353), in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
349 retain_graph = create_graph
351 # The reason we repeat the same comment below is that
352 # some Python versions print out the first line of a multi-line function
353 # calls in the traceback and some print out the last line
--> 354 _engine_run_backward(
355 tensors,
356 grad_tensors_,
357 retain_graph,
358 create_graph,
359 inputs_tuple,
360 allow_unreachable=True,
361 accumulate_grad=True,
362 )
File [/opt/conda/lib/python3.11/site-packages/torch/autograd/graph.py:829](http://localhost:8888/opt/conda/lib/python3.11/site-packages/torch/autograd/graph.py#line=828), in _engine_run_backward(t_outputs, *args, **kwargs)
827 unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
828 try:
--> 829 return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
830 t_outputs, *args, **kwargs
831 ) # Calls into the C++ engine to run the backward pass
832 finally:
833 if attach_logging_hooks:
File [/opt/conda/lib/python3.11/site-packages/torch/autograd/function.py:311](http://localhost:8888/opt/conda/lib/python3.11/site-packages/torch/autograd/function.py#line=310), in BackwardCFunction.apply(self, *args)
305 raise RuntimeError(
306 "Implementing both 'backward' and 'vjp' for a custom "
307 "Function is not allowed. You should only implement one "
308 "of them."
309 )
310 user_fn = vjp_fn if vjp_fn is not Function.vjp else backward_fn
--> 311 return user_fn(self, *args)
File [/opt/conda/lib/python3.11/site-packages/unsloth_zoo/gradient_checkpointing.py:568](http://localhost:8888/opt/conda/lib/python3.11/site-packages/unsloth_zoo/gradient_checkpointing.py#line=567), in UnslothCheckpointFunction.backward(ctx, *args)
565 pass
567 with torch.enable_grad(), device_autocast_ctx, torch.amp.autocast("cpu", **ctx.cpu_autocast_kwargs): # type: ignore[attr-defined]
--> 568 outputs = ctx.run_function(*detached_inputs)
569 pass
570 pass
File [/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1773](http://localhost:8888/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py#line=1772), in Module._wrapped_call_impl(self, *args, **kwargs)
1771 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1772 else:
-> 1773 return self._call_impl(*args, **kwargs)
File [/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1784](http://localhost:8888/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py#line=1783), in Module._call_impl(self, *args, **kwargs)
1779 # If we don't have any hooks, we want to skip the rest of the logic in
1780 # this function, and just call forward.
1781 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1782 or _global_backward_pre_hooks or _global_backward_hooks
1783 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1784 return forward_call(*args, **kwargs)
1786 result = None
1787 called_always_called_hooks = set()
File [/opt/conda/lib/python3.11/site-packages/transformers/utils/deprecation.py:172](http://localhost:8888/opt/conda/lib/python3.11/site-packages/transformers/utils/deprecation.py#line=171), in deprecate_kwarg.<locals>.wrapper.<locals>.wrapped_func(*args, **kwargs)
168 elif minimum_action in (Action.NOTIFY, Action.NOTIFY_ALWAYS) and not is_torchdynamo_compiling():
169 # DeprecationWarning is ignored by default, so we use FutureWarning instead
170 warnings.warn(message, FutureWarning, stacklevel=2)
--> 172 return func(*args, **kwargs)
File [/opt/conda/lib/python3.11/site-packages/transformers/models/qwen3/modeling_qwen3.py:275](http://localhost:8888/opt/conda/lib/python3.11/site-packages/transformers/models/qwen3/modeling_qwen3.py#line=274), in Qwen3DecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_values, use_cache, cache_position, position_embeddings, **kwargs)
273 residual = hidden_states
274 hidden_states = self.post_attention_layernorm(hidden_states)
--> 275 hidden_states = self.mlp(hidden_states)
276 hidden_states = residual + hidden_states
277 return hidden_states
File [/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1773](http://localhost:8888/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py#line=1772), in Module._wrapped_call_impl(self, *args, **kwargs)
1771 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1772 else:
-> 1773 return self._call_impl(*args, **kwargs)
File [/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1784](http://localhost:8888/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py#line=1783), in Module._call_impl(self, *args, **kwargs)
1779 # If we don't have any hooks, we want to skip the rest of the logic in
1780 # this function, and just call forward.
1781 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1782 or _global_backward_pre_hooks or _global_backward_hooks
1783 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1784 return forward_call(*args, **kwargs)
1786 result = None
1787 called_always_called_hooks = set()
File [/workspace/work/unsloth_compiled_cache/unsloth_compiled_module_qwen3.py:210](http://localhost:8888/lab/tree/work/work/unsloth_compiled_cache/unsloth_compiled_module_qwen3.py#line=209), in Qwen3MLP.forward(self, x)
209 def forward(self, x):
--> 210 return Qwen3MLP_forward(self, x)
File [/opt/conda/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py:736](http://localhost:8888/opt/conda/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py#line=735), in _TorchDynamoContext.__call__.<locals>.compile_wrapper(*args, **kwargs)
733 _maybe_set_eval_frame(_callback_from_stance(callback))
735 try:
--> 736 return fn(*args, **kwargs)
737 except Unsupported as e:
738 if config.verbose:
File [/workspace/work/unsloth_compiled_cache/unsloth_compiled_module_qwen3.py:195](http://localhost:8888/lab/tree/work/work/unsloth_compiled_cache/unsloth_compiled_module_qwen3.py#line=194), in Qwen3MLP_forward(self, x)
193 @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options)
194 def Qwen3MLP_forward(self, x):
--> 195 down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
196 return down_proj
File [/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1773](http://localhost:8888/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py#line=1772), in Module._wrapped_call_impl(self, *args, **kwargs)
1771 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1772 else:
-> 1773 return self._call_impl(*args, **kwargs)
File [/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1784](http://localhost:8888/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py#line=1783), in Module._call_impl(self, *args, **kwargs)
1779 # If we don't have any hooks, we want to skip the rest of the logic in
1780 # this function, and just call forward.
1781 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1782 or _global_backward_pre_hooks or _global_backward_hooks
1783 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1784 return forward_call(*args, **kwargs)
1786 result = None
1787 called_always_called_hooks = set()
File [/workspace/work/unsloth_compiled_cache/Linear8bitLt_peft_forward.py:73](http://localhost:8888/lab/tree/work/work/unsloth_compiled_cache/Linear8bitLt_peft_forward.py#line=72), in unsloth_forward(self, x, *args, **kwargs)
71 result = self.base_layer(x, *args, **kwargs)
72 else:
---> 73 result = self.base_layer(x, *args, **kwargs)
74 for active_adapter in self.active_adapters:
75 if active_adapter not in self.lora_A.keys():
File [/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1773](http://localhost:8888/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py#line=1772), in Module._wrapped_call_impl(self, *args, **kwargs)
1771 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1772 else:
-> 1773 return self._call_impl(*args, **kwargs)
File [/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1784](http://localhost:8888/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py#line=1783), in Module._call_impl(self, *args, **kwargs)
1779 # If we don't have any hooks, we want to skip the rest of the logic in
1780 # this function, and just call forward.
1781 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1782 or _global_backward_pre_hooks or _global_backward_hooks
1783 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1784 return forward_call(*args, **kwargs)
1786 result = None
1787 called_always_called_hooks = set()
File [/opt/conda/lib/python3.11/site-packages/bitsandbytes/nn/modules.py:1072](http://localhost:8888/opt/conda/lib/python3.11/site-packages/bitsandbytes/nn/modules.py#line=1071), in Linear8bitLt.forward(self, x)
1069 if self.bias is not None and self.bias.dtype != x.dtype:
1070 self.bias.data = self.bias.data.to(x.dtype)
-> 1072 out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
1074 if not self.state.has_fp16_weights and self.state.CB is not None:
1075 self.weight.data = self.state.CB
File [/opt/conda/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py:424](http://localhost:8888/opt/conda/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py#line=423), in matmul(A, B, out, state, threshold, bias)
422 if A.device.type in ("cpu", "xpu"):
423 return MatMul8bitFp.apply(A, B, out, bias, state)
--> 424 return MatMul8bitLt.apply(A, B, out, bias, state)
File [/opt/conda/lib/python3.11/site-packages/torch/autograd/function.py:576](http://localhost:8888/opt/conda/lib/python3.11/site-packages/torch/autograd/function.py#line=575), in Function.apply(cls, *args, **kwargs)
573 if not torch._C._are_functorch_transforms_active():
574 # See NOTE: [functorch vjp and autograd interaction]
575 args = _functorch.utils.unwrap_dead_wrappers(args)
--> 576 return super().apply(*args, **kwargs) # type: ignore[misc]
578 if not is_setup_ctx_defined:
579 raise RuntimeError(
580 "In order to use an autograd.Function with functorch transforms "
581 "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
582 "staticmethod. For more details, please see "
583 "https://pytorch.org/docs/main/notes/extending.func.html"
584 )
File [/opt/conda/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py:154](http://localhost:8888/opt/conda/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py#line=153), in MatMul8bitLt.forward(ctx, A, B, out, bias, state)
153 class MatMul8bitLt(torch.autograd.Function):
--> 154 @staticmethod
155 def forward(
156 ctx: torch.autograd.function.FunctionCtx,
157 A: torch.Tensor,
158 B: torch.Tensor,
159 out: Optional[torch.Tensor] = None,
160 bias: Optional[torch.Tensor] = None,
161 state: Optional[MatmulLtState] = None,
162 ):
163 state = state or MatmulLtState()
165 # default of pytorch behavior if inputs are empty
File [/opt/conda/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py:929](http://localhost:8888/opt/conda/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py#line=928), in DisableContext.__call__.<locals>._fn(*args, **kwargs)
927 _maybe_set_eval_frame(_callback_from_stance(self.callback))
928 try:
--> 929 return fn(*args, **kwargs)
930 finally:
931 set_eval_frame(None)
File [/opt/conda/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py:1241](http://localhost:8888/opt/conda/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py#line=1240), in aot_module_simplified.<locals>.forward(*runtime_args)
1239 full_args.extend(params_flat)
1240 full_args.extend(runtime_args)
-> 1241 return compiled_fn(full_args)
File [/opt/conda/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py:384](http://localhost:8888/opt/conda/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py#line=383), in _create_runtime_wrapper.<locals>.runtime_wrapper(args)
382 torch._C._set_grad_enabled(False)
383 record_runtime_wrapper_prologue_exit(cm)
--> 384 all_outs = call_func_at_runtime_with_args(
385 compiled_fn, args, disable_amp=disable_amp, steal_args=True
386 )
387 finally:
388 if grad_enabled:
File [/opt/conda/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/utils.py:126](http://localhost:8888/opt/conda/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/utils.py#line=125), in call_func_at_runtime_with_args(f, args, steal_args, disable_amp)
124 with context():
125 if getattr(f, "_boxed_call", False):
--> 126 out = normalize_as_list(f(args))
127 else:
128 # TODO: Please remove soon
129 # https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670
130 warnings.warn(
131 "Your compiler for AOTAutograd is returning a function that doesn't take boxed arguments. "
132 "Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. "
133 "See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale."
134 )
File [/opt/conda/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py:750](http://localhost:8888/opt/conda/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py#line=749), in EffectTokensWrapper.post_compile.<locals>.inner_fn(args)
747 args = [*([None] * num_tokens), *args]
748 old_args.clear()
--> 750 outs = compiled_fn(args)
752 # Inductor cache DummyModule can return None
753 if outs is None:
File [/opt/conda/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py:556](http://localhost:8888/opt/conda/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py#line=555), in FunctionalizedRngRuntimeWrapper.post_compile.<locals>.wrapper(runtime_args)
549 out = self._functionalized_rng_runtime_epilogue(
550 runtime_metadata,
551 out,
552 # TODO: this won't be right for the backward when we convert the call_compiled_backward to use the wrapper
553 runtime_metadata.num_forward_returns,
554 )
555 return out
--> 556 return compiled_fn(runtime_args)
File [/opt/conda/lib/python3.11/site-packages/torch/_inductor/output_code.py:584](http://localhost:8888/opt/conda/lib/python3.11/site-packages/torch/_inductor/output_code.py#line=583), in CompiledFxGraph.__call__(self, inputs)
582 assert self.current_callable is not None
583 try:
--> 584 return self.current_callable(inputs)
585 finally:
586 get_runtime_metrics_context().finish()
File [/opt/conda/lib/python3.11/site-packages/torch/_inductor/utils.py:2716](http://localhost:8888/opt/conda/lib/python3.11/site-packages/torch/_inductor/utils.py#line=2715), in align_inputs_from_check_idxs.<locals>.run(new_inputs)
2712 def run(new_inputs: list[InputType]) -> Any:
2713 old_tensors, new_tensors = copy_misaligned_inputs(
2714 new_inputs, inputs_to_check, mutated_input_idxs
2715 )
-> 2716 out = model(new_inputs)
2718 # If a mutated tensor was cloned to be aligned, we need to reflect back the mutation to the
2719 # original tensor.
2720 if len(old_tensors):
File [/tmp/torchinductor_unsloth/67/c67ekh5dv3tzbugvqdncd2xm7vt3km63f42igb2w4jybfr4uccul.py:123](http://localhost:8888/tmp/torchinductor_unsloth/67/c67ekh5dv3tzbugvqdncd2xm7vt3km63f42igb2w4jybfr4uccul.py#line=122), in call(args)
121 assert_alignment(buf7, 16, 'torch.ops.bitsandbytes.int8_mixed_scaled_mm.default')
122 buf8 = buf6[1]
--> 123 assert_size_stride(buf8, (u1, ), (1, ), 'torch.ops.bitsandbytes.int8_mixed_scaled_mm.default')
124 assert_alignment(buf8, 16, 'torch.ops.bitsandbytes.int8_mixed_scaled_mm.default')
125 if not (u1 >= 0):
AssertionError: wrong number of dimensions2 for op: torch.ops.bitsandbytes.int8_mixed_scaled_mm.default
Running a notebook in the official docker imgage with the sample code (based on the unsloth notebook
Qwen3_(4B)-Instruct) below will fail on my system.Producing the error trace
🦥 You can also ask via our Reddit page: https://www.reddit.com/r/unsloth/