Skip to content

RuntimeError("Invalid device string: 'bfloat16'") #2090

@lurf21

Description

@lurf21

I am fine-tuning Qwen2.5_Coder according to the code in the notebook, the following code in the Inference section reports an error:

outputs = model.generate(input_ids = inputs, max_new_tokens = 64, use_cache = True, temperature = 1.5, min_p = 0.1)

Traceback:

File "/root/main.py", line 119, in train
    outputs = model.generate(input_ids = inputs, max_new_tokens = 64, use_cache = True,
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/unsloth/models/llama.py", line 1579, in unsloth_fast_generate
    output = self._old_generate(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/transformers/generation/utils.py", line 2223, in generate
    result = self._sample(
             ^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/transformers/generation/utils.py", line 3214, in _sample
    outputs = model_forward(**model_inputs, return_dict=True)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/accelerate/utils/operations.py", line 819, in forward
    return model_forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/accelerate/utils/operations.py", line 807, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/unsloth/models/llama.py", line 1026, in _CausalLM_fast_forward
    outputs = fast_forward_inference(
              ^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/unsloth/models/llama.py", line 927, in LlamaModel_fast_forward_inference
    X = X.to(self.config.torch_dtype)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Invalid device string: 'bfloat16'

I think it is related with #404.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions