[Distributed] add pack-check method for float8_e5m2#136115
[Distributed] add pack-check method for float8_e5m2#136115kwen2501 wants to merge 2 commits intogh/kwen2501/61/basefrom
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/136115
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit c907c0c with merge base 0216936 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| // We want to check 8 x FP8 simultaneously, hence this template definition. | ||
| template<typename T> | ||
| struct HasNanFP8x8 { | ||
| // I am a dumb implementation. You should never call in here, unless the check |
There was a problem hiding this comment.
Static Assert False to raise a compile error if we call in here?
There was a problem hiding this comment.
Or at least can we issue warning somehow?
There was a problem hiding this comment.
I struggle between providing basic functionality so that user code can run without break (current code) vs speed. And eventually chose the former :)
There was a problem hiding this comment.
But your point seems better: compile error can force a developer to implement the template if they want to add a new data type to AT_DISPATCH_FLOATING_TYPES_AND4 below.
There was a problem hiding this comment.
Hmm, it seems static_assert doesn't work well with template definition prior to c++23. I may try =delete as suggested here.
There was a problem hiding this comment.
@kwen2501 You could also at least unroll the loop, but the speed gains would be minimal unless the compiler realizes to inline isnan
There was a problem hiding this comment.
You could also use the self != self to check for NaN based on the dtype, but I am not sure if that's faster / more portable.
There was a problem hiding this comment.
@kwen2501 You could also at least unroll the loop, but the speed gains would be minimal unless the compiler realizes to inline isnan
Yeah, we tried that. Since the final result is a reduction, i.e.
packHasNan = isnan(byte0) || isnan(byte1) ... || isnan(byte7),
the compiler does not seem quite willing the unroll the loop.
There was a problem hiding this comment.
We could also manually unroll since it's only 8 elements (as painful as that would be).
Add support for Float8_e5m2, following similar algorithm used for Float8_e4m3fn (i.e. overflow check). Made `HasNanFP8x8` a template so that it is extendable based on dtype. cc XilunWu H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Add support for Float8_e5m2, following similar algorithm used for Float8_e4m3fn (i.e. overflow check). Made `HasNanFP8x8` a template so that it is extendable based on dtype. Pull Request resolved: pytorch#136115 Approved by: https://github.com/Skylion007 ghstack dependencies: pytorch#135891, pytorch#135961
Stack from ghstack (oldest at bottom):
Add support for Float8_e5m2, following similar algorithm used for Float8_e4m3fn (i.e. overflow check).
Made
HasNanFP8x8a template so that it is extendable based on dtype.cc @XilunWu @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o