Terminology
Dependency chain:
A dependency chain from TV0->TV5 is a path through the arithmetic operations from TV0 to TV5. For example if we had:
TV1 = TV0 * 2.0
TV3 = TV1 + TV2
TV4 = TV3 * TV0
TV5 = TV3 - TV4
then
TV0 -> TV1 -> TV3 -> TV4 -> TV5
and
would both be dependency chains of TV0->TV5
Use chains: A use chain are all paths from a given tensor view to all outputs. The chains will be connected through dependencies, and consist of all tensors that depend in some way on a given tensor. If we refer to the Dependency chain example, all use chains of TV3 are:
TV3 -> TV4 -> TV5
TV3 -> TV5
Producer/Consumer:
A producer-consumer relationship is one where there is a valid, non-empty dependency chain from a TensorView (producer), to another TensorView (consumer). A direct producer-consumer relationship would be if a consumer has producer as the input to its origin expression (the expression that generates consumer).
Common consumer:
A common consumer of a producer is a TensorView that exists in the intersection of all use chains of a given producer. In the Use chains example, TV5 is a common consumer of TV3, but TV4 is not.
replayPasC and replayCasP:
These are the two major replay functions used in computeAt transform replay. replayPasC stands for replay producer as consumer, and replayCasP stands for replay consumer as producer. Both functions have a position argument, which means we need to replay so that axes < pos match between the two TensorViews. The reason there are two functions are because to do the replay we need to create a map of the root domains of the TensorViews and how this mapping occurs changes based on which TensorView is the consumer/producer and what we're trying to replay. As we run these two replays we also set the computeAt of the producer. replayPasC and replayCasP are wrapped in functions called computeAt_impl and forwardComputeAt_impl respectively in the code base.
Compute at current implementation and challenges
The computeAt pass is having some challenges getting correct structure. Of note is: #110
The syntax of computeAt is producer->computeAt(consumer, pos) producer and consumer here do not need to have a direct producer-consumer relationship.
Which means that we should generate producer at pos within the loop nest structure of consumer. Meaning what we'd like is a structure like:
for consumer->axis(0:pos)
for producer->axis(pos:-1)
producer = ...
for consumer->axis(pos:-1)
consumer = ... producer ...
Where (pos:-1) indicates iterating over the axes/domains starting from position, going to the last axis of the tensor. The general process for doing this is to transform producer so its axes (0:pos) are equivalent to consumers axes (0:pos). This transformation is done with the minimal required transformations, for example any axes that don't need to be transformed to generate producer axes (0:pos) are unchanged from producers previous state.
The general challenge here is the multiple consumer issue. If we have a second consumer of producer that is "unrelated" to the first consumer, that second consumer needs to follow the same pattern as the first, because we would need to generate a structure that looks like:
for consumer1->axis(0:pos)
for producer->axis(pos:-1)
producer = ...
for consumer1->axis(pos:-1)
consumer1 = ... producer ...
for consumer2->axis(pos:-1)
consumer2 = ... producer ...
any other placement of consumer2 would not be valid due to the transformation on producer unless we generated producer multiple times which is not currently supported (this could be an optional flag on computeAt, or a transformation on producer). This means that we potentially need to modify tensors that are outside the dependency chains from producer->consumer, meaning we need to be able to propagate the transformation dictated by the computeAt call.
Current implementation:
- Look for a common consumer of producer after consumer.
- If common consumer exists:
2.a. Start from consumer, follow a dependency chain from consumer to common consumer. As walking this dependency chain run replayCasP so that common consumer will be transformed under pos like consumer.
2.b. Now that common consumer matches the computeAt on consumer, follow all dependency chains backwards from producer to common consumer, and run replayPasC. Now all TensorViews from producer to common consumer should match.
- If a common consumer does not exist:
3.a. Take a dependency chain from producer to consumer, iterate the TensorViews from consumer to producer and incrementally run replayPasC.
3.b. Run through all use chains of producer, iterate the TensorViews from producer to outputs and incrementally run replayCasP
What's wrong:
As we iterate in steps 2.a., 2.b., 3.a., and 3.b. the computeAt position will change depending on if we're running replayCasP or replayPasC. The position used in replayPasC is relative to the consumer, and the position in replayCasP is relative to the producer, which is correct.
Consider:
T2[i0, r1, i2] = T1[i0, i1, i2] ...
T3[i0, i2] = T2[i0, r1, i2] ...
and we start at T2, position 2, and we iterate forward with replayCasP. We will end up with the computeAt settings:
T1->computeAt(T2, 2)
T2->computeAt(T3, 1)
The second because there was a reduction within the computeAt, so we effectively lose an axis as we go over it.
Also consider:
T2[i0, b1, i2] = T1[i0, i2] ...
T3[i0, i1, i2] = T2[i0, b1, i2] ...
and we start at T3, position 2, and we iterate backward with replayPasC. We will end up with the computeAt settings:
T1->computeAt(T2, 1)
T2->computeAt(T3, 2)
If this doesn't seem like a problem with our current approach yet, let's consider #110
We effectively have:
T1[i0, i1] = T0[i0, i1]
T2[i0, r1] = T1[i0, i1]
T3[i0, r1] = T1[i0, i1]
T4[i0] = T2[i0, r1], T3[i0, r1]
and we call:
Based on our current procedure, we see that T1 has multiple uses, we look for its common consumer which is T4. We forward propagate from T2 to T4 based on a dependency chain and get:
Then we go backwards through all dep chains from producer->common consumer:
T3->computeAt(T4, 1)
T2->computeAt(T4, 1)
T1->computeAt(T2, 1)
T1->computeAt(T3, 1)
Therefore there is no way to inline the consumption of T1 into both T2 and T3. Even though, in theory T2, and T3 can be inlined with eachother.
This procedure would work if instead of going backward through all dep chains from producer to common consumer we went forward calling replayCasP. However, if instead of reductions we had broadcasts, it would not work for broadcasting. In the current approach, either one or the other works in this case, and right now broadcast does and reduction breaks as shown here.
- Pushing everything forward or backward from consumer is safe.
- Pushing anything forward on (all use chains of producer) - (all dep chains from
common_consumer -> producer or consumer -> producer if common_consumer does not exist) is safe
Terminology
Dependency chain:
A dependency chain from
TV0->TV5is a path through the arithmetic operations fromTV0 to TV5. For example if we had:then
and
would both be dependency chains of
TV0->TV5Use chains: A use chain are all paths from a given tensor view to all outputs. The chains will be connected through dependencies, and consist of all tensors that depend in some way on a given tensor. If we refer to the Dependency chain example, all use chains of
TV3are:Producer/Consumer:
A producer-consumer relationship is one where there is a valid, non-empty dependency chain from a
TensorView(producer), to anotherTensorView(consumer). A direct producer-consumer relationship would be if a consumer has producer as the input to its origin expression (the expression that generates consumer).Common consumer:
A common consumer of a producer is a
TensorViewthat exists in the intersection of all use chains of a given producer. In the Use chains example,TV5is a common consumer ofTV3, butTV4is not.replayPasC and replayCasP:
These are the two major replay functions used in computeAt transform replay.
replayPasCstands for replay producer as consumer, andreplayCasPstands for replay consumer as producer. Both functions have a position argument, which means we need to replay so thataxes < posmatch between the twoTensorViews. The reason there are two functions are because to do the replay we need to create a map of the root domains of the TensorViews and how this mapping occurs changes based on which TensorView is the consumer/producer and what we're trying to replay. As we run these two replays we also set thecomputeAtof theproducer.replayPasCandreplayCasPare wrapped in functions calledcomputeAt_implandforwardComputeAt_implrespectively in the code base.Compute at current implementation and challenges
The
computeAtpass is having some challenges getting correct structure. Of note is: #110The syntax of
computeAtisproducer->computeAt(consumer, pos)producer and consumer here do not need to have a direct producer-consumer relationship.Which means that we should generate
produceratposwithin the loop nest structure ofconsumer. Meaning what we'd like is a structure like:Where
(pos:-1)indicates iterating over the axes/domains starting from position, going to the last axis of the tensor. The general process for doing this is to transform producer so its axes(0:pos)are equivalent to consumers axes(0:pos). This transformation is done with the minimal required transformations, for example any axes that don't need to be transformed to generate producer axes(0:pos)are unchanged from producers previous state.The general challenge here is the multiple consumer issue. If we have a second consumer of producer that is "unrelated" to the first consumer, that second consumer needs to follow the same pattern as the first, because we would need to generate a structure that looks like:
any other placement of
consumer2would not be valid due to the transformation onproducerunless we generatedproducermultiple times which is not currently supported (this could be an optional flag oncomputeAt, or a transformation on producer). This means that we potentially need to modify tensors that are outside the dependency chains fromproducer->consumer, meaning we need to be able to propagate the transformation dictated by thecomputeAtcall.Current implementation:
2.a. Start from consumer, follow a dependency chain from
consumertocommon consumer. As walking this dependency chain runreplayCasPso thatcommon consumerwill be transformed underposlike consumer.2.b. Now that
common consumermatches thecomputeAtonconsumer, follow all dependency chains backwards from producer to common consumer, and run replayPasC. Now all TensorViews from producer to common consumer should match.3.a. Take a dependency chain from producer to consumer, iterate the TensorViews from consumer to producer and incrementally run
replayPasC.3.b. Run through all use chains of producer, iterate the TensorViews from producer to outputs and incrementally run
replayCasPWhat's wrong:
As we iterate in steps 2.a., 2.b., 3.a., and 3.b. the
computeAtposition will change depending on if we're runningreplayCasPorreplayPasC. The position used inreplayPasCis relative to the consumer, and the position inreplayCasPis relative to the producer, which is correct.Consider:
and we start at
T2, position 2, and we iterate forward withreplayCasP. We will end up with thecomputeAtsettings:The second because there was a reduction within the computeAt, so we effectively lose an axis as we go over it.
Also consider:
and we start at
T3, position 2, and we iterate backward withreplayPasC. We will end up with thecomputeAtsettings:If this doesn't seem like a problem with our current approach yet, let's consider #110
We effectively have:
and we call:
Based on our current procedure, we see that T1 has multiple uses, we look for its common consumer which is T4. We forward propagate from T2 to T4 based on a dependency chain and get:
Then we go backwards through all dep chains from producer->common consumer:
Therefore there is no way to inline the consumption of T1 into both T2 and T3. Even though, in theory T2, and T3 can be inlined with eachother.
This procedure would work if instead of going backward through all dep chains from producer to common consumer we went forward calling
replayCasP. However, if instead of reductions we had broadcasts, it would not work for broadcasting. In the current approach, either one or the other works in this case, and right now broadcast does and reduction breaks as shown here.common_consumer->producerorconsumer->producerif common_consumer does not exist) is safe