Skip to content

Re-entrant GroupedGridReduction#1727

Merged
naoyam merged 3 commits intodevelfrom
re_entrant_horizontally_grouped_grid_reduction
May 25, 2022
Merged

Re-entrant GroupedGridReduction#1727
naoyam merged 3 commits intodevelfrom
re_entrant_horizontally_grouped_grid_reduction

Conversation

@naoyam
Copy link
Copy Markdown
Collaborator

@naoyam naoyam commented May 24, 2022

Enable re-entrance with GroupedGridReduction. Mostly just copied the logic already implemented for GridReduction to GroupedGridReduction.

See FusionGroupedReductionChannelsLastBatchNormLike. The two grid reductions with vectorized iteration domains are grouped as:

  #pragma unroll
  for(nvfuser_index_t i241 = 0; i241 < 2; ++i241) {
    // Allocate global tensor T16
    // Allocate global tensor T17
    // Allocate global tensor T18
    reduction::gridReduceGroup<false, true, false, false, true, false, false>(
      T5[i241],
      T15[i241],
      float(0),
      [](float &a, float b) { a = a + b; },
      &T16[0],
      T9[i241],
      T14[i241],
      float(0),
      [](float &a, float b) { a = a + b; },
      &T17[0],
      &T18[0],
      shared_mem,
      true,
      true,
      i241,
      2,
      T19[0],
      T19[1]);
  }

Comment on lines -208 to -209
const nvfuser_index_t entrance_ind_ = PERSISTENT_REDUCTION ? 0 : entrance_ind;
const nvfuser_index_t n_entrances_ = PERSISTENT_REDUCTION ? 1 : n_entrances;
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just removed unused variables

}
}

inline void clearL2Cache() {
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just copied from the benchmark directory as I'm doing ad-hoc perf testing in the C++ test files.

Naoya Maruyama added 2 commits May 24, 2022 15:54
@naoyam naoyam changed the title [WIP] Re-entrant GroupedGridReduction Re-entrant GroupedGridReduction May 25, 2022
@naoyam naoyam marked this pull request as ready for review May 25, 2022 00:54
@naoyam naoyam requested a review from csarofeen May 25, 2022 00:57
Copy link
Copy Markdown
Owner

@csarofeen csarofeen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM


auto tv0_cache = tv0->cacheAfter();

const int vec = 2;
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity what does the heuristics select if it's not grouped? Can the heuristics/scheduler run after grouping?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scheduler_params output:


===== Reduction Stats ========
total_reduction_numel: 99
total_iteration_numel: 999
vectorize_factor: 1
n_tensor_inputs: 1
max_input_dtype_size: 4
block(16, 16, 1)

===== Reduction Parameters ========

Red On Slow Dim

Iteration Domain: blockIdx.x / threadIdx.x / multiple reductions per block /
Inner Reduction Domain: cross block - threadIdx.y / unroll / factor 4
Launch Parameters: BlockDim.x = 16, BlockDim.y = 16, BlockDim.z = -1, GridDim.x = -1, GridDim.y = -1, GridDim.z = -1, Smem Size = 0

====================================


===== Reduction Stats ========
total_reduction_numel: 99
total_iteration_numel: 999
vectorize_factor: 1
n_tensor_inputs: 1
max_input_dtype_size: 4
block(16, 16, 1)

===== Reduction Parameters ========

Red On Slow Dim

Iteration Domain: blockIdx.x / threadIdx.x / multiple reductions per block /
Inner Reduction Domain: cross block - threadIdx.y / unroll / factor 4
Launch Parameters: BlockDim.x = 16, BlockDim.y = 16, BlockDim.z = -1, GridDim.x = -1, GridDim.y = -1, GridDim.z = -1, Smem Size = 0

====================================

Grouping before scheduling currently fails as expressions are changed from ReducitonOp to GroupedReductionOp. Just making it work would be likely just some mechanical changes.

@naoyam naoyam merged commit 5247682 into devel May 25, 2022
@naoyam naoyam deleted the re_entrant_horizontally_grouped_grid_reduction branch May 25, 2022 16:34
jjsjann123 added a commit that referenced this pull request Jun 22, 2022
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Bug fixes and minor refactor

Squashed commits to WAR github API
Commits that's actually in this PR from the devel branch:

```
4c60e7d Add examples infrastructure for using nvFuser in a standalone program (#1725)
02a05d9 Fix issue #1751 (#1753)
8a69aa3 Refactor NvFuser transpose API to match eager mode behavior (#1746)
ffdf6b7 Remove BroadcastWithoutStride. (#1738)
02bab16 Fix flipping of a boolean flag (#1745)
465d668 cleanup (#1744)
26d354e fixing noncontig broadcast (#1742)
856b6b2 Add IterDomainBuilder (#1736)
1fd974f fixing warning for gcc7 (#1732)
de2740a disabling complex in python tests for #1730 (#1733)
fbbbe0a fixing MSVC build (#1728)
b5feee5 Fix the fused reduction runtime kernel (#1729)
5247682 Re-entrant GroupedGridReduction (#1727)
```

RUN_TORCHBENCH: nvfuser
Pull Request resolved: pytorch#79147
Approved by: https://github.com/davidberard98
jjsjann123 added a commit that referenced this pull request Jun 22, 2022
…h#79406)

Landing reverted PR pytorch#79147.

Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Bug fixes and minor refactor

Squashed commits to WAR github API
Commits that's actually in this PR from the devel branch:

```
4c60e7d Add examples infrastructure for using nvFuser in a standalone program (#1725)
02a05d9 Fix issue #1751 (#1753)
8a69aa3 Refactor NvFuser transpose API to match eager mode behavior (#1746)
ffdf6b7 Remove BroadcastWithoutStride. (#1738)
02bab16 Fix flipping of a boolean flag (#1745)
465d668 cleanup (#1744)
26d354e fixing noncontig broadcast (#1742)
856b6b2 Add IterDomainBuilder (#1736)
1fd974f fixing warning for gcc7 (#1732)
de2740a disabling complex in python tests for #1730 (#1733)
fbbbe0a fixing MSVC build (#1728)
b5feee5 Fix the fused reduction runtime kernel (#1729)
5247682 Re-entrant GroupedGridReduction (#1727)
```

RUN_TORCHBENCH: nvfuser
Pull Request resolved: pytorch#79406
Approved by: https://github.com/davidberard98
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants