Skip to content

AssertionError('initial value for logits error [FIXED]  #1248

@daegonYu

Description

@daegonYu
{
	"name": "CompilationError",
	"message": "at 53:4:
    loss_ptr      += row_idx
    logsumexp_ptr += row_idx * N_CHUNKS + chunk_idx
    labels_ptr    += row_idx

    col_offsets = chunk_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = col_offsets < VOCAB_SIZE

    label_idx = tl.load(labels_ptr).to(tl.int32)
    logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float(\"inf\"))

    # Go logit scaling for Cohere: t * x
    if DO_LOGIT_SCALING: logits = LOGIT_SCALE * logits
    ^
AssertionError('initial value for `logits` is of type <[65536], bf16>, but the then block redefines it as <[65536], fp32>')",
	"stack": "---------------------------------------------------------------------------
CompilationError                          Traceback (most recent call last)
Cell In[28], line 1
----> 1 trainer_stats = trainer.train()

File <string>:156, in train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)

File <string>:380, in _fast_inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)

File <string>:31, in _unsloth_training_step(self, model, inputs, num_items_in_batch)

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/unsloth/models/_utils.py:945, in _unsloth_pre_compute_loss(self, model, inputs, *args, **kwargs)
    943     pass
    944 pass
--> 945 return self._old_compute_loss(model, inputs, *args, **kwargs)

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/transformers/trainer.py:3633, in Trainer.compute_loss(self, model, inputs, return_outputs, num_items_in_batch)
   3631         loss_kwargs[\"num_items_in_batch\"] = num_items_in_batch
   3632     inputs = {**inputs, **loss_kwargs}
-> 3633 outputs = model(**inputs)
   3634 # Save past state if it exists
   3635 # TODO: this needs to be fixed and made cleaner later.
   3636 if self.args.past_index >= 0:

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/accelerate/utils/operations.py:823, in convert_outputs_to_fp32.<locals>.forward(*args, **kwargs)
    822 def forward(*args, **kwargs):
--> 823     return model_forward(*args, **kwargs)

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/accelerate/utils/operations.py:811, in ConvertOutputsToFp32.__call__(self, *args, **kwargs)
    810 def __call__(self, *args, **kwargs):
--> 811     return convert_to_fp32(self.model_forward(*args, **kwargs))

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/torch/amp/autocast_mode.py:44, in autocast_decorator.<locals>.decorate_autocast(*args, **kwargs)
     41 @functools.wraps(func)
     42 def decorate_autocast(*args, **kwargs):
     43     with autocast_instance:
---> 44         return func(*args, **kwargs)

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/torch/_compile.py:32, in _disable_dynamo.<locals>.inner(*args, **kwargs)
     29     disable_fn = torch._dynamo.disable(fn, recursive)
     30     fn.__dynamo_disable = disable_fn
---> 32 return disable_fn(*args, **kwargs)

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py:632, in DisableContext.__call__.<locals>._fn(*args, **kwargs)
    630 prior = _maybe_set_eval_frame(callback)
    631 try:
--> 632     return fn(*args, **kwargs)
    633 finally:
    634     _maybe_set_eval_frame(prior)

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/unsloth/models/llama.py:1046, in PeftModelForCausalLM_fast_forward(self, input_ids, causal_mask, attention_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, task_ids, num_logits_to_keep, **kwargs)
   1031 @torch._disable_dynamo
   1032 def PeftModelForCausalLM_fast_forward(
   1033     self,
   (...)
   1044     **kwargs,
   1045 ):
-> 1046     return self.base_model(
   1047         input_ids=input_ids,
   1048         causal_mask=causal_mask,
   1049         attention_mask=attention_mask,
   1050         inputs_embeds=inputs_embeds,
   1051         labels=labels,
   1052         output_attentions=output_attentions,
   1053         output_hidden_states=output_hidden_states,
   1054         return_dict=return_dict,
   1055         num_logits_to_keep=num_logits_to_keep,
   1056         **kwargs,
   1057     )

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/peft/tuners/tuners_utils.py:197, in BaseTuner.forward(self, *args, **kwargs)
    196 def forward(self, *args: Any, **kwargs: Any):
--> 197     return self.model.forward(*args, **kwargs)

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/unsloth/models/llama.py:987, in CausalLM_fast_forward.<locals>._CausalLM_fast_forward(self, input_ids, causal_mask, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, num_logits_to_keep, *args, **kwargs)
    984     pass
    986     shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]]))
--> 987     loss = fast_cross_entropy_loss(
    988         logits = shift_logits,
    989         labels = shift_labels,
    990         logit_softcapping = logit_softcapping,
    991         logit_scaling     = logit_scaling,
    992         n_items           = kwargs.get(\"num_items_in_batch\", None) or kwargs.get(\"n_items\", None),
    993     )
    994 else:
    995     if logit_scaling != 0:

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/unsloth/kernels/cross_entropy_loss.py:386, in fast_cross_entropy_loss(logits, labels, logit_softcapping, logit_scaling, n_items)
    383 batch, seq_len, d = logits.shape
    384 assert(labels.shape == (batch, seq_len))
--> 386 loss = Fast_CrossEntropyLoss.apply(
    387     logits.view(batch*seq_len, d),
    388     labels.view(-1),
    389     logit_softcapping,
    390     logit_scaling,
    391 )
    392 if n_items is None:
    393     n_items = torch.count_nonzero(labels != -100)

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/torch/autograd/function.py:575, in Function.apply(cls, *args, **kwargs)
    572 if not torch._C._are_functorch_transforms_active():
    573     # See NOTE: [functorch vjp and autograd interaction]
    574     args = _functorch.utils.unwrap_dead_wrappers(args)
--> 575     return super().apply(*args, **kwargs)  # type: ignore[misc]
    577 if not is_setup_ctx_defined:
    578     raise RuntimeError(
    579         \"In order to use an autograd.Function with functorch transforms \"
    580         \"(vmap, grad, jvp, jacrev, ...), it must override the setup_context \"
    581         \"staticmethod. For more details, please see \"
    582         \"https://pytorch.org/docs/main/notes/extending.func.html\"
    583     )

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/unsloth/kernels/cross_entropy_loss.py:311, in Fast_CrossEntropyLoss.forward(ctx, logits, labels, logit_softcapping, logit_scaling)
    307 else:
    308     # For large vocabs > 65336 like Gemma 256K
    309     logsumexp = torch.empty((n_rows, n_chunks,), dtype = torch.float32, device = \"cuda:0\")
--> 311     _chunked_cross_entropy_forward[(n_rows, n_chunks,)](
    312         logits, logits.stride(0),
    313         losses,
    314         logsumexp,
    315         labels,
    316         VOCAB_SIZE       = vocab_size,
    317         N_CHUNKS         = n_chunks,
    318         BLOCK_SIZE       = MAX_FUSED_SIZE,
    319         DO_SOFTCAPPING   = DO_SOFTCAPPING,
    320         SOFTCAP          = logit_softcapping,
    321         DO_LOGIT_SCALING = DO_LOGIT_SCALING,
    322         LOGIT_SCALE      = logit_scaling,
    323         num_warps        = 32,
    324     )
    325     # logsumexp(chunked_logsumexp) - x
    326     # Do the -x separately
    327     logsumexp = torch.logsumexp(logsumexp, dim = 1) # Row sum

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/triton/runtime/jit.py:345, in KernelInterface.__getitem__.<locals>.<lambda>(*args, **kwargs)
    339 def __getitem__(self, grid) -> T:
    340     \"\"\"
    341     A JIT function is launched with: fn[grid](*args, **kwargs).
    342     Hence JITFunction.__getitem__ returns a callable proxy that
    343     memorizes the grid.
    344     \"\"\"
--> 345     return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/triton/runtime/autotuner.py:338, in Heuristics.run(self, *args, **kwargs)
    336 for v, heur in self.values.items():
    337     kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs})
--> 338 return self.fn.run(*args, **kwargs)

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/triton/runtime/jit.py:662, in JITFunction.run(self, grid, warmup, *args, **kwargs)
    660     # compile the kernel
    661     src = self.ASTSource(self, signature, constants, configs[0])
--> 662     kernel = self.compile(
    663         src,
    664         target=target,
    665         options=options.__dict__,
    666     )
    667     self.cache[device][key] = kernel
    669 # Check that used global values have not changed.

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/triton/compiler/compiler.py:276, in compile(src, target, options)
    274 codegen_fns = backend.get_codegen_implementation()
    275 try:
--> 276     module = src.make_ir(options, codegen_fns, context)
    277 except Exception as e:
    278     filter_traceback(e)

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/triton/compiler/compiler.py:113, in ASTSource.make_ir(self, options, codegen_fns, context)
    112 def make_ir(self, options, codegen_fns, context):
--> 113     return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns)

CompilationError: at 53:4:
    loss_ptr      += row_idx
    logsumexp_ptr += row_idx * N_CHUNKS + chunk_idx
    labels_ptr    += row_idx

    col_offsets = chunk_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = col_offsets < VOCAB_SIZE

    label_idx = tl.load(labels_ptr).to(tl.int32)
    logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float(\"inf\"))

    # Go logit scaling for Cohere: t * x
    if DO_LOGIT_SCALING: logits = LOGIT_SCALE * logits
    ^
AssertionError('initial value for `logits` is of type <[65536], bf16>, but the then block redefines it as <[65536], fp32>')"
}

Package Version


accelerate 1.1.0
aiohappyeyeballs 2.4.3
aiohttp 3.10.10
aiosignal 1.3.1
asttokens 2.4.1
attrs 24.2.0
bitsandbytes 0.44.1
certifi 2024.8.30
charset-normalizer 3.4.0
comm 0.2.2
datasets 3.1.0
debugpy 1.8.7
decorator 5.1.1
dill 0.3.8
docstring_parser 0.16
exceptiongroup 1.2.2
executing 2.1.0
filelock 3.13.1
frozenlist 1.5.0
fsspec 2024.9.0
gmpy2 2.1.2
hf_transfer 0.1.8
huggingface-hub 0.26.2
idna 3.10
importlib_metadata 8.5.0
ipykernel 6.29.5
ipython 8.29.0
jedi 0.19.1
Jinja2 3.1.4
jupyter_client 8.6.3
jupyter_core 5.7.2
markdown-it-py 3.0.0
MarkupSafe 2.1.3
matplotlib-inline 0.1.7
mdurl 0.1.2
mpmath 1.3.0
multidict 6.1.0
multiprocess 0.70.16
nest_asyncio 1.6.0
networkx 3.3
numpy 2.1.3
nvidia-cublas-cu12 12.4.5.8
nvidia-cuda-cupti-cu12 12.4.127
nvidia-cuda-nvrtc-cu12 12.4.127
nvidia-cuda-runtime-cu12 12.4.127
nvidia-cudnn-cu12 9.1.0.70
nvidia-cufft-cu12 11.2.1.3
nvidia-curand-cu12 10.3.5.147
nvidia-cusolver-cu12 11.6.1.9
nvidia-cusparse-cu12 12.3.1.170
nvidia-nccl-cu12 2.21.5
nvidia-nvjitlink-cu12 12.4.127
nvidia-nvtx-cu12 12.4.127
packaging 24.1
pandas 2.2.3
parso 0.8.4
peft 0.13.2
pexpect 4.9.0
pickleshare 0.7.5
pillow 10.2.0
pip 24.3.1
platformdirs 4.3.6
prompt_toolkit 3.0.48
propcache 0.2.0
protobuf 3.20.3
psutil 6.1.0
ptyprocess 0.7.0
pure_eval 0.2.3
pyarrow 18.0.0
Pygments 2.18.0
python-dateutil 2.9.0
pytz 2024.2
PyYAML 6.0.2
pyzmq 26.2.0
regex 2024.9.11
requests 2.32.3
rich 13.9.4
safetensors 0.4.5
sentencepiece 0.2.0
setuptools 75.3.0
shtab 1.7.1
six 1.16.0
stack-data 0.6.2
sympy 1.13.1
tokenizers 0.20.3
torch 2.5.1
tornado 6.4.1
tqdm 4.66.6
traitlets 5.14.3
transformers 4.46.2
triton 3.1.0
trl 0.12.0
typing_extensions 4.12.2
tyro 0.8.14
tzdata 2024.2
unsloth 2024.11.1
unsloth_zoo 2024.11.1
urllib3 2.2.3
wcwidth 0.2.13
wheel 0.44.0
xformers 0.0.28.post3
xxhash 3.5.0
yarl 1.17.1
zipp 3.20.2

Metadata

Metadata

Assignees

No one assigned

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions