@@ -200,23 +200,41 @@ def unpack_tensor_to_dict(tensor_data):
200200 return unpacked_dict
201201
202202
203- def enable_ipex_fusion (linear ):
203+ def enable_ipex_fusion (linear , grad = False ):
204204 from bitsandbytes .backends .cpu_xpu_common import _ipex_cpu_version_prereq
205205
206206 if _ipex_cpu_version_prereq (2 , 5 ):
207207 quant_state = linear .weight .quant_state
208208 new_weight , new_scales , new_zeros , _ , compensation = \
209+ torch .ops .ipex_prepack .woq_linear_pack_weight (
210+ linear .weight .data .reshape ([quant_state .shape [0 ], quant_state .shape [1 ] // 2 ]),
211+ "nf4" ,
212+ quant_state .shape , # weight shape
213+ quant_state .absmax .view (quant_state .shape [0 ], quant_state .shape [1 ] // quant_state .blocksize ), # scales
214+ None , # zero_points
215+ None , # bias
216+ None , # batch_size
217+ quant_state .blocksize ,
218+ 2 ,
219+ )
220+ if grad or True :
221+ backward_new_weight , backward_new_scales , backward_new_zeros , _ , backward_compensation = \
209222 torch .ops .ipex_prepack .woq_linear_pack_weight (
210- linear .weight .data .reshape ([quant_state .shape [0 ], quant_state .shape [1 ] // 2 ]),
223+ linear .weight .t (). data .reshape ([quant_state .shape [1 ], quant_state .shape [0 ] // 2 ]),
211224 "nf4" ,
212225 quant_state .shape , # weight shape
213- quant_state .absmax .view (quant_state .shape [0 ], quant_state .shape [1 ] // quant_state .blocksize ), # scales
226+ quant_state .absmax .view (quant_state .shape [1 ], quant_state .shape [0 ] // quant_state .blocksize ), # scales
214227 None , # zero_points
215228 None , # bias
216229 None , # batch_size
217230 quant_state .blocksize ,
218231 2 ,
219232 )
233+ setattr (linear .weight .quant_state , "backward_weight" , backward_new_weight )
234+ setattr (linear .weight .quant_state , "backward_new_scales" , backward_new_scales )
235+ setattr (linear .weight .quant_state , "backward_new_zeros" , backward_new_zeros )
236+ setattr (linear .weight .quant_state , "backward_compensation" , backward_compensation )
237+
220238 linear .weight .data = new_weight .data
221239 setattr (linear .weight .quant_state , "ipex" , True )
222240 setattr (linear .weight .quant_state , "new_scales" , new_scales )
0 commit comments