implement Repeat with fixed output shape#7114
Conversation
|
|
||
| # shift the repeats by one | ||
| # tensor([0, 0, 1, 2, 0, 4, 0, 6, 7, 8]) | ||
| exclusive_repeats = torch.roll(repeats, shifts=1) |
There was a problem hiding this comment.
Why do we want to shift right? If we don't shift, I assume we don't need to -1 at L700, and it will still output the correct answer?
|
|
||
| # value in gather_indices represents the index in the input. | ||
| # tensor([1, 2, 2, 4, 4, 4, 4, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7]) | ||
| gather_indices = torch.cumsum(block_split_indicators, dim=0) - 1 |
There was a problem hiding this comment.
Trying to reason about this step... but really hard to understand... Can you explain a little bit to me? I couldn't get why the block_split_indicators can be converted to gather_indices...
There was a problem hiding this comment.
I can understand for each indicator, the index needs to bump. Trying to understand the value to bump.
There was a problem hiding this comment.
Okay, now I understand. If some index are skipped, it means the indicator get selected more. The more it skips, the more it get selects. Which equals to the value to bump.
|
Let's skip the GPU tests to move fast. |
|
Thanks, Jack! |
No description provided.