Commit 35f02b9
[OPTIONS] Set mma as default in PassContext() (#530)
## PR Description
### Summary
This PR sets `mma` (matrix-matrix multiply accumulate) as the default
computation mode, replacing `simt`. Due to the switch, precision is
affected since `fp32` inputs are cast to `tf32`, which has the same
mantissa bit as `fp16`. Test cases are updated to align with `fp16`
precision where feasible.
### Details
1. **Default MMA Computation Mode**:
- Set `mma` as the default mode in place of `simt`.
2. **Precision Adjustments in Tests**:
- Adjusted test cases to align with `fp16` precision, as `tf32` casting
shares the same mantissa bit with `fp16`.
3. **Tolerance Modifications**:
- Updated tolerances in tests affected by the precision difference,
particularly `test_attention`.
- Minimum observed tolerances:
- `mma`: 0.2
- `simt`: ~1e-4
### Testing Script for Tolerance
To determine the minimum tolerance for `test_attention`, the following
script was used:
```
import torch
from typing import Optional, Union
import pytest
import numpy as np
import numpy.testing
import hidet
import hidet.testing
from hidet.graph import FlowGraph
from hidet import ops
# Define fixture to create Hidet models
@pytest.fixture
def shared_weights(device):
wte = torch.randn([50257, 768], device=device)
wpe = torch.randn([1024, 768], device=device)
w1 = torch.randn([768, 768 * 3], device=device)
b1 = torch.randn([768 * 3], device=device)
return wte, wpe, w1, b1
# Define fixture to create Hidet models
@pytest.fixture
def hidet_model(shared_weights, device):
wte, wpe, w1, b1 = shared_weights
# Convert PyTorch tensors to Hidet tensors
wte_hidet = hidet.from_torch(wte)
wpe_hidet = hidet.from_torch(wpe)
w1_hidet = hidet.from_torch(w1)
b1_hidet = hidet.from_torch(b1)
def get_graph(seq):
n_head = 12
ids = hidet.symbol([seq], dtype='int32', device=device)
x = hidet.ops.take(wte_hidet, ids) + hidet.ops.take(wpe_hidet, hidet.ops.arange(ids.shape[0], device=device))
causal_mask = (1 - hidet.ops.tri(x.shape[0], dtype=x.dtype, device=x.device)) * -1e10
x = hidet.ops.matmul(x, w1_hidet) + b1_hidet
x = hidet.ops.reshape(x, [x.shape[0], 3, n_head, x.shape[1] // (3 * n_head)])
x = hidet.ops.transpose(x, [1, 2, 0, 3])
q, k, v = [t for t in hidet.ops.split(x, 3, axis=0)]
x = hidet.ops.softmax(q @ hidet.ops.transpose(k, [-1, -2]) / float(np.sqrt(q.shape[-1])) + causal_mask, axis=-1) @ v
return hidet.trace_from(x)
graph_dynamic = get_graph('seq')
graph_dynamic_opt = hidet.graph.optimize(graph_dynamic)
return graph_dynamic, graph_dynamic_opt, get_graph
# Define fixture to create PyTorch model
@pytest.fixture
def torch_model(shared_weights, device):
wte, wpe, w1, b1 = shared_weights
def get_graph(seq, ids):
n_head = 12
x = torch.index_select(wte, 0, ids) + torch.index_select(wpe, 0, torch.arange(ids.shape[0], device=device))
causal_mask = (1 - torch.triu(torch.ones((x.shape[0], x.shape[0]), dtype=x.dtype, device=x.device))) * -1e10
x = torch.matmul(x, w1) + b1
x = x.view(x.shape[0], 3, n_head, x.shape[1] // (3 * n_head))
x = x.permute(1, 2, 0, 3)
q, k, v = torch.chunk(x, 3, dim=0)
x = torch.softmax((q @ k.transpose(-1, -2)) / np.sqrt(q.shape[-1]) + causal_mask, dim=-1) @ v
return x
return get_graph
# Test case for each sequence length
@pytest.mark.parametrize("seq", [1, 2, 3, 4, 8])
@pytest.mark.parametrize('device', ['cuda'])
def test_attention(seq, hidet_model, torch_model, device):
graph_dynamic, graph_dynamic_opt, hidet_static_fn = hidet_model
torch_fn = torch_model
# Generate consistent ids
ids_torch = torch.randint(0, 50257, (seq,), dtype=torch.int32, device=device)
ids_hidet = hidet.from_torch(ids_torch)
# Hidet results
graph_static = hidet_static_fn(seq)
y_static_hidet = graph_static(ids_hidet)
y_dynamic_hidet = graph_dynamic(ids_hidet)
y_dynamic_opt_hidet = graph_dynamic_opt(ids_hidet)
# PyTorch results
y_torch = torch_fn(seq, ids_torch)
# Check for close matches
for y_hidet in [y_dynamic_opt_hidet]:
np.testing.assert_allclose(y_hidet.cpu().numpy(), y_torch.cpu().numpy(), atol=0, rtol=0)
```
This script runs `test_attention` across multiple sequence lengths,
generating consistent `ids` for both Hidet and PyTorch, and outputs the
observed tolerance.
### Code Changes
- Set `mma` as the default instead of `simt`.
- Updated test case precision settings to `fp16` where applicable.
- Adjusted tolerance in `test_attention` and other sensitive tests based
on the minimum tolerance values identified with the above script.
### Note
For `tests/unit_tests/test_dynamic_shape.py::test_attention[xxx-cuda]`,
it is occasionally observed that the absolute difference will boost up
to as high as 60.1 parent f8c057b commit 35f02b9
22 files changed
Lines changed: 22 additions & 31 deletions
File tree
- examples/gpt-2
- gallery
- developer-guides
- getting-started
- hidet-script
- tutorials
- python/hidet/graph
- frontend/torch
- ops/matmul
- transforms
- tests
- benchmarks
- frontends/torch
- models
- ir/parser
- lang/cute
- operators
- unit_tests
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
102 | 102 | | |
103 | 103 | | |
104 | 104 | | |
105 | | - | |
106 | 105 | | |
107 | 106 | | |
108 | 107 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
96 | 96 | | |
97 | 97 | | |
98 | 98 | | |
99 | | - | |
| 99 | + | |
100 | 100 | | |
101 | 101 | | |
102 | 102 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
45 | 45 | | |
46 | 46 | | |
47 | 47 | | |
48 | | - | |
| 48 | + | |
49 | 49 | | |
50 | 50 | | |
51 | 51 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
104 | 104 | | |
105 | 105 | | |
106 | 106 | | |
107 | | - | |
| 107 | + | |
108 | 108 | | |
109 | 109 | | |
110 | 110 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
117 | 117 | | |
118 | 118 | | |
119 | 119 | | |
120 | | - | |
| 120 | + | |
121 | 121 | | |
122 | 122 | | |
123 | 123 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
107 | 107 | | |
108 | 108 | | |
109 | 109 | | |
110 | | - | |
111 | 110 | | |
112 | 111 | | |
113 | 112 | | |
114 | 113 | | |
115 | 114 | | |
116 | | - | |
117 | | - | |
118 | 115 | | |
119 | 116 | | |
120 | 117 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
25 | 25 | | |
26 | 26 | | |
27 | 27 | | |
28 | | - | |
| 28 | + | |
29 | 29 | | |
30 | 30 | | |
31 | 31 | | |
| |||
779 | 779 | | |
780 | 780 | | |
781 | 781 | | |
782 | | - | |
| 782 | + | |
783 | 783 | | |
784 | 784 | | |
785 | 785 | | |
| |||
795 | 795 | | |
796 | 796 | | |
797 | 797 | | |
798 | | - | |
| 798 | + | |
799 | 799 | | |
800 | 800 | | |
801 | 801 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
96 | 96 | | |
97 | 97 | | |
98 | 98 | | |
99 | | - | |
| 99 | + | |
100 | 100 | | |
101 | 101 | | |
102 | 102 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
79 | 79 | | |
80 | 80 | | |
81 | 81 | | |
82 | | - | |
| 82 | + | |
83 | 83 | | |
84 | 84 | | |
85 | 85 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
151 | 151 | | |
152 | 152 | | |
153 | 153 | | |
154 | | - | |
155 | 154 | | |
156 | 155 | | |
0 commit comments