Skip to content

Commit 3fb1255

Browse files
committed
Update on "[pt2 bug bash] Fix nn.functional.pad compile crash with deterministic mode + replication padding"
Fixes #170079 ## Context `torch.compile(ReplicationPad1d(...), fullgraph=True)` crashes when `torch.use_deterministic_algorithms(True)` is set on CUDA. The error: Dynamo can't trace through `importlib.import_module`. The deterministic code path exists because the native `replication_pad1d_backward` CUDA kernel uses `atomicAdd` (non-deterministic). `functional.py` calls `_replication_pad` — a Python decomposition using `_unsafe_index`, whose backward uses `index_put` (deterministic). ## Dynamo limitations encountered Three separate Dynamo tracing barriers prevented calling `_replication_pad` directly: ### 1. `importlib.import_module` is marked as skipped ```python torch.compile(fullgraph=True) def fn(x): import importlib return importlib.import_module("torch").sin(x) fn(torch.randn(3)) # Unsupported: function marked as skipped ``` ### 2. `elementwise_dtypes` returns non-Tensor (from `pw_cast_for_opmath`) ```python torch.compile(fullgraph=True) def fn(x): from torch._prims_common import elementwise_dtypes, ELEMENTWISE_TYPE_PROMOTION_KIND dt, _ = elementwise_dtypes(x, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT) return x.to(dt) fn(torch.randn(3)) # Unsupported: torch.* op returned non-Tensor ``` ### 3. `torch._check` with closure lambda ```python torch.compile(fullgraph=True) def fn(x): dim = x.dim() torch._check(dim in (2, 3), lambda: f"expected 2D or 3D, got {dim}D") return x + 1 fn(torch.randn(3, 3)) # Unsupported: Can't extract message from torch._check() ``` ## Iteration log | # | Approach | Who | Tests | Reviewer pushback | Why it failed | |---|----------|-----|-------|-------------------|---------------| | 1 | Replace `importlib` with `from...import` | Claude | bilinear/trilinear pass, replicate fails | "why do we need bilinear/trilinear tests?" — scoped fix to reported bug only | Hit limitation #2: `pw_cast_for_opmath` | | 2 | Skip decomposition under compile via `is_compiling()`, rely on AOTAutograd's `register_decomposition` | Claude | forward-only `backend="eager"` passes | "can you verify at inductor level this is actually deterministic?" — inspect AOT graph | No backward decomposition registered; backward still uses native `replication_pad1d_backward` (non-deterministic) | | 3 | Unwrap `pw_cast_for_opmath` via `__wrapped__` | Claude | N/A — fails immediately | N/A | Hit limitation #3: `torch._check()` closure | | 4 | `nonstrict_trace` — Dynamo skips body, AOTAutograd traces through | Reviewer suggestion | `backend="aot_eager"`, forward + backward under `DeterministicGuard(True)` | N/A — fix is correct | N/A | ## Key insight The fix isn't about making Dynamo trace the decomposition or skipping it entirely — it's about putting the boundary in the right place. Dynamo doesn't need to see inside; AOTAutograd does. `nonstrict_trace` is exactly this boundary. Each "obvious" fix had passing tests that weren't testing the right thing. Only when the reviewer pushed for backward determinism verification and AOT graph inspection did the weaknesses surface. The backward completing without error under `DeterministicGuard(True)` proves determinism — PyTorch explicitly raises `RuntimeError` if any non-deterministic CUDA kernel executes under this mode. Authored with Claude. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo [ghstack-poisoned]
2 parents ccd55a3 + f7c6a9b commit 3fb1255

487 files changed

Lines changed: 20854 additions & 6502 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.bazelrc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
build --cxxopt=--std=c++17
1+
build --cxxopt=--std=c++20
22
build --copt=-I.
33
# Bazel does not support including its cc_library targets as system
44
# headers. We work around this for generated code

.ci/docker/build.sh

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,12 +195,16 @@ case "$tag" in
195195
NINJA_VERSION=1.9.0
196196
TRITON=yes
197197
;;
198-
pytorch-linux-noble-xpu-n-py3 | pytorch-linux-noble-xpu-n-py3-inductor-benchmarks)
198+
pytorch-linux-noble-xpu-n-py3 | pytorch-linux-noble-xpu-n-py3-client | pytorch-linux-noble-xpu-n-py3-inductor-benchmarks)
199199
ANACONDA_PYTHON_VERSION=3.10
200200
GCC_VERSION=13
201201
VISION=yes
202202
XPU_VERSION=2025.3
203-
XPU_DRIVER_TYPE=LTS
203+
if [[ $tag =~ "client" ]]; then
204+
XPU_DRIVER_TYPE=CLIENT
205+
else
206+
XPU_DRIVER_TYPE=LTS
207+
fi
204208
NINJA_VERSION=1.9.0
205209
TRITON=yes
206210
if [[ $tag =~ "benchmarks" ]]; then

.ci/docker/ci_commit_pins/nccl.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
v2.29.3-1
1+
v2.29.7-1
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
307748db7742a0f8259a7ea0336909eb55d2051a
1+
33f782efa9464adebb448ea1f1df1a64ec37ceb0

.ci/docker/common/install_cuda.sh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ function install_126 {
111111
}
112112

113113
function install_129 {
114-
CUDNN_VERSION=9.17.1.4
114+
CUDNN_VERSION=9.20.0.48
115115
echo "Installing CUDA 12.9.1 and cuDNN ${CUDNN_VERSION} and NVSHMEM and NCCL and cuSparseLt-0.7.1"
116116
# install CUDA 12.9.1 in the same container
117117
install_cuda 12.9.1 cuda_12.9.1_575.57.08_linux
@@ -129,7 +129,7 @@ function install_129 {
129129
}
130130

131131
function install_128 {
132-
CUDNN_VERSION=9.19.0.56
132+
CUDNN_VERSION=9.20.0.48
133133
echo "Installing CUDA 12.8.1 and cuDNN ${CUDNN_VERSION} and NVSHMEM and NCCL and cuSparseLt-0.7.1"
134134
# install CUDA 12.8.1 in the same container
135135
install_cuda 12.8.1 cuda_12.8.1_570.124.06_linux
@@ -147,7 +147,7 @@ function install_128 {
147147
}
148148

149149
function install_130 {
150-
CUDNN_VERSION=9.19.0.56
150+
CUDNN_VERSION=9.20.0.48
151151
echo "Installing CUDA 13.0 and cuDNN ${CUDNN_VERSION} and NVSHMEM and NCCL and cuSparseLt-0.8.0"
152152
# install CUDA 13.0 in the same container
153153
install_cuda 13.0.2 cuda_13.0.2_580.95.05_linux
@@ -165,7 +165,7 @@ function install_130 {
165165
}
166166

167167
function install_132 {
168-
CUDNN_VERSION=9.19.0.56
168+
CUDNN_VERSION=9.20.0.48
169169
echo "Installing CUDA 13.2 and cuDNN ${CUDNN_VERSION} and NVSHMEM and NCCL and cuSparseLt-0.8.0"
170170
# install CUDA 13.2 in the same container
171171
install_cuda 13.2.0 cuda_13.2.0_595.45.04_linux

.ci/docker/requirements-docs.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@ sphinx==7.2.6
22
#Description: This is used to generate PyTorch docs
33
#Pinned versions: 7.2.6
44

5-
pytorch_sphinx_theme2==0.4.3
5+
pytorch_sphinx_theme2==0.4.6
66
#Description: This is needed to generate PyTorch docs
7-
#Pinned versions: 0.4.3
7+
#Pinned versions: 0.4.6
88

99
# TODO: sphinxcontrib.katex 0.9.0 adds a local KaTeX server to speed up pre-rendering
1010
# but it doesn't seem to work and hangs around idly. The initial thought that it is probably

.ci/lumen_cli/cli/lib/core/torchtitan/torchtitan_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,12 @@ def __init__(self, args: Any):
2121

2222
def prepare(self):
2323
clone_torchtitan(dst=self.work_directory)
24-
# torchao nightly is required by torchtitan
24+
# torchao and torchcomms nightlies are required by torchtitan
2525
pip_install_packages(
2626
packages=[
2727
"--pre",
2828
"torchao",
29+
"torchcomms",
2930
"--index-url",
3031
"https://download.pytorch.org/whl/nightly/cu129",
3132
],

.ci/magma/Makefile

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ DOCKER_RUN = set -eou pipefail; ${DOCKER_CMD} run --rm -i \
1616
magma/build_magma.sh
1717

1818
.PHONY: all
19+
all: magma-cuda132
1920
all: magma-cuda130
2021
all: magma-cuda129
2122
all: magma-cuda128
@@ -26,6 +27,12 @@ clean:
2627
$(RM) -r magma-*
2728
$(RM) -r output
2829

30+
.PHONY: magma-cuda132
31+
magma-cuda132: DESIRED_CUDA := 13.2
32+
magma-cuda132: CUDA_ARCH_LIST := -gencode arch=compute_80,code=sm_80 -gencode arch=compute_86,code=sm_86 -gencode arch=compute_90,code=sm_90 -gencode arch=compute_100,code=sm_100 -gencode arch=compute_120,code=sm_120
33+
magma-cuda132:
34+
$(DOCKER_RUN)
35+
2936
.PHONY: magma-cuda130
3037
magma-cuda130: DESIRED_CUDA := 13.0
3138
magma-cuda130: CUDA_ARCH_LIST := -gencode arch=compute_80,code=sm_80 -gencode arch=compute_86,code=sm_86 -gencode arch=compute_90,code=sm_90 -gencode arch=compute_100,code=sm_100 -gencode arch=compute_120,code=sm_120

.ci/pytorch/common_utils.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ function install_flash_attn_cute() {
320320
git checkout "${flash_attn_commit}"
321321

322322
# Install only the 'cute' sub-directory
323-
pip_install -e flash_attn/cute/
323+
pip_install flash_attn/cute/
324324
popd
325325

326326
# remove the local repo
@@ -367,7 +367,7 @@ function install_cutlass_api() {
367367
git checkout "${cutlass_commit}"
368368

369369
# Install cutlass_api with torch extras
370-
pip_install -e "python/cutlass_api[torch]"
370+
pip_install "python/cutlass_api[torch]"
371371
popd
372372

373373
rm -rf cutlass-build

.ci/pytorch/smoke_test/smoke_test.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
package_type = os.getenv("MATRIX_PACKAGE_TYPE")
2626
target_os = os.getenv("TARGET_OS", sys.platform)
2727
BASE_DIR = Path(__file__).parent.parent.parent
28+
PYTORCH_ROOT = BASE_DIR.parent
2829

2930
is_cuda_system = gpu_arch_type == "cuda"
3031
NIGHTLY_ALLOWED_DELTA = 3
@@ -216,6 +217,89 @@ def find_pypi_package_version(package: str) -> str | None:
216217
return None
217218

218219

220+
def get_expected_cudnn_version_linux(cuda_version: str) -> str | None:
221+
"""Parse expected cuDNN version from generate_binary_build_matrix.py for Linux.
222+
223+
Reads PYTORCH_EXTRA_INSTALL_REQUIREMENTS and extracts the cudnn version
224+
for the given CUDA version (e.g. "12.6").
225+
"""
226+
matrix_script = (
227+
PYTORCH_ROOT / ".github" / "scripts" / "generate_binary_build_matrix.py"
228+
)
229+
if not matrix_script.exists():
230+
print(f"Warning: {matrix_script} not found, skipping cuDNN version check")
231+
return None
232+
233+
content = matrix_script.read_text()
234+
# Match the full cudnn package version like nvidia-cudnn-cu12==9.10.2.21
235+
# and extract major.minor.patch (dropping the build number)
236+
pattern = (
237+
rf'"{re.escape(cuda_version)}":\s*\(\s*'
238+
r"[\s\S]*?nvidia-cudnn-cu\d+==(\d+\.\d+\.\d+)\.\d+"
239+
)
240+
match = re.search(pattern, content)
241+
if match:
242+
return match.group(1)
243+
return None
244+
245+
246+
def get_expected_cudnn_version_windows(cuda_version: str) -> str | None:
247+
"""Parse expected cuDNN version from cuda_install.bat for Windows.
248+
249+
Reads the batch file and extracts EXPECTED_CUDNN_VERSION for the given
250+
CUDA version (e.g. "12.6" maps to CUDA_VER 126).
251+
"""
252+
bat_file = (
253+
PYTORCH_ROOT / ".ci" / "pytorch" / "windows" / "internal" / "cuda_install.bat"
254+
)
255+
if not bat_file.exists():
256+
print(f"Warning: {bat_file} not found, skipping cuDNN version check")
257+
return None
258+
259+
content = bat_file.read_text()
260+
# Convert "12.6" to "126" to match batch file's CUDA_VER format
261+
cuda_ver_nodot = cuda_version.replace(".", "")
262+
# Match: if %CUDA_VER% EQU 126 ( ... set EXPECTED_CUDNN_VERSION=9.10.2 )
263+
pattern = (
264+
rf"if %CUDA_VER% EQU {re.escape(cuda_ver_nodot)}\s*\("
265+
r"[\s\S]*?set EXPECTED_CUDNN_VERSION=(\d+\.\d+\.\d+)"
266+
)
267+
match = re.search(pattern, content)
268+
if match:
269+
return match.group(1)
270+
return None
271+
272+
273+
def check_cudnn_version(cuda_version: str, actual_cudnn_version: str) -> None:
274+
"""Validate cuDNN version matches expected version from build config files."""
275+
if sys.platform in ["linux", "linux2"]:
276+
expected = get_expected_cudnn_version_linux(cuda_version)
277+
source = "generate_binary_build_matrix.py"
278+
elif sys.platform == "win32":
279+
expected = get_expected_cudnn_version_windows(cuda_version)
280+
source = "cuda_install.bat"
281+
else:
282+
print(f"cuDNN version check not supported on platform {sys.platform}")
283+
return
284+
285+
if expected is None:
286+
print(
287+
f"Warning: Could not determine expected cuDNN version for CUDA {cuda_version} "
288+
f"from {source}, skipping validation"
289+
)
290+
return
291+
292+
if not actual_cudnn_version.startswith(expected):
293+
raise RuntimeError(
294+
f"cuDNN version mismatch for CUDA {cuda_version}. "
295+
f"Loaded: {actual_cudnn_version} Expected: {expected} (from {source})"
296+
)
297+
print(
298+
f"cuDNN version check passed: {actual_cudnn_version} matches "
299+
f"expected {expected} from {source}"
300+
)
301+
302+
219303
def cudnn_to_version_str(cudnn_version: int) -> str:
220304
patch = int(cudnn_version % 10)
221305
minor = int((cudnn_version / 100) % 100)
@@ -294,6 +378,8 @@ def smoke_test_cuda(
294378
f"Expected: {torch_cudnn_compile_version}"
295379
)
296380

381+
check_cudnn_version(gpu_arch_ver, torch_cudnn_version)
382+
297383
if sys.platform in ["linux", "linux2"]:
298384
torch_nccl_version = ".".join(str(v) for v in torch.cuda.nccl.version())
299385
print(f"Torch nccl; version: {torch_nccl_version}")

0 commit comments

Comments
 (0)