Handle torch ver in flexattn#37400
Conversation
|
Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the |
| # see https://github.com/pytorch/pytorch/issues/146260 for training | ||
| self.training = training | ||
| if _torch_version.split("+")[0] == "2.6.0" and training: | ||
| if is_torch_greater_or_equal("2.6.0") and training: |
There was a problem hiding this comment.
pytorch/pytorch#143299 should've fixed this issue so it makes more sense to directly look for 2.6.0 and not <=
I think it's fine to use from packaging import version ... instead of creating another function in the utils 👀 but up to debate
There was a problem hiding this comment.
@vasqu not what you mean here 🤔 this will check ">= 2.6.0", is_torch_greater_or_equal's already part of the package
I think it makes the check more futureproof
am I missing something?
There was a problem hiding this comment.
The torch version guard was introduced for torch==2.6.0 explicitly.
The PR I linked fixed some issues which should remove the need for this check, i.e. we don't need to compile with "max-autotune-no-cudagraphs". This means that future versions should also not need it which is why I suggested an == and not a 2.6.0<=.
Edit: the wording before was probably less than ideal :D
There was a problem hiding this comment.
update the PR to usd version instead of manual str checking, lmk how it looks to you
* Handle torch ver in flexattn * update
* Handle torch ver in flexattn * update
* Handle torch ver in flexattn * update
Follow up #37399
@ArthurZucker