INT4 XPU enabling#1577
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1577
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 00c742c with merge base 6726b0b ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| layout_list = [] | ||
| if device == "cpu" and TORCH_VERSION_AT_LEAST_2_6: | ||
| layout_list.append(Int4CPULayout()) | ||
| elif device == "xpu" and TORCH_VERSION_AT_LEAST_2_6: |
There was a problem hiding this comment.
here as well, 2_6 or 2_7?
|
|
||
| __torch_function__ = torch._C._disabled_torch_function_impl | ||
|
|
||
| def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
There was a problem hiding this comment.
btw for this one, we have some unpacking op for tensor core tiled layout that we should really be using:
ao/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu
Lines 311 to 312 in cf45336
might be better to do the same instead of hacking with quantize ops
There was a problem hiding this comment.
sure. I will give a check.
|
btw why the op is added in pytorch/pytorch#137566 instead of in torchao? any plans to move it to torchao? |
@mingfeima @EikanWang can you comment? |
The situation is different for XPU (the intel GPUs) from CPU and CUDA here. Not sure that whether providing sycl or oneDNN xpu ops in ao is a feasible solution. |
91067e2 to
895376f
Compare
|
@jerryzh168 pls review again· |
| _ = torch.load(f, weights_only=False) | ||
|
|
||
| @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") | ||
| # TODO(#1690): delete this once config migration is done |
| if self.scale_and_zero is not None: | ||
| return ["packed_weight", "scale_and_zero"], [self.transposed, self._layout] | ||
| else: | ||
| return ["packed_weight", "scale", "zero"], [self.transposed, self._layout] |
There was a problem hiding this comment.
why do we have two formats here? maybe should split into multiple layouts?
There was a problem hiding this comment.
integer zp and floating zp
I don't split into 2 layout because from user side it will be confusing
current:
quantize_(model, int4_weight_only(group_size=32, layout=Int4XPULayout(), zero_point_domain=ZeroPointDomain.INT))
quantize_(model, int4_weight_only(group_size=32, layout=Int4XPULayout(), zero_point_domain=ZeroPointDomain.Float))
but if different layouts
quantize_(model, int4_weight_only(group_size=32, layout=Int4XPULayoutIntZP(), zero_point_domain=ZeroPointDomain.INT))
quantize_(model, int4_weight_only(group_size=32, layout=Int4XPULayoutFloatZP(), zero_point_domain=ZeroPointDomain.Float))I think the current implementation is more straightforward for users.
There was a problem hiding this comment.
but layout defines how we store the packed weights actually, using a single layout for multiple things is breaking this abstraction I feel
is the concern around specifying zero_point_domain multiple times? we could remove that and just infer the zero_point_domain from layout I think (the latter API)
There was a problem hiding this comment.
Since only XPU supports integer zp, can I move it in the next PR?
layout defines how we store the packed weights actually it should include the layout of scales and zeros, right?
There was a problem hiding this comment.
Since only XPU supports integer zp, can I move it in the next PR?
what is this referring to?
layout defines how we store the packed weights actually it should include the layout of scales and zeros, right?
yeah that's correct, ideally I think we should not use layout to control whether we have packed weight / scale_and_zero / scale, zero, the duplication should actually happen in the tensor level (we create different tensor subclass tensors), not layout. feel free to go that route if want.
There was a problem hiding this comment.
Can I separate into different layouts, and bind the zero point domain into each layout in the next PR?
| ) | ||
| if is_device(x.device.type, "xpu") and TORCH_VERSION_AT_LEAST_2_7 \ | ||
| and not isinstance(self.scales_and_zeros, torch.Tensor): | ||
| c = torch.ops.aten._weight_int4pack_mm_with_scales_and_zeros( |
There was a problem hiding this comment.
is this supposed to match line 550 in GPTQ.py?
There was a problem hiding this comment.
removed hqq support in this PR to simply the logic
| scales_and_zeros.to(scales_precision), | ||
| ).to(dtype=x.dtype) | ||
| elif is_device(x.device.type, "xpu") and TORCH_VERSION_AT_LEAST_2_7: | ||
| c = torch.ops.aten._weight_int4pack_mm_with_scales_and_zeros( |
There was a problem hiding this comment.
also why do we have this function? can the slicing (scales_and_zeros[0] and scales_and_zeros[1]) be done in the _weight_int4pack_mm itself?
There was a problem hiding this comment.
removed GPTQ support in this PR to simply the logic. I will open another PR after I seperate the layouts
| self.weight_int4pack = torch.ops.aten._convert_weight_to_int4pack_for_cpu( | ||
| W_q_torch, self.inner_k_tiles | ||
| ) | ||
| if is_device(W_q.device.type, "Xpu") and TORCH_VERSION_AT_LEAST_2_7: |
| shape_after_reduction = shape_for_reduction | ||
| for i in reduction_dims: | ||
| shape_after_reduction[i] = 1 | ||
| if shape_after_reduction[0] == 12288: |
| if preserve_zero: | ||
| zero_point = quant_min - torch.round(min_val_neg / scale) | ||
| zero_point = torch.clamp(zero_point, quant_min, quant_max) | ||
| zero_point_dtype = torch.int32 |
There was a problem hiding this comment.
In fact preserve_zero and INT zero point domain couples here, I think it is duplicated someway
The reason for setting this parameter as an int is that many places calling this function use the default floating parameter.
There was a problem hiding this comment.
preserve_zero talks about whether zero (in original floating point domain) is exactly representable or not, it's not coupled with zero point domain I think, even zero is exactly representation, we can still choose zero_point_domain to be in float
There was a problem hiding this comment.
yes, from the math side they are not related. but the code here implies this, see the condition dispatch from Line954 to 966. We need a refactor here.
There was a problem hiding this comment.
I think it would be better to do assert instead of changing condition if there is coupling?
|
|
||
|
|
||
| def get_groupwise_affine_qparams(w, n_bit=4, groupsize=128, dtype=torch.bfloat16): | ||
| def get_groupwise_affine_qparams(w, n_bit=4, groupsize=128, dtype=torch.bfloat16, zero_point_domain_is_int=False): |
There was a problem hiding this comment.
nit: I'm wondering if we should just expose zero_point_domain as an arg directly here
There was a problem hiding this comment.
which way do you prefer? how about
def get_groupwise_affine_qparams(w, n_bit=4, groupsize=128, data_dtype=torch.bfloat16, scale_dtype=torch.bfloat16, zero_dtype=torch.bfloat16):
There was a problem hiding this comment.
sure, this sounds good. what about preserve_zero and zero_point_domain? I don't think these can be fully inferred?
There was a problem hiding this comment.
is this going to be updated? I feel it's a bit weird to introduce a boolean flag for zero_point_domain when we can just pass zero_point_domain itself around
| zero_point_domain in LAYOUT_TO_ZERO_POINT_DOMAIN[type(layout)] | ||
| ), f"Layout only support {LAYOUT_TO_ZERO_POINT_DOMAIN[layout]}" | ||
|
|
||
| preserve_zero = LAYOUT_TO_PRESERVE_ZEROS[type(layout)] if zero_point_domain!=ZeroPointDomain.INT else True |
There was a problem hiding this comment.
does zero_point_dtype need to change after this is set
There was a problem hiding this comment.
same as the above. In fact preserve_zero and INT zero point domain couples.
There was a problem hiding this comment.
there are three things, preserve_zero = {True, False}, zero_point_domain = {FLOAT, INT, NONE} and zero_point_dtype = {float, int, ...}
it's true that not all combinations are valid, but I don't think they are coupled, see
ao/torchao/quantization/quant_primitives.py
Lines 755 to 770 in 64bcf4c
There was a problem hiding this comment.
how about 946f530, expose it as an independent argument?
jerryzh168
left a comment
There was a problem hiding this comment.
looks fine for now, but at some point we will probably add new tensor subclass tensors for these
e8357d5 to
b3d985d
Compare
Signed-off-by: Meng, Hengyu <hengyu.meng@intel.com>
4a7bb7b to
075a34a
Compare
remove zero_point_dtype assigning Signed-off-by: Meng, Hengyu <hengyu.meng@intel.com> fix import lint enable zp dtype: u8/s8/s16/s32/s64 Signed-off-by: Meng, Hengyu <hengyu.meng@intel.com>
Signed-off-by: Meng, Hengyu <hengyu.meng@intel.com>
| ) | ||
|
|
||
|
|
||
| def check_cpu_version(device, version="2.6.0"): |
There was a problem hiding this comment.
nit: I feel a more descriptive name might be better, this is a bit vague, it's a boolean, so the name should be is_xxx, also version arg seems to be not used in any of the callsite, it can be removed I think
e.g. is_cpu_device_and_after_torch_2_6, similar for the xpu check
There was a problem hiding this comment.
also
versionarg seems to be not used in any of the callsite, it can be removed I think
There might be another check, for example, checking cpu and torch version 2.8. And there are a lot of this kind of checks especially on CUDA: CUDA+2.4, CUDA+2.5, so on
Line 399 in 31f119e
ao/test/integration/test_integration.py
Line 1805 in 31f119e
There was a problem hiding this comment.
checking cuda is fine, but this check is specific for cpu and xpu right, and related to the change of the int4mm operator
There was a problem hiding this comment.
checking cpu and torch version 2.8
when do we need this?
There was a problem hiding this comment.
ao/torchao/dtypes/uintx/int4_cpu_layout.py
Lines 114 to 122 in 31f119e
sorry, should be 2_6 and 2_5. So even for the same device, there should be check for different version
There was a problem hiding this comment.
this is getting too complicated I feel, can we just drop the support for some pytorch versions?
There was a problem hiding this comment.
it's fine if you want to keep the version number, but still it would be good to change the function name to make it clearer I think, also this discussion doesn't have to block the PR, please feel free to merge and fix later
There was a problem hiding this comment.
this is getting too complicated I feel, can we just drop the support for some pytorch versions?
there are still regression tests in current CI, as early as torch 2.3. But I agree we should only support the latest version since AO is an experimental(?) innovation project
There was a problem hiding this comment.
yeah we're only committed to support the most recent 2 version of pytorch, but here I meant drop support for CPU layout for older pytorch versions and just keep one. torchao started as an experimental project but now it's more official.
| input_tmp = input | ||
| if not ( | ||
| is_device(input.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6 | ||
| if (not (check_cpu_version(input.device))) and ( |
There was a problem hiding this comment.
my comment was actually meant to say that we can encapsulate these format changes to some function input_tmp = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8) (together with the device and version checks) so we don't need to have these across the codebase since it's pretty error prone and we are not sure when to use them
but this can be done in a separate PR
|
@jerryzh168 |
* enable floating zero point with little numerical issue * unify weight packing * review to view * add justfy contiguous * fix torch.compile * remove typos in tests * overload copy_ for torch.load * copy_ need the 2nd args to be int4 * expose preserve_zero * refactor zero_point_domain dispatch * export zero_point_domain and preserve_zero as the top arguments * format * fix parameter initialization in UT Signed-off-by: Meng, Hengyu <hengyu.meng@intel.com> * encapsulate version check as helpers remove zero_point_dtype assigning Signed-off-by: Meng, Hengyu <hengyu.meng@intel.com> enable zp dtype: u8/s8/s16/s32/s64 Signed-off-by: Meng, Hengyu <hengyu.meng@intel.com> * fix zero_point_dtype Signed-off-by: Meng, Hengyu <hengyu.meng@intel.com> --------- Signed-off-by: Meng, Hengyu <hengyu.meng@intel.com>
The PR is a draft currently.The PR will add 2 kinds of INT4 support on XPU: floating zero points and integer zero points, following the discussion in #1264.
Integer zero points which has been natively supported via OneDNN pytorch/pytorch#137566
Floating zero points, the default behaviour in this repo, supported by intel/torch-xpu-ops#1130, more implementations on the way.