Skip to content

Params4bit.to does not keep bnb_quantized status #1664

@mklabunde

Description

@mklabunde

System Info

Python version: 3.13.2
PyTorch version (GPU?): 2.7.0+cu126 (True)
Bitsandbytes version: 0.46.0

Reproduction

import torch
from bitsandbytes.nn import Params4bit

module = torch.nn.Linear(2, 2)
module = Params4bit(module.weight)
module.to("cuda:5")

module.to(
        device="cuda:0"
)  # this works because the Param4bit instance with bnb_quantized=False that is returned is not assigned.

module = module.to(
        device="cuda:0"
)  # this will lead to an error later because the new module is a Param4bit instance with bnb_quantized=False

module.to(device="cuda")  # RuntimeError: Blockwise quantization only supports 16/32-bit floats, but got torch.uint8
Full error message
Traceback (most recent call last):
  File "/ceph-ssd/bnb_fix/main.py", line 16, in <module>
    module.to(device="cuda")  # RuntimeError: Blockwise quantization only supports 16/32-bit floats, but got torch.uint8
    ~~~~~~~~~^^^^^^^^^^^^^^^
  File "/ceph-ssd/bnb_fix/.venv/lib/python3.13/site-packages/bitsandbytes/nn/modules.py", line 343, in to
    return self._quantize(device)
           ~~~~~~~~~~~~~~^^^^^^^^
  File "/ceph-ssd/bnb_fix/.venv/lib/python3.13/site-packages/bitsandbytes/nn/modules.py", line 302, in _quantize
    w_4bit, quant_state = bnb.functional.quantize_4bit(
                          ~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
        w,
        ^^
    ...<3 lines>...
        quant_storage=self.quant_storage,
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/ceph-ssd/bnb_fix/.venv/lib/python3.13/site-packages/bitsandbytes/functional.py", line 1008, in quantize_4bit
    _out, _absmax = torch.ops.bitsandbytes.quantize_4bit.default(
                    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
        A,
        ^^
    ...<2 lines>...
        quant_storage,
        ^^^^^^^^^^^^^^
    )
    ^
  File "/ceph-ssd/bnb_fix/.venv/lib/python3.13/site-packages/torch/_ops.py", line 756, in __call__
    return self._op(*args, **kwargs)
           ~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/ceph-ssd/bnb_fix/.venv/lib/python3.13/site-packages/torch/_compile.py", line 51, in inner
    return disable_fn(*args, **kwargs)
  File "/ceph-ssd/bnb_fix/.venv/lib/python3.13/site-packages/torch/_dynamo/eval_frame.py", line 838, in _fn
    return fn(*args, **kwargs)
  File "/ceph-ssd/bnb_fix/.venv/lib/python3.13/site-packages/torch/library.py", line 719, in func_no_dynamo
    return func(*args, **kwargs)
  File "/ceph-ssd/bnb_fix/.venv/lib/python3.13/site-packages/bitsandbytes/backends/cuda/ops.py", line 299, in _
    torch._check(
    ~~~~~~~~~~~~^
        A.dtype in [torch.bfloat16, torch.float16, torch.float32],
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}",
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/ceph-ssd/bnb_fix/.venv/lib/python3.13/site-packages/torch/__init__.py", line 1660, in _check
    _check_with(RuntimeError, cond, message)
    ~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/ceph-ssd/bnb_fix/.venv/lib/python3.13/site-packages/torch/__init__.py", line 1642, in _check_with
    raise error_type(message_evaluated)
RuntimeError: Blockwise 4bit quantization only supports 16/32-bit floats, but got torch.uint8

Expected behavior

The return value of Params4bit.to should correctly reflect the current quantization status, but the bnb_quantized attribute is not correctly set currently. This then leads an attempt at requantization when moving it to a GPU, while it is already quantized.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions