Skip to content

A new index hoisting and CSE approach based on subexpression equivalence and dependency#2235

Merged
zasdfgbnm merged 36 commits intodevelfrom
expr_based_index_hoist
Dec 6, 2022
Merged

A new index hoisting and CSE approach based on subexpression equivalence and dependency#2235
zasdfgbnm merged 36 commits intodevelfrom
expr_based_index_hoist

Conversation

@zasdfgbnm
Copy link
Copy Markdown
Collaborator

@zasdfgbnm zasdfgbnm commented Dec 2, 2022

This is a complete rewrite of the index hoisting. I no longer use IterDomain mapping to do the analysis, instead, the analysis is just based on the indexing expressions, and their subexpression equivalence and dependency. Doing so, the code is shorter, cleaner, and it can find more hoisting opportunities. I don't have benchmark yet. Will run them this weekend.

@zasdfgbnm zasdfgbnm changed the title [NOT READY FOR REVIEW] Expr based index hoisting [NOT READY FOR REVIEW] Expr based index hoisting and CSE Dec 2, 2022
@zasdfgbnm zasdfgbnm changed the title [NOT READY FOR REVIEW] Expr based index hoisting and CSE Expr based index hoisting and CSE Dec 2, 2022
@zasdfgbnm zasdfgbnm marked this pull request as ready for review December 2, 2022 14:37
@zasdfgbnm zasdfgbnm requested review from csarofeen and naoyam December 2, 2022 14:38
}

//! A utility class to check if an expression of a particular type exists
class ExprFinder : kir::ConstIrVisitor {
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved to kernel_ir.cpp

//! ParallelType::Group, but it isn't sufficient as the loop may be
//! for an initialization expression, for which the loop shold not
//! be grouped. Make sure a GroupedGridReduction is found.
bool isGroupedLoop(const kir::ForLoop* loop) {
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed as ForLoop::isGroup

auto root_ind_i =
is_overriden ? override_it->second : index_map.at(root_dom[i]);

// index hoist must be done before the adjustments for halo
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We no longer have this limitation anymore

// Version of hoisting without using reference tensor,
// should eventually deprecate the other one once reference
// tensor is completely deprecated.
Val* hoistConsumerIndex(
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consumer, producer, and predicate hoisting now use a single interface.

Comment on lines +1496 to +1497
strided_inds[i] = GpuLower::current()->commonIndexMap().hoistIndex(
strided_inds[i], loops);
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should take the sum of strided_inds first, and hoist the resulting val. Done in #2234.

Comment on lines +2802 to +2805
start_index = GpuLower::current()->commonIndexMap().hoistIndex(
start_magic_zero_info.index, loops);
stop_index = GpuLower::current()->commonIndexMap().hoistIndex(
stop_magic_zero_info.index, loops);
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of hoisting start and stop index, we should be able to just hoist the entire predicate as a boolean expression. I will not do it in this PR, will investigate later.

@zasdfgbnm
Copy link
Copy Markdown
Collaborator Author

When can we lift out of two exactly mapped loops? Wouldn't that mean the index of the loop is involved so we can't lift anything on that index out of the loop that defines it right?

Generally, no we can't, but if those loops are not materialized, it's still possible to reuse.

https://github.com/csarofeen/pytorch/blob/devel/torch/csrc/jit/codegen/cuda/lower_index_hoist.cpp#L166-L169

For trivial loop, I am not considering its loop variable as a dependency, so it can be safely lifted out of it.

if ((i165 < (T0.size[0] * (T0.size[1] * (T0.size[2] * T0.size[3]))))) {
int64_t i143;
i143 = (((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x);
int64_t i94;
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i94 is used three times, so it is hoisted.

Comment on lines +140 to +141
auto create_fn = def->newObjectFunc();
create_fn(index->container(), inputs, {index}, def->attributes());
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am so happy that now I can create an arbitrary expression without knowing what type it is.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really cool!

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool. Love it.

Comment thread torch/csrc/jit/codegen/cuda/lower_index_hoist.cpp Outdated
Comment thread torch/csrc/jit/codegen/cuda/lower_index_hoist.cpp Outdated
Comment on lines +140 to +141
auto create_fn = def->newObjectFunc();
create_fn(index->container(), inputs, {index}, def->attributes());
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really cool!

if (current_level > 0) {
loop = loops[current_level - 1];
}
auto it = common_index_map_.find(loop);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit lost here. loop can be nullptr. Is it safe to do lookup with nullptr?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does the loop start with current_level=0? Doesn't level mean the loop position where all the loop dependency is resolved for this index? So, could there be an equivalent index in the loops from 0th to level-1-th positions? Isn't it just sufficient to check loops[level]?

Copy link
Copy Markdown
Collaborator Author

@zasdfgbnm zasdfgbnm Dec 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think nullptr lookup is fine. It is just an integer zero, it will not access the object.
I need something to refer to the top level exprs, which has no corresponding loop. I can either use 0 to refer to top level exprs and use 1 to refer to loops[0], or use -1 to refer to top level exprs and use 0 to refer to loops[0]

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I still don't follow why you need to loops that are outer than the level loop. Is it because some subexpression of index may be hoisted at those outer loops?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I still don't follow why you need to loops that are outer than the level loop. Is it because some subexpression of index may be hoisted at those outer loops?

Oh, I think you are right. We don't need it. I originally wrote this thinking that some subexpression of index might be hoisted at outer loops. I thought that I would find subexpressions that are both in index and common_index_map_. But seems that I eventually didn't do that. All I am doing in this function is to find something sameAs index, I am not further breaking index down into pieces(the breakdown of index is done in hoistIndex). So I think checking loops[pos] is sufficient.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated this PR, did some cleanup with this new understanding. Now pos == -1 means top level exprs, pos == 0 means loops[0].

Comment thread torch/csrc/jit/codegen/cuda/lower_index_hoist.cpp Outdated
@zasdfgbnm
Copy link
Copy Markdown
Collaborator Author

Checked matmul perf:
Screenshot_20221202_184243
I think it is within normal variation.

The diff in generate code https://www.diffchecker.com/ekzrijW0

Will also check batchnorm.

@zasdfgbnm
Copy link
Copy Markdown
Collaborator Author

This is the perf for batchnorm
Screenshot_20221202_225801

Copy link
Copy Markdown
Owner

@csarofeen csarofeen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me, however I'm a bit fuzzy on how CommonIndexMap is produced so I think @naoyam should be the one to approve.

Comment thread torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp
Comment thread torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp
Comment thread torch/csrc/jit/codegen/cuda/ir_base_nodes.h Outdated
Comment thread torch/csrc/jit/codegen/cuda/kernel_ir.cpp Outdated
Comment thread torch/csrc/jit/codegen/cuda/kernel_ir.cpp
Comment thread torch/csrc/jit/codegen/cuda/lower_index_hoist.cpp Outdated
Comment thread torch/csrc/jit/codegen/cuda/lower_index_hoist.h Outdated
Comment thread torch/csrc/jit/codegen/cuda/lower_index_hoist.cpp Outdated
Comment on lines +140 to +141
auto create_fn = def->newObjectFunc();
create_fn(index->container(), inputs, {index}, def->attributes());
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool. Love it.

Comment thread torch/csrc/jit/codegen/cuda/lower_index_hoist.cpp
@naoyam
Copy link
Copy Markdown
Collaborator

naoyam commented Dec 6, 2022

This looks good to me, however I'm a bit fuzzy on how CommonIndexMap is produced so I think @naoyam should be the one to approve.

Seems like there's a lot of changes since I looked at this PR last week. Revisiting now.

int64_t position, // if `index` is given (i.e., is_give == true), then
// this is the position of the outer-most loop nest that
// contains all the dependencies of `index`. if `index`
// is a subexpression of the given index, then this is
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit confused here: index is a subexpression of the given index. Which index are you referring to by "the given index"?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"the given index" refers to the index given to the public method hoistIndex. I have renamed this function to hoistIndexImpl and updated the comment here to be more clear.

//! try insert ((i1*1 + i2*2) + i3*3) to common_index_map_[FOR i3],
//! try insert ((i1*1 + i2*2) + i3*3) + i4*4 to common_index_map_[FOR i4],
//! Before insertion, this function recursively uses
//! eliminateCommonSubexpression to find existing expressions/subexpressions
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function was renamed?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it was renamed to reuseIndexIfAlreadyComputed. I have fixed the comment. Thanks for catching it!

Comment thread torch/csrc/jit/codegen/cuda/lower_index_hoist.cpp Outdated
Copy link
Copy Markdown
Collaborator

@naoyam naoyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks for the cleanup and enhancement!

@zasdfgbnm zasdfgbnm merged commit 7f1bb3f into devel Dec 6, 2022
@zasdfgbnm zasdfgbnm deleted the expr_based_index_hoist branch December 6, 2022 22:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants