File "/root/main.py", line 119, in train
outputs = model.generate(input_ids = inputs, max_new_tokens = 64, use_cache = True,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/unsloth/models/llama.py", line 1579, in unsloth_fast_generate
output = self._old_generate(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/transformers/generation/utils.py", line 2223, in generate
result = self._sample(
^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/transformers/generation/utils.py", line 3214, in _sample
outputs = model_forward(**model_inputs, return_dict=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/accelerate/utils/operations.py", line 819, in forward
return model_forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/accelerate/utils/operations.py", line 807, in __call__
return convert_to_fp32(self.model_forward(*args, **kwargs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/unsloth/models/llama.py", line 1026, in _CausalLM_fast_forward
outputs = fast_forward_inference(
^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/unsloth/models/llama.py", line 927, in LlamaModel_fast_forward_inference
X = X.to(self.config.torch_dtype)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Invalid device string: 'bfloat16'
I am fine-tuning Qwen2.5_Coder according to the code in the notebook, the following code in the Inference section reports an error:
Traceback:
I think it is related with #404.