|
25 | 25 | package_type = os.getenv("MATRIX_PACKAGE_TYPE") |
26 | 26 | target_os = os.getenv("TARGET_OS", sys.platform) |
27 | 27 | BASE_DIR = Path(__file__).parent.parent.parent |
| 28 | +PYTORCH_ROOT = BASE_DIR.parent |
28 | 29 |
|
29 | 30 | is_cuda_system = gpu_arch_type == "cuda" |
30 | 31 | NIGHTLY_ALLOWED_DELTA = 3 |
@@ -216,6 +217,89 @@ def find_pypi_package_version(package: str) -> str | None: |
216 | 217 | return None |
217 | 218 |
|
218 | 219 |
|
| 220 | +def get_expected_cudnn_version_linux(cuda_version: str) -> str | None: |
| 221 | + """Parse expected cuDNN version from generate_binary_build_matrix.py for Linux. |
| 222 | +
|
| 223 | + Reads PYTORCH_EXTRA_INSTALL_REQUIREMENTS and extracts the cudnn version |
| 224 | + for the given CUDA version (e.g. "12.6"). |
| 225 | + """ |
| 226 | + matrix_script = ( |
| 227 | + PYTORCH_ROOT / ".github" / "scripts" / "generate_binary_build_matrix.py" |
| 228 | + ) |
| 229 | + if not matrix_script.exists(): |
| 230 | + print(f"Warning: {matrix_script} not found, skipping cuDNN version check") |
| 231 | + return None |
| 232 | + |
| 233 | + content = matrix_script.read_text() |
| 234 | + # Match the full cudnn package version like nvidia-cudnn-cu12==9.10.2.21 |
| 235 | + # and extract major.minor.patch (dropping the build number) |
| 236 | + pattern = ( |
| 237 | + rf'"{re.escape(cuda_version)}":\s*\(\s*' |
| 238 | + r"[\s\S]*?nvidia-cudnn-cu\d+==(\d+\.\d+\.\d+)\.\d+" |
| 239 | + ) |
| 240 | + match = re.search(pattern, content) |
| 241 | + if match: |
| 242 | + return match.group(1) |
| 243 | + return None |
| 244 | + |
| 245 | + |
| 246 | +def get_expected_cudnn_version_windows(cuda_version: str) -> str | None: |
| 247 | + """Parse expected cuDNN version from cuda_install.bat for Windows. |
| 248 | +
|
| 249 | + Reads the batch file and extracts EXPECTED_CUDNN_VERSION for the given |
| 250 | + CUDA version (e.g. "12.6" maps to CUDA_VER 126). |
| 251 | + """ |
| 252 | + bat_file = ( |
| 253 | + PYTORCH_ROOT / ".ci" / "pytorch" / "windows" / "internal" / "cuda_install.bat" |
| 254 | + ) |
| 255 | + if not bat_file.exists(): |
| 256 | + print(f"Warning: {bat_file} not found, skipping cuDNN version check") |
| 257 | + return None |
| 258 | + |
| 259 | + content = bat_file.read_text() |
| 260 | + # Convert "12.6" to "126" to match batch file's CUDA_VER format |
| 261 | + cuda_ver_nodot = cuda_version.replace(".", "") |
| 262 | + # Match: if %CUDA_VER% EQU 126 ( ... set EXPECTED_CUDNN_VERSION=9.10.2 ) |
| 263 | + pattern = ( |
| 264 | + rf"if %CUDA_VER% EQU {re.escape(cuda_ver_nodot)}\s*\(" |
| 265 | + r"[\s\S]*?set EXPECTED_CUDNN_VERSION=(\d+\.\d+\.\d+)" |
| 266 | + ) |
| 267 | + match = re.search(pattern, content) |
| 268 | + if match: |
| 269 | + return match.group(1) |
| 270 | + return None |
| 271 | + |
| 272 | + |
| 273 | +def check_cudnn_version(cuda_version: str, actual_cudnn_version: str) -> None: |
| 274 | + """Validate cuDNN version matches expected version from build config files.""" |
| 275 | + if sys.platform in ["linux", "linux2"]: |
| 276 | + expected = get_expected_cudnn_version_linux(cuda_version) |
| 277 | + source = "generate_binary_build_matrix.py" |
| 278 | + elif sys.platform == "win32": |
| 279 | + expected = get_expected_cudnn_version_windows(cuda_version) |
| 280 | + source = "cuda_install.bat" |
| 281 | + else: |
| 282 | + print(f"cuDNN version check not supported on platform {sys.platform}") |
| 283 | + return |
| 284 | + |
| 285 | + if expected is None: |
| 286 | + print( |
| 287 | + f"Warning: Could not determine expected cuDNN version for CUDA {cuda_version} " |
| 288 | + f"from {source}, skipping validation" |
| 289 | + ) |
| 290 | + return |
| 291 | + |
| 292 | + if not actual_cudnn_version.startswith(expected): |
| 293 | + raise RuntimeError( |
| 294 | + f"cuDNN version mismatch for CUDA {cuda_version}. " |
| 295 | + f"Loaded: {actual_cudnn_version} Expected: {expected} (from {source})" |
| 296 | + ) |
| 297 | + print( |
| 298 | + f"cuDNN version check passed: {actual_cudnn_version} matches " |
| 299 | + f"expected {expected} from {source}" |
| 300 | + ) |
| 301 | + |
| 302 | + |
219 | 303 | def cudnn_to_version_str(cudnn_version: int) -> str: |
220 | 304 | patch = int(cudnn_version % 10) |
221 | 305 | minor = int((cudnn_version / 100) % 100) |
@@ -294,6 +378,8 @@ def smoke_test_cuda( |
294 | 378 | f"Expected: {torch_cudnn_compile_version}" |
295 | 379 | ) |
296 | 380 |
|
| 381 | + check_cudnn_version(gpu_arch_ver, torch_cudnn_version) |
| 382 | + |
297 | 383 | if sys.platform in ["linux", "linux2"]: |
298 | 384 | torch_nccl_version = ".".join(str(v) for v in torch.cuda.nccl.version()) |
299 | 385 | print(f"Torch nccl; version: {torch_nccl_version}") |
|
0 commit comments