[Quantization] Quantization API#309
Conversation
…ant-static-matmul
yaoyaoding
left a comment
There was a problem hiding this comment.
Thanks @Aalanli ! Good progress!
I left some comments on the minor issues.
python/hidet/graph/nn/linear.py
Outdated
| class SymQuantLinearTransposed(Module): | ||
| def __init__(self, weight: Tensor, bias: Optional[Tensor] = None, quant_type: str = 'int8'): | ||
| super().__init__() | ||
| self.in_features = weight.shape[0] | ||
| self.out_features = weight.shape[1] | ||
| qweight, scale = ops.symmetric_quantize(weight, quant_type=quant_type, dims=[-1]) | ||
| self.qweight = qweight | ||
| self.scale = scale | ||
| self.bias = bias | ||
|
|
||
| def extra_str(self) -> str: | ||
| return 'in_features={}, out_features={}'.format(self.in_features, self.out_features) | ||
|
|
||
| def forward(self, x: Tensor) -> Tensor: | ||
| x = ops.matmul(x, ops.symmetric_dequantize(ops.barrier(self.qweight), self.scale, dims=[-1])) | ||
| if self.bias is not None: | ||
| x = ops.add(x, self.bias) | ||
| return x |
There was a problem hiding this comment.
Maybe we can also put all the quantization nn layers to a sub-namespace like hidet.graph.nn.quantized (like torch people used torch.nn.quantized) or hidet.graph.nn.quant (what you used in ops).
There was a problem hiding this comment.
Right, I think that this module is currently not needed, since quantization is applied during graph pass anyways. And the copying mechanisms won't work here when converting from torch.
| @@ -15,7 +15,7 @@ | |||
| from hidet.ir.compute import reduce | |||
There was a problem hiding this comment.
@xinli-git, could you help to have a look at the change of norm? Thanks!
In the future, let's try to unify the schedule template for different data types, which will reduce the complexity of maintanance.
There was a problem hiding this comment.
Sorry, this change in norm is exactly the same as the earlier one. Since I needed to apply the same fix for some tests to pass.
|
Hi @Aalanli, I forget one thing. It is recommanded to add put some of code in the |
…. ) (#294) [Ir][Primitives] add vectorized conversion instructions [Ir][CuTe] add reduce primitives in cute (#295) [Ir][CuTe] add mma primitives (#296) [Ir][CuTe] add other primitives in cute (#297) [Transforms][CuTe] add instruction selection pass (#298) [Transforms][CuTe] add resolve bank conflict pass (#299) [Transforms][CuTe] add resolve auto keywords pass (#300) [Transforms][CuTe] add shared memory allocation pass (#301) [Transforms][CuTe] add vectorize elementwise operation pass (#302) [Transforms][CuTe] add analysis pass (#303) [Transforms][CuTe] add canonicalization pass (#304) [Transforms][CuTe] add deadcode elimination pass (#305) [Transforms][CuTe] refactor cute lowering pass (#306) [Graph][Ops] matmul cute (#307) [Ir] cute miscs (#308) [Tests] cute tests (#309) [Chore] fix ci (#313) --------- Co-authored-by: xiaocenxiaocen <xiao.zhang@centml.ai>
…. ) (#294) [Ir][Primitives] add vectorized conversion instructions [Ir][CuTe] add reduce primitives in cute (#295) [Ir][CuTe] add mma primitives (#296) [Ir][CuTe] add other primitives in cute (#297) [Transforms][CuTe] add instruction selection pass (#298) [Transforms][CuTe] add resolve bank conflict pass (#299) [Transforms][CuTe] add resolve auto keywords pass (#300) [Transforms][CuTe] add shared memory allocation pass (#301) [Transforms][CuTe] add vectorize elementwise operation pass (#302) [Transforms][CuTe] add analysis pass (#303) [Transforms][CuTe] add canonicalization pass (#304) [Transforms][CuTe] add deadcode elimination pass (#305) [Transforms][CuTe] refactor cute lowering pass (#306) [Graph][Ops] matmul cute (#307) [Ir] cute miscs (#308) [Tests] cute tests (#309) [Chore] fix ci (#313) --------- Co-authored-by: xiaocenxiaocen <xiao.zhang@centml.ai>
…. ) (#294) [Ir][Primitives] add vectorized conversion instructions [Ir][CuTe] add reduce primitives in cute (#295) [Ir][CuTe] add mma primitives (#296) [Ir][CuTe] add other primitives in cute (#297) [Transforms][CuTe] add instruction selection pass (#298) [Transforms][CuTe] add resolve bank conflict pass (#299) [Transforms][CuTe] add resolve auto keywords pass (#300) [Transforms][CuTe] add shared memory allocation pass (#301) [Transforms][CuTe] add vectorize elementwise operation pass (#302) [Transforms][CuTe] add analysis pass (#303) [Transforms][CuTe] add canonicalization pass (#304) [Transforms][CuTe] add deadcode elimination pass (#305) [Transforms][CuTe] refactor cute lowering pass (#306) [Graph][Ops] matmul cute (#307) [Ir] cute miscs (#308) [Tests] cute tests (#309) [Chore] fix ci (#313) --------- Co-authored-by: xiaocenxiaocen <xiao.zhang@centml.ai>
Add extensible quantization API.
See examples/quantization/gpt2.py for usage example.
On gpt2 with first 500 test split of wikitext-2-raw-v1:
original f32 ppl: 129.88427568662286
original f32 acc: [top-1: 0.291, top-5: 0.486, top-10: 0.561]
quantized f16 ppl: 131.41456528937462
quantized f16 acc: [top-1: 0.288, top-5: 0.482, top-10: 0.556]
quantized f16 -> int8 ppl: 131.11489348364347
quantized f16 -> int8 acc: [top-1: 0.284, top-5: 0.481, top-10: 0.554]
Currently supported: