Skip to content

Commit 766cfa8

Browse files
committed
Update on "[inductor] Fix index_reduce_ on view inputs raising AssertionError in assert_functional_graph"
The `_index_fill` decomposition used mutable `empty_like + copy_` to restore strides when `index_copy` returned a contiguous tensor, which broke the functional graph invariant. Replace with the functional `prims.copy_strided` prim that does the same thing as a single op. Fixes #144846 Authored with Claude. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo [ghstack-poisoned]
2 parents 7a3001d + ff8d256 commit 766cfa8

492 files changed

Lines changed: 11701 additions & 4847 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.ci/docker/ci_commit_pins/nccl.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
v2.28.9-1
1+
v2.29.3-1

.ci/lumen_cli/cli/build_cli/register_build.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
import argparse
1+
from __future__ import annotations
2+
3+
import argparse # noqa: TC003
24
import logging
35

46
from cli.lib.common.cli_helper import register_targets, RichHelp, TargetSpec

.ci/lumen_cli/cli/lib/common/cli_helper.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
44
"""
55

6+
from __future__ import annotations
7+
68
import argparse
79
from abc import ABC, abstractmethod
810

@@ -11,7 +13,7 @@
1113
from collections.abc import Callable # Python 3.11+
1214
from typing import Any, Required, TypedDict
1315
except ImportError:
14-
from collections.abc import Callable
16+
from collections.abc import Callable # noqa: TC003
1517
from typing import Any, TypedDict
1618

1719
from typing_extensions import Required # Fallback for Python <3.11

.ci/lumen_cli/cli/lib/common/docker_helper.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
Docker Utility helpers for CLI tasks.
33
"""
44

5+
from __future__ import annotations
6+
57
import logging
6-
from typing import Optional
78

89
import docker
910
from docker.errors import APIError, NotFound
@@ -12,7 +13,7 @@
1213
logger = logging.getLogger(__name__)
1314

1415
# lazy singleton so we don't reconnect every call
15-
_docker_client: Optional[docker.DockerClient] = None
16+
_docker_client: docker.DockerClient | None = None
1617

1718

1819
def _get_client() -> docker.DockerClient:
@@ -23,7 +24,7 @@ def _get_client() -> docker.DockerClient:
2324

2425

2526
def local_image_exists(
26-
image_name: str, client: Optional[docker.DockerClient] = None
27+
image_name: str, client: docker.DockerClient | None = None
2728
) -> bool:
2829
"""Return True if a local Docker image exists."""
2930
if not image_name:

.ci/lumen_cli/cli/lib/common/envs_helper.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22
Environment Variables and Dataclasses Utility helpers for CLI tasks.
33
"""
44

5+
from __future__ import annotations
6+
57
import os
68
from dataclasses import field, fields, is_dataclass, MISSING
79
from pathlib import Path
810
from textwrap import indent
9-
from typing import Optional, Union
1011

1112
from cli.lib.common.utils import str2bool
1213

@@ -18,9 +19,9 @@ def get_env(name: str, default: str = "") -> str:
1819

1920
def env_path_optional(
2021
name: str,
21-
default: Optional[Union[str, Path]] = None,
22+
default: str | Path | None = None,
2223
resolve: bool = True,
23-
) -> Optional[Path]:
24+
) -> Path | None:
2425
"""Get environment variable as optional Path."""
2526
val = get_env(name) or default
2627
if not val:
@@ -32,7 +33,7 @@ def env_path_optional(
3233

3334
def env_path(
3435
name: str,
35-
default: Optional[Union[str, Path]] = None,
36+
default: str | Path | None = None,
3637
resolve: bool = True,
3738
) -> Path:
3839
"""Get environment variable as Path, raise if missing."""
@@ -61,7 +62,7 @@ def env_bool_field(
6162

6263
def env_path_field(
6364
name: str,
64-
default: Union[str, Path] = "",
65+
default: str | Path = "",
6566
*,
6667
resolve: bool = True,
6768
) -> Path:

.ci/lumen_cli/cli/lib/common/path_helper.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,31 @@
11
"""Path utility helpers for CLI tasks."""
22

3+
from __future__ import annotations
4+
35
import logging
46
import shutil
57
from pathlib import Path
6-
from typing import Union
78

89

910
logger = logging.getLogger(__name__)
1011

1112

12-
def get_path(path: Union[str, Path], resolve: bool = False) -> Path:
13+
def get_path(path: str | Path, resolve: bool = False) -> Path:
1314
"""Convert to Path object, optionally resolving to absolute path."""
1415
if not path:
1516
raise ValueError("Path cannot be None or empty")
1617
result = Path(path)
1718
return result.resolve() if resolve else result
1819

1920

20-
def ensure_dir_exists(path: Union[str, Path]) -> Path:
21+
def ensure_dir_exists(path: str | Path) -> Path:
2122
"""Create directory if it doesn't exist."""
2223
path_obj = get_path(path)
2324
path_obj.mkdir(parents=True, exist_ok=True)
2425
return path_obj
2526

2627

27-
def remove_dir(path: Union[str, Path, None]) -> None:
28+
def remove_dir(path: str | Path | None) -> None:
2829
"""Remove directory if it exists."""
2930
if not path:
3031
return
@@ -33,13 +34,13 @@ def remove_dir(path: Union[str, Path, None]) -> None:
3334
shutil.rmtree(path_obj)
3435

3536

36-
def force_create_dir(path: Union[str, Path]) -> Path:
37+
def force_create_dir(path: str | Path) -> Path:
3738
"""Remove directory if exists, then create fresh empty directory."""
3839
remove_dir(path)
3940
return ensure_dir_exists(path)
4041

4142

42-
def copy(src: Union[str, Path], dst: Union[str, Path]) -> None:
43+
def copy(src: str | Path, dst: str | Path) -> None:
4344
"""Copy file or directory from src to dst."""
4445
src_path = get_path(src, resolve=True)
4546
dst_path = get_path(dst, resolve=True)
@@ -57,6 +58,6 @@ def copy(src: Union[str, Path], dst: Union[str, Path]) -> None:
5758
raise ValueError(f"Unsupported path type: {src_path}")
5859

5960

60-
def is_path_exist(path: Union[str, Path, None]) -> bool:
61+
def is_path_exist(path: str | Path | None) -> bool:
6162
"""Check if path exists."""
6263
return bool(path and get_path(path).exists())

.ci/lumen_cli/cli/lib/common/pip_helper.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1+
from __future__ import annotations
2+
13
import glob
24
import logging
35
import shlex
46
import shutil
57
import sys
6-
from collections.abc import Iterable
8+
from collections.abc import Iterable # noqa: TC003
79
from importlib.metadata import PackageNotFoundError, version # noqa: UP035
8-
from typing import Optional, Union
910

1011
from cli.lib.common.utils import run_command
1112

@@ -17,8 +18,8 @@ def pip_install_packages(
1718
packages: Iterable[str] = (),
1819
env=None,
1920
*,
20-
requirements: Optional[str] = None,
21-
constraints: Optional[str] = None,
21+
requirements: str | None = None,
22+
constraints: str | None = None,
2223
prefer_uv: bool = False,
2324
) -> None:
2425
use_uv = prefer_uv and shutil.which("uv") is not None
@@ -37,14 +38,14 @@ def pip_install_packages(
3738
run_command(" ".join(map(shlex.quote, cmd)), env=env)
3839

3940

40-
def pip_install_first_match(pattern: str, extras: Optional[str] = None, pref_uv=False):
41+
def pip_install_first_match(pattern: str, extras: str | None = None, pref_uv=False):
4142
wheel = first_matching_pkg(pattern)
4243
target = f"{wheel}[{extras}]" if extras else wheel
4344
logger.info("Installing %s...", target)
4445
pip_install_packages([target], prefer_uv=pref_uv)
4546

4647

47-
def run_python(args: Union[str, list[str]], env=None):
48+
def run_python(args: str | list[str], env=None):
4849
"""
4950
Run the python in the current environment.
5051
"""

.ci/lumen_cli/cli/lib/common/utils.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22
General Utility helpers for CLI tasks.
33
"""
44

5+
from __future__ import annotations
6+
57
import logging
68
import os
79
import shlex
810
import subprocess
911
import sys
1012
from contextlib import contextmanager
1113
from pathlib import Path
12-
from typing import Optional
1314

1415

1516
logger = logging.getLogger(__name__)
@@ -19,8 +20,8 @@ def run_command(
1920
cmd: str,
2021
use_shell: bool = False,
2122
log_cmd: bool = True,
22-
cwd: Optional[str] = None,
23-
env: Optional[dict] = None,
23+
cwd: str | None = None,
24+
env: dict | None = None,
2425
check: bool = True,
2526
) -> int:
2627
"""Run a command with optional shell execution."""
@@ -61,7 +62,7 @@ def run_command(
6162
return proc.returncode
6263

6364

64-
def str2bool(value: Optional[str]) -> bool:
65+
def str2bool(value: str | None) -> bool:
6566
"""Convert environment variables to boolean values."""
6667
if not value:
6768
return False
@@ -120,7 +121,7 @@ def working_directory(path: str):
120121

121122
def get_wheels(
122123
output_dir: Path,
123-
max_depth: Optional[int] = None,
124+
max_depth: int | None = None,
124125
) -> list[str]:
125126
"""Return a list of wheels found in the given output directory."""
126127
root = Path(output_dir)
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import logging
2+
from pathlib import Path
3+
from typing import Any
4+
5+
import yaml
6+
from cli.lib.common.git_helper import clone_external_repo
7+
from cli.lib.common.utils import run_command, temp_environ, working_directory
8+
9+
10+
logger = logging.getLogger(__name__)
11+
12+
_TORCHTITAN_TEST_LIBRARY_PATH = Path(__file__).parent / "torchtitan_test_library.yaml"
13+
14+
15+
def _load_torchtitan_test_library_yaml() -> dict[str, Any]:
16+
if not _TORCHTITAN_TEST_LIBRARY_PATH.exists():
17+
raise FileNotFoundError(
18+
f"torchtitan test library YAML not found: {_TORCHTITAN_TEST_LIBRARY_PATH}"
19+
)
20+
with open(_TORCHTITAN_TEST_LIBRARY_PATH, encoding="utf-8") as f:
21+
return yaml.safe_load(f)
22+
23+
24+
def load_torchtitan_test_library() -> dict[str, Any]:
25+
return _load_torchtitan_test_library_yaml()
26+
27+
28+
def clone_torchtitan(dst: str = "torchtitan"):
29+
_, commit = clone_external_repo(
30+
target="torchtitan",
31+
repo="https://github.com/pytorch/torchtitan.git",
32+
dst=dst,
33+
)
34+
return commit
35+
36+
37+
def run_test_plan(
38+
test_plan: str,
39+
tests_map: dict[str, Any],
40+
):
41+
logger.info("Running torchtitan test plan: %s", test_plan)
42+
if test_plan not in tests_map:
43+
raise RuntimeError(
44+
f"test plan '{test_plan}' not found in torchtitan test library"
45+
)
46+
47+
tests = tests_map[test_plan]
48+
title = tests.get("title", "unknown test")
49+
logger.info("Running tests: %s", title)
50+
51+
with (
52+
working_directory(tests.get("working_directory", "")),
53+
temp_environ(tests.get("env_vars", {})),
54+
):
55+
failures = []
56+
for step in tests["steps"]:
57+
logger.info("Running step: %s", step)
58+
code = run_command(cmd=step, check=False, use_shell=True)
59+
if code != 0:
60+
failures.append(step)
61+
logger.info("Finished step: %s", step)
62+
if failures:
63+
logger.error("Failed steps: %s", failures)
64+
raise RuntimeError(f"{len(failures)} test steps failed: {failures}")
65+
logger.info("All tests passed for plan: %s", test_plan)
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import logging
2+
from typing import Any
3+
4+
from cli.lib.common.cli_helper import BaseRunner
5+
from cli.lib.common.pip_helper import pip_install_packages
6+
from cli.lib.common.utils import working_directory
7+
from cli.lib.core.torchtitan.lib import (
8+
clone_torchtitan,
9+
load_torchtitan_test_library,
10+
run_test_plan,
11+
)
12+
13+
14+
logger = logging.getLogger(__name__)
15+
16+
17+
class TorchtitanTestRunner(BaseRunner):
18+
def __init__(self, args: Any):
19+
self.work_directory = "torchtitan"
20+
self.test_plan = args.test_plan
21+
22+
def prepare(self):
23+
clone_torchtitan(dst=self.work_directory)
24+
# torchao nightly is required by torchtitan
25+
pip_install_packages(
26+
packages=[
27+
"--pre",
28+
"torchao",
29+
"--index-url",
30+
"https://download.pytorch.org/whl/nightly/cu129",
31+
],
32+
)
33+
with working_directory(self.work_directory):
34+
pip_install_packages(packages=["-e", "."])
35+
pip_install_packages(packages=["pytest", "pytest-cov"])
36+
37+
def run(self):
38+
self.prepare()
39+
with working_directory(self.work_directory):
40+
run_test_plan(self.test_plan, load_torchtitan_test_library())

0 commit comments

Comments
 (0)