Skip to content

Use versioned flavor of get driver entrypoint function#1835

Merged
ptrendx merged 6 commits intoNVIDIA:mainfrom
ptrendx:pr_entrypoint
Jun 5, 2025
Merged

Use versioned flavor of get driver entrypoint function#1835
ptrendx merged 6 commits intoNVIDIA:mainfrom
ptrendx:pr_entrypoint

Conversation

@ptrendx
Copy link
Member

@ptrendx ptrendx commented May 30, 2025

Description

Fixes the issue with cuStreamGetCtx pointing to cuStreamCtx_v2 in the CUDA 13 drivers.

Summary (mostly) by copilot:
This pull request updates the transformer_engine/common/util/cuda_driver.cpp and transformer_engine/common/util/cuda_driver.h files to enhance compatibility with different CUDA versions. The changes introduce a mechanism to query driver entry points based on the CUDA version, improving flexibility in handling CUDA driver symbols.

Enhancements for CUDA version compatibility:

  • Updated get_symbol function in transformer_engine/common/util/cuda_driver.cpp: Refactored the function to support querying driver entry points using either a versioned or non-versioned mechanism. The function now accepts a cuda_version parameter and dynamically resolves the appropriate entry point function (cudaGetDriverEntryPoint or cudaGetDriverEntryPointByVersion).
  • Modified get_symbol function declaration in transformer_engine/common/util/cuda_driver.h: Added an optional cuda_version parameter with a default value of 12010 (our oldest supported version) to allow backward compatibility while enabling version-specific queries.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
@ptrendx ptrendx requested a review from timmoon10 May 30, 2025 20:57
@ptrendx
Copy link
Member Author

ptrendx commented May 30, 2025

/te-ci

timmoon10
timmoon10 previously approved these changes May 30, 2025
Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

flx42 added a commit to flx42/TransformerEngine that referenced this pull request May 31, 2025
it was added

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
timmoon10
timmoon10 previously approved these changes Jun 2, 2025
@flx42
Copy link
Member

flx42 commented Jun 3, 2025

I verified that it works!

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
@ptrendx
Copy link
Member Author

ptrendx commented Jun 3, 2025

@timmoon10 sorry for the churn, I was looking at how others dealt with this issue and found this issue from cutlass: NVIDIA/cutlass#2079 - since we, just like them, link against libcudart.so.12, the check for CUDA 12.5 during the compilation is not enough and we actually need to dynamically load the symbols. Fortunately, since we already link against libcudart, we don't need to try to find the lib by name (so at least there is that). @flx42 Could you verify this new version?

@ptrendx
Copy link
Member Author

ptrendx commented Jun 3, 2025

/te-ci

ptrendx added 2 commits June 3, 2025 13:22
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
@flx42
Copy link
Member

flx42 commented Jun 4, 2025

@timmoon10 sorry for the churn, I was looking at how others dealt with this issue and found this issue from cutlass: NVIDIA/cutlass#2079 - since we, just like them, link against libcudart.so.12, the check for CUDA 12.5 during the compilation is not enough and we actually need to dynamically load the symbols. Fortunately, since we already link against libcudart, we don't need to try to find the lib by name (so at least there is that). @flx42 Could you verify this new version?

Still works fine!

@ptrendx
Copy link
Member Author

ptrendx commented Jun 4, 2025

/te-ci

@ptrendx ptrendx merged commit 557f0cb into NVIDIA:main Jun 5, 2025
34 of 37 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants