A new index hoisting and CSE approach based on subexpression equivalence and dependency#2235
A new index hoisting and CSE approach based on subexpression equivalence and dependency#2235
Conversation
This reverts commit 52c7ded.
| } | ||
|
|
||
| //! A utility class to check if an expression of a particular type exists | ||
| class ExprFinder : kir::ConstIrVisitor { |
There was a problem hiding this comment.
moved to kernel_ir.cpp
| //! ParallelType::Group, but it isn't sufficient as the loop may be | ||
| //! for an initialization expression, for which the loop shold not | ||
| //! be grouped. Make sure a GroupedGridReduction is found. | ||
| bool isGroupedLoop(const kir::ForLoop* loop) { |
There was a problem hiding this comment.
Renamed as ForLoop::isGroup
| auto root_ind_i = | ||
| is_overriden ? override_it->second : index_map.at(root_dom[i]); | ||
|
|
||
| // index hoist must be done before the adjustments for halo |
There was a problem hiding this comment.
We no longer have this limitation anymore
| // Version of hoisting without using reference tensor, | ||
| // should eventually deprecate the other one once reference | ||
| // tensor is completely deprecated. | ||
| Val* hoistConsumerIndex( |
There was a problem hiding this comment.
Consumer, producer, and predicate hoisting now use a single interface.
| strided_inds[i] = GpuLower::current()->commonIndexMap().hoistIndex( | ||
| strided_inds[i], loops); |
There was a problem hiding this comment.
I think we should take the sum of strided_inds first, and hoist the resulting val. Done in #2234.
| start_index = GpuLower::current()->commonIndexMap().hoistIndex( | ||
| start_magic_zero_info.index, loops); | ||
| stop_index = GpuLower::current()->commonIndexMap().hoistIndex( | ||
| stop_magic_zero_info.index, loops); |
There was a problem hiding this comment.
Instead of hoisting start and stop index, we should be able to just hoist the entire predicate as a boolean expression. I will not do it in this PR, will investigate later.
For trivial loop, I am not considering its loop variable as a dependency, so it can be safely lifted out of it. |
| if ((i165 < (T0.size[0] * (T0.size[1] * (T0.size[2] * T0.size[3]))))) { | ||
| int64_t i143; | ||
| i143 = (((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x); | ||
| int64_t i94; |
There was a problem hiding this comment.
i94 is used three times, so it is hoisted.
| auto create_fn = def->newObjectFunc(); | ||
| create_fn(index->container(), inputs, {index}, def->attributes()); |
There was a problem hiding this comment.
I am so happy that now I can create an arbitrary expression without knowing what type it is.
| auto create_fn = def->newObjectFunc(); | ||
| create_fn(index->container(), inputs, {index}, def->attributes()); |
| if (current_level > 0) { | ||
| loop = loops[current_level - 1]; | ||
| } | ||
| auto it = common_index_map_.find(loop); |
There was a problem hiding this comment.
A bit lost here. loop can be nullptr. Is it safe to do lookup with nullptr?
There was a problem hiding this comment.
Why does the loop start with current_level=0? Doesn't level mean the loop position where all the loop dependency is resolved for this index? So, could there be an equivalent index in the loops from 0th to level-1-th positions? Isn't it just sufficient to check loops[level]?
There was a problem hiding this comment.
I think nullptr lookup is fine. It is just an integer zero, it will not access the object.
I need something to refer to the top level exprs, which has no corresponding loop. I can either use 0 to refer to top level exprs and use 1 to refer to loops[0], or use -1 to refer to top level exprs and use 0 to refer to loops[0]
There was a problem hiding this comment.
OK, I still don't follow why you need to loops that are outer than the level loop. Is it because some subexpression of index may be hoisted at those outer loops?
There was a problem hiding this comment.
OK, I still don't follow why you need to loops that are outer than the
levelloop. Is it because some subexpression ofindexmay be hoisted at those outer loops?
Oh, I think you are right. We don't need it. I originally wrote this thinking that some subexpression of index might be hoisted at outer loops. I thought that I would find subexpressions that are both in index and common_index_map_. But seems that I eventually didn't do that. All I am doing in this function is to find something sameAs index, I am not further breaking index down into pieces(the breakdown of index is done in hoistIndex). So I think checking loops[pos] is sufficient.
There was a problem hiding this comment.
I updated this PR, did some cleanup with this new understanding. Now pos == -1 means top level exprs, pos == 0 means loops[0].
|
Checked matmul perf: The diff in generate code https://www.diffchecker.com/ekzrijW0 Will also check batchnorm. |
| auto create_fn = def->newObjectFunc(); | ||
| create_fn(index->container(), inputs, {index}, def->attributes()); |
Seems like there's a lot of changes since I looked at this PR last week. Revisiting now. |
| int64_t position, // if `index` is given (i.e., is_give == true), then | ||
| // this is the position of the outer-most loop nest that | ||
| // contains all the dependencies of `index`. if `index` | ||
| // is a subexpression of the given index, then this is |
There was a problem hiding this comment.
A bit confused here: index is a subexpression of the given index. Which index are you referring to by "the given index"?
There was a problem hiding this comment.
"the given index" refers to the index given to the public method hoistIndex. I have renamed this function to hoistIndexImpl and updated the comment here to be more clear.
| //! try insert ((i1*1 + i2*2) + i3*3) to common_index_map_[FOR i3], | ||
| //! try insert ((i1*1 + i2*2) + i3*3) + i4*4 to common_index_map_[FOR i4], | ||
| //! Before insertion, this function recursively uses | ||
| //! eliminateCommonSubexpression to find existing expressions/subexpressions |
There was a problem hiding this comment.
Yes, it was renamed to reuseIndexIfAlreadyComputed. I have fixed the comment. Thanks for catching it!
naoyam
left a comment
There was a problem hiding this comment.
LGTM. Thanks for the cleanup and enhancement!


This is a complete rewrite of the index hoisting. I no longer use IterDomain mapping to do the analysis, instead, the analysis is just based on the indexing expressions, and their subexpression equivalence and dependency. Doing so, the code is shorter, cleaner, and it can find more hoisting opportunities. I don't have benchmark yet. Will run them this weekend.