Skip to content

Commit cd7bf21

Browse files
committed
enable backward
1 parent b8df1aa commit cd7bf21

6 files changed

Lines changed: 54 additions & 18 deletions

File tree

bitsandbytes/autograd/_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -552,7 +552,7 @@ def backward(ctx, grad_output):
552552
# if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
553553
if req_gradA:
554554
if getattr(ctx.state, "ipex", False):
555-
grad_A = F.gemv_4bit(grad_output, B, None, state=ctx.state)
555+
grad_A = F.gemv_4bit(grad_output, B, None, state=ctx.state, backward=True)
556556
else:
557557
grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype).t())
558558

bitsandbytes/backends/cpu.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,12 +163,13 @@ def gemv_4bit(
163163
transposed_A=False,
164164
transposed_B=False,
165165
state: QuantState = None,
166+
backward=False,
166167
) -> torch.Tensor:
167168
assert_on_cpu([A, B, out])
168169
if state is None:
169170
raise ValueError("state cannot be None. gemv_4bit() requires the state from quantize_4bit()")
170171

171-
return gemm_4bit_impl(A, B, out, transposed_A, transposed_B, state)
172+
return gemm_4bit_impl(A, B, out, transposed_A, transposed_B, state, backward)
172173

173174
def dequantize_blockwise(
174175
self,

bitsandbytes/backends/cpu_xpu_common.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,7 @@ def gemm_4bit_impl(
486486
transposed_A=False,
487487
transposed_B=False,
488488
state: QuantState = None,
489+
backward=False,
489490
) -> torch.Tensor:
490491
"""
491492
Matrix-matrix multiplication with 4-bit quantization.
@@ -511,9 +512,14 @@ def gemm_4bit_impl(
511512
GEMM output tensor.
512513
"""
513514
if ipex_cpu and _ipex_cpu_version_prereq(2, 5) and getattr(state, "ipex", False):
514-
output = torch.ops.torch_ipex.woq_linear(A, B, "nf4", state.shape,
515-
state.new_scales, state.new_zeros, None, None, state.blocksize,
516-
ipex_cpu.quantization.WoqLowpMode.BF16, 1, state.compensation)
515+
if backward:
516+
output = torch.ops.torch_ipex.woq_linear(A, state.backward_weight, "nf4", torch.Size([state.shape[1], state.shape[0]]),
517+
state.backward_new_scales, state.backward_new_zeros, None, None, state.blocksize,
518+
ipex_cpu.quantization.WoqLowpMode.BF16, 1, state.backward_compensation)
519+
else:
520+
output = torch.ops.torch_ipex.woq_linear(A, B, "nf4", state.shape,
521+
state.new_scales, state.new_zeros, None, None, state.blocksize,
522+
ipex_cpu.quantization.WoqLowpMode.BF16, 1, state.compensation)
517523
else:
518524
dqB = dequantize_4bit_impl(B, state, blocksize=state.blocksize).t()
519525
output = torch.matmul(A, dqB.to(A.dtype))

bitsandbytes/functional.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1530,16 +1530,28 @@ def gemv_4bit(
15301530
transposed_A=False,
15311531
transposed_B=False,
15321532
state=None,
1533+
backward=False,
15331534
):
15341535
ensure_backend_is_available(A.device.type)
1535-
return backends[A.device.type].gemv_4bit(
1536-
A,
1537-
B,
1538-
out=out,
1539-
transposed_A=transposed_A,
1540-
transposed_B=transposed_B,
1541-
state=state,
1542-
)
1536+
if A.device.type == "cpu":
1537+
return backends[A.device.type].gemv_4bit(
1538+
A,
1539+
B,
1540+
out=out,
1541+
transposed_A=transposed_A,
1542+
transposed_B=transposed_B,
1543+
state=state,
1544+
backward=backward,
1545+
)
1546+
else:
1547+
return backends[A.device.type].gemv_4bit(
1548+
A,
1549+
B,
1550+
out=out,
1551+
transposed_A=transposed_A,
1552+
transposed_B=transposed_B,
1553+
state=state,
1554+
)
15431555

15441556

15451557
def igemm(

bitsandbytes/nn/modules.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -468,9 +468,8 @@ def forward(self, x: torch.Tensor):
468468
and not getattr(self.weight.quant_state, "ipex", False)
469469
and self.weight.quant_state.shape[1] % self.weight.quant_state.blocksize == 0
470470
and self.weight.quant_state.quant_type == "nf4"
471-
and x.requires_grad == False
472471
):
473-
enable_ipex_fusion(self)
472+
enable_ipex_fusion(self, x.requires_grad)
474473

475474
# weights are cast automatically as Int8Params, but the bias has to be cast manually
476475
if self.bias is not None and self.bias.dtype != x.dtype:

bitsandbytes/utils.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -200,23 +200,41 @@ def unpack_tensor_to_dict(tensor_data):
200200
return unpacked_dict
201201

202202

203-
def enable_ipex_fusion(linear):
203+
def enable_ipex_fusion(linear, grad=False):
204204
from bitsandbytes.backends.cpu_xpu_common import _ipex_cpu_version_prereq
205205

206206
if _ipex_cpu_version_prereq(2, 5):
207207
quant_state = linear.weight.quant_state
208208
new_weight, new_scales, new_zeros, _, compensation = \
209+
torch.ops.ipex_prepack.woq_linear_pack_weight(
210+
linear.weight.data.reshape([quant_state.shape[0], quant_state.shape[1] // 2]),
211+
"nf4",
212+
quant_state.shape, # weight shape
213+
quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize), # scales
214+
None, # zero_points
215+
None, # bias
216+
None, # batch_size
217+
quant_state.blocksize,
218+
2,
219+
)
220+
if grad or True:
221+
backward_new_weight, backward_new_scales, backward_new_zeros, _, backward_compensation = \
209222
torch.ops.ipex_prepack.woq_linear_pack_weight(
210-
linear.weight.data.reshape([quant_state.shape[0], quant_state.shape[1] // 2]),
223+
linear.weight.t().data.reshape([quant_state.shape[1], quant_state.shape[0] // 2]),
211224
"nf4",
212225
quant_state.shape, # weight shape
213-
quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize), # scales
226+
quant_state.absmax.view(quant_state.shape[1], quant_state.shape[0] // quant_state.blocksize), # scales
214227
None, # zero_points
215228
None, # bias
216229
None, # batch_size
217230
quant_state.blocksize,
218231
2,
219232
)
233+
setattr(linear.weight.quant_state, "backward_weight", backward_new_weight)
234+
setattr(linear.weight.quant_state, "backward_new_scales", backward_new_scales)
235+
setattr(linear.weight.quant_state, "backward_new_zeros", backward_new_zeros)
236+
setattr(linear.weight.quant_state, "backward_compensation", backward_compensation)
237+
220238
linear.weight.data = new_weight.data
221239
setattr(linear.weight.quant_state, "ipex", True)
222240
setattr(linear.weight.quant_state, "new_scales", new_scales)

0 commit comments

Comments
 (0)