Skip to content

Add CUDA_VERSION check for cudaGetDriverEntryPointByVersion#4447

Merged
wujingyue merged 3 commits intomainfrom
wjy/driver
May 14, 2025
Merged

Add CUDA_VERSION check for cudaGetDriverEntryPointByVersion#4447
wujingyue merged 3 commits intomainfrom
wjy/driver

Conversation

@wujingyue
Copy link
Collaborator

@wujingyue wujingyue commented May 14, 2025

@wujingyue
Copy link
Collaborator Author

!test

@github-actions
Copy link

github-actions bot commented May 14, 2025

Review updated until commit 56f1b73

Description

  • Extracted getDriverEntryPoint function for better modularity

  • Added CUDA_VERSION check to use appropriate driver entry point function

  • Updated DEFINE_DRIVER_API_WRAPPER macro to use new getDriverEntryPoint function


Changes walkthrough 📝

Relevant files
Enhancement
driver_api.cpp
Refactor driver API wrapper with version check                     

csrc/driver_api.cpp

  • Extracted getDriverEntryPoint function to handle driver entry point
    loading
  • Added CUDA_VERSION check to use cudaGetDriverEntryPointByVersion if
    CUDA version is 12.5 or higher
  • Updated DEFINE_DRIVER_API_WRAPPER macro to utilize the new
    getDriverEntryPoint function
  • +43/-28 

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 No relevant tests
    ⚡ Recommended focus areas for review

    Possible Issue

    The getDriverEntryPoint function uses cudaGetDriverEntryPointByVersion when CUDA_VERSION >= 12050, but the version parameter is not used in the cudaGetDriverEntryPoint call for older versions. This might lead to incorrect behavior if the version is expected to be used in some way.

    namespace {
    void getDriverEntryPoint(
        const char* symbol,
        unsigned int version,
        void** entry_point) {
    #if (CUDA_VERSION >= 12050)
      NVFUSER_CUDA_RT_SAFE_CALL(cudaGetDriverEntryPointByVersion(
          symbol, entry_point, version, cudaEnableDefault));
    #else
      (void)version;
      NVFUSER_CUDA_RT_SAFE_CALL(
          cudaGetDriverEntryPoint(symbol, entry_point, cudaEnableDefault));
    #endif
    Code Duplication

    The DEFINE_DRIVER_API_WRAPPER macro has been redefined with significant changes. It would be beneficial to ensure that the new macro does not introduce code duplication and that the old logic is not unnecessarily repeated.

    #define DEFINE_DRIVER_API_WRAPPER(funcName, version)            \
      namespace {                                                   \
      template <typename ReturnType, typename... Args>              \
      struct funcName##Loader {                                     \
        static ReturnType lazilyLoadAndInvoke(Args... args) {       \
          static decltype(::funcName)* entry_point;                 \
          static std::once_flag once;                               \
          std::call_once(                                           \
              once,                                                 \
              getDriverEntryPoint,                                  \
              #funcName,                                            \
              version,                                              \
              reinterpret_cast<void**>(&entry_point));              \
          return entry_point(args...);                              \
        }                                                           \
        /* This ctor is just a CTAD helper, it is only used in a */ \
        /* non-evaluated environment*/                              \
        funcName##Loader(ReturnType(Args...)){};                    \
      };                                                            \
                                                                    \
      /* Use CTAD rule to deduct return and argument types */       \
      template <typename ReturnType, typename... Args>              \
      funcName##Loader(ReturnType(Args...))                         \
          ->funcName##Loader<ReturnType, Args...>;                  \
      } /* namespace */                                             \
                                                                    \
      decltype(::funcName)* funcName =                              \
    Documentation

    The new getDriverEntryPoint function and the changes in DEFINE_DRIVER_API_WRAPPER macro should be documented to explain the rationale behind the changes and how they address the performance or functionality issues.

    namespace {
    void getDriverEntryPoint(
        const char* symbol,
        unsigned int version,
        void** entry_point) {
    #if (CUDA_VERSION >= 12050)
      NVFUSER_CUDA_RT_SAFE_CALL(cudaGetDriverEntryPointByVersion(
          symbol, entry_point, version, cudaEnableDefault));
    #else
      (void)version;
      NVFUSER_CUDA_RT_SAFE_CALL(
          cudaGetDriverEntryPoint(symbol, entry_point, cudaEnableDefault));
    #endif
    }
    } // namespace
    
    #define DEFINE_DRIVER_API_WRAPPER(funcName, version)            \
      namespace {                                                   \
      template <typename ReturnType, typename... Args>              \
      struct funcName##Loader {                                     \
        static ReturnType lazilyLoadAndInvoke(Args... args) {       \
          static decltype(::funcName)* entry_point;                 \
          static std::once_flag once;                               \
          std::call_once(                                           \
              once,                                                 \
              getDriverEntryPoint,                                  \
              #funcName,                                            \
              version,                                              \
              reinterpret_cast<void**>(&entry_point));              \
          return entry_point(args...);                              \
        }                                                           \
        /* This ctor is just a CTAD helper, it is only used in a */ \
        /* non-evaluated environment*/                              \
        funcName##Loader(ReturnType(Args...)){};                    \
      };                                                            \
                                                                    \
      /* Use CTAD rule to deduct return and argument types */       \
      template <typename ReturnType, typename... Args>              \
      funcName##Loader(ReturnType(Args...))                         \
          ->funcName##Loader<ReturnType, Args...>;                  \
      } /* namespace */                                             \
                                                                    \
      decltype(::funcName)* funcName =                              \

    @wujingyue
    Copy link
    Collaborator Author

    !test

    @wujingyue wujingyue merged commit 77b50ba into main May 14, 2025
    53 checks passed
    @wujingyue wujingyue deleted the wjy/driver branch May 14, 2025 16:10
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants