[Operator] Add a opaque operator base class#414
Merged
yaoyaoding merged 2 commits intohidet-org:mainfrom Jan 11, 2024
Merged
Conversation
yaoyaoding
added a commit
to yaoyaoding/hidet
that referenced
this pull request
Jan 12, 2024
When we do not want to gives the computation for some operator because
its too tedious or can not expressed using our computation defintion
DSL, we can define an opaque operator that only gives
1. the dtype and shape inference function that infer the output dtype
and shape given the inputs'
2. the implement function that implements the operator given the
input/output dtype and shape
An example to define an opaque operator to perform matrix
multiplication.
```python
from typing import List, Union
import hidet
from hidet import Tensor
from hidet.graph.ops.opaque import OpaqueOperator
from hidet.ir.dtypes import float32
from hidet.ir import IRModule
hidet.option.cache_dir('./outs/cache')
class OpaqueMatmul(OpaqueOperator):
def __init__(self, x: Tensor, y: Tensor):
super().__init__(
name='matmul',
inputs={
'x': x,
'y': y
},
)
def symbolic_forward(self, x: Tensor, y: Tensor):
assert x.dtype == y.dtype == float32
assert x.device.is_cuda()
m, k = x.shape
k, n = y.shape
return {
'z': self.symbol(
shape=[m, n],
dtype=x.dtype,
device=x.device
)
}
def implement_cuda(self, inputs: List[Tensor], outputs: List[Tensor]) -> Union[IRModule, List[IRModule]]:
import hidet
from hidet.lang import attrs
from hidet.lang.types import f32
from hidet.lang.cuda import threadIdx, blockIdx
m_size, k_size = inputs[0].shape
k_size, n_size = inputs[1].shape
with hidet.script_module() as script_module:
@hidet.script
def matmul(x: f32[m_size, k_size], y: f32[k_size, n_size], z: f32[m_size, n_size]):
attrs.func_kind = 'cuda_kernel'
attrs.cuda.block_dim = (32, 32)
attrs.cuda.grid_dim = ((n_size + 31) // 32, (m_size + 31) // 32)
i = threadIdx.x + blockIdx.x * 32
j = threadIdx.y + blockIdx.y * 32
if i < n_size and j < m_size:
z[j, i] = 0.0
for k in range(k_size):
z[j, i] += x[j, k] * y[k, i]
return script_module.ir_module()
def opaque_matmul(x: Tensor, y: Tensor) -> Tensor:
return OpaqueMatmul(x, y).outputs[0]
def test_opaque_operator():
a = hidet.randn([128, 128], dtype='float32', device='cuda')
b = hidet.randn([128, 128], dtype='float32', device='cuda')
c1 = opaque_matmul(a, b)
c2 = a @ b
print(hidet.ops.max(hidet.ops.abs(c1 - c2), dims=[0, 1]))
```
vadiklyutiy
pushed a commit
that referenced
this pull request
Dec 19, 2024
…barrier (#414) Add primitives: - `prmt` - `lop3` - `sub_f16x2`, `fma_f16x2` - `barrier` See the tests and function documentation for the usage of each primitive.
vadiklyutiy
pushed a commit
that referenced
this pull request
Dec 20, 2024
…barrier (#414) Add primitives: - `prmt` - `lop3` - `sub_f16x2`, `fma_f16x2` - `barrier` See the tests and function documentation for the usage of each primitive.
vadiklyutiy
pushed a commit
that referenced
this pull request
Dec 26, 2024
…barrier (#414) Add primitives: - `prmt` - `lop3` - `sub_f16x2`, `fma_f16x2` - `barrier` See the tests and function documentation for the usage of each primitive.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
When we do not want to gives the computation for some operator because its too tedious or can not expressed using our computation defintion DSL, we can define an opaque operator that only gives
An example to define an opaque operator to perform matrix multiplication.