{
"name": "CompilationError",
"message": "at 53:4:
loss_ptr += row_idx
logsumexp_ptr += row_idx * N_CHUNKS + chunk_idx
labels_ptr += row_idx
col_offsets = chunk_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = col_offsets < VOCAB_SIZE
label_idx = tl.load(labels_ptr).to(tl.int32)
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float(\"inf\"))
# Go logit scaling for Cohere: t * x
if DO_LOGIT_SCALING: logits = LOGIT_SCALE * logits
^
AssertionError('initial value for `logits` is of type <[65536], bf16>, but the then block redefines it as <[65536], fp32>')",
"stack": "---------------------------------------------------------------------------
CompilationError Traceback (most recent call last)
Cell In[28], line 1
----> 1 trainer_stats = trainer.train()
File <string>:156, in train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
File <string>:380, in _fast_inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
File <string>:31, in _unsloth_training_step(self, model, inputs, num_items_in_batch)
File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/unsloth/models/_utils.py:945, in _unsloth_pre_compute_loss(self, model, inputs, *args, **kwargs)
943 pass
944 pass
--> 945 return self._old_compute_loss(model, inputs, *args, **kwargs)
File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/transformers/trainer.py:3633, in Trainer.compute_loss(self, model, inputs, return_outputs, num_items_in_batch)
3631 loss_kwargs[\"num_items_in_batch\"] = num_items_in_batch
3632 inputs = {**inputs, **loss_kwargs}
-> 3633 outputs = model(**inputs)
3634 # Save past state if it exists
3635 # TODO: this needs to be fixed and made cleaner later.
3636 if self.args.past_index >= 0:
File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/accelerate/utils/operations.py:823, in convert_outputs_to_fp32.<locals>.forward(*args, **kwargs)
822 def forward(*args, **kwargs):
--> 823 return model_forward(*args, **kwargs)
File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/accelerate/utils/operations.py:811, in ConvertOutputsToFp32.__call__(self, *args, **kwargs)
810 def __call__(self, *args, **kwargs):
--> 811 return convert_to_fp32(self.model_forward(*args, **kwargs))
File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/torch/amp/autocast_mode.py:44, in autocast_decorator.<locals>.decorate_autocast(*args, **kwargs)
41 @functools.wraps(func)
42 def decorate_autocast(*args, **kwargs):
43 with autocast_instance:
---> 44 return func(*args, **kwargs)
File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/torch/_compile.py:32, in _disable_dynamo.<locals>.inner(*args, **kwargs)
29 disable_fn = torch._dynamo.disable(fn, recursive)
30 fn.__dynamo_disable = disable_fn
---> 32 return disable_fn(*args, **kwargs)
File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py:632, in DisableContext.__call__.<locals>._fn(*args, **kwargs)
630 prior = _maybe_set_eval_frame(callback)
631 try:
--> 632 return fn(*args, **kwargs)
633 finally:
634 _maybe_set_eval_frame(prior)
File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/unsloth/models/llama.py:1046, in PeftModelForCausalLM_fast_forward(self, input_ids, causal_mask, attention_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, task_ids, num_logits_to_keep, **kwargs)
1031 @torch._disable_dynamo
1032 def PeftModelForCausalLM_fast_forward(
1033 self,
(...)
1044 **kwargs,
1045 ):
-> 1046 return self.base_model(
1047 input_ids=input_ids,
1048 causal_mask=causal_mask,
1049 attention_mask=attention_mask,
1050 inputs_embeds=inputs_embeds,
1051 labels=labels,
1052 output_attentions=output_attentions,
1053 output_hidden_states=output_hidden_states,
1054 return_dict=return_dict,
1055 num_logits_to_keep=num_logits_to_keep,
1056 **kwargs,
1057 )
File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/peft/tuners/tuners_utils.py:197, in BaseTuner.forward(self, *args, **kwargs)
196 def forward(self, *args: Any, **kwargs: Any):
--> 197 return self.model.forward(*args, **kwargs)
File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
168 output = module._old_forward(*args, **kwargs)
169 else:
--> 170 output = module._old_forward(*args, **kwargs)
171 return module._hf_hook.post_forward(module, output)
File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/unsloth/models/llama.py:987, in CausalLM_fast_forward.<locals>._CausalLM_fast_forward(self, input_ids, causal_mask, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, num_logits_to_keep, *args, **kwargs)
984 pass
986 shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]]))
--> 987 loss = fast_cross_entropy_loss(
988 logits = shift_logits,
989 labels = shift_labels,
990 logit_softcapping = logit_softcapping,
991 logit_scaling = logit_scaling,
992 n_items = kwargs.get(\"num_items_in_batch\", None) or kwargs.get(\"n_items\", None),
993 )
994 else:
995 if logit_scaling != 0:
File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/unsloth/kernels/cross_entropy_loss.py:386, in fast_cross_entropy_loss(logits, labels, logit_softcapping, logit_scaling, n_items)
383 batch, seq_len, d = logits.shape
384 assert(labels.shape == (batch, seq_len))
--> 386 loss = Fast_CrossEntropyLoss.apply(
387 logits.view(batch*seq_len, d),
388 labels.view(-1),
389 logit_softcapping,
390 logit_scaling,
391 )
392 if n_items is None:
393 n_items = torch.count_nonzero(labels != -100)
File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/torch/autograd/function.py:575, in Function.apply(cls, *args, **kwargs)
572 if not torch._C._are_functorch_transforms_active():
573 # See NOTE: [functorch vjp and autograd interaction]
574 args = _functorch.utils.unwrap_dead_wrappers(args)
--> 575 return super().apply(*args, **kwargs) # type: ignore[misc]
577 if not is_setup_ctx_defined:
578 raise RuntimeError(
579 \"In order to use an autograd.Function with functorch transforms \"
580 \"(vmap, grad, jvp, jacrev, ...), it must override the setup_context \"
581 \"staticmethod. For more details, please see \"
582 \"https://pytorch.org/docs/main/notes/extending.func.html\"
583 )
File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/unsloth/kernels/cross_entropy_loss.py:311, in Fast_CrossEntropyLoss.forward(ctx, logits, labels, logit_softcapping, logit_scaling)
307 else:
308 # For large vocabs > 65336 like Gemma 256K
309 logsumexp = torch.empty((n_rows, n_chunks,), dtype = torch.float32, device = \"cuda:0\")
--> 311 _chunked_cross_entropy_forward[(n_rows, n_chunks,)](
312 logits, logits.stride(0),
313 losses,
314 logsumexp,
315 labels,
316 VOCAB_SIZE = vocab_size,
317 N_CHUNKS = n_chunks,
318 BLOCK_SIZE = MAX_FUSED_SIZE,
319 DO_SOFTCAPPING = DO_SOFTCAPPING,
320 SOFTCAP = logit_softcapping,
321 DO_LOGIT_SCALING = DO_LOGIT_SCALING,
322 LOGIT_SCALE = logit_scaling,
323 num_warps = 32,
324 )
325 # logsumexp(chunked_logsumexp) - x
326 # Do the -x separately
327 logsumexp = torch.logsumexp(logsumexp, dim = 1) # Row sum
File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/triton/runtime/jit.py:345, in KernelInterface.__getitem__.<locals>.<lambda>(*args, **kwargs)
339 def __getitem__(self, grid) -> T:
340 \"\"\"
341 A JIT function is launched with: fn[grid](*args, **kwargs).
342 Hence JITFunction.__getitem__ returns a callable proxy that
343 memorizes the grid.
344 \"\"\"
--> 345 return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/triton/runtime/autotuner.py:338, in Heuristics.run(self, *args, **kwargs)
336 for v, heur in self.values.items():
337 kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs})
--> 338 return self.fn.run(*args, **kwargs)
File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/triton/runtime/jit.py:662, in JITFunction.run(self, grid, warmup, *args, **kwargs)
660 # compile the kernel
661 src = self.ASTSource(self, signature, constants, configs[0])
--> 662 kernel = self.compile(
663 src,
664 target=target,
665 options=options.__dict__,
666 )
667 self.cache[device][key] = kernel
669 # Check that used global values have not changed.
File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/triton/compiler/compiler.py:276, in compile(src, target, options)
274 codegen_fns = backend.get_codegen_implementation()
275 try:
--> 276 module = src.make_ir(options, codegen_fns, context)
277 except Exception as e:
278 filter_traceback(e)
File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/triton/compiler/compiler.py:113, in ASTSource.make_ir(self, options, codegen_fns, context)
112 def make_ir(self, options, codegen_fns, context):
--> 113 return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns)
CompilationError: at 53:4:
loss_ptr += row_idx
logsumexp_ptr += row_idx * N_CHUNKS + chunk_idx
labels_ptr += row_idx
col_offsets = chunk_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = col_offsets < VOCAB_SIZE
label_idx = tl.load(labels_ptr).to(tl.int32)
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float(\"inf\"))
# Go logit scaling for Cohere: t * x
if DO_LOGIT_SCALING: logits = LOGIT_SCALE * logits
^
AssertionError('initial value for `logits` is of type <[65536], bf16>, but the then block redefines it as <[65536], fp32>')"
}
Package Version
accelerate 1.1.0
aiohappyeyeballs 2.4.3
aiohttp 3.10.10
aiosignal 1.3.1
asttokens 2.4.1
attrs 24.2.0
bitsandbytes 0.44.1
certifi 2024.8.30
charset-normalizer 3.4.0
comm 0.2.2
datasets 3.1.0
debugpy 1.8.7
decorator 5.1.1
dill 0.3.8
docstring_parser 0.16
exceptiongroup 1.2.2
executing 2.1.0
filelock 3.13.1
frozenlist 1.5.0
fsspec 2024.9.0
gmpy2 2.1.2
hf_transfer 0.1.8
huggingface-hub 0.26.2
idna 3.10
importlib_metadata 8.5.0
ipykernel 6.29.5
ipython 8.29.0
jedi 0.19.1
Jinja2 3.1.4
jupyter_client 8.6.3
jupyter_core 5.7.2
markdown-it-py 3.0.0
MarkupSafe 2.1.3
matplotlib-inline 0.1.7
mdurl 0.1.2
mpmath 1.3.0
multidict 6.1.0
multiprocess 0.70.16
nest_asyncio 1.6.0
networkx 3.3
numpy 2.1.3
nvidia-cublas-cu12 12.4.5.8
nvidia-cuda-cupti-cu12 12.4.127
nvidia-cuda-nvrtc-cu12 12.4.127
nvidia-cuda-runtime-cu12 12.4.127
nvidia-cudnn-cu12 9.1.0.70
nvidia-cufft-cu12 11.2.1.3
nvidia-curand-cu12 10.3.5.147
nvidia-cusolver-cu12 11.6.1.9
nvidia-cusparse-cu12 12.3.1.170
nvidia-nccl-cu12 2.21.5
nvidia-nvjitlink-cu12 12.4.127
nvidia-nvtx-cu12 12.4.127
packaging 24.1
pandas 2.2.3
parso 0.8.4
peft 0.13.2
pexpect 4.9.0
pickleshare 0.7.5
pillow 10.2.0
pip 24.3.1
platformdirs 4.3.6
prompt_toolkit 3.0.48
propcache 0.2.0
protobuf 3.20.3
psutil 6.1.0
ptyprocess 0.7.0
pure_eval 0.2.3
pyarrow 18.0.0
Pygments 2.18.0
python-dateutil 2.9.0
pytz 2024.2
PyYAML 6.0.2
pyzmq 26.2.0
regex 2024.9.11
requests 2.32.3
rich 13.9.4
safetensors 0.4.5
sentencepiece 0.2.0
setuptools 75.3.0
shtab 1.7.1
six 1.16.0
stack-data 0.6.2
sympy 1.13.1
tokenizers 0.20.3
torch 2.5.1
tornado 6.4.1
tqdm 4.66.6
traitlets 5.14.3
transformers 4.46.2
triton 3.1.0
trl 0.12.0
typing_extensions 4.12.2
tyro 0.8.14
tzdata 2024.2
unsloth 2024.11.1
unsloth_zoo 2024.11.1
urllib3 2.2.3
wcwidth 0.2.13
wheel 0.44.0
xformers 0.0.28.post3
xxhash 3.5.0
yarl 1.17.1
zipp 3.20.2