Skip to content

Commit f72eb5a

Browse files
drisspgpytorchmergebot
authored andcommitted
__grid__constant is only suported on cuda version >= 11.8 (#121275)
Summary: Update the macros to exclude using __grid__constant on compiling for devices > sm80 but cuda version < 11.8. Test Plan: buck2 build --keep-going --config buck2.log_configured_graph_size=true --flagfile fbcode//mode/dev fbcode//sigrid/predictor/client/python:ig_sigrid_client_pybinding Differential Revision: D54556796 Co-authored-by: Driss Guessous <drisspg@meta.com> Pull Request resolved: #121275 Approved by: https://github.com/drisspg
1 parent dad1b76 commit f72eb5a

2 files changed

Lines changed: 8 additions & 0 deletions

File tree

aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ namespace pytorch_flash {
1616
// Determine if the architecture supports FLASH and define a macro to handle parameter modifiers
1717
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
1818
#define ARCH_SUPPORTS_FLASH
19+
#endif
20+
21+
#if defined(ARCH_SUPPORTS_FLASH) && defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 11 && \
22+
defined(__CUDACC_VER_MINOR__) && __CUDACC_VER_MINOR__ >= 8
1923
#define KERNEL_PARAM_MODIFIER __grid_constant__
2024
#else
2125
#define KERNEL_PARAM_MODIFIER

aten/src/ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ namespace pytorch_flash {
1515
// Determine if the architecture supports FLASH and define a macro to handle parameter modifiers
1616
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
1717
#define ARCH_SUPPORTS_FLASH
18+
#endif
19+
20+
#if defined(ARCH_SUPPORTS_FLASH) && defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 11 && \
21+
defined(__CUDACC_VER_MINOR__) && __CUDACC_VER_MINOR__ >= 8
1822
#define KERNEL_PARAM_MODIFIER __grid_constant__
1923
#else
2024
#define KERNEL_PARAM_MODIFIER

0 commit comments

Comments
 (0)