Skip to content

Fix issues in reductions and thread predicates#470

Merged
naoyam merged 10 commits into20_10_20_develfrom
fix-issue367
Nov 2, 2020
Merged

Fix issues in reductions and thread predicates#470
naoyam merged 10 commits into20_10_20_develfrom
fix-issue367

Conversation

@naoyam
Copy link
Copy Markdown
Collaborator

@naoyam naoyam commented Oct 30, 2020

Thread predicates are ignored when calling blockReduce and gridReduce. See #468 for a reproducer of the problem when it is ignored for blockReduce. See #367 for a reproducer of gridReduce. This PR also adds the reproducers as new tests.

For gridReduce, the best way to apply thread predicates is, IMO, to set the TIDx/y/z template parameters as false. An assumption I have is that the thread predicate of a GridReduction must not include BIDx/y/z since we don't allow multiple calls to gridReduce in a single kernel.

To pass around the predicate info until the CUDA code for calling gridReduce is generated, a new field of type ParallelTypeBitmap is added to kir::GridReduction. To use ParalleTypeBitmap from kernel_ir.h, I also extracted the class from lower_utils.h into its own header file.

For blockReduce, the change is much more trivial (IndexLowering::visit(kir::ReductionOp*).

Fixes #367
Fixes #468

Naoya Maruyama added 4 commits October 29, 2020 17:12
When TIDx/y/z are predicated, set the TIDx/y/z template flags as false

Closes #367
@naoyam naoyam requested review from csarofeen and tlemo October 30, 2020 08:07
Copy link
Copy Markdown
Owner

@csarofeen csarofeen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me. Do we assert somewhere if someone attempts to use 2 grid reductions in the same kernel? We actually may want to do this at some point and just want to make sure until then we throw a hard error.

Can you also please update the issues with the code that is now generated?
Thanks!

@naoyam
Copy link
Copy Markdown
Collaborator Author

naoyam commented Nov 2, 2020

Added an issue of detecting multiple grid reductions (#475).

@naoyam
Copy link
Copy Markdown
Collaborator Author

naoyam commented Nov 2, 2020

Updated issue #367 with the generated kernel.

@naoyam
Copy link
Copy Markdown
Collaborator Author

naoyam commented Nov 2, 2020

Updated issue #468 with the generated kernel.

Comment thread test/cpp/jit/test_gpu.cpp Outdated
}
}

std::string generateGridReduceTemplateFlags(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please place this local helper in a anonymous namespace

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's already in a class defined in an anonymous namespace.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, right. One reason I don't like anonymous namespaces, you have to scroll around a lot to find them :(

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's a downside.

Comment thread torch/csrc/jit/codegen/cuda/codegen.cpp Outdated
Comment thread torch/csrc/jit/codegen/cuda/kernel_ir.h Outdated
Comment thread torch/csrc/jit/codegen/cuda/kernel_ir.h Outdated
Comment thread torch/csrc/jit/codegen/cuda/lower_index.h Outdated
Comment thread torch/csrc/jit/codegen/cuda/kernel_ir.h Outdated
ReductionOp* reduction_op_ = nullptr;
Allocate* reduction_buffer_ = nullptr;
Allocate* sync_buffer_ = nullptr;
ParallelTypeBitmap thread_pred_;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't we use the "normal" predicate (Expr::predicate_) instead of this?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could. We would need to make some changes to gridReduce. In particular, indexing the work buffer written by each thread block needs some non-significant change.

However, gridReduce does have template bool parameters exactly for predicating based on block and thread indices being zero, so using those template flags should make more sense. The normal predicate is not for thread/block indices, so it can't be mapped to the template flags, and that's why we need to separate them for gridReduce. Note that for other expressions, we just combine them by &&.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment to the code itself too.

Comment thread torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp
@naoyam naoyam merged commit a4d48c3 into 20_10_20_devel Nov 2, 2020
@naoyam naoyam deleted the fix-issue367 branch November 2, 2020 23:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

ReductionOp ignores thread predicates Issue with block reduction then grid reduction

3 participants