Skip to content

Commit 92b8faf

Browse files
committed
Update on "[PyTorch] Second try: use c10::FastMap for memoizing in Pickler"
These maps don't rely on reference stability, so FastMap should be fine. First try (#96360) was reverted because it broke internal tests. Differential Revision: [D43995796](https://our.internmc.facebook.com/intern/diff/D43995796/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D43995796/)! [ghstack-poisoned]
2 parents 4aab764 + b8f9f9d commit 92b8faf

65 files changed

Lines changed: 2715 additions & 1170 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/scripts/build_triton_wheel.py

Lines changed: 74 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
#!/usr/bin/env python3
2-
from subprocess import check_call
2+
import shutil
3+
import sys
34
from pathlib import Path
5+
from subprocess import check_call
46
from tempfile import TemporaryDirectory
57
from typing import Optional
6-
import sys
7-
import shutil
8+
89
SCRIPT_DIR = Path(__file__).parent
910
REPO_DIR = SCRIPT_DIR.parent.parent
1011

12+
1113
def read_triton_pin() -> str:
1214
with open(REPO_DIR / ".ci" / "docker" / "ci_commit_pins" / "triton.txt") as f:
1315
return f.read().strip()
@@ -19,7 +21,7 @@ def read_triton_version() -> str:
1921

2022

2123
def check_and_replace(inp: str, src: str, dst: str) -> str:
22-
""" Checks that `src` can be found in `input` and replaces it with `dst` """
24+
"""Checks that `src` can be found in `input` and replaces it with `dst`"""
2325
if src not in inp:
2426
raise RuntimeError(f"Can't find ${src} in the input")
2527
return inp.replace(src, dst)
@@ -29,9 +31,11 @@ def patch_setup_py(path: Path, *, version: str, name: str = "triton") -> None:
2931
with open(path) as f:
3032
orig = f.read()
3133
# Replace name
32-
orig = check_and_replace(orig, "name=\"triton\",", f"name=\"{name}\",")
34+
orig = check_and_replace(orig, 'name="triton",', f'name="{name}",')
3335
# Replace version
34-
orig = check_and_replace(orig, f"version=\"{read_triton_version()}\",", f"version=\"{version}\",")
36+
orig = check_and_replace(
37+
orig, f'version="{read_triton_version()}",', f'version="{version}",'
38+
)
3539
with open(path, "w") as f:
3640
f.write(orig)
3741

@@ -40,39 +44,81 @@ def patch_init_py(path: Path, *, version: str) -> None:
4044
with open(path) as f:
4145
orig = f.read()
4246
# Replace version
43-
orig = check_and_replace(orig, f"__version__ = '{read_triton_version()}'", f"__version__ = \"{version}\"")
47+
orig = check_and_replace(
48+
orig, f"__version__ = '{read_triton_version()}'", f'__version__ = "{version}"'
49+
)
4450
with open(path, "w") as f:
4551
f.write(orig)
4652

4753

48-
def build_triton(*, version: str, commit_hash: str, build_conda: bool = False, py_version : Optional[str] = None) -> Path:
54+
def build_triton(
55+
*,
56+
version: str,
57+
commit_hash: str,
58+
build_conda: bool = False,
59+
py_version: Optional[str] = None,
60+
) -> Path:
4961
with TemporaryDirectory() as tmpdir:
5062
triton_basedir = Path(tmpdir) / "triton"
5163
triton_pythondir = triton_basedir / "python"
5264
check_call(["git", "clone", "https://github.com/openai/triton"], cwd=tmpdir)
5365
check_call(["git", "checkout", commit_hash], cwd=triton_basedir)
5466
if build_conda:
5567
with open(triton_basedir / "meta.yaml", "w") as meta:
56-
print(f"package:\n name: torchtriton\n version: {version}+{commit_hash[:10]}\n", file=meta)
68+
print(
69+
f"package:\n name: torchtriton\n version: {version}+{commit_hash[:10]}\n",
70+
file=meta,
71+
)
5772
print("source:\n path: .\n", file=meta)
58-
print("build:\n string: py{{py}}\n number: 1\n script: cd python; "
59-
"python setup.py install --single-version-externally-managed --record=record.txt\n", file=meta)
60-
print("requirements:\n host:\n - python\n - setuptools\n run:\n - python\n"
61-
" - filelock\n - pytorch\n", file=meta)
62-
print("about:\n home: https://github.com/openai/triton\n license: MIT\n summary:"
63-
" 'A language and compiler for custom Deep Learning operation'", file=meta)
64-
65-
patch_init_py(triton_pythondir / "triton" / "__init__.py", version=f"{version}+{commit_hash[:10]}")
73+
print(
74+
"build:\n string: py{{py}}\n number: 1\n script: cd python; "
75+
"python setup.py install --single-version-externally-managed --record=record.txt\n",
76+
file=meta,
77+
)
78+
print(
79+
"requirements:\n host:\n - python\n - setuptools\n run:\n - python\n"
80+
" - filelock\n - pytorch\n",
81+
file=meta,
82+
)
83+
print(
84+
"about:\n home: https://github.com/openai/triton\n license: MIT\n summary:"
85+
" 'A language and compiler for custom Deep Learning operation'",
86+
file=meta,
87+
)
88+
89+
patch_init_py(
90+
triton_pythondir / "triton" / "__init__.py",
91+
version=f"{version}+{commit_hash[:10]}",
92+
)
6693
if py_version is None:
6794
py_version = f"{sys.version_info.major}.{sys.version_info.minor}"
68-
check_call(["conda", "build", "--python", py_version,
69-
"-c", "pytorch-nightly", "--output-folder", tmpdir, "."], cwd=triton_basedir)
95+
check_call(
96+
[
97+
"conda",
98+
"build",
99+
"--python",
100+
py_version,
101+
"-c",
102+
"pytorch-nightly",
103+
"--output-folder",
104+
tmpdir,
105+
".",
106+
],
107+
cwd=triton_basedir,
108+
)
70109
conda_path = list(Path(tmpdir).glob("linux-64/torchtriton*.bz2"))[0]
71110
shutil.copy(conda_path, Path.cwd())
72111
return Path.cwd() / conda_path.name
73112

74-
patch_setup_py(triton_pythondir / "setup.py", name="pytorch-triton", version=f"{version}+{commit_hash[:10]}")
75-
patch_init_py(triton_pythondir / "triton" / "__init__.py", version=f"{version}+{commit_hash[:10]}")
113+
patch_setup_py(
114+
triton_pythondir / "setup.py",
115+
name="pytorch-triton",
116+
version=f"{version}+{commit_hash[:10]}",
117+
)
118+
patch_init_py(
119+
triton_pythondir / "triton" / "__init__.py",
120+
version=f"{version}+{commit_hash[:10]}",
121+
)
76122
check_call([sys.executable, "setup.py", "bdist_wheel"], cwd=triton_pythondir)
77123
whl_path = list((triton_pythondir / "dist").glob("*.whl"))[0]
78124
shutil.copy(whl_path, Path.cwd())
@@ -81,16 +127,19 @@ def build_triton(*, version: str, commit_hash: str, build_conda: bool = False, p
81127

82128
def main() -> None:
83129
from argparse import ArgumentParser
130+
84131
parser = ArgumentParser("Build Triton binaries")
85132
parser.add_argument("--build-conda", action="store_true")
86133
parser.add_argument("--py-version", type=str)
87134
parser.add_argument("--commit-hash", type=str, default=read_triton_pin())
88135
parser.add_argument("--triton-version", type=str, default=read_triton_version())
89136
args = parser.parse_args()
90-
build_triton(commit_hash=args.commit_hash,
91-
version=args.triton_version,
92-
build_conda=args.build_conda,
93-
py_version=args.py_version)
137+
build_triton(
138+
commit_hash=args.commit_hash,
139+
version=args.triton_version,
140+
build_conda=args.build_conda,
141+
py_version=args.py_version,
142+
)
94143

95144

96145
if __name__ == "__main__":

.github/scripts/check_labels.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,12 @@
33

44
from typing import Any
55

6-
from gitutils import (
7-
get_git_remote_name,
8-
get_git_repo_dir,
9-
GitRepo,
10-
)
6+
from github_utils import gh_delete_comment, gh_post_pr_comment
7+
8+
from gitutils import get_git_remote_name, get_git_repo_dir, GitRepo
9+
from label_utils import has_required_labels, is_label_err_comment, LABEL_ERR_MSG
1110
from trymerge import GitHubPR
12-
from github_utils import (
13-
gh_delete_comment,
14-
gh_post_pr_comment,
15-
)
16-
from label_utils import (
17-
LABEL_ERR_MSG,
18-
is_label_err_comment,
19-
has_required_labels,
20-
)
11+
2112

2213
def delete_all_label_err_comments(pr: "GitHubPR") -> None:
2314
for comment in pr.get_comments():
@@ -33,6 +24,7 @@ def add_label_err_comment(pr: "GitHubPR") -> None:
3324

3425
def parse_args() -> Any:
3526
from argparse import ArgumentParser
27+
3628
parser = ArgumentParser("Check PR labels")
3729
parser.add_argument("pr_num", type=int)
3830

.github/scripts/collect_ciflow_labels.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
#!/usr/bin/env python3
2+
import sys
23
from pathlib import Path
3-
from typing import Any, Dict, List, Set, cast
4+
from typing import Any, cast, Dict, List, Set
5+
46
import yaml
5-
import sys
67

78
GITHUB_DIR = Path(__file__).parent.parent
89

10+
911
def get_workflows_push_tags() -> Set[str]:
1012
"Extract all known push tags from workflows"
1113
rc: Set[str] = set()
@@ -22,8 +24,10 @@ def get_workflows_push_tags() -> Set[str]:
2224

2325

2426
def filter_ciflow_tags(tags: Set[str]) -> List[str]:
25-
" Return sorted list of ciflow tags"
26-
return sorted(tag[:-2] for tag in tags if tag.startswith("ciflow/") and tag.endswith("/*"))
27+
"Return sorted list of ciflow tags"
28+
return sorted(
29+
tag[:-2] for tag in tags if tag.startswith("ciflow/") and tag.endswith("/*")
30+
)
2731

2832

2933
def read_probot_config() -> Dict[str, Any]:
@@ -40,6 +44,7 @@ def update_probot_config(labels: Set[str]) -> None:
4044

4145
if __name__ == "__main__":
4246
from argparse import ArgumentParser
47+
4348
parser = ArgumentParser("Validate or update list of tags")
4449
parser.add_argument("--validate-tags", action="store_true")
4550
args = parser.parse_args()
@@ -51,9 +56,15 @@ def update_probot_config(labels: Set[str]) -> None:
5156
if config_tags != ciflow_tags:
5257
print("Tags mismatch!")
5358
if ciflow_tags.difference(config_tags):
54-
print("Reference in workflows but not in config", ciflow_tags.difference(config_tags))
59+
print(
60+
"Reference in workflows but not in config",
61+
ciflow_tags.difference(config_tags),
62+
)
5563
if config_tags.difference(ciflow_tags):
56-
print("Reference in config, but not in workflows", config_tags.difference(ciflow_tags))
64+
print(
65+
"Reference in config, but not in workflows",
66+
config_tags.difference(ciflow_tags),
67+
)
5768
print(f"Please run {__file__} to remediate the difference")
5869
sys.exit(-1)
5970
print("All tags are listed in pytorch-probot.yml")

.github/scripts/comment_on_pr.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
import os
12
from typing import Any
3+
24
from github_utils import gh_post_pr_comment
35
from gitutils import get_git_remote_name, get_git_repo_dir, GitRepo
46
from trymerge_explainer import BOT_COMMANDS_WIKI
5-
import os
67

78

89
def parse_args() -> Any:

.github/scripts/convert_lintrunner_annotations_to_github.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from pathlib import Path
77
from typing import NamedTuple, Optional
88

9+
910
# From: https://docs.github.com/en/rest/reference/checks
1011
class GitHubAnnotationLevel(str, Enum):
1112
NOTICE = "notice"
@@ -24,7 +25,12 @@ class GitHubAnnotation(NamedTuple):
2425
title: Optional[str]
2526
raw_details: Optional[str]
2627

27-
PYTORCH_ROOT = Path(subprocess.check_output(['git', 'rev-parse', '--show-toplevel']).decode('ascii').strip())
28+
29+
PYTORCH_ROOT = Path(
30+
subprocess.check_output(["git", "rev-parse", "--show-toplevel"])
31+
.decode("ascii")
32+
.strip()
33+
)
2834

2935
annotations = []
3036
for line in sys.stdin:
@@ -33,7 +39,6 @@ class GitHubAnnotation(NamedTuple):
3339
path = lint_message.get("path")
3440
line = lint_message.get("line")
3541

36-
3742
code = lint_message["code"]
3843
severity = lint_message["severity"]
3944
name = lint_message["name"]
@@ -48,16 +53,18 @@ class GitHubAnnotation(NamedTuple):
4853
# normalize path relative to git root
4954
path = Path(path).relative_to(PYTORCH_ROOT)
5055

51-
annotations.append(GitHubAnnotation(
52-
path=str(path),
53-
start_line=int(line),
54-
end_line=int(line),
55-
start_column=None,
56-
end_column=None,
57-
annotation_level=GitHubAnnotationLevel.FAILURE,
58-
message=description,
59-
title=f"({code}) {name}",
60-
raw_details=None,
61-
)._asdict())
56+
annotations.append(
57+
GitHubAnnotation(
58+
path=str(path),
59+
start_line=int(line),
60+
end_line=int(line),
61+
start_column=None,
62+
end_column=None,
63+
annotation_level=GitHubAnnotationLevel.FAILURE,
64+
message=description,
65+
title=f"({code}) {name}",
66+
raw_details=None,
67+
)._asdict()
68+
)
6269

6370
print(json.dumps(annotations), flush=True)

.github/scripts/ensure_actions_will_cancel.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,18 @@
22

33
import argparse
44
import sys
5-
import yaml
65

76
from pathlib import Path
87

8+
import yaml
9+
910

1011
REPO_ROOT = Path(__file__).resolve().parent.parent.parent
1112
WORKFLOWS = REPO_ROOT / ".github" / "workflows"
12-
EXPECTED_GROUP = "${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}" \
13+
EXPECTED_GROUP = (
14+
"${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}"
1315
"-${{ github.event_name == 'workflow_dispatch' }}"
16+
)
1417

1518

1619
def should_check(filename: Path) -> bool:
Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/env python3
2-
'''
2+
"""
33
Test ownership was introduced in https://github.com/pytorch/pytorch/issues/66232.
44
55
As a part of enforcing test ownership, we want to maintain a list of existing PyTorch labels
@@ -8,17 +8,19 @@
88
99
This script assumes the correct env vars are set for AWS permissions.
1010
11-
'''
11+
"""
1212

13-
import boto3 # type: ignore[import]
1413
import json
14+
from typing import Any
15+
16+
import boto3 # type: ignore[import]
1517

1618
from label_utils import gh_get_labels
17-
from typing import Any
1819

1920

2021
def parse_args() -> Any:
2122
from argparse import ArgumentParser
23+
2224
parser = ArgumentParser("Export PR labels")
2325
parser.add_argument("org", type=str)
2426
parser.add_argument("repo", type=str)
@@ -30,9 +32,9 @@ def main() -> None:
3032
args = parse_args()
3133
print(f"Exporting labels for {args.org}/{args.repo}")
3234
labels_file_name = "pytorch_labels.json"
33-
obj = boto3.resource('s3').Object('ossci-metrics', labels_file_name)
35+
obj = boto3.resource("s3").Object("ossci-metrics", labels_file_name)
3436
obj.put(Body=json.dumps(gh_get_labels(args.org, args.repo)).encode())
3537

3638

37-
if __name__ == '__main__':
39+
if __name__ == "__main__":
3840
main()

0 commit comments

Comments
 (0)