mx: triton kernel to cast to mx and write in col-major#1932
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1932
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit 7ecd79f with merge base 3fb1665 ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| # example transformation (specifics depend on tile sizes): | ||
| # [0, 1, 2, 3, 4, 5, 6, 7] -> [0, 1, 4, 5, 8, 9, 12, 13] | ||
| col_scale_indices = col_scale_indices + ( | ||
| tl.floor(col_scale_indices / BLOCKS_PER_ROW_TILE) * jump_vals_per_col |
There was a problem hiding this comment.
i think we should just be doing integer division instead of floor + /
eellison
left a comment
There was a problem hiding this comment.
looks good ! next step: compile to generate this
| ).to(tl.int32) | ||
|
|
||
| # TODO(future): mask this store | ||
| tl.store(col_scale_start_ptr + col_scale_indices, col_scale_e8m0) |
There was a problem hiding this comment.
in the launcher, should we assert divisibility of block sizes, so we hard error for this case ?
There was a problem hiding this comment.
yes, for now I hackly assert that on L1319:L1322
| ) | ||
|
|
||
| return ( | ||
| output_col_major.t(), |
There was a problem hiding this comment.
since only the data_ptr of output_col_major is used when you pass it into triton, you could initialize it with the correct strides
|
|
||
| return ( | ||
| output_col_major.t(), | ||
| col_scale.reshape(-1, 1).view(torch.float8_e8m0fnu), |
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
Summary:
Implements a triton kernel for a cast to mxfp8 from a row-major input across dim1, which is 3.5x to 4.5x faster than what compile can generate today. Note that this is a prototype kernel, and I expect to (a) improve it in future PRs and (b) delete it in ~weeks when we have compile support for this.
An integration into
MXLinearwill follow in a separate PR.Example of tiling (simplified for small example size):
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags: