Skip to content

Swizzle op formulation for non-affine swizzles#1441

Merged
shmsong merged 18 commits intodevelfrom
swizzle_op_implv2
Jul 8, 2022
Merged

Swizzle op formulation for non-affine swizzles#1441
shmsong merged 18 commits intodevelfrom
swizzle_op_implv2

Conversation

@shmsong
Copy link
Copy Markdown

@shmsong shmsong commented Feb 8, 2022

This PR is continuation of #1439 and #1440. The previous 2 PR handles the affine part of the thread swizzle that is needed for mma integration. This PR provides support for non-affine swizzles that will be represented as SwizzleOp's on iterdomains. Current swizzle op support is very minimal, restrictions including:

  1. Only 2D swizzle is supported.
  2. Swizzled iterdomains must be compile-time constant size.
  3. No swizzle inlining (all swizzles on the right of their CA axis)
  4. No swizzle composition (no swizzle can be consumer of another swizzle)
  5. No reduction or broadcast on swizzled iterdomains.

Enabling each of the above cases adds complexity to the infrastructure so will be handled when required by concrete use cases.

This PR includes:

  • A swizzle2D op on iterdomains (FusionIR)
  • A new data type IntPair and 2 ops Swizzle2DInt and PairSelect to facilitate generating inlined swizzle math.
  • Preliminary swizzle validation and checks on the use restrictions.

Base automatically changed from ampere_mma_op to devel May 23, 2022 23:50
@shmsong shmsong force-pushed the swizzle_op_implv2 branch from 203158f to f1d506f Compare June 28, 2022 20:13
indent() << "NVFUSER_UPDATE_MAGIC_ZERO\n";
}

void handle(const kir::Swizzle2DInt* swizzle_2d) {
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Details on the new IR nodes in kernel_ir.h

@shmsong shmsong force-pushed the swizzle_op_implv2 branch from 48bb5c6 to d61a26b Compare June 29, 2022 06:06
@shmsong
Copy link
Copy Markdown
Author

shmsong commented Jun 29, 2022

Apologies for the size of this PR, but most of the code changes is introducing the new kernel IR nodes for generating swizzle expressions.

@shmsong shmsong changed the title WIP: swizzle op formulation for non-affine swizzles Swizzle op formulation for non-affine swizzles Jun 29, 2022
@shmsong shmsong requested review from csarofeen and naoyam June 29, 2022 07:09
Copy link
Copy Markdown
Owner

@csarofeen csarofeen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pushing some comments while still reviewing. Only significant comment is if the integer type associated with swizzle should consistently be Index not Int.

Comment thread torch/csrc/jit/codegen/cuda/type.cpp Outdated
// Replay producer dimensions.
ReplayTransformations replay_PasC(
consumer_CA_ids, forwarded_replay_map, false);
consumer_CA_ids, forwarded_replay_map, false, replay_swizzle);
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might conflict with #1782 CC @zasdfgbnm

Comment thread torch/csrc/jit/codegen/cuda/transform_iter.h Outdated
// replay swizzles are skipped while the mapping
// makes progress. This makes sure that, for example
// different tensors can still be inlined despite
// different local swizzle patterns.
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inlined but outside the swizzle axes, correct?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually within the swizzle axes as well just not swizzled, example:

for i in ...
  float T0[8];
  for j in 0..8:
   T0[j] = ...; // or maybe a vectorized load
  for j in ...
    T1[swizzle(i,j)] = T0[j] + ...;

This is what we will be able to do when we can choose to ignore swizzles on the replay and mapping. There are cases where we don't want to ignore them as well but those cases will be enabled in a follow up to keep this PR not to large.

This pattern is useful to support and in this case T0 is inlined in T1 beyond swizzled point but T0 is not swizzled like T1 is.

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So here we're just trying to make sure we can locally inline part of the swizzle, but it seems to me like the validation pass would prevent this from working.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The validation would prevent T1 to inline into anything, but T0 can inline into T1.

Merging more details in this this thread with:
#1441 (comment)

Comment thread torch/csrc/jit/codegen/cuda/transform_iter.cpp
Comment thread torch/csrc/jit/codegen/cuda/transform_iter.cpp Outdated
Comment thread torch/csrc/jit/codegen/cuda/lower_validation.cpp Outdated
Comment thread torch/csrc/jit/codegen/cuda/transform_iter.cpp Outdated
Comment thread torch/csrc/jit/codegen/cuda/transform_iter.cpp Outdated
Comment thread torch/csrc/jit/codegen/cuda/kernel_ir.cpp Outdated
Copy link
Copy Markdown
Owner

@csarofeen csarofeen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good to me. Going to give some time to respond to comments, but I don't see any hard blockers to this PR.

Comment thread torch/csrc/jit/codegen/cuda/codegen.cpp
Comment thread torch/csrc/jit/codegen/cuda/lower_validation.cpp
inlined_swizzles.empty(), "No support for inlined swizzles");

// Make sure thread swizzling only on sharedmem
if (tv->getMemoryType() != MemoryType::Shared) {
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't it be:
Thread swizzling allowed on global memory and shared memory
Block swizzling allowed on global memory

This would match the logic for the parallelization validation. Maybe you're just waiting until later to build this out, as the implication for these rules is really about when we shift swizzle layouts from expression to expression.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes currently not yet supporting persistent block swizzling yet. Will build out in follow ups.

// be on the right of computeAt axes.
// This has been ensured in previous checks.
// Sharing swizzled iterdomains is much more complex
// and is not a near term priority.
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Related to the previous comment. Does this mean we can't zCurve blocks yet?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not yet and in a follow up I was going to add an extra attribute on the swizzle op to support zCurve on blocks as well as register access swizzle.

Comment thread torch/csrc/jit/codegen/cuda/index_compute.cpp
Comment thread torch/csrc/jit/codegen/cuda/index_compute.cpp
}

// TODO: merge the two swizzle compute logic once the new one is ready.
// will need to replace cyclic shift swizzle with xor since swizzle2d
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any advantages or disadvantages to doing so for transpose kernels? Seems it would be straightforward to add cyclic shift to swizzle2d.

Copy link
Copy Markdown
Author

@shmsong shmsong Jul 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think not much difference in transpose use case. Xor just saves one instruction per swizzled index.

I can just add cyclic shift swizzle too when removing the legacy swizzle pass to keep the same behavior as before.

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking of asking @zasdfgbnm to start looking at the point @rdspring1 got transpose scheduling. Seems we can try to see if it makes any difference. Xor swizzle instead of cyclic swizzle seems fine to me if there's no perf impact.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Would be happy to support either way it's decided on the transpose end.


// Input of swizzle ops will not be mapped to any
// by BestEffortReplay, as BestEffortReplay has to be
// one to one. IdGraph will further map them together.
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are inputs not mapped but outputs are? That seems backwards to my intuition.

Copy link
Copy Markdown
Author

@shmsong shmsong Jul 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either way should work out when skipping swizzles as the inputs and outputs should be treated like the same node in this case. It's just more convenient to map the outputs so that matching the subsequent splits and merges on the swizzle outputs would just progress as usual. such as mapping:

I1o, I1i = swizzle(I0o, I0i);
I2 = merge(I1o, I1i);

with

I4 = merge(I3o, I3i)

where root expression is

T0[I0o, I0i] = T1[I3o, I3i];

so when skipping swizzles we want I4 to map to I2 and therefore easier when I1's map to I3's.

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh so best effort replay is just forwarding the inputs to outputs, so yeah makes sense.

Comment thread torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp Outdated
@shmsong
Copy link
Copy Markdown
Author

shmsong commented Jul 2, 2022

fatal: unable to access 'https://gitlab.com/libeigen/eigen.git/': The requested URL returned error: 503

gitlab down the CI cannot run through at the moment. Will need to re-run when it's up.

Copy link
Copy Markdown
Owner

@csarofeen csarofeen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a bit confused about the validation still.

Comment thread torch/csrc/jit/codegen/cuda/index_compute.cpp
}

// TODO: merge the two swizzle compute logic once the new one is ready.
// will need to replace cyclic shift swizzle with xor since swizzle2d
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking of asking @zasdfgbnm to start looking at the point @rdspring1 got transpose scheduling. Seems we can try to see if it makes any difference. Xor swizzle instead of cyclic swizzle seems fine to me if there's no perf impact.

Comment thread torch/csrc/jit/codegen/cuda/transform_iter.cpp

void validateSwizzle(Fusion* fusion) {
auto used_vals = fusion->usedMathVals();
for (auto tv : ir_utils::filterByType<TensorView>(used_vals)) {
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm having trouble parsing this validation pass. It seems to me:
(1) I can't inline a swizzle into anything.
(2) If memory type isn't shared memory I can't inline a swizzle involving thread dims.

I feel like I'm missing something in this validation pass. To me it seems like (2) is redundant to (1). I'm still uncertain in this PR what swizzle support exists and what's validated. I just want to be confident we're validating we currently support cases that come in. I guess I'm comfortable if (2) is redundant, just want to make sure that understanding is correct because I don't understand the block that follows if not.

What exact pattern do we need for MatMul? For transpose we wouldn't try to inline the swizzle in shared memory (if this accidentally happened I think it would be a parallelization/inlining error).

Copy link
Copy Markdown
Author

@shmsong shmsong Jul 5, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will put this note on the comment in this PR:

There will be 2 types of swizzles, using the same infrastructure, tentatively naming them persistent swizzle and non-persistent swizzle, (a second candidate pair avoid overloading terms, loop swizzle (non-persistent) and data swizzle(persistent)). This PR currently only has the persistent swizzle in shared memory or register enabled.

Conceptually persistent swizzle are the swizzle ops that will affect how data is stored, like in shared or global memory. It's very similar to the transpose swizzle pattern. e.g.

for i in ...
  for j in ...
   Tswizzled[swizzle(i,j)] = T0[i,j]

// at this point, Tswizzled holds swizzled data pattern. 

for i in ...
  for j in ...
   T1[i,j] = Tswizzled[swizzle(i,j)]

In a follow up will be the non-persistent swizzle, which will actually be replayed in reference index pass and checked by inline and parallel type validations, and additional checks also needed. In this case, only the loop order will be permuted but no swizzled data pattern should show up on any tensor, e.g.:

for i in ...
  for j in ...
   T1[swizzle(i,j)] = T0[swizzle(i,j)]

These are the two patterns needed in MatMul, the persistent swizzle in prolog, and non-persistent swizzle in block zcurve and register access.

Copy link
Copy Markdown
Author

@shmsong shmsong Jul 5, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Back to inlining checks, all the exisiting inlining check should apply for non-persistent swizzle's, and they will be turned on in a follow up.

Inlining persistent swizzles is extra complexity that we don't yet have a use case to motivate extension for. It'd essentially need to generate some code that looks like:

for i in ...
  for j in ...
   Tswizzled[0] = T0[inverse_swizzle(i,j)];
   // Tswizzled is inlined, but if we collect all data realized in each 
  //  iteration of (i,j) as a array Tswizzled[i,j], it'd be the same as the non-inlined
  //  version below.
   T[swizzle(i,j)] = Tswizzled[0];

if we want to inline:

for i in ...
  for j in ...
   Tswizzled[swizzle(i,j)] = T0[i,j]

// at this point, Tswizzle holds swizzled data pattern. 

for i in ...
  for j in ...
   T1[i,j] = Tswizzled[swizzle(i,j)]

The particularly complex part is that in this case we'd need to replay producer swizzle onto consumer indexing.

But I remember we have decided not to pursue this feature any near term.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On the validation pass in this PR: it is only validating persistent swizzle's, non-parallized or block level parallel into shared memory, so we want to check :

  1. It is not inlined, won't support any near term based on the discussion above. This is checking Tswizzled in the example above.

  2. If it is block parallel we need to force it in shared memory. I will move this second part in sync_info pass.

// replay swizzles are skipped while the mapping
// makes progress. This makes sure that, for example
// different tensors can still be inlined despite
// different local swizzle patterns.
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So here we're just trying to make sure we can locally inline part of the swizzle, but it seems to me like the validation pass would prevent this from working.


// Input of swizzle ops will not be mapped to any
// by BestEffortReplay, as BestEffortReplay has to be
// one to one. IdGraph will further map them together.
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh so best effort replay is just forwarding the inputs to outputs, so yeah makes sense.

Comment thread torch/csrc/jit/codegen/cuda/tensor_view.cpp Outdated
int in_x_size = maybe_in_x_size.value();
int in_y_size = maybe_in_y_size.value();
if (swizzle_type == Swizzle2DType::Transpose ||
swizzle_type == Swizzle2DType::XOR) {
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CC @zasdfgbnm this is the XOR swizzle I was mentioning.

Copy link
Copy Markdown
Owner

@csarofeen csarofeen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM added one minor note on cleanup.

@shmsong shmsong merged commit acd5ed4 into devel Jul 8, 2022
@shmsong shmsong deleted the swizzle_op_implv2 branch July 8, 2022 03:14
shmsong pushed a commit to shmsong/pytorch that referenced this pull request Jul 24, 2022
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Code changes includes:

- codegen improvements:
  1. Indexing refactor -> Remove reference tensor in predicate indexing logic
  2. MMA Rfactor support for cross-warp and cross-CTA split on K dimension
  3. Grouping grid allreduces across iterations
  4. Swizzle op formulation for non-affine swizzles
  5. Use scheduler_utils to cache inputs and outputs in schedulePointwise
- scheduler refactor
  1. New compute at interface
- transformation propagation refactor on MaxInfoSpanningTree
  1. Added sibling path that is required to generate consistent replay for some cases where `MaxInfoSpanningTree` is used with a selector.
  2. Optimization to skip Transform propagator
  3. SpanningTreePrinter for debugging
- parser update
  1. Fixes `div`
  2. Added `_to_copy`
  3. Broadcast in dim with expand to support expanding to concrete size
  4. Dropout prob extremal patch
- executor patch on caching strides for output allocation

Squashed commits to WAR github API
Commits that's actually in this PR from the devel branch:

```
3b87896 Fix allocation of work buffers and `fused_reduction::ParallelReduce` with unswitch (csarofeen#1818)
4cae122 schedulePointwise cleanup: - computeAt + InlinePropagator (csarofeen#1815)
3df9742 Use scheduler_utils to cache inputs and outputs in schedulePointwise (csarofeen#1811)
03180aa improve broadcast resolution (csarofeen#1792)
bee6c69 bug fix (csarofeen#1819)
4413c8f Support PYTORCH_NVFUSER_DUMP=transform_propagator (csarofeen#1812)
de6b7ca Fix negative position in InlinePropagator (csarofeen#1813)
10a996c Remove redundant check in schedulePointwise (csarofeen#1810)
acd5ed4 Swizzle op formulation for non-affine swizzles (csarofeen#1441)
3ed8330 Kernel args patch to show zero_init buffer (csarofeen#1809)
037a75a Dropout prob extremal patch (csarofeen#1804)
282c429 spam nvrtc options (csarofeen#1783)
3ba6a5f Broadcast in dim with expand (csarofeen#1794)
fd4be12 remove dead indexing code (csarofeen#1806)
fa4e6a4 Check siblings in getMaxPosAll (csarofeen#1805)
025c840 Grouping grid allreduces across iterations (csarofeen#1755)
37c579e Temporarily disable test requring large shared memory. (csarofeen#1802)
5f375d0 More cleanup on InlinePropagator (csarofeen#1800)
8d384da Indexing refactor stage 2 : Remove reference tensor in predicate indexing logic (csarofeen#1784)
f008140 MMA Rfactor support for cross-warp and cross-CTA split on K dimension (csarofeen#1554)
76b3cca Add parsing support for `_to_copy` to handle AMP casts. (csarofeen#1756)
ef04f6c Coding style cleanups (csarofeen#1798)
38c7f3c InlinePropagator please don't replay (csarofeen#1797)
3f2c263 validateDomain in TransformPropagator (csarofeen#1796)
c077085 Use TransformPropagatorWithCheck in many tests (csarofeen#1795)
d0d0908 Some further cleanup for the new computeAt interface (csarofeen#1793)
45f5203 Fix TransformReplay::getMatchedLeafPosWithoutReplay* (csarofeen#1791)
28cbaf9 New compute at interface (csarofeen#1743)
635ebfc Add SpanningTreePrinter (csarofeen#1786)
59f3c32 Output allocate patch (csarofeen#1790)
fe93bf5 Transform propagator skip replay when possible (csarofeen#1782)
ebf23a5 Fix isIntegralType error msg (csarofeen#1789)
0c82ecf Disable register reuse across serial broadcast ops (csarofeen#1787)
33a824d Adding sibling path for MaxInfoSpanningTree (csarofeen#1776)
86f46aa Fix div(Val, TensorView) (csarofeen#1778)
d3de227 Fix FusionMaxRootDomainInfoSpanningTreePrintTwice_CUDA (csarofeen#1781)
ecc7a87 Extend mma dimension and layout checking to support strided batched matmul and tensor contractions (csarofeen#1761)
```

[ghstack-poisoned]
shmsong pushed a commit to shmsong/pytorch that referenced this pull request Jul 24, 2022
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Code changes includes:

- codegen improvements:
  1. Indexing refactor -> Remove reference tensor in predicate indexing logic
  2. MMA Rfactor support for cross-warp and cross-CTA split on K dimension
  3. Grouping grid allreduces across iterations
  4. Swizzle op formulation for non-affine swizzles
  5. Use scheduler_utils to cache inputs and outputs in schedulePointwise
- scheduler refactor
  1. New compute at interface
- transformation propagation refactor on MaxInfoSpanningTree
  1. Added sibling path that is required to generate consistent replay for some cases where `MaxInfoSpanningTree` is used with a selector.
  2. Optimization to skip Transform propagator
  3. SpanningTreePrinter for debugging
- parser update
  1. Fixes `div`
  2. Added `_to_copy`
  3. Broadcast in dim with expand to support expanding to concrete size
  4. Dropout prob extremal patch
- executor patch on caching strides for output allocation

Squashed commits to WAR github API
Commits that's actually in this PR from the devel branch:

```
3b87896 Fix allocation of work buffers and `fused_reduction::ParallelReduce` with unswitch (csarofeen#1818)
4cae122 schedulePointwise cleanup: - computeAt + InlinePropagator (csarofeen#1815)
3df9742 Use scheduler_utils to cache inputs and outputs in schedulePointwise (csarofeen#1811)
03180aa improve broadcast resolution (csarofeen#1792)
bee6c69 bug fix (csarofeen#1819)
4413c8f Support PYTORCH_NVFUSER_DUMP=transform_propagator (csarofeen#1812)
de6b7ca Fix negative position in InlinePropagator (csarofeen#1813)
10a996c Remove redundant check in schedulePointwise (csarofeen#1810)
acd5ed4 Swizzle op formulation for non-affine swizzles (csarofeen#1441)
3ed8330 Kernel args patch to show zero_init buffer (csarofeen#1809)
037a75a Dropout prob extremal patch (csarofeen#1804)
282c429 spam nvrtc options (csarofeen#1783)
3ba6a5f Broadcast in dim with expand (csarofeen#1794)
fd4be12 remove dead indexing code (csarofeen#1806)
fa4e6a4 Check siblings in getMaxPosAll (csarofeen#1805)
025c840 Grouping grid allreduces across iterations (csarofeen#1755)
37c579e Temporarily disable test requring large shared memory. (csarofeen#1802)
5f375d0 More cleanup on InlinePropagator (csarofeen#1800)
8d384da Indexing refactor stage 2 : Remove reference tensor in predicate indexing logic (csarofeen#1784)
f008140 MMA Rfactor support for cross-warp and cross-CTA split on K dimension (csarofeen#1554)
76b3cca Add parsing support for `_to_copy` to handle AMP casts. (csarofeen#1756)
ef04f6c Coding style cleanups (csarofeen#1798)
38c7f3c InlinePropagator please don't replay (csarofeen#1797)
3f2c263 validateDomain in TransformPropagator (csarofeen#1796)
c077085 Use TransformPropagatorWithCheck in many tests (csarofeen#1795)
d0d0908 Some further cleanup for the new computeAt interface (csarofeen#1793)
45f5203 Fix TransformReplay::getMatchedLeafPosWithoutReplay* (csarofeen#1791)
28cbaf9 New compute at interface (csarofeen#1743)
635ebfc Add SpanningTreePrinter (csarofeen#1786)
59f3c32 Output allocate patch (csarofeen#1790)
fe93bf5 Transform propagator skip replay when possible (csarofeen#1782)
ebf23a5 Fix isIntegralType error msg (csarofeen#1789)
0c82ecf Disable register reuse across serial broadcast ops (csarofeen#1787)
33a824d Adding sibling path for MaxInfoSpanningTree (csarofeen#1776)
86f46aa Fix div(Val, TensorView) (csarofeen#1778)
d3de227 Fix FusionMaxRootDomainInfoSpanningTreePrintTwice_CUDA (csarofeen#1781)
ecc7a87 Extend mma dimension and layout checking to support strided batched matmul and tensor contractions (csarofeen#1761)
```

RUN_TORCHBENCH: nvfuser

Differential Revision: [D38043938](https://our.internmc.facebook.com/intern/diff/D38043938)

[ghstack-poisoned]
csarofeen pushed a commit that referenced this pull request Aug 4, 2022
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Code changes includes:

- codegen improvements:
  1. Indexing refactor -> Remove reference tensor in predicate indexing logic
  2. MMA Rfactor support for cross-warp and cross-CTA split on K dimension
  3. Grouping grid allreduces across iterations
  4. Swizzle op formulation for non-affine swizzles
  5. Use scheduler_utils to cache inputs and outputs in schedulePointwise
- scheduler refactor
  1. New compute at interface
- transformation propagation refactor on MaxInfoSpanningTree
  1. Added sibling path that is required to generate consistent replay for some cases where `MaxInfoSpanningTree` is used with a selector.
  2. Optimization to skip Transform propagator
  3. SpanningTreePrinter for debugging
- parser update
  1. Fixes `div`
  2. Added `_to_copy`
  3. Broadcast in dim with expand to support expanding to concrete size
  4. Dropout prob extremal patch
- executor patch on caching strides for output allocation

Squashed commits to WAR github API
Commits that's actually in this PR from the devel branch:

```
3b87896 Fix allocation of work buffers and `fused_reduction::ParallelReduce` with unswitch (#1818)
4cae122 schedulePointwise cleanup: - computeAt + InlinePropagator (#1815)
3df9742 Use scheduler_utils to cache inputs and outputs in schedulePointwise (#1811)
03180aa improve broadcast resolution (#1792)
bee6c69 bug fix (#1819)
4413c8f Support PYTORCH_NVFUSER_DUMP=transform_propagator (#1812)
de6b7ca Fix negative position in InlinePropagator (#1813)
10a996c Remove redundant check in schedulePointwise (#1810)
acd5ed4 Swizzle op formulation for non-affine swizzles (#1441)
3ed8330 Kernel args patch to show zero_init buffer (#1809)
037a75a Dropout prob extremal patch (#1804)
282c429 spam nvrtc options (#1783)
3ba6a5f Broadcast in dim with expand (#1794)
fd4be12 remove dead indexing code (#1806)
fa4e6a4 Check siblings in getMaxPosAll (#1805)
025c840 Grouping grid allreduces across iterations (#1755)
37c579e Temporarily disable test requring large shared memory. (#1802)
5f375d0 More cleanup on InlinePropagator (#1800)
8d384da Indexing refactor stage 2 : Remove reference tensor in predicate indexing logic (#1784)
f008140 MMA Rfactor support for cross-warp and cross-CTA split on K dimension (#1554)
76b3cca Add parsing support for `_to_copy` to handle AMP casts. (#1756)
ef04f6c Coding style cleanups (#1798)
38c7f3c InlinePropagator please don't replay (#1797)
3f2c263 validateDomain in TransformPropagator (#1796)
c077085 Use TransformPropagatorWithCheck in many tests (#1795)
d0d0908 Some further cleanup for the new computeAt interface (#1793)
45f5203 Fix TransformReplay::getMatchedLeafPosWithoutReplay* (#1791)
28cbaf9 New compute at interface (#1743)
635ebfc Add SpanningTreePrinter (#1786)
59f3c32 Output allocate patch (#1790)
fe93bf5 Transform propagator skip replay when possible (#1782)
ebf23a5 Fix isIntegralType error msg (#1789)
0c82ecf Disable register reuse across serial broadcast ops (#1787)
33a824d Adding sibling path for MaxInfoSpanningTree (#1776)
86f46aa Fix div(Val, TensorView) (#1778)
d3de227 Fix FusionMaxRootDomainInfoSpanningTreePrintTwice_CUDA (#1781)
ecc7a87 Extend mma dimension and layout checking to support strided batched matmul and tensor contractions (#1761)
```

RUN_TORCHBENCH: nvfuser

Differential Revision: [D38043938](https://our.internmc.facebook.com/intern/diff/D38043938)
Pull Request resolved: pytorch#81861
Approved by: https://github.com/davidberard98
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants