Skip to content

Commit 7ba6fb6

Browse files
leslie-fang-intelpytorchmergebot
authored andcommitted
[Inductor][CPP] Enable vectorized fp8 E5M2 quant dequant (#153365)
**Summary** This PR enables the vectorization codegen with Inductor CPP backend for `FP8_E5M2` `quant` from `float32` and `dequant` to `float32`. **Test Plan** ``` python test/inductor/test_cpu_repro.py -k test_dequant_quant_lowering_fp8_e5m2 ``` Pull Request resolved: #153365 Approved by: https://github.com/jansel, https://github.com/jgong5 ghstack dependencies: #152417, #152418, #153364
1 parent 84b657d commit 7ba6fb6

3 files changed

Lines changed: 35 additions & 2 deletions

File tree

aten/src/ATen/cpu/vec/vec512/vec512_convert.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,28 @@ struct VecConvert<float, 1, Float8_e4m3fn, 1> {
312312
}
313313
};
314314

315+
template <>
316+
struct VecConvert<Float8_e5m2, 1, float, 1> {
317+
static inline VectorizedN<Float8_e5m2, 1> apply(
318+
const VectorizedN<float, 1>& src_n) {
319+
at::vec::Vectorized<float> src = src_n[0];
320+
__m128i res128 = cvtfp32_fp8e5m2(src);
321+
return at::vec::Vectorized<Float8_e5m2>(_mm512_castsi128_si512(res128));
322+
}
323+
};
324+
325+
template <>
326+
struct VecConvert<float, 1, Float8_e5m2, 1> {
327+
static inline VectorizedN<float, 1> apply(
328+
const VectorizedN<Float8_e5m2, 1>& src_n) {
329+
// cvt first 16x8 bits from Float8_e5m2 to float
330+
at::vec::Vectorized<Float8_e5m2> src = src_n[0];
331+
__m512 result;
332+
cvtfp8e5m2_fp32(_mm512_castsi512_si128(src), result);
333+
return at::vec::Vectorized<float>(result);
334+
}
335+
};
336+
315337
#endif
316338

317339
} // namespace CPU_CAPABILITY

test/inductor/test_cpu_repro.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1418,10 +1418,15 @@ def fn(
14181418
use_quant_list = [False, True]
14191419
use_tensor_overload_list = [False, True]
14201420

1421-
assert dtype in [torch.uint8, torch.int8, torch.float8_e4m3fn]
1421+
assert dtype in [
1422+
torch.uint8,
1423+
torch.int8,
1424+
torch.float8_e4m3fn,
1425+
torch.float8_e5m2,
1426+
]
14221427
quant_min = 0 if dtype == torch.uint8 else -128
14231428
quant_max = 255 if dtype == torch.uint8 else 127
1424-
if dtype == torch.float8_e4m3fn:
1429+
if dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
14251430
quant_min = int(torch.finfo(dtype).min)
14261431
quant_max = int(torch.finfo(dtype).max)
14271432
use_tensor_overload_list = [
@@ -1486,6 +1491,10 @@ def test_dequant_quant_lowering_int8(self):
14861491
def test_dequant_quant_lowering_fp8_e4m3(self):
14871492
self._test_dequant_quant_lowering_helper(torch.float8_e4m3fn)
14881493

1494+
@requires_vectorization
1495+
def test_dequant_quant_lowering_fp8_e5m2(self):
1496+
self._test_dequant_quant_lowering_helper(torch.float8_e5m2)
1497+
14891498
def _test_dequant_maxpool2d_lowering_helper(self, dtype):
14901499
def fn(x, scale, zero_point, quant_min, quant_max, dtype):
14911500
x = torch.ops.quantized_decomposed.dequantize_per_tensor(

torch/_inductor/codegen/cpp.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ def get_export_declaration():
155155
torch.int32,
156156
torch.int64,
157157
torch.float8_e4m3fn,
158+
torch.float8_e5m2,
158159
]
159160

160161
MASKED_VECTORIZABLE_DTYPES: list[torch.dtype] = [
@@ -1609,6 +1610,7 @@ def to_dtype(x, dtype, src_dtype=None, use_compute_dtypes=True):
16091610
torch.int32,
16101611
torch.int64,
16111612
torch.float8_e4m3fn,
1613+
torch.float8_e5m2,
16121614
], f"{__name__} does not support {dtype}"
16131615
assert isinstance(x, CppCSEVariable)
16141616
src_dtype = x.dtype

0 commit comments

Comments
 (0)