-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy pathcompiler.py
More file actions
318 lines (284 loc) · 10.6 KB
/
compiler.py
File metadata and controls
318 lines (284 loc) · 10.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
#
# MIT License
#
# Copyright (c) 2024, Mattias Aabmets
#
# The contents of this file are subject to the terms and conditions defined in the License.
# You may not use, modify, or distribute this file except in compliance with the License.
#
# SPDX-License-Identifier: MIT
#
from __future__ import annotations
import re
import os
import ast
import sys
import shutil
import platform
import setuptools
import subprocess
from cffi import FFI
from typing import Generator
from pathlib import Path
from textwrap import dedent
from itertools import product
from dataclasses import dataclass
from contextlib import contextmanager
from quantcrypt.internal import constants as const
from quantcrypt.internal import pqclean, utils
@dataclass(frozen=True)
class Target:
spec: const.AlgoSpec
variant: const.PQAVariant
source_dir: Path
required_flags: list[str]
accepted: bool
@property
def _kem_cdefs(self) -> str:
return dedent("""
#define {cdef_name}_CRYPTO_SECRETKEYBYTES ...
#define {cdef_name}_CRYPTO_PUBLICKEYBYTES ...
#define {cdef_name}_CRYPTO_CIPHERTEXTBYTES ...
#define {cdef_name}_CRYPTO_BYTES ...
int {cdef_name}_crypto_kem_keypair(
uint8_t *pk, uint8_t *sk
);
int {cdef_name}_crypto_kem_enc(
uint8_t *ct, uint8_t *ss, const uint8_t *pk
);
int {cdef_name}_crypto_kem_dec(
uint8_t *ss, const uint8_t *ct, const uint8_t *sk
);
""".format(cdef_name=self.cdef_name))
@property
def _dss_cdefs(self) -> str:
return dedent("""
#define {cdef_name}_CRYPTO_SECRETKEYBYTES ...
#define {cdef_name}_CRYPTO_PUBLICKEYBYTES ...
#define {cdef_name}_CRYPTO_BYTES ...
int {cdef_name}_crypto_sign_keypair(
uint8_t *pk, uint8_t *sk
);
int {cdef_name}_crypto_sign_signature(
uint8_t *sig, size_t *siglen,
const uint8_t *m, size_t mlen, const uint8_t *sk
);
int {cdef_name}_crypto_sign_verify(
const uint8_t *sig, size_t siglen,
const uint8_t *m, size_t mlen, const uint8_t *pk
);
int {cdef_name}_crypto_sign(
uint8_t *sm, size_t *smlen,
const uint8_t *m, size_t mlen, const uint8_t *sk
);
int {cdef_name}_crypto_sign_open(
uint8_t *m, size_t *mlen,
const uint8_t *sm, size_t smlen, const uint8_t *pk
);
""".format(cdef_name=self.cdef_name))
@property
def cdef_name(self) -> str:
return self.spec.cdef_name(self.variant)
@property
def module_name(self) -> str:
return self.spec.module_name(self.variant)
@property
def ffi_cdefs(self) -> str:
if self.spec.type == const.PQAType.KEM:
return self._kem_cdefs
return self._dss_cdefs
@property
def variant_files(self) -> list[str]:
return [
file.as_posix() for file in self.source_dir.rglob("**/*")
if file.is_file() and file.suffix in ['.c', '.S', '.s']
]
@property
def include_directive(self) -> str:
header_file = self.source_dir / "api.h"
return f'#include "{header_file.as_posix()}"'
def _windows_compiler_args(self) -> list[str]: # pragma: no cover
extra_flags: list[str] = []
compiler_args = ["/O2", "/MD", "/nologo"]
for flag in self.required_flags:
extra_flags.append(f"/arch:{flag.upper()}")
compiler_args.extend(extra_flags)
return compiler_args
def _linux_compiler_args(self) -> list[str]: # pragma: no cover
arch = platform.machine().lower()
extra_flags: list[str] = []
compiler_args = [
"-fdata-sections", "-ffunction-sections",
"-O3", "-flto", "-std=c99", "-s"
]
if arch in const.AMDArches:
for flag in self.required_flags:
extra_flags.append(f"-m{flag.lower()}")
elif arch in const.ARMArches:
march_flag = "-march=armv8.5-a"
for flag in self.required_flags:
march_flag += f"+{flag.lower()}"
extra_flags.append(march_flag)
compiler_args.extend(extra_flags)
return compiler_args
@staticmethod
def _darwin_compiler_args() -> list[str]: # pragma: no cover
return ["-fdata-sections", "-ffunction-sections", "-O3", "-flto", "-std=c99"]
@property
def compiler_args(self) -> list[str]: # pragma: no cover
opsys = platform.system().lower()
if opsys == "windows":
return self._windows_compiler_args()
elif opsys == "linux":
return self._linux_compiler_args()
elif opsys == "darwin":
return self._darwin_compiler_args()
return list()
@property
def linker_args(self) -> list[str]: # pragma: no cover
if platform.system().lower() == "windows":
return ["/NODEFAULTLIB:MSVCRTD"]
return []
@property
def libraries(self) -> list[str]: # pragma: no cover
if platform.system().lower() == "windows":
return ["advapi32"]
return []
class Compiler:
@staticmethod
def patch_distutils(): # pragma: no cover
setuptools_path = Path(setuptools.__file__).parent
distutils_path = "_distutils/compilers/C/unix.py"
compiler_path = setuptools_path / distutils_path
with compiler_path.open("r", encoding="utf-8") as f:
lines = f.readlines()
pattern = re.compile(r'^( {0,4}src_extensions\s*=\s*)(\[[^]]*])')
did_append = False
for i, line in enumerate(lines):
match = pattern.search(line)
if match:
prefix, list_str = match.group(1), match.group(2)
ext_list: list[str] = ast.literal_eval(list_str) # NOSONAR
for suffix in ['.S', '.s']:
if suffix not in ext_list:
ext_list.append(suffix)
did_append = True
lines[i] = prefix + repr(ext_list) + "\n"
break
if did_append:
with compiler_path.open("w", encoding="utf-8") as f:
f.writelines(lines)
@staticmethod
def get_compile_targets(
target_variants: list[const.PQAVariant],
target_algos: list[const.AlgoSpec]
) -> tuple[list[Target], list[Target]]:
accepted: list[Target] = []
rejected: list[Target] = []
specs = const.SupportedAlgos
variants = const.PQAVariant.members()
for spec, variant in product(specs, variants):
source_dir, required_flags = pqclean.check_platform_support(spec, variant)
acceptable = (
source_dir is not None
and required_flags is not None
and variant in target_variants
and spec in target_algos
)
(accepted if acceptable else rejected).append(Target(
spec=spec,
variant=variant,
source_dir=source_dir or Path(),
required_flags=required_flags or [],
accepted=acceptable
))
return accepted, rejected
@classmethod
@contextmanager
def build_path(cls) -> Generator[None, None, None]:
old_cwd = os.getcwd()
bin_path = utils.search_upwards("bin")
for path in bin_path.iterdir():
if path.is_file():
path.unlink()
else:
shutil.rmtree(path, ignore_errors=True)
new_cwd = bin_path / "build"
new_cwd.mkdir(parents=True, exist_ok=True)
os.chdir(new_cwd)
yield
for path in new_cwd.iterdir(): # type: Path
if path.is_file() and path.suffix in [".pyd", ".so"]:
shutil.copyfile(path, bin_path / path.name)
os.chdir(old_cwd)
shutil.rmtree(new_cwd, ignore_errors=True)
@staticmethod
def compile(target: Target, debug: bool) -> None:
com_dir, com_files = pqclean.get_common_filepaths(target.variant)
ffi = FFI()
ffi.cdef(target.ffi_cdefs)
ffi.set_source(
module_name=target.module_name,
source=target.include_directive,
sources=[*com_files, *target.variant_files],
include_dirs=[com_dir, target.source_dir.as_posix()],
extra_compile_args=target.compiler_args,
extra_link_args=target.linker_args,
libraries=target.libraries,
)
ffi.compile(verbose=debug, debug=debug)
@staticmethod
def log_progress(target: Target) -> None: # pragma: no cover
algo = target.spec.armor_name()
variant = target.variant.value
prefix, suffix = '', "..."
if __name__ == "__main__":
prefix = const.SubprocTag
algo = f"[bold sky_blue2]{algo}[/]"
variant = f"[italic tan]{variant}[/]"
suffix = f"[grey46]{suffix}[/]"
msg = f"{prefix}Compiling {variant} variant of {algo}{suffix}"
print(msg, flush=True)
@classmethod
def run(cls,
target_variants: list[const.PQAVariant] = None,
target_algos: list[const.AlgoSpec] = None,
*,
in_subprocess: bool = False,
verbose: bool = False,
debug: bool = False,
) -> subprocess.Popen | list[Target]:
if target_variants is None: # pragma: no cover
target_variants = const.PQAVariant.members()
if target_algos is None: # pragma: no cover
target_algos = const.SupportedAlgos
if in_subprocess: # pragma: no cover
return subprocess.Popen(
args=[
sys.executable, __file__,
utils.b64pickle(target_variants),
utils.b64pickle(target_algos)
],
stderr=subprocess.STDOUT,
stdout=subprocess.PIPE,
text=True
)
pqclean_dir = pqclean.find_pqclean_dir(src_must_exist=False)
if not pqclean.check_sources_exist(pqclean_dir):
pqclean.download_extract_pqclean(pqclean_dir)
accepted, rejected = cls.get_compile_targets(
target_variants, target_algos
)
if not accepted:
return rejected
cls.patch_distutils()
with cls.build_path():
for target in accepted:
if verbose or debug: # pragma: no cover
cls.log_progress(target)
cls.compile(target, debug)
return rejected
if __name__ == "__main__":
_target_variants = utils.b64pickle(sys.argv[1])
_target_algos = utils.b64pickle(sys.argv[2])
Compiler.run(_target_variants, _target_algos, verbose=True)