Skip to content

Commit e8b01af

Browse files
committed
Update on "rationalize specialize_int_float"
cc soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
2 parents 4f64178 + 3f8e4ac commit e8b01af

193 files changed

Lines changed: 5183 additions & 2149 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.ci/pytorch/common_utils.sh

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -149,32 +149,6 @@ function clone_pytorch_xla() {
149149
fi
150150
}
151151

152-
function install_filelock() {
153-
pip_install filelock
154-
}
155-
156-
function install_triton() {
157-
local commit
158-
if [[ "${TEST_CONFIG}" == *rocm* ]]; then
159-
echo "skipping triton due to rocm"
160-
else
161-
commit=$(get_pinned_commit triton)
162-
if [[ "${BUILD_ENVIRONMENT}" == *gcc7* ]]; then
163-
# Trition needs gcc-9 to build
164-
sudo apt-get install -y g++-9
165-
CXX=g++-9 pip_install --user "git+https://github.com/openai/triton@${commit}#subdirectory=python"
166-
elif [[ "${BUILD_ENVIRONMENT}" == *clang* ]]; then
167-
# Trition needs <filesystem> which surprisingly is not available with clang-9 toolchain
168-
sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test
169-
sudo apt-get install -y g++-9
170-
CXX=g++-9 pip_install --user "git+https://github.com/openai/triton@${commit}#subdirectory=python"
171-
else
172-
pip_install --user "git+https://github.com/openai/triton@${commit}#subdirectory=python"
173-
fi
174-
pip_install --user jinja2
175-
fi
176-
}
177-
178152
function setup_torchdeploy_deps(){
179153
conda install -y -n "py_${ANACONDA_PYTHON_VERSION}" "libpython-static=${ANACONDA_PYTHON_VERSION}"
180154
local CC

.ci/pytorch/test.sh

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -855,8 +855,6 @@ elif [[ "${BUILD_ENVIRONMENT}" == *libtorch* ]]; then
855855
# TODO: run some C++ tests
856856
echo "no-op at the moment"
857857
elif [[ "$TEST_CONFIG" == distributed ]]; then
858-
install_filelock
859-
install_triton
860858
test_distributed
861859
# Only run RPC C++ tests on the first shard
862860
if [[ "${SHARD_NUMBER}" == 1 ]]; then
@@ -866,25 +864,19 @@ elif [[ "$TEST_CONFIG" == deploy ]]; then
866864
checkout_install_torchdeploy
867865
test_torch_deploy
868866
elif [[ "${TEST_CONFIG}" == *inductor_distributed* ]]; then
869-
install_filelock
870-
install_triton
871867
install_huggingface
872868
test_inductor_distributed
873869
elif [[ "${TEST_CONFIG}" == *dynamo* && "${SHARD_NUMBER}" == 1 && $NUM_TEST_SHARDS -gt 1 ]]; then
874870
test_without_numpy
875871
install_torchvision
876-
install_triton
877872
test_dynamo_shard 1
878873
test_aten
879874
elif [[ "${TEST_CONFIG}" == *dynamo* && "${SHARD_NUMBER}" == 2 && $NUM_TEST_SHARDS -gt 1 ]]; then
880875
install_torchvision
881-
install_filelock
882-
install_triton
883876
test_dynamo_shard 2
884877
elif [[ "${TEST_CONFIG}" == *aot_eager_all* ]]; then
885878
install_torchtext
886879
install_torchvision
887-
install_filelock
888880
checkout_install_torchbench
889881
install_huggingface
890882
install_timm
@@ -897,7 +889,6 @@ elif [[ "${TEST_CONFIG}" == *aot_eager_all* ]]; then
897889
fi
898890
elif [[ "${TEST_CONFIG}" == *aot_eager_huggingface* ]]; then
899891
install_torchvision
900-
install_filelock
901892
install_huggingface
902893
if [[ "${TEST_CONFIG}" == *dynamic* ]]; then
903894
test_aot_eager_benchmark huggingface "" --dynamic-shapes
@@ -906,7 +897,6 @@ elif [[ "${TEST_CONFIG}" == *aot_eager_huggingface* ]]; then
906897
fi
907898
elif [[ "${TEST_CONFIG}" == *aot_eager_timm* && $NUM_TEST_SHARDS -gt 1 ]]; then
908899
install_torchvision
909-
install_filelock
910900
install_timm
911901
id=$((SHARD_NUMBER-1))
912902
if [[ "${TEST_CONFIG}" == *dynamic* ]]; then
@@ -917,7 +907,6 @@ elif [[ "${TEST_CONFIG}" == *aot_eager_timm* && $NUM_TEST_SHARDS -gt 1 ]]; then
917907
elif [[ "${TEST_CONFIG}" == *aot_eager_torchbench* ]]; then
918908
install_torchtext
919909
install_torchvision
920-
install_filelock
921910
checkout_install_torchbench
922911
if [[ "${TEST_CONFIG}" == *dynamic* ]]; then
923912
PYTHONPATH=$(pwd)/torchbench test_aot_eager_benchmark torchbench "" --dynamic-shapes
@@ -926,11 +915,6 @@ elif [[ "${TEST_CONFIG}" == *aot_eager_torchbench* ]]; then
926915
fi
927916
elif [[ "${TEST_CONFIG}" == *inductor_huggingface* ]]; then
928917
install_torchvision
929-
install_filelock
930-
if [[ "${TEST_CONFIG}" != *inductor_huggingface_cpu_accuracy* ]]; then
931-
# Cpp backend does not depend on triton
932-
install_triton
933-
fi
934918
install_huggingface
935919
if [[ "${TEST_CONFIG}" == *inductor_huggingface_perf* ]]; then
936920
test_inductor_huggingface_perf
@@ -941,11 +925,6 @@ elif [[ "${TEST_CONFIG}" == *inductor_huggingface* ]]; then
941925
fi
942926
elif [[ "${TEST_CONFIG}" == *inductor_timm* && $NUM_TEST_SHARDS -gt 1 ]]; then
943927
install_torchvision
944-
install_filelock
945-
if [[ "${TEST_CONFIG}" != *inductor_timm_cpu_accuracy* ]]; then
946-
# Cpp backend does not depend on triton
947-
install_triton
948-
fi
949928
install_timm
950929
id=$((SHARD_NUMBER-1))
951930
if [[ "${TEST_CONFIG}" == *inductor_timm_perf* && $NUM_TEST_SHARDS -gt 1 ]]; then
@@ -958,11 +937,6 @@ elif [[ "${TEST_CONFIG}" == *inductor_timm* && $NUM_TEST_SHARDS -gt 1 ]]; then
958937
elif [[ "${TEST_CONFIG}" == *inductor_torchbench* ]]; then
959938
install_torchtext
960939
install_torchvision
961-
install_filelock
962-
if [[ "${TEST_CONFIG}" != *inductor_torchbench_cpu_accuracy* ]]; then
963-
# Cpp backend does not depend on triton
964-
install_triton
965-
fi
966940
if [[ "${TEST_CONFIG}" == *inductor_torchbench_perf* ]]; then
967941
checkout_install_torchbench
968942
test_inductor_torchbench_perf
@@ -978,19 +952,15 @@ elif [[ "${TEST_CONFIG}" == *inductor_torchbench* ]]; then
978952
fi
979953
elif [[ "${TEST_CONFIG}" == *inductor* && "${SHARD_NUMBER}" == 1 ]]; then
980954
install_torchvision
981-
install_filelock
982-
install_triton
983955
test_inductor
984956
test_inductor_distributed
985957
elif [[ "${SHARD_NUMBER}" == 1 && $NUM_TEST_SHARDS -gt 1 ]]; then
986958
test_without_numpy
987959
install_torchvision
988-
install_triton
989960
test_python_shard 1
990961
test_aten
991962
elif [[ "${SHARD_NUMBER}" == 2 && $NUM_TEST_SHARDS -gt 1 ]]; then
992963
install_torchvision
993-
install_triton
994964
test_python_shard 2
995965
test_libtorch
996966
test_aot_compilation
@@ -1000,7 +970,6 @@ elif [[ "${SHARD_NUMBER}" == 2 && $NUM_TEST_SHARDS -gt 1 ]]; then
1000970
elif [[ "${SHARD_NUMBER}" -gt 2 ]]; then
1001971
# Handle arbitrary number of shards
1002972
install_torchvision
1003-
install_triton
1004973
test_python_shard "$SHARD_NUMBER"
1005974
elif [[ "${BUILD_ENVIRONMENT}" == *vulkan* ]]; then
1006975
test_vulkan
@@ -1018,7 +987,6 @@ elif [[ "${TEST_CONFIG}" == *functorch* ]]; then
1018987
test_functorch
1019988
else
1020989
install_torchvision
1021-
install_triton
1022990
install_monkeytype
1023991
test_python
1024992
test_aten

.circleci/scripts/binary_ios_upload.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ fi
3333
cp ${PROJ_ROOT}/LICENSE ${ZIP_DIR}/
3434
# zip the library
3535
export DATE="$(date -u +%Y%m%d)"
36-
export IOS_NIGHTLY_BUILD_VERSION="2.0.0.${DATE}"
36+
export IOS_NIGHTLY_BUILD_VERSION="2.1.0.${DATE}"
3737
if [ "${BUILD_LITE_INTERPRETER}" == "1" ]; then
3838
# libtorch_lite_ios_nightly_1.11.0.20210810.zip
3939
ZIPFILE="libtorch_lite_ios_nightly_${IOS_NIGHTLY_BUILD_VERSION}.zip"

.circleci/scripts/binary_populate_env.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ PIP_UPLOAD_FOLDER='nightly/'
5959
# We put this here so that OVERRIDE_PACKAGE_VERSION below can read from it
6060
export DATE="$(date -u +%Y%m%d)"
6161
#TODO: We should be pulling semver version from the base version.txt
62-
BASE_BUILD_VERSION="2.0.0.dev$DATE"
62+
BASE_BUILD_VERSION="2.1.0.dev$DATE"
6363
# Change BASE_BUILD_VERSION to git tag when on a git tag
6464
# Use 'git -C' to make doubly sure we're in the correct directory for checking
6565
# the git tag

.github/ci_commit_pins/vision.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
120e7af6466190b754cf3026c685a5d31561da90
1+
caf12f840037193fb3d1e6c60168c37dfa218f43
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
pytorch-triton-rocm>=2.0.0.dev
1+
pytorch-triton-rocm>=2.0.0,<2.1

.github/scripts/check_labels.py

Lines changed: 18 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,64 +1,34 @@
11
#!/usr/bin/env python3
2-
"""check_labels.py"""
2+
"""Check whether a PR has required labels."""
33

4-
from typing import Any, List
4+
from typing import Any
55

6-
from label_utils import gh_get_labels
76
from gitutils import (
87
get_git_remote_name,
98
get_git_repo_dir,
109
GitRepo,
1110
)
12-
from trymerge import (
13-
_fetch_url,
11+
from trymerge import GitHubPR
12+
from github_utils import (
13+
gh_delete_comment,
1414
gh_post_pr_comment,
15-
GitHubPR,
1615
)
17-
18-
19-
BOT_AUTHORS = ["github-actions", "pytorchmergebot", "pytorch-bot"]
20-
21-
ERR_MSG_TITLE = "This PR needs a label"
22-
ERR_MSG = (
23-
f"# {ERR_MSG_TITLE}\n"
24-
"If your changes are user facing and intended to be a part of release notes, please use a label starting with `release notes:`.\n\n" # noqa: E501 pylint: disable=line-too-long
25-
"If not, please add the `topic: not user facing` label.\n\n"
26-
"For more information, see https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work." # noqa: E501 pylint: disable=line-too-long
16+
from label_utils import (
17+
LABEL_ERR_MSG,
18+
is_label_err_comment,
19+
has_required_labels,
2720
)
2821

29-
30-
def get_release_notes_labels(org: str, repo: str) -> List[str]:
31-
return [label for label in gh_get_labels(org, repo) if label.lstrip().startswith("release notes:")]
32-
33-
34-
def delete_comment(comment_id: int) -> None:
35-
url = f"https://api.github.com/repos/pytorch/pytorch/issues/comments/{comment_id}"
36-
_fetch_url(url, method="DELETE")
37-
38-
39-
def has_required_labels(pr: GitHubPR) -> bool:
40-
pr_labels = pr.get_labels()
41-
# Check if PR is not user facing
42-
is_not_user_facing_pr = any(label.strip() == "topic: not user facing" for label in pr_labels)
43-
return (
44-
is_not_user_facing_pr or
45-
any(label.strip() in get_release_notes_labels(pr.org, pr.project) for label in pr_labels)
46-
)
47-
48-
49-
def delete_comments(pr: GitHubPR) -> None:
50-
# Delete all previous comments
22+
def delete_all_label_err_comments(pr: "GitHubPR") -> None:
5123
for comment in pr.get_comments():
52-
if comment.body_text.lstrip(" #").startswith(ERR_MSG_TITLE) and comment.author_login in BOT_AUTHORS:
53-
delete_comment(comment.database_id)
24+
if is_label_err_comment(comment):
25+
gh_delete_comment(pr.org, pr.project, comment.database_id)
5426

5527

56-
def add_comment(pr: GitHubPR) -> None:
28+
def add_label_err_comment(pr: "GitHubPR") -> None:
5729
# Only make a comment if one doesn't exist already
58-
for comment in pr.get_comments():
59-
if comment.body_text.lstrip(" #").startswith(ERR_MSG_TITLE) and comment.author_login in BOT_AUTHORS:
60-
return
61-
gh_post_pr_comment(pr.org, pr.project, pr.pr_num, ERR_MSG)
30+
if not any(is_label_err_comment(comment) for comment in pr.get_comments()):
31+
gh_post_pr_comment(pr.org, pr.project, pr.pr_num, LABEL_ERR_MSG)
6232

6333

6434
def parse_args() -> Any:
@@ -79,10 +49,10 @@ def main() -> None:
7949
try:
8050
if not has_required_labels(pr):
8151
exit_code = 1
82-
print(ERR_MSG)
83-
add_comment(pr)
52+
print(LABEL_ERR_MSG)
53+
add_label_err_comment(pr)
8454
else:
85-
delete_comments(pr)
55+
delete_all_label_err_comments(pr)
8656
except Exception as e:
8757
pass
8858

.github/scripts/comment_on_pr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from typing import Any
2-
from trymerge import gh_post_pr_comment
2+
from github_utils import gh_post_pr_comment
33
from gitutils import get_git_remote_name, get_git_repo_dir, GitRepo
44
from trymerge_explainer import BOT_COMMANDS_WIKI
55
import os

.github/scripts/github_utils.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
"""GitHub Utilities"""
2+
3+
import json
4+
import os
5+
6+
from dataclasses import dataclass
7+
from typing import Any, Callable, cast, Dict, List, Optional
8+
from urllib.error import HTTPError
9+
from urllib.parse import quote
10+
from urllib.request import Request, urlopen
11+
12+
13+
@dataclass
14+
class GitHubComment:
15+
body_text: str
16+
created_at: str
17+
author_login: str
18+
author_association: str
19+
editor_login: Optional[str]
20+
database_id: int
21+
22+
23+
def gh_fetch_url(
24+
url: str, *,
25+
headers: Optional[Dict[str, str]] = None,
26+
data: Optional[Dict[str, Any]] = None,
27+
method: Optional[str] = None,
28+
reader: Callable[[Any], Any] = lambda x: x.read()
29+
) -> Any:
30+
if headers is None:
31+
headers = {}
32+
token = os.environ.get("GITHUB_TOKEN")
33+
if token is not None and url.startswith('https://api.github.com/'):
34+
headers['Authorization'] = f'token {token}'
35+
data_ = json.dumps(data).encode() if data is not None else None
36+
try:
37+
with urlopen(Request(url, headers=headers, data=data_, method=method)) as conn:
38+
return reader(conn)
39+
except HTTPError as err:
40+
if err.code == 403 and all(key in err.headers for key in ['X-RateLimit-Limit', 'X-RateLimit-Used']):
41+
print(f"""Rate limit exceeded:
42+
Used: {err.headers['X-RateLimit-Used']}
43+
Limit: {err.headers['X-RateLimit-Limit']}
44+
Remaining: {err.headers['X-RateLimit-Remaining']}
45+
Resets at: {err.headers['x-RateLimit-Reset']}""")
46+
raise
47+
48+
49+
def gh_fetch_json(
50+
url: str,
51+
params: Optional[Dict[str, Any]] = None,
52+
data: Optional[Dict[str, Any]] = None
53+
) -> List[Dict[str, Any]]:
54+
headers = {'Accept': 'application/vnd.github.v3+json'}
55+
if params is not None and len(params) > 0:
56+
url += '?' + '&'.join(f"{name}={quote(str(val))}" for name, val in params.items())
57+
return cast(List[Dict[str, Any]], gh_fetch_url(url, headers=headers, data=data, reader=json.load))
58+
59+
def _gh_fetch_json_any(
60+
url: str,
61+
params: Optional[Dict[str, Any]] = None,
62+
data: Optional[Dict[str, Any]] = None
63+
) -> Any:
64+
headers = {'Accept': 'application/vnd.github.v3+json'}
65+
if params is not None and len(params) > 0:
66+
url += '?' + '&'.join(f"{name}={quote(str(val))}" for name, val in params.items())
67+
return gh_fetch_url(url, headers=headers, data=data, reader=json.load)
68+
69+
70+
def gh_fetch_json_list(
71+
url: str,
72+
params: Optional[Dict[str, Any]] = None,
73+
data: Optional[Dict[str, Any]] = None
74+
) -> List[Dict[str, Any]]:
75+
return cast(List[Dict[str, Any]], _gh_fetch_json_any(url, params, data))
76+
77+
78+
def gh_fetch_json_dict(
79+
url: str,
80+
params: Optional[Dict[str, Any]] = None,
81+
data: Optional[Dict[str, Any]] = None
82+
) -> Dict[str, Any] :
83+
return cast(Dict[str, Any], _gh_fetch_json_any(url, params, data))
84+
85+
86+
def _gh_post_comment(url: str, comment: str, dry_run: bool = False) -> List[Dict[str, Any]]:
87+
if dry_run:
88+
print(comment)
89+
return []
90+
return gh_fetch_json_list(url, data={"body": comment})
91+
92+
93+
def gh_post_pr_comment(org: str, repo: str, pr_num: int, comment: str, dry_run: bool = False) -> List[Dict[str, Any]]:
94+
return _gh_post_comment(f'https://api.github.com/repos/{org}/{repo}/issues/{pr_num}/comments', comment, dry_run)
95+
96+
97+
def gh_post_commit_comment(org: str, repo: str, sha: str, comment: str, dry_run: bool = False) -> List[Dict[str, Any]]:
98+
return _gh_post_comment(f'https://api.github.com/repos/{org}/{repo}/commits/{sha}/comments', comment, dry_run)
99+
100+
101+
def gh_delete_comment(org: str, repo: str, comment_id: int) -> None:
102+
url = f"https://api.github.com/repos/{org}/{repo}/issues/comments/{comment_id}"
103+
gh_fetch_url(url, method="DELETE")

0 commit comments

Comments
 (0)