Skip to content

Lazy load driver APIs using cudaGetDriverEntryPoint#4197

Merged
wujingyue merged 7 commits intomainfrom
wjy/writevalue
Apr 10, 2025
Merged

Lazy load driver APIs using cudaGetDriverEntryPoint#4197
wujingyue merged 7 commits intomainfrom
wjy/writevalue

Conversation

@wujingyue
Copy link
Collaborator

@wujingyue wujingyue commented Apr 4, 2025

This is apparently more robust than #4196 because it doesn't hard code the version.

Fixes #3907

cc @samnordmann

@github-actions
Copy link

github-actions bot commented Apr 4, 2025

Review updated until commit 2c4c299

Description

  • Updated driver API loading to use cudaGetDriverEntryPoint

  • Applied changes to all driver APIs

  • Cleaned up and organized includes


Changes walkthrough 📝

Relevant files
Enhancement
driver_api.cpp
Update driver API loading mechanism                                           

csrc/driver_api.cpp

  • Updated macro to use cudaGetDriverEntryPoint for lazy loading
  • Added static variables and std::once_flag for thread-safe
    initialization
  • Organized and cleaned up includes
  • +34/-25 
    driver_api.h
    Update driver API declarations                                                     

    csrc/driver_api.h

  • Updated macro to use PFN_ for function pointers
  • Corrected function name in CUDA 12+ macro
  • Organized and cleaned up includes
  • +3/-3     

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 No relevant tests
    ⚡ Recommended focus areas for review

    Possible Issue

    The use of cudaGetDriverEntryPoint with cudaEnableDefault might not be the correct flag to use. The flag cudaEnableDefault is not a standard CUDA flag and might lead to undefined behavior.

    NVFUSER_CUDA_RT_SAFE_CALL(cudaGetDriverEntryPoint(                \
        #funcName, reinterpret_cast<void**>(&f), cudaEnableDefault)); \
    Typedef Consistency

    The macro DECLARE_DRIVER_API_WRAPPER now uses PFN_##funcName instead of decltype(::funcName)*. Ensure that this change is consistent with the rest of the codebase and does not introduce any type mismatches.

    #define DECLARE_DRIVER_API_WRAPPER(funcName) extern PFN_##funcName funcName;
    API Versioning

    The macro ALL_DRIVER_API_WRAPPER now includes cuStreamWriteValue32 instead of cuStreamWriteValue32_v2. Verify that this change is compatible with the CUDA version requirements and does not introduce any versioning issues.

    fn(cuStreamWriteValue32);          \

    @wujingyue wujingyue requested a review from zasdfgbnm April 4, 2025 22:57
    @wujingyue
    Copy link
    Collaborator Author

    !test

    @samnordmann
    Copy link
    Collaborator

    This is apparently more robust than #4196 because it doesn't hard code the version.

    Fixes #3907

    cc @samnordmann

    Thank you for the fix!

    Why applying it only to cuStreamWriteValue32? I'm afraid we're gonna run into the issue with other API calls (e.g., with cuMemGetAddressRange).
    Do you suggest we add each API function one by one here whenever it causes an issue?

    @wujingyue
    Copy link
    Collaborator Author

    !test

    @wujingyue wujingyue changed the title Enable lazy loading for cuStreamWriteValue32 using cudaGetDriverEntryPoint Lazy load driver APIs using cudaGetDriverEntryPoint Apr 9, 2025
    @wujingyue
    Copy link
    Collaborator Author

    !test

    @wujingyue wujingyue requested a review from zasdfgbnm April 9, 2025 23:16
    @wujingyue
    Copy link
    Collaborator Author

    !test

    @wujingyue wujingyue merged commit af63372 into main Apr 10, 2025
    53 checks passed
    @wujingyue wujingyue deleted the wjy/writevalue branch April 10, 2025 04:21
    wujingyue added a commit that referenced this pull request Apr 10, 2025
    It's no longer needed after #4197
    wujingyue added a commit that referenced this pull request Apr 10, 2025
    pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Jun 17, 2025
    Fixes  #154073
    
    Reference: NVIDIA/Fuser#4197
    
    See PR #154097
    
    @nWEIdia is currently out of the office, so I’ve temporarily taken over his work.
    
    Pull Request resolved: #156097
    Approved by: https://github.com/ngimel, https://github.com/cyyever
    
    Co-authored-by: Wei Wang <weiwan@nvidia.com>
    pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Jun 21, 2025
    Fixes  #154073
    
    Reference: NVIDIA/Fuser#4197
    
    See PR #154097
    
    @nWEIdia is currently out of the office, so I’ve temporarily taken over his work.
    
    Pull Request resolved: #156097
    Approved by: https://github.com/ngimel
    
    Co-authored-by: Wei Wang <weiwan@nvidia.com>
    pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Jul 10, 2025
    Fixes  #154073
    
    Reference: NVIDIA/Fuser#4197
    
    See PR #154097
    
    @nWEIdia is currently out of the office, so I’ve temporarily taken over his work.
    
    Pull Request resolved: #156097
    Approved by: https://github.com/syed-ahmed, https://github.com/wujingyue, https://github.com/atalman
    
    Co-authored-by: Wei Wang <weiwan@nvidia.com>
    pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Jul 16, 2025
    pytorchbot pushed a commit to pytorch/pytorch that referenced this pull request Jul 17, 2025
    Reopen #156097
    
    Fixes #154073
    
    Reference: NVIDIA/Fuser#4197
    
    See PR #156097 and #154097
    
    Pull Request resolved: #158295
    Approved by: https://github.com/Skylion007, https://github.com/ngimel, https://github.com/eqy, https://github.com/huydhn
    
    Co-authored-by: Wei Wang <weiwan@nvidia.com>
    (cherry picked from commit a9f902a)
    atalman pushed a commit to pytorch/pytorch that referenced this pull request Jul 18, 2025
    [CUDA] Use runtime driver API for cuStreamWriteValue32 (#158295)
    
    Reopen #156097
    
    Fixes #154073
    
    Reference: NVIDIA/Fuser#4197
    
    See PR #156097 and #154097
    
    Pull Request resolved: #158295
    Approved by: https://github.com/Skylion007, https://github.com/ngimel, https://github.com/eqy, https://github.com/huydhn
    
    
    (cherry picked from commit a9f902a)
    
    Co-authored-by: Frank Lin <eee4017@gmail.com>
    Co-authored-by: Wei Wang <weiwan@nvidia.com>
    tvukovic-amd pushed a commit to ROCm/pytorch that referenced this pull request Aug 20, 2025
    [CUDA] Use runtime driver API for cuStreamWriteValue32 (pytorch#158295)
    
    Reopen pytorch#156097
    
    Fixes pytorch#154073
    
    Reference: NVIDIA/Fuser#4197
    
    See PR pytorch#156097 and pytorch#154097
    
    Pull Request resolved: pytorch#158295
    Approved by: https://github.com/Skylion007, https://github.com/ngimel, https://github.com/eqy, https://github.com/huydhn
    
    
    (cherry picked from commit a9f902a)
    
    Co-authored-by: Frank Lin <eee4017@gmail.com>
    Co-authored-by: Wei Wang <weiwan@nvidia.com>
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    Error with driver API's lazy load of cuStream ops

    3 participants