CPU AVX implementation for Softmax, Norm#357
Conversation
works on 8x8 at least but bad exp save for omp changes working and faster than pytorch works and is fast but exp is WIP remove useless files minor changes for rebase delete trash fix trash fix trash initial commit works on 8x8 at least but bad exp save for omp changes working and faster than pytorch works and is fast but exp is WIP remove useless files minor changes for rebase delete trash fix trash fix trash change imports fix for diff size, compiledmodule error fix
works on 8x8 at least but bad exp save for omp changes working and faster than pytorch works and is fast but exp is WIP remove useless files minor changes for rebase delete trash fix trash fix trash
works on 8x8 at least but bad exp save for omp changes working and faster than pytorch works and is fast but exp is WIP remove useless files minor changes for rebase delete trash fix trash fix trash
works on 8x8 at least but bad exp save for omp changes working and faster than pytorch works and is fast but exp is WIP remove useless files minor changes for rebase delete trash fix trash fix trash
yaoyaoding
left a comment
There was a problem hiding this comment.
Thanks @fishingguy456 for your first PR to hidet!
I left some comments. In general,
- do not forget to run tests, lint and format before submit the PR.
- for the new operators, we should add some tests for them. See the examples in
tests/operators/.... - our current design allow one task to have cpu and cuda implementation, and they share the same property of whether allow prologue, epilogue. When we want to change the allow properties, it is better to create a new task override the original one, so that it does not interfere with our original operator. In the future, we might want to add a device parameter to these functions (like
allow_prologue(self, device) -> bool) so that we do not need creating a new class. But for now, let's create a new class and add resolve rule.
| def run_batch_matmul(self, a: Tensor, b: Tensor, is_cpu: bool) -> Tensor: | ||
| if is_cpu: |
There was a problem hiding this comment.
We can directly check the device of tensor, without the need to pass is_cpu as parameter.
is_cpu = a.device.is_cpu()
| def allow_epilogue(self) -> bool: | ||
| return True | ||
|
|
There was a problem hiding this comment.
If the cpu and cuda task have different behavior, it is better to create a subclass of the task and override the subclass:
class CPUNormalizeTask(NormalizeTask):
...
There was a problem hiding this comment.
And implement the resolve rule to convert the normalize operator to corresponding cpu_normalize operator.
| norm_cpu_kernel.kind = "cpu_kernel" | ||
| avx_f32x8_find_sum.kind = "cpu_internal" |
There was a problem hiding this comment.
Better to avoid setting the function attributes outside its definition. Instead, use
from hidet.lang import attrs
@hidet.script
def norm_cpu_kernel(...):
attrs.func_kind = "cpu_kernel"
...
python/hidet/graph/ops/softmax.py
Outdated
| from hidet.ir.builders import StmtBuilder | ||
| from hidet.ir.primitives import active_mask, shfl_down_sync, shfl_sync | ||
| from .utils import Task, TensorNode, compute, reduce | ||
| from typing import List, Union |
There was a problem hiding this comment.
Move to the top.
Remember to run format & lint, see https://docs.hidet.org/stable/developer-guides/contributing.html#contributing
python/hidet/graph/ops/softmax.py
Outdated
| return ir_module | ||
|
|
||
| def implement_cpu(self, working_dir: str) -> Union[IRModule, List[IRModule]]: | ||
| # if not all(is_constant(dim) for dim in self.inputs[0].shape)\ |
There was a problem hiding this comment.
| # if not all(is_constant(dim) for dim in self.inputs[0].shape)\ |
| def allow_epilogue(self) -> bool: | ||
| return False | ||
|
|
||
| def allow_prologue(self) -> bool: | ||
| return False |
There was a problem hiding this comment.
Create a CPU version of the operator because cuda version allows prologue & epilogue.
python/hidet/graph/ops/softmax.py
Outdated
| softmax_cpu_kernel.kind = "cpu_kernel" | ||
| apply_exponent.kind = "cpu_internal" |
python/hidet/ir/expr.py
Outdated
| if not (isinstance(func_var, Var) and isinstance(args, tuple)): | ||
| print(func_var, args) | ||
| print(type(args[0])) | ||
| print(type(func_var), type(args)) |
There was a problem hiding this comment.
| if not (isinstance(func_var, Var) and isinstance(args, tuple)): | |
| print(func_var, args) | |
| print(type(args[0])) | |
| print(type(func_var), type(args)) |
| from hidet.ir.func import Function | ||
|
|
||
| @script | ||
| def avx_x86_f32x8_find_sum(x: f32x8) -> f32: |
There was a problem hiding this comment.
Is there any convention to use "find" in the function name?
If not, I would prefer to name directly as "avx_x86_f32x8_sum" and "avx_x86_f32x8_max".
|
Thanks @fishingguy456 ! Could you also add a test for softmax? Hi @BolinSNLHM, could you have a look of this PR? I did not check the kernel implementation details. |
| def avx_x86_f32x8_sum(x: f32x8) -> f32: | ||
| attrs.func_kind = "cpu_internal" | ||
| attrs.func_name = "avx_x86_float32x8_sum" | ||
| sum_vec = call_primitive_func( | ||
| 'avx_x86_float32x4_add', | ||
| [ | ||
| call_primitive_func('avx_x86_float32x8_extract_half', [x, 0b0]), | ||
| call_primitive_func('avx_x86_float32x8_extract_half', [x, 0b1]), | ||
| ], | ||
| ) | ||
| sum_vec = call_primitive_func('avx_x86_float32x4_hadd', [sum_vec, sum_vec]) | ||
| sum_vec = call_primitive_func('avx_x86_float32x4_hadd', [sum_vec, sum_vec]) | ||
| return call_primitive_func('avx_x86_float32x4_extract_last', [sum_vec]) | ||
|
|
||
| assert isinstance(avx_x86_f32x8_sum, Function) | ||
| register_primitive_function(avx_x86_f32x8_sum.name, avx_x86_f32x8_sum) | ||
|
|
||
| @script | ||
| def avx_x86_f32x8_scalar_max(x: f32x8) -> f32: | ||
| attrs.func_kind = "cpu_internal" | ||
| attrs.func_name = "avx_x86_float32x8_scalar_max" | ||
| y = call_primitive_func('avx_x86_float32x8_permute_2f128', [x, x, 1]) | ||
| m1 = call_primitive_func('avx_x86_float32x8_max', [x, y]) | ||
| m2 = call_primitive_func('avx_x86_float32x8_permute', [m1, 0b01001110]) | ||
| m3 = call_primitive_func('avx_x86_float32x8_max', [m1, m2]) | ||
| m4 = call_primitive_func('avx_x86_float32x8_permute', [m3, 0b10110001]) | ||
| m = call_primitive_func('avx_x86_float32x8_max', [m3, m4]) | ||
| return call_primitive_func('avx_x86_float32x8_extract_last', [m]) |
There was a problem hiding this comment.
Would it be possible to only declare the primitives (e.g., avx_x86_f32x8_extract_half) in this file, and then define functions like avx_x86_f32x8_sum as helper functions in Hidet Script in a separate file where it would be needed? The code should work as it is, but it looks a bit odd to have hidet.script decorator and multiple calls to call_primitive_func here...
|
|
||
| def allow_epilogue(self) -> bool: | ||
| return True | ||
| return False |
There was a problem hiding this comment.
Why should we change this to False? 🤔
| from hidet.lang import script, attrs | ||
| from hidet.ir.dtypes import f32x8, f32 | ||
| from hidet.ir.func import Function | ||
|
|
||
| @script | ||
| def avx_x86_f32x8_sum(x: f32x8) -> f32: | ||
| attrs.func_kind = "cpu_internal" | ||
| attrs.func_name = "avx_x86_float32x8_sum" | ||
| sum_vec = call_primitive_func( | ||
| 'avx_x86_float32x4_add', | ||
| [ | ||
| call_primitive_func('avx_x86_float32x8_extract_half', [x, 0b0]), | ||
| call_primitive_func('avx_x86_float32x8_extract_half', [x, 0b1]), | ||
| ], | ||
| ) | ||
| sum_vec = call_primitive_func('avx_x86_float32x4_hadd', [sum_vec, sum_vec]) | ||
| sum_vec = call_primitive_func('avx_x86_float32x4_hadd', [sum_vec, sum_vec]) | ||
| return call_primitive_func('avx_x86_float32x4_extract_last', [sum_vec]) | ||
|
|
||
| assert isinstance(avx_x86_f32x8_sum, Function) | ||
| register_primitive_function(avx_x86_f32x8_sum.name, avx_x86_f32x8_sum) | ||
|
|
||
| @script | ||
| def avx_x86_f32x8_scalar_max(x: f32x8) -> f32: | ||
| attrs.func_kind = "cpu_internal" | ||
| attrs.func_name = "avx_x86_float32x8_scalar_max" | ||
| y = call_primitive_func('avx_x86_float32x8_permute_2f128', [x, x, 1]) | ||
| m1 = call_primitive_func('avx_x86_float32x8_max', [x, y]) | ||
| m2 = call_primitive_func('avx_x86_float32x8_permute', [m1, 0b01001110]) | ||
| m3 = call_primitive_func('avx_x86_float32x8_max', [m1, m2]) | ||
| m4 = call_primitive_func('avx_x86_float32x8_permute', [m3, 0b10110001]) | ||
| m = call_primitive_func('avx_x86_float32x8_max', [m3, m4]) | ||
| return call_primitive_func('avx_x86_float32x8_extract_last', [m]) |
There was a problem hiding this comment.
I recommand to move these user-defined functions over avx (not the ones provided by underlying vector library) like avx_x86_f32x8_sum to another file called avx_helpers.py.
For functions like avx_x86_float32x4_extract_last, we also need to define a wrapper function like
def avx_x86_float32x4_extract_last(x: Expr) -> Call:
return call_primitive_func('avx_x86_float32x4_extract_last', [x])There was a problem hiding this comment.
In the new file, we directly use avx_x86_float32x4_extract_last(...) in the hidet script, instead of calling call_primitive_func.
|
Thanks @fishingguy456 ! |
Working but inefficient batch matmul. Takes path of matmul_f32_x86 instead of cpu autoscheduler.