Hello! I wanted to test the int8 performance benefit, but ran into this error (CUDA and pytorch 12.1):
python3 generate.py --quantize llm.int8 --prompt "Hello, my name is"
->
Loading model ...
bin /usr/local/lib/python3.8/dist-packages/bitsandbytes/libbitsandbytes_cuda121.so
Time to load model: 29.66 seconds.
Global seed set to 1234
cuBLAS API failed with status 15
A: torch.Size([6, 4096]), B: torch.Size([12288, 4096]), C: (6, 12288); (lda, ldb, ldc): (c_int(192), c_int(393216), c_int(192)); (m, n, k): (c_int(6), c_int(12288), c_int(4096))
error detectedTraceback (most recent call last):
File "generate.py", line 172, in <module>
CLI(main)
File "/usr/local/lib/python3.8/dist-packages/jsonargparse/cli.py", line 85, in CLI
return _run_component(component, cfg_init)
File "/usr/local/lib/python3.8/dist-packages/jsonargparse/cli.py", line 147, in _run_component
return component(**cfg)
File "generate.py", line 147, in main
y = generate(model, encoded, max_new_tokens, temperature=temperature, top_k=top_k)
File "/usr/local/lib/python3.8/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "generate.py", line 65, in generate
logits = model(x, max_seq_length, input_pos)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/lightning/fabric/wrappers.py", line 109, in forward
output = self._forward_module(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _call_impl
return forward_call(*args, **kwargs)
File "/app/lit-llama/lit_llama/model.py", line 114, in forward
x, self.kv_caches[i] = block(x, rope, mask, max_seq_length, input_pos, self.kv_caches[i])
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _call_impl
return forward_call(*args, **kwargs)
File "/app/lit-llama/lit_llama/model.py", line 159, in forward
h, new_kv_cache = self.attn(self.rms_1(x), rope, mask, max_seq_length, input_pos, kv_cache)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _call_impl
return forward_call(*args, **kwargs)
File "/app/lit-llama/lit_llama/model.py", line 191, in forward
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/bitsandbytes/nn/modules.py", line 320, in forward
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
File "/usr/local/lib/python3.8/dist-packages/bitsandbytes/autograd/_functions.py", line 500, in matmul
return MatMul8bitLt.apply(A, B, out, bias, state)
File "/usr/local/lib/python3.8/dist-packages/torch/autograd/function.py", line 506, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/usr/local/lib/python3.8/dist-packages/bitsandbytes/autograd/_functions.py", line 397, in forward
out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
File "/usr/local/lib/python3.8/dist-packages/bitsandbytes/functional.py", line 1436, in igemmlt
raise Exception('cublasLt ran into an error!')
Exception: cublasLt ran into an error!
Hello! I wanted to test the int8 performance benefit, but ran into this error (CUDA and pytorch 12.1):
python3 generate.py --quantize llm.int8 --prompt "Hello, my name is"->