Skip to content

Commit 4cb9b24

Browse files
committed
Update on "Rewrite torch.broadcast_shapes to be unbacked SymInt friendly"
This is similar to what I did in https://github.com/pytorch/pytorch/pull/94790/files#diff-39e82af71afdadbc56a4ecf552ec668ddcda794f8ea3ec41edaef23456fd56e9 but I have to do this everywhere we have a broadcasting implementation. If you want me to spend some BE time deduping these, please holler. Signed-off-by: Edward Z. Yang <ezyangmeta.com> [ghstack-poisoned]
2 parents c7fa416 + 23f18f5 commit 4cb9b24

88 files changed

Lines changed: 8270 additions & 6359 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.ci/pytorch/win-build.sh

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,6 @@ source "$SCRIPT_PARENT_DIR/common.sh"
1515
# shellcheck source=./common-build.sh
1616
source "$SCRIPT_PARENT_DIR/common-build.sh"
1717

18-
IMAGE_COMMIT_ID=$(git rev-parse HEAD)
19-
export IMAGE_COMMIT_ID
20-
export IMAGE_COMMIT_TAG=${BUILD_ENVIRONMENT}-${IMAGE_COMMIT_ID}
21-
if [[ ${JOB_NAME} == *"develop"* ]]; then
22-
export IMAGE_COMMIT_TAG=develop-${IMAGE_COMMIT_TAG}
23-
fi
24-
2518
export TMP_DIR="${PWD}/build/win_tmp"
2619
TMP_DIR_WIN=$(cygpath -w "${TMP_DIR}")
2720
export TMP_DIR_WIN
@@ -59,7 +52,4 @@ set -ex
5952

6053
assert_git_not_dirty
6154

62-
if [ ! -f "${TMP_DIR}"/"${IMAGE_COMMIT_TAG}".7z ] && [ ! "${BUILD_ENVIRONMENT}" == "" ]; then
63-
exit 1
64-
fi
6555
echo "BUILD PASSED"

.ci/pytorch/win-test-helpers/build_pytorch.bat

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -138,14 +138,7 @@ python -c "import os, glob; os.system('python -mpip install --no-index --no-deps
138138
if "%BUILD_ENVIRONMENT%"=="" (
139139
echo NOTE: To run `import torch`, please make sure to activate the conda environment by running `call %CONDA_PARENT_DIR%\Miniconda3\Scripts\activate.bat %CONDA_PARENT_DIR%\Miniconda3` in Command Prompt before running Git Bash.
140140
) else (
141-
if "%USE_CUDA%"=="1" (
142-
7z a %TMP_DIR_WIN%\%IMAGE_COMMIT_TAG%.7z %CONDA_PARENT_DIR%\Miniconda3\Lib\site-packages\torch %CONDA_PARENT_DIR%\Miniconda3\Lib\site-packages\torchgen %CONDA_PARENT_DIR%\Miniconda3\Lib\site-packages\functorch %CONDA_PARENT_DIR%\Miniconda3\Lib\site-packages\nvfuser && copy /Y "%TMP_DIR_WIN%\%IMAGE_COMMIT_TAG%.7z" "%PYTORCH_FINAL_PACKAGE_DIR%\"
143-
) else (
144-
7z a %TMP_DIR_WIN%\%IMAGE_COMMIT_TAG%.7z %CONDA_PARENT_DIR%\Miniconda3\Lib\site-packages\torch %CONDA_PARENT_DIR%\Miniconda3\Lib\site-packages\torchgen %CONDA_PARENT_DIR%\Miniconda3\Lib\site-packages\functorch && copy /Y "%TMP_DIR_WIN%\%IMAGE_COMMIT_TAG%.7z" "%PYTORCH_FINAL_PACKAGE_DIR%\"
145-
)
146-
147-
if errorlevel 1 exit /b
148-
if not errorlevel 0 exit /b
141+
copy /Y "dist\*.whl" "%PYTORCH_FINAL_PACKAGE_DIR%"
149142

150143
:: export test times so that potential sharded tests that'll branch off this build will use consistent data
151144
python tools/stats/export_test_times.py

.ci/pytorch/win-test-helpers/setup_pytorch_env.bat

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,13 @@ call %INSTALLER_DIR%\activate_miniconda3.bat
1414
if errorlevel 1 exit /b
1515
if not errorlevel 0 exit /b
1616

17+
:: PyTorch is now installed using the standard wheel on Windows into the conda environment.
18+
:: However, the test scripts are still frequently referring to the workspace temp directory
19+
:: build\torch. Rather than changing all these references, making a copy of torch folder
20+
:: from conda to the current workspace is easier. The workspace will be cleaned up after
21+
:: the job anyway
22+
xcopy /s %CONDA_PARENT_DIR%\Miniconda3\Lib\site-packages\torch %TMP_DIR_WIN%\build\torch\
23+
1724
pushd .
1825
if "%VC_VERSION%" == "" (
1926
call "C:\Program Files (x86)\Microsoft Visual Studio\%VC_YEAR%\%VC_PRODUCT%\VC\Auxiliary\Build\vcvarsall.bat" x64
@@ -48,16 +55,6 @@ set NUMBAPRO_NVVM=%CUDA_PATH%\nvvm\bin\nvvm64_32_0.dll
4855

4956
set PYTHONPATH=%TMP_DIR_WIN%\build;%PYTHONPATH%
5057

51-
if NOT "%BUILD_ENVIRONMENT%"=="" (
52-
pushd %TMP_DIR_WIN%\build
53-
copy /Y %PYTORCH_FINAL_PACKAGE_DIR_WIN%\%IMAGE_COMMIT_TAG%.7z %TMP_DIR_WIN%\
54-
:: 7z: -aos skips if exists because this .bat can be called multiple times
55-
7z x %TMP_DIR_WIN%\%IMAGE_COMMIT_TAG%.7z -aos
56-
popd
57-
) else (
58-
xcopy /s %CONDA_PARENT_DIR%\Miniconda3\Lib\site-packages\torch %TMP_DIR_WIN%\build\torch\
59-
)
60-
6158
@echo off
6259
echo @echo off >> %TMP_DIR_WIN%/ci_scripts/pytorch_env_restore.bat
6360
for /f "usebackq tokens=*" %%i in (`set`) do echo set "%%i" >> %TMP_DIR_WIN%/ci_scripts/pytorch_env_restore.bat

.ci/pytorch/win-test.sh

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,6 @@ SCRIPT_PARENT_DIR=$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )
55
# shellcheck source=./common.sh
66
source "$SCRIPT_PARENT_DIR/common.sh"
77

8-
IMAGE_COMMIT_ID=$(git rev-parse HEAD)
9-
export IMAGE_COMMIT_ID
10-
export IMAGE_COMMIT_TAG=${BUILD_ENVIRONMENT}-${IMAGE_COMMIT_ID}
11-
if [[ ${JOB_NAME} == *"develop"* ]]; then
12-
export IMAGE_COMMIT_TAG=develop-${IMAGE_COMMIT_TAG}
13-
fi
14-
158
export TMP_DIR="${PWD}/build/win_tmp"
169
TMP_DIR_WIN=$(cygpath -w "${TMP_DIR}")
1710
export TMP_DIR_WIN
@@ -21,13 +14,12 @@ export PROJECT_DIR_WIN
2114
export TEST_DIR="${PWD}/test"
2215
TEST_DIR_WIN=$(cygpath -w "${TEST_DIR}")
2316
export TEST_DIR_WIN
24-
export PYTORCH_FINAL_PACKAGE_DIR="${PYTORCH_FINAL_PACKAGE_DIR:-/c/users/circleci/workspace/build-results}"
17+
export PYTORCH_FINAL_PACKAGE_DIR="${PYTORCH_FINAL_PACKAGE_DIR:-/c/w/build-results}"
2518
PYTORCH_FINAL_PACKAGE_DIR_WIN=$(cygpath -w "${PYTORCH_FINAL_PACKAGE_DIR}")
2619
export PYTORCH_FINAL_PACKAGE_DIR_WIN
2720

2821
mkdir -p "$TMP_DIR"/build/torch
2922

30-
3123
# This directory is used only to hold "pytorch_env_restore.bat", called via "setup_pytorch_env.bat"
3224
CI_SCRIPTS_DIR=$TMP_DIR/ci_scripts
3325
mkdir -p "$CI_SCRIPTS_DIR"
@@ -36,7 +28,6 @@ if [ -n "$(ls "$CI_SCRIPTS_DIR"/*)" ]; then
3628
rm "$CI_SCRIPTS_DIR"/*
3729
fi
3830

39-
4031
export SCRIPT_HELPERS_DIR=$SCRIPT_PARENT_DIR/win-test-helpers
4132

4233
if [[ "$TEST_CONFIG" = "force_on_cpu" ]]; then

.github/ci_commit_pins/triton.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
c8bfe3f548b164f745ada620a560f87f41ab8465
1+
3aa3d7024e88e9b18e3ff54eab681adfda37298b

.github/scripts/check_labels.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,16 +75,19 @@ def main() -> None:
7575
org, project = repo.gh_owner_and_name()
7676
pr = GitHubPR(org, project, args.pr_num)
7777

78+
exit_code = 0
7879
try:
7980
if not has_required_labels(pr):
81+
exit_code = 1
8082
print(ERR_MSG)
8183
add_comment(pr)
82-
exit(1)
8384
else:
8485
delete_comments(pr)
8586
except Exception as e:
8687
pass
8788

89+
exit(exit_code)
90+
8891

8992
if __name__ == "__main__":
9093
main()

.github/scripts/export_pytorch_labels.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,24 @@
1414
import json
1515

1616
from label_utils import gh_get_labels
17+
from typing import Any
18+
19+
20+
def parse_args() -> Any:
21+
from argparse import ArgumentParser
22+
parser = ArgumentParser("Export PR labels")
23+
parser.add_argument("org", type=str)
24+
parser.add_argument("repo", type=str)
25+
26+
return parser.parse_args()
1727

1828

1929
def main() -> None:
30+
args = parse_args()
31+
print(f"Exporting labels for {args.org}/{args.repo}")
2032
labels_file_name = "pytorch_labels.json"
2133
obj = boto3.resource('s3').Object('ossci-metrics', labels_file_name)
22-
obj.put(Body=json.dumps(gh_get_labels()).encode())
34+
obj.put(Body=json.dumps(gh_get_labels(args.org, args.repo)).encode())
2335

2436

2537
if __name__ == '__main__':

.github/scripts/trymerge.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,11 @@ def _fetch_url(url: str, *,
456456
return reader(conn)
457457
except HTTPError as err:
458458
if err.code == 403 and all(key in err.headers for key in ['X-RateLimit-Limit', 'X-RateLimit-Used']):
459-
print(f"Rate limit exceeded: {err.headers['X-RateLimit-Used']}/{err.headers['X-RateLimit-Limit']}")
459+
print(f"""Rate limit exceeded:
460+
Used: {err.headers['X-RateLimit-Used']}
461+
Limit: {err.headers['X-RateLimit-Limit']}
462+
Remaining: {err.headers['X-RateLimit-Remaining']}
463+
Resets at: {err.headers['x-RateLimit-Reset']}""")
460464
raise
461465

462466
def _fetch_json_any(

.github/workflows/_win-test.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,11 @@ jobs:
190190
export COMMIT_MESSAGES="${COMMIT_MESSAGES//[\'\"]}"
191191
export PR_BODY="${PR_BODY//[\'\"]}"
192192
193+
pushd "${PYTORCH_FINAL_PACKAGE_DIR}"
194+
# shellcheck disable=SC2046
195+
python3 -mpip install $(echo *.whl)[opt-einsum]
196+
popd
197+
193198
.ci/pytorch/win-test.sh
194199
195200
- name: Print remaining test logs

.github/workflows/update_pytorch_labels.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,4 @@ jobs:
2424
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_OSSCI_METRICS_V2_SECRET_ACCESS_KEY }}
2525
run: |
2626
python3 -m pip install boto3==1.19.12
27-
.github/scripts/export_pytorch_labels.py
27+
.github/scripts/export_pytorch_labels.py pytorch pytorch

0 commit comments

Comments
 (0)