Skip to content

Commit 6e6231f

Browse files
Hector Yuenfacebook-github-bot
authored andcommitted
unit test for fc parallelization aot (#50056)
Summary: Pull Request resolved: #50056 buck test //caffe2/caffe2/contrib/fakelowp/test:test_chunkingnnpi -- --fallback-classic Test Plan: https://our.intern.facebook.com/intern/testinfra/testrun/7036874446100155 Reviewed By: venkatacrc Differential Revision: D25731079 fbshipit-source-id: 4aa4ffc641659cd90bf4670d28cb43e43ae76dcd
1 parent ee80b45 commit 6e6231f

1 file changed

Lines changed: 142 additions & 0 deletions

File tree

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
# Must happen before importing caffe2.python.*
2+
import caffe2.python.fakelowp.init_shared_libs # noqa
3+
import datetime
4+
import numpy as np
5+
from hypothesis import given, settings, example
6+
from hypothesis import strategies as st
7+
from caffe2.python import core, workspace
8+
from caffe2.python.onnx.onnxifi import onnxifi_caffe2_net
9+
from caffe2.python.fakelowp.test_utils import print_test_debug_info
10+
import caffe2.python.serialized_test.serialized_test_util as serial
11+
12+
# Test that parallel chunks behave the same way as the serial one
13+
14+
workspace.GlobalInit(
15+
[
16+
"caffe2",
17+
"--glow_global_fp16=1",
18+
"--glow_global_fused_scale_offset_fp16=1",
19+
"--glow_global_force_sls_fp16_accum=1",
20+
"--glow_nnpi_num_parallel_chunks=2",
21+
"--glow_use_dag_optimizer=false",
22+
"--glow_dump_graph=true",
23+
]
24+
)
25+
26+
class Fusions(serial.SerializedTestCase):
27+
def _get_scale_zp(self, tensor):
28+
tensor_max = np.max(tensor)
29+
tensor_min = min(0, np.min(tensor))
30+
scale = np.float32(np.float16((tensor_max - tensor_min) / 255.0))
31+
if scale < 1e-6:
32+
scale = 1e-6
33+
zero_point = 0 - tensor_min / scale
34+
zero_point = int(round(np.clip(zero_point, 0, 255.0)))
35+
return (scale, zero_point)
36+
37+
@given(
38+
scale=st.floats(1e-4, 1e2),
39+
zp=st.integers(-128, 128),
40+
rand_seed=st.integers(0, 65534),
41+
m=st.integers(32, 64),
42+
k=st.integers(1000, 6000),
43+
n=st.integers(200, 600),
44+
)
45+
# @example(m=64, k=5423, n=553, scale=1e-3, zp=120, rand_seed=1)
46+
@settings(deadline=datetime.timedelta(seconds=1000), max_examples=1)
47+
def test_ParallelFC(self, m, k, n, scale, zp, rand_seed):
48+
np.random.seed(rand_seed)
49+
workspace.ResetWorkspace()
50+
51+
# Y = W_T * X + b
52+
X_fp32 = np.random.uniform(-1, 1, size=(m, k)).astype(np.float16) \
53+
.astype(np.float32)
54+
55+
W_fp32 = np.random.uniform(-1, 1, size=(n, k)).astype(np.float32)
56+
b_fp32 = np.zeros((n,), dtype=np.float32)
57+
58+
X_scale, X_zero_point = self._get_scale_zp(X_fp32)
59+
60+
workspace.FeedBlob("X", X_fp32)
61+
workspace.FeedBlob("W", W_fp32)
62+
workspace.FeedBlob("b", b_fp32)
63+
64+
workspace.RunOperatorOnce(
65+
core.CreateOperator(
66+
"Int8FCPackWeight",
67+
["W"],
68+
["W_int8"],
69+
engine="DNNLOWP",
70+
save_unpacked_weights=True,
71+
in_scale=X_scale,
72+
)
73+
)
74+
75+
ref_net = core.Net("net")
76+
ref_net.Int8QuantizeNNPI(
77+
["X"],
78+
["X_int8"],
79+
Y_scale=X_scale,
80+
Y_zero_point=X_zero_point
81+
)
82+
ref_net.Int8FCFakeAcc32NNPI(
83+
["X_int8", "W_int8", "b"],
84+
["Y_int8"],
85+
Y_scale=X_scale,
86+
Y_zero_point=X_zero_point,
87+
)
88+
ref_net.Int8Relu(
89+
["Y_int8"],
90+
["Y_relu"],
91+
Y_zero_point=X_zero_point,
92+
Y_scale=X_scale,
93+
)
94+
ref_net.Int8DequantizeNNPI(
95+
["Y_relu"],
96+
["Y"]
97+
)
98+
ref_net.Proto().external_output.append("Y")
99+
100+
# run ref_net
101+
workspace.RunNetOnce(ref_net)
102+
Y_fbgemm = workspace.FetchBlob("Y")
103+
104+
# run onnxifi net
105+
ref_net.Proto().op[0].type = "Int8Quantize"
106+
ref_net.Proto().op[1].type = "Int8FC"
107+
ref_net.Proto().op[2].type = "Int8Relu"
108+
ref_net.Proto().op[3].type = "Int8Dequantize"
109+
net_onnxified = onnxifi_caffe2_net(
110+
ref_net.Proto(),
111+
{},
112+
debug=True,
113+
adjust_batch=False,
114+
use_onnx=False,
115+
weight_names=["W_int8", "b"],
116+
)
117+
num_onnxified_ops = sum(
118+
1 if o.type == "Onnxifi" else 0 for o in net_onnxified.op
119+
)
120+
print(net_onnxified)
121+
np.testing.assert_equal(num_onnxified_ops, 1)
122+
workspace.CreateNet(net_onnxified)
123+
workspace.RunNet(net_onnxified.name)
124+
Y_glow = workspace.FetchBlob("Y")
125+
126+
if not np.allclose(Y_glow, Y_fbgemm):
127+
diff_Y = np.abs(Y_glow - Y_fbgemm)
128+
print_test_debug_info(
129+
"int8_fc",
130+
{
131+
"seed": rand_seed,
132+
"n": n,
133+
"X": X_fp32,
134+
"W": W_fp32,
135+
"b": b_fp32,
136+
"Y_fbgemm": Y_fbgemm,
137+
"Y_glow": Y_glow,
138+
"diff": diff_Y,
139+
"maxdiff": diff_Y.max(axis=1),
140+
},
141+
)
142+
assert 0

0 commit comments

Comments
 (0)