Swizzle op formulation for non-affine swizzles#1441
Conversation
203158f to
f1d506f
Compare
| indent() << "NVFUSER_UPDATE_MAGIC_ZERO\n"; | ||
| } | ||
|
|
||
| void handle(const kir::Swizzle2DInt* swizzle_2d) { |
There was a problem hiding this comment.
Details on the new IR nodes in kernel_ir.h
48bb5c6 to
d61a26b
Compare
|
Apologies for the size of this PR, but most of the code changes is introducing the new kernel IR nodes for generating swizzle expressions. |
… swizzle_op_implv2
csarofeen
left a comment
There was a problem hiding this comment.
Pushing some comments while still reviewing. Only significant comment is if the integer type associated with swizzle should consistently be Index not Int.
| // Replay producer dimensions. | ||
| ReplayTransformations replay_PasC( | ||
| consumer_CA_ids, forwarded_replay_map, false); | ||
| consumer_CA_ids, forwarded_replay_map, false, replay_swizzle); |
| // replay swizzles are skipped while the mapping | ||
| // makes progress. This makes sure that, for example | ||
| // different tensors can still be inlined despite | ||
| // different local swizzle patterns. |
There was a problem hiding this comment.
Inlined but outside the swizzle axes, correct?
There was a problem hiding this comment.
Actually within the swizzle axes as well just not swizzled, example:
for i in ...
float T0[8];
for j in 0..8:
T0[j] = ...; // or maybe a vectorized load
for j in ...
T1[swizzle(i,j)] = T0[j] + ...;This is what we will be able to do when we can choose to ignore swizzles on the replay and mapping. There are cases where we don't want to ignore them as well but those cases will be enabled in a follow up to keep this PR not to large.
This pattern is useful to support and in this case T0 is inlined in T1 beyond swizzled point but T0 is not swizzled like T1 is.
There was a problem hiding this comment.
So here we're just trying to make sure we can locally inline part of the swizzle, but it seems to me like the validation pass would prevent this from working.
There was a problem hiding this comment.
The validation would prevent T1 to inline into anything, but T0 can inline into T1.
Merging more details in this this thread with:
#1441 (comment)
csarofeen
left a comment
There was a problem hiding this comment.
Overall looks good to me. Going to give some time to respond to comments, but I don't see any hard blockers to this PR.
| inlined_swizzles.empty(), "No support for inlined swizzles"); | ||
|
|
||
| // Make sure thread swizzling only on sharedmem | ||
| if (tv->getMemoryType() != MemoryType::Shared) { |
There was a problem hiding this comment.
Shouldn't it be:
Thread swizzling allowed on global memory and shared memory
Block swizzling allowed on global memory
This would match the logic for the parallelization validation. Maybe you're just waiting until later to build this out, as the implication for these rules is really about when we shift swizzle layouts from expression to expression.
There was a problem hiding this comment.
Yes currently not yet supporting persistent block swizzling yet. Will build out in follow ups.
| // be on the right of computeAt axes. | ||
| // This has been ensured in previous checks. | ||
| // Sharing swizzled iterdomains is much more complex | ||
| // and is not a near term priority. |
There was a problem hiding this comment.
Related to the previous comment. Does this mean we can't zCurve blocks yet?
There was a problem hiding this comment.
Not yet and in a follow up I was going to add an extra attribute on the swizzle op to support zCurve on blocks as well as register access swizzle.
| } | ||
|
|
||
| // TODO: merge the two swizzle compute logic once the new one is ready. | ||
| // will need to replace cyclic shift swizzle with xor since swizzle2d |
There was a problem hiding this comment.
Are there any advantages or disadvantages to doing so for transpose kernels? Seems it would be straightforward to add cyclic shift to swizzle2d.
There was a problem hiding this comment.
I think not much difference in transpose use case. Xor just saves one instruction per swizzled index.
I can just add cyclic shift swizzle too when removing the legacy swizzle pass to keep the same behavior as before.
There was a problem hiding this comment.
I'm thinking of asking @zasdfgbnm to start looking at the point @rdspring1 got transpose scheduling. Seems we can try to see if it makes any difference. Xor swizzle instead of cyclic swizzle seems fine to me if there's no perf impact.
There was a problem hiding this comment.
Sure. Would be happy to support either way it's decided on the transpose end.
|
|
||
| // Input of swizzle ops will not be mapped to any | ||
| // by BestEffortReplay, as BestEffortReplay has to be | ||
| // one to one. IdGraph will further map them together. |
There was a problem hiding this comment.
Why are inputs not mapped but outputs are? That seems backwards to my intuition.
There was a problem hiding this comment.
Either way should work out when skipping swizzles as the inputs and outputs should be treated like the same node in this case. It's just more convenient to map the outputs so that matching the subsequent splits and merges on the swizzle outputs would just progress as usual. such as mapping:
I1o, I1i = swizzle(I0o, I0i);
I2 = merge(I1o, I1i);
with
I4 = merge(I3o, I3i)
where root expression is
T0[I0o, I0i] = T1[I3o, I3i];
so when skipping swizzles we want I4 to map to I2 and therefore easier when I1's map to I3's.
There was a problem hiding this comment.
Ahh so best effort replay is just forwarding the inputs to outputs, so yeah makes sense.
gitlab down the CI cannot run through at the moment. Will need to re-run when it's up. |
csarofeen
left a comment
There was a problem hiding this comment.
Just a bit confused about the validation still.
| } | ||
|
|
||
| // TODO: merge the two swizzle compute logic once the new one is ready. | ||
| // will need to replace cyclic shift swizzle with xor since swizzle2d |
There was a problem hiding this comment.
I'm thinking of asking @zasdfgbnm to start looking at the point @rdspring1 got transpose scheduling. Seems we can try to see if it makes any difference. Xor swizzle instead of cyclic swizzle seems fine to me if there's no perf impact.
|
|
||
| void validateSwizzle(Fusion* fusion) { | ||
| auto used_vals = fusion->usedMathVals(); | ||
| for (auto tv : ir_utils::filterByType<TensorView>(used_vals)) { |
There was a problem hiding this comment.
I'm having trouble parsing this validation pass. It seems to me:
(1) I can't inline a swizzle into anything.
(2) If memory type isn't shared memory I can't inline a swizzle involving thread dims.
I feel like I'm missing something in this validation pass. To me it seems like (2) is redundant to (1). I'm still uncertain in this PR what swizzle support exists and what's validated. I just want to be confident we're validating we currently support cases that come in. I guess I'm comfortable if (2) is redundant, just want to make sure that understanding is correct because I don't understand the block that follows if not.
What exact pattern do we need for MatMul? For transpose we wouldn't try to inline the swizzle in shared memory (if this accidentally happened I think it would be a parallelization/inlining error).
There was a problem hiding this comment.
I will put this note on the comment in this PR:
There will be 2 types of swizzles, using the same infrastructure, tentatively naming them persistent swizzle and non-persistent swizzle, (a second candidate pair avoid overloading terms, loop swizzle (non-persistent) and data swizzle(persistent)). This PR currently only has the persistent swizzle in shared memory or register enabled.
Conceptually persistent swizzle are the swizzle ops that will affect how data is stored, like in shared or global memory. It's very similar to the transpose swizzle pattern. e.g.
for i in ...
for j in ...
Tswizzled[swizzle(i,j)] = T0[i,j]
// at this point, Tswizzled holds swizzled data pattern.
for i in ...
for j in ...
T1[i,j] = Tswizzled[swizzle(i,j)]In a follow up will be the non-persistent swizzle, which will actually be replayed in reference index pass and checked by inline and parallel type validations, and additional checks also needed. In this case, only the loop order will be permuted but no swizzled data pattern should show up on any tensor, e.g.:
for i in ...
for j in ...
T1[swizzle(i,j)] = T0[swizzle(i,j)]These are the two patterns needed in MatMul, the persistent swizzle in prolog, and non-persistent swizzle in block zcurve and register access.
There was a problem hiding this comment.
Back to inlining checks, all the exisiting inlining check should apply for non-persistent swizzle's, and they will be turned on in a follow up.
Inlining persistent swizzles is extra complexity that we don't yet have a use case to motivate extension for. It'd essentially need to generate some code that looks like:
for i in ...
for j in ...
Tswizzled[0] = T0[inverse_swizzle(i,j)];
// Tswizzled is inlined, but if we collect all data realized in each
// iteration of (i,j) as a array Tswizzled[i,j], it'd be the same as the non-inlined
// version below.
T[swizzle(i,j)] = Tswizzled[0];if we want to inline:
for i in ...
for j in ...
Tswizzled[swizzle(i,j)] = T0[i,j]
// at this point, Tswizzle holds swizzled data pattern.
for i in ...
for j in ...
T1[i,j] = Tswizzled[swizzle(i,j)]The particularly complex part is that in this case we'd need to replay producer swizzle onto consumer indexing.
But I remember we have decided not to pursue this feature any near term.
There was a problem hiding this comment.
On the validation pass in this PR: it is only validating persistent swizzle's, non-parallized or block level parallel into shared memory, so we want to check :
-
It is not inlined, won't support any near term based on the discussion above. This is checking
Tswizzledin the example above. -
If it is block parallel we need to force it in shared memory. I will move this second part in
sync_infopass.
| // replay swizzles are skipped while the mapping | ||
| // makes progress. This makes sure that, for example | ||
| // different tensors can still be inlined despite | ||
| // different local swizzle patterns. |
There was a problem hiding this comment.
So here we're just trying to make sure we can locally inline part of the swizzle, but it seems to me like the validation pass would prevent this from working.
|
|
||
| // Input of swizzle ops will not be mapped to any | ||
| // by BestEffortReplay, as BestEffortReplay has to be | ||
| // one to one. IdGraph will further map them together. |
There was a problem hiding this comment.
Ahh so best effort replay is just forwarding the inputs to outputs, so yeah makes sense.
| int in_x_size = maybe_in_x_size.value(); | ||
| int in_y_size = maybe_in_y_size.value(); | ||
| if (swizzle_type == Swizzle2DType::Transpose || | ||
| swizzle_type == Swizzle2DType::XOR) { |
There was a problem hiding this comment.
CC @zasdfgbnm this is the XOR swizzle I was mentioning.
csarofeen
left a comment
There was a problem hiding this comment.
LGTM added one minor note on cleanup.
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/ Code changes includes: - codegen improvements: 1. Indexing refactor -> Remove reference tensor in predicate indexing logic 2. MMA Rfactor support for cross-warp and cross-CTA split on K dimension 3. Grouping grid allreduces across iterations 4. Swizzle op formulation for non-affine swizzles 5. Use scheduler_utils to cache inputs and outputs in schedulePointwise - scheduler refactor 1. New compute at interface - transformation propagation refactor on MaxInfoSpanningTree 1. Added sibling path that is required to generate consistent replay for some cases where `MaxInfoSpanningTree` is used with a selector. 2. Optimization to skip Transform propagator 3. SpanningTreePrinter for debugging - parser update 1. Fixes `div` 2. Added `_to_copy` 3. Broadcast in dim with expand to support expanding to concrete size 4. Dropout prob extremal patch - executor patch on caching strides for output allocation Squashed commits to WAR github API Commits that's actually in this PR from the devel branch: ``` 3b87896 Fix allocation of work buffers and `fused_reduction::ParallelReduce` with unswitch (csarofeen#1818) 4cae122 schedulePointwise cleanup: - computeAt + InlinePropagator (csarofeen#1815) 3df9742 Use scheduler_utils to cache inputs and outputs in schedulePointwise (csarofeen#1811) 03180aa improve broadcast resolution (csarofeen#1792) bee6c69 bug fix (csarofeen#1819) 4413c8f Support PYTORCH_NVFUSER_DUMP=transform_propagator (csarofeen#1812) de6b7ca Fix negative position in InlinePropagator (csarofeen#1813) 10a996c Remove redundant check in schedulePointwise (csarofeen#1810) acd5ed4 Swizzle op formulation for non-affine swizzles (csarofeen#1441) 3ed8330 Kernel args patch to show zero_init buffer (csarofeen#1809) 037a75a Dropout prob extremal patch (csarofeen#1804) 282c429 spam nvrtc options (csarofeen#1783) 3ba6a5f Broadcast in dim with expand (csarofeen#1794) fd4be12 remove dead indexing code (csarofeen#1806) fa4e6a4 Check siblings in getMaxPosAll (csarofeen#1805) 025c840 Grouping grid allreduces across iterations (csarofeen#1755) 37c579e Temporarily disable test requring large shared memory. (csarofeen#1802) 5f375d0 More cleanup on InlinePropagator (csarofeen#1800) 8d384da Indexing refactor stage 2 : Remove reference tensor in predicate indexing logic (csarofeen#1784) f008140 MMA Rfactor support for cross-warp and cross-CTA split on K dimension (csarofeen#1554) 76b3cca Add parsing support for `_to_copy` to handle AMP casts. (csarofeen#1756) ef04f6c Coding style cleanups (csarofeen#1798) 38c7f3c InlinePropagator please don't replay (csarofeen#1797) 3f2c263 validateDomain in TransformPropagator (csarofeen#1796) c077085 Use TransformPropagatorWithCheck in many tests (csarofeen#1795) d0d0908 Some further cleanup for the new computeAt interface (csarofeen#1793) 45f5203 Fix TransformReplay::getMatchedLeafPosWithoutReplay* (csarofeen#1791) 28cbaf9 New compute at interface (csarofeen#1743) 635ebfc Add SpanningTreePrinter (csarofeen#1786) 59f3c32 Output allocate patch (csarofeen#1790) fe93bf5 Transform propagator skip replay when possible (csarofeen#1782) ebf23a5 Fix isIntegralType error msg (csarofeen#1789) 0c82ecf Disable register reuse across serial broadcast ops (csarofeen#1787) 33a824d Adding sibling path for MaxInfoSpanningTree (csarofeen#1776) 86f46aa Fix div(Val, TensorView) (csarofeen#1778) d3de227 Fix FusionMaxRootDomainInfoSpanningTreePrintTwice_CUDA (csarofeen#1781) ecc7a87 Extend mma dimension and layout checking to support strided batched matmul and tensor contractions (csarofeen#1761) ``` [ghstack-poisoned]
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/ Code changes includes: - codegen improvements: 1. Indexing refactor -> Remove reference tensor in predicate indexing logic 2. MMA Rfactor support for cross-warp and cross-CTA split on K dimension 3. Grouping grid allreduces across iterations 4. Swizzle op formulation for non-affine swizzles 5. Use scheduler_utils to cache inputs and outputs in schedulePointwise - scheduler refactor 1. New compute at interface - transformation propagation refactor on MaxInfoSpanningTree 1. Added sibling path that is required to generate consistent replay for some cases where `MaxInfoSpanningTree` is used with a selector. 2. Optimization to skip Transform propagator 3. SpanningTreePrinter for debugging - parser update 1. Fixes `div` 2. Added `_to_copy` 3. Broadcast in dim with expand to support expanding to concrete size 4. Dropout prob extremal patch - executor patch on caching strides for output allocation Squashed commits to WAR github API Commits that's actually in this PR from the devel branch: ``` 3b87896 Fix allocation of work buffers and `fused_reduction::ParallelReduce` with unswitch (csarofeen#1818) 4cae122 schedulePointwise cleanup: - computeAt + InlinePropagator (csarofeen#1815) 3df9742 Use scheduler_utils to cache inputs and outputs in schedulePointwise (csarofeen#1811) 03180aa improve broadcast resolution (csarofeen#1792) bee6c69 bug fix (csarofeen#1819) 4413c8f Support PYTORCH_NVFUSER_DUMP=transform_propagator (csarofeen#1812) de6b7ca Fix negative position in InlinePropagator (csarofeen#1813) 10a996c Remove redundant check in schedulePointwise (csarofeen#1810) acd5ed4 Swizzle op formulation for non-affine swizzles (csarofeen#1441) 3ed8330 Kernel args patch to show zero_init buffer (csarofeen#1809) 037a75a Dropout prob extremal patch (csarofeen#1804) 282c429 spam nvrtc options (csarofeen#1783) 3ba6a5f Broadcast in dim with expand (csarofeen#1794) fd4be12 remove dead indexing code (csarofeen#1806) fa4e6a4 Check siblings in getMaxPosAll (csarofeen#1805) 025c840 Grouping grid allreduces across iterations (csarofeen#1755) 37c579e Temporarily disable test requring large shared memory. (csarofeen#1802) 5f375d0 More cleanup on InlinePropagator (csarofeen#1800) 8d384da Indexing refactor stage 2 : Remove reference tensor in predicate indexing logic (csarofeen#1784) f008140 MMA Rfactor support for cross-warp and cross-CTA split on K dimension (csarofeen#1554) 76b3cca Add parsing support for `_to_copy` to handle AMP casts. (csarofeen#1756) ef04f6c Coding style cleanups (csarofeen#1798) 38c7f3c InlinePropagator please don't replay (csarofeen#1797) 3f2c263 validateDomain in TransformPropagator (csarofeen#1796) c077085 Use TransformPropagatorWithCheck in many tests (csarofeen#1795) d0d0908 Some further cleanup for the new computeAt interface (csarofeen#1793) 45f5203 Fix TransformReplay::getMatchedLeafPosWithoutReplay* (csarofeen#1791) 28cbaf9 New compute at interface (csarofeen#1743) 635ebfc Add SpanningTreePrinter (csarofeen#1786) 59f3c32 Output allocate patch (csarofeen#1790) fe93bf5 Transform propagator skip replay when possible (csarofeen#1782) ebf23a5 Fix isIntegralType error msg (csarofeen#1789) 0c82ecf Disable register reuse across serial broadcast ops (csarofeen#1787) 33a824d Adding sibling path for MaxInfoSpanningTree (csarofeen#1776) 86f46aa Fix div(Val, TensorView) (csarofeen#1778) d3de227 Fix FusionMaxRootDomainInfoSpanningTreePrintTwice_CUDA (csarofeen#1781) ecc7a87 Extend mma dimension and layout checking to support strided batched matmul and tensor contractions (csarofeen#1761) ``` RUN_TORCHBENCH: nvfuser Differential Revision: [D38043938](https://our.internmc.facebook.com/intern/diff/D38043938) [ghstack-poisoned]
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/ Code changes includes: - codegen improvements: 1. Indexing refactor -> Remove reference tensor in predicate indexing logic 2. MMA Rfactor support for cross-warp and cross-CTA split on K dimension 3. Grouping grid allreduces across iterations 4. Swizzle op formulation for non-affine swizzles 5. Use scheduler_utils to cache inputs and outputs in schedulePointwise - scheduler refactor 1. New compute at interface - transformation propagation refactor on MaxInfoSpanningTree 1. Added sibling path that is required to generate consistent replay for some cases where `MaxInfoSpanningTree` is used with a selector. 2. Optimization to skip Transform propagator 3. SpanningTreePrinter for debugging - parser update 1. Fixes `div` 2. Added `_to_copy` 3. Broadcast in dim with expand to support expanding to concrete size 4. Dropout prob extremal patch - executor patch on caching strides for output allocation Squashed commits to WAR github API Commits that's actually in this PR from the devel branch: ``` 3b87896 Fix allocation of work buffers and `fused_reduction::ParallelReduce` with unswitch (#1818) 4cae122 schedulePointwise cleanup: - computeAt + InlinePropagator (#1815) 3df9742 Use scheduler_utils to cache inputs and outputs in schedulePointwise (#1811) 03180aa improve broadcast resolution (#1792) bee6c69 bug fix (#1819) 4413c8f Support PYTORCH_NVFUSER_DUMP=transform_propagator (#1812) de6b7ca Fix negative position in InlinePropagator (#1813) 10a996c Remove redundant check in schedulePointwise (#1810) acd5ed4 Swizzle op formulation for non-affine swizzles (#1441) 3ed8330 Kernel args patch to show zero_init buffer (#1809) 037a75a Dropout prob extremal patch (#1804) 282c429 spam nvrtc options (#1783) 3ba6a5f Broadcast in dim with expand (#1794) fd4be12 remove dead indexing code (#1806) fa4e6a4 Check siblings in getMaxPosAll (#1805) 025c840 Grouping grid allreduces across iterations (#1755) 37c579e Temporarily disable test requring large shared memory. (#1802) 5f375d0 More cleanup on InlinePropagator (#1800) 8d384da Indexing refactor stage 2 : Remove reference tensor in predicate indexing logic (#1784) f008140 MMA Rfactor support for cross-warp and cross-CTA split on K dimension (#1554) 76b3cca Add parsing support for `_to_copy` to handle AMP casts. (#1756) ef04f6c Coding style cleanups (#1798) 38c7f3c InlinePropagator please don't replay (#1797) 3f2c263 validateDomain in TransformPropagator (#1796) c077085 Use TransformPropagatorWithCheck in many tests (#1795) d0d0908 Some further cleanup for the new computeAt interface (#1793) 45f5203 Fix TransformReplay::getMatchedLeafPosWithoutReplay* (#1791) 28cbaf9 New compute at interface (#1743) 635ebfc Add SpanningTreePrinter (#1786) 59f3c32 Output allocate patch (#1790) fe93bf5 Transform propagator skip replay when possible (#1782) ebf23a5 Fix isIntegralType error msg (#1789) 0c82ecf Disable register reuse across serial broadcast ops (#1787) 33a824d Adding sibling path for MaxInfoSpanningTree (#1776) 86f46aa Fix div(Val, TensorView) (#1778) d3de227 Fix FusionMaxRootDomainInfoSpanningTreePrintTwice_CUDA (#1781) ecc7a87 Extend mma dimension and layout checking to support strided batched matmul and tensor contractions (#1761) ``` RUN_TORCHBENCH: nvfuser Differential Revision: [D38043938](https://our.internmc.facebook.com/intern/diff/D38043938) Pull Request resolved: pytorch#81861 Approved by: https://github.com/davidberard98
This PR is continuation of #1439 and #1440. The previous 2 PR handles the affine part of the thread swizzle that is needed for mma integration. This PR provides support for non-affine swizzles that will be represented as
SwizzleOp's on iterdomains. Current swizzle op support is very minimal, restrictions including:Enabling each of the above cases adds complexity to the infrastructure so will be handled when required by concrete use cases.
This PR includes:
IntPairand 2 opsSwizzle2DIntandPairSelectto facilitate generating inlined swizzle math.