Skip to content

Commit 35f02b9

Browse files
ZichuWuvadiklyutiy
authored andcommitted
[OPTIONS] Set mma as default in PassContext() (#530)
## PR Description ### Summary This PR sets `mma` (matrix-matrix multiply accumulate) as the default computation mode, replacing `simt`. Due to the switch, precision is affected since `fp32` inputs are cast to `tf32`, which has the same mantissa bit as `fp16`. Test cases are updated to align with `fp16` precision where feasible. ### Details 1. **Default MMA Computation Mode**: - Set `mma` as the default mode in place of `simt`. 2. **Precision Adjustments in Tests**: - Adjusted test cases to align with `fp16` precision, as `tf32` casting shares the same mantissa bit with `fp16`. 3. **Tolerance Modifications**: - Updated tolerances in tests affected by the precision difference, particularly `test_attention`. - Minimum observed tolerances: - `mma`: 0.2 - `simt`: ~1e-4 ### Testing Script for Tolerance To determine the minimum tolerance for `test_attention`, the following script was used: ``` import torch from typing import Optional, Union import pytest import numpy as np import numpy.testing import hidet import hidet.testing from hidet.graph import FlowGraph from hidet import ops # Define fixture to create Hidet models @pytest.fixture def shared_weights(device): wte = torch.randn([50257, 768], device=device) wpe = torch.randn([1024, 768], device=device) w1 = torch.randn([768, 768 * 3], device=device) b1 = torch.randn([768 * 3], device=device) return wte, wpe, w1, b1 # Define fixture to create Hidet models @pytest.fixture def hidet_model(shared_weights, device): wte, wpe, w1, b1 = shared_weights # Convert PyTorch tensors to Hidet tensors wte_hidet = hidet.from_torch(wte) wpe_hidet = hidet.from_torch(wpe) w1_hidet = hidet.from_torch(w1) b1_hidet = hidet.from_torch(b1) def get_graph(seq): n_head = 12 ids = hidet.symbol([seq], dtype='int32', device=device) x = hidet.ops.take(wte_hidet, ids) + hidet.ops.take(wpe_hidet, hidet.ops.arange(ids.shape[0], device=device)) causal_mask = (1 - hidet.ops.tri(x.shape[0], dtype=x.dtype, device=x.device)) * -1e10 x = hidet.ops.matmul(x, w1_hidet) + b1_hidet x = hidet.ops.reshape(x, [x.shape[0], 3, n_head, x.shape[1] // (3 * n_head)]) x = hidet.ops.transpose(x, [1, 2, 0, 3]) q, k, v = [t for t in hidet.ops.split(x, 3, axis=0)] x = hidet.ops.softmax(q @ hidet.ops.transpose(k, [-1, -2]) / float(np.sqrt(q.shape[-1])) + causal_mask, axis=-1) @ v return hidet.trace_from(x) graph_dynamic = get_graph('seq') graph_dynamic_opt = hidet.graph.optimize(graph_dynamic) return graph_dynamic, graph_dynamic_opt, get_graph # Define fixture to create PyTorch model @pytest.fixture def torch_model(shared_weights, device): wte, wpe, w1, b1 = shared_weights def get_graph(seq, ids): n_head = 12 x = torch.index_select(wte, 0, ids) + torch.index_select(wpe, 0, torch.arange(ids.shape[0], device=device)) causal_mask = (1 - torch.triu(torch.ones((x.shape[0], x.shape[0]), dtype=x.dtype, device=x.device))) * -1e10 x = torch.matmul(x, w1) + b1 x = x.view(x.shape[0], 3, n_head, x.shape[1] // (3 * n_head)) x = x.permute(1, 2, 0, 3) q, k, v = torch.chunk(x, 3, dim=0) x = torch.softmax((q @ k.transpose(-1, -2)) / np.sqrt(q.shape[-1]) + causal_mask, dim=-1) @ v return x return get_graph # Test case for each sequence length @pytest.mark.parametrize("seq", [1, 2, 3, 4, 8]) @pytest.mark.parametrize('device', ['cuda']) def test_attention(seq, hidet_model, torch_model, device): graph_dynamic, graph_dynamic_opt, hidet_static_fn = hidet_model torch_fn = torch_model # Generate consistent ids ids_torch = torch.randint(0, 50257, (seq,), dtype=torch.int32, device=device) ids_hidet = hidet.from_torch(ids_torch) # Hidet results graph_static = hidet_static_fn(seq) y_static_hidet = graph_static(ids_hidet) y_dynamic_hidet = graph_dynamic(ids_hidet) y_dynamic_opt_hidet = graph_dynamic_opt(ids_hidet) # PyTorch results y_torch = torch_fn(seq, ids_torch) # Check for close matches for y_hidet in [y_dynamic_opt_hidet]: np.testing.assert_allclose(y_hidet.cpu().numpy(), y_torch.cpu().numpy(), atol=0, rtol=0) ``` This script runs `test_attention` across multiple sequence lengths, generating consistent `ids` for both Hidet and PyTorch, and outputs the observed tolerance. ### Code Changes - Set `mma` as the default instead of `simt`. - Updated test case precision settings to `fp16` where applicable. - Adjusted tolerance in `test_attention` and other sensitive tests based on the minimum tolerance values identified with the above script. ### Note For `tests/unit_tests/test_dynamic_shape.py::test_attention[xxx-cuda]`, it is occasionally observed that the absolute difference will boost up to as high as 60.
1 parent f8c057b commit 35f02b9

22 files changed

Lines changed: 22 additions & 31 deletions

File tree

examples/gpt-2/gpt_model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,6 @@ def gpt2(model_size: str = "124M", seq_length: Optional[int] = 1000, use_fp16=Fa
102102
with hidet.graph.PassContext() as ctx:
103103
if use_fp16:
104104
ctx.set_precision('float16')
105-
ctx.set_mma('mma')
106105
graph_opt = hidet.graph.optimize(graph)
107106

108107
hidet.save_graph(graph_opt, hf_path)

gallery/developer-guides/add-torch-operator-mapping.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def run_model():
9696
x = torch.randn(10, 10, device='cuda')
9797
y1 = model_opt(x)
9898
y2 = model(x)
99-
torch.testing.assert_close(actual=y1, expected=y2)
99+
torch.testing.assert_close(actual=y1, expected=y2, atol=3e-3, rtol=3e-3)
100100
print('success!')
101101

102102

gallery/getting-started/quick-start.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
y2 = model(x)
4646

4747
# check the correctness
48-
torch.testing.assert_close(actual=y1, expected=y2, rtol=1e-2, atol=1e-2)
48+
torch.testing.assert_close(actual=y1, expected=y2, rtol=2e-2, atol=2e-2)
4949

5050

5151
# benchmark the performance

gallery/hidet-script/3-kernel-functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def matmul(a: f32[m_size, k_size], b: f32[k_size, n_size], c: f32[m_size, n_size
104104
module(a, b, c)
105105

106106
# compare the result with torch.matmul
107-
hidet.utils.assert_close(c, a.torch() @ b.torch(), atol=1e-4, rtol=1e-4)
107+
hidet.utils.assert_close(c, a.torch() @ b.torch(), atol=1e-3, rtol=1e-3)
108108

109109
# %%
110110
# We can check the generated source code:

gallery/tutorials/optimize-onnx-model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def bench_hidet_graph(graph: hidet.FlowGraph):
117117
cuda_graph = graph.cuda_graph()
118118
(output,) = cuda_graph.run([data])
119119
np.testing.assert_allclose(
120-
actual=output.cpu().numpy(), desired=torch_output.cpu().numpy(), rtol=1e-2, atol=1e-2
120+
actual=output.cpu().numpy(), desired=torch_output.cpu().numpy(), rtol=5e-2, atol=5e-2
121121
)
122122
print(' Hidet: {:.3f} ms'.format(benchmark_func(lambda: cuda_graph.run())))
123123

python/hidet/graph/frontend/torch/dynamo_backends.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,14 +107,11 @@ def get_flow_graph(interpreter: Interpreter, example_inputs):
107107

108108
def get_compiled_graph(flow_graph: FlowGraph):
109109
parallel_k = dynamo_config['parallel_k']
110-
tensor_core = dynamo_config['use_tensor_core']
111110
save_dir = dynamo_config['dump_graph_ir']
112111
with PassContext() as ctx:
113112
if save_dir:
114113
graph_dir = resolve_save_dir_multigraph(save_dir)
115114
ctx.save_graph_instrument(graph_dir)
116-
if tensor_core:
117-
ctx.set_mma('mma' if tensor_core else 'simt')
118115
ctx.set_parallel_k(disabled=(parallel_k == 'disabled'), search=(parallel_k == 'search'))
119116
ctx.allow_source_graph_removal(True)
120117
logger.info('start to optimize the flow graph')

python/hidet/graph/ops/matmul/batch_matmul.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626

2727
class BatchMatmulTask(Task):
28-
def __init__(self, a: TensorNode, b: TensorNode, mma: str = 'simt'):
28+
def __init__(self, a: TensorNode, b: TensorNode, mma: str = 'mma'):
2929
batch_size, m_size, k_size = a.shape
3030
batch_size, k_size, n_size = b.shape
3131
self.batch_size = batch_size
@@ -779,7 +779,7 @@ def is_false():
779779

780780

781781
class BatchMatmulOp(Operator):
782-
def __init__(self, a: Tensor, b: Tensor, mma: str = 'simt'):
782+
def __init__(self, a: Tensor, b: Tensor, mma: str = 'mma'):
783783
# if is_false(a.shape[0] == b.shape[0]) or is_false(a.shape[2] == b.shape[1]):
784784
# raise
785785
if not (
@@ -795,7 +795,7 @@ def __init__(self, a: Tensor, b: Tensor, mma: str = 'simt'):
795795
super().__init__(inputs=[a, b], attributes={'mma': mma}, task=task)
796796

797797

798-
def batch_matmul(a: Tensor, b: Tensor, mma: str = 'simt') -> Tensor:
798+
def batch_matmul(a: Tensor, b: Tensor, mma: str = 'mma') -> Tensor:
799799
"""Batched matrix multiplication.
800800
801801
Parameters

python/hidet/graph/ops/matmul/resolve.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ class MatmulResolveRule(ResolveRule):
9696

9797
def run_batch_matmul(self, a: Tensor, b: Tensor) -> Tensor:
9898
parallel_k = self.get_config('parallel_k', default='default') # 'default', 'search', 2, 4, ...
99-
mma = self.get_config('mma', default='simt') # 'simt', 'mma'
99+
mma = self.get_config('mma', default='mma') # 'simt', 'mma'
100100

101101
if any(not isinstance(v, int) for v in a.shape + b.shape):
102102
nparts = 1

python/hidet/graph/transforms/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def __init__(self):
7979
'reduce_precision': None,
8080
# mma primitive:
8181
# ['simt', 'mma']
82-
'mma': 'simt',
82+
'mma': 'mma',
8383
# parallel k
8484
# ['default', 'disabled', 'search', 2, 4, ...]
8585
'parallel_k': 'default',

tests/benchmarks/bench_op.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,5 @@ def bench_reduce(params: str, *args, **kwargs) -> float:
151151

152152
with hidet.graph.PassContext() as ctx:
153153
ctx.set_reduce_precision(dtype)
154-
ctx.set_mma('mma')
155154
latency = bench_func(params, dtype)
156155
print(latency)

0 commit comments

Comments
 (0)