Conversation
A tree where leaves are matrix multiplies and inner vertices are adds can be computed as a single mm. Such subgraph often appear in backward if a single weight is reused multiple times (e.g. in RNNs). NOTE: this seems to be slightly slower on the GPU than the naive implementation, but it's a huge win on the CPU (think 100x lower overhead)
| // This pass looks for trees in the graph, where leaves are mm ops, and the inner | ||
| // vertices are add nodes. Once we have such a tree they can be reduced to two | ||
| // concats and a single mm (basically into a single multiply of a wide matrix, with | ||
| // a tall matrix). |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| // a tall matrix). | ||
| // Such patterns show up mostly in backward of RNNs, since the derivative of many | ||
| // uses of matrix multiplies with same weights forms exactly such a tree | ||
| // (note that it's usually also highly imbalanced i.e. has O(n) depth). |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| return arr; | ||
| } | ||
|
|
||
| struct TreeToken { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| std::array<int64_t, 2> lhs_sizes; | ||
| std::array<int64_t, 2> rhs_sizes; | ||
| Node *node = nullptr; | ||
| bool valid = false; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| Node *node = nullptr; | ||
| bool valid = false; | ||
|
|
||
| static TreeToken from_mm(Node *mm) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| // TreeTokens will be used to label nodes of the graph, if the nodes will fit | ||
| // our mm/add tree pattern. Basically we do dynamic programming on DAGs, where | ||
| // when we reach node N with inputs A and B, then A and B have already been | ||
| // procesed, and we can try to unify their TreeTokens (if they have them) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
zdevito
left a comment
There was a problem hiding this comment.
This looks good! Clean and understandable.
| // | R2 | | ||
| // | | | ||
| // +------+ | ||
| // +------+------+ +------+ |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| // If we ever get around implementing this, the right solution is probably to fuse | ||
| // MMs for the common part, and assume it's an input leaf for the outer two parts | ||
| // (I don't think it's beneficial to recompute, unless the subtree is super small, | ||
| // but let's not get into such details). |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| // See Note [Overlapping trees] | ||
| if (&l == &r || !l.is_root || !r.is_root) | ||
| return token; | ||
| // We can batch the tree only if all sizes match, because we need to |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| // See Note [Overlapping trees] (regarding the uses().size() == 1 check) | ||
| // We could treat a subtree with multiple uses as if it was overlapping. | ||
| if (lhs_it != tokens.end() && rhs_it != tokens.end() && | ||
| lhs->output()->uses().size() == 1 && rhs->output()->uses().size() == 1) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| if (lhs_it != tokens.end() && rhs_it != tokens.end() && | ||
| lhs->output()->uses().size() == 1 && rhs->output()->uses().size() == 1) { | ||
| if (auto token = TreeToken::unify(node, lhs_it->second, rhs_it->second)) | ||
| tokens[node] = token; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| // topological order and labeling nodes with TreeTokens. Then, we look for roots of | ||
| // the trees we formed and fuse them. | ||
|
|
||
| enum class Side { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
Hmm the Windows contbuilds seem to be failing at test stage... Any ideas for non-standard things I could have used? cc: @peterjc123 |
| const std::vector<std::int64_t>& strides() const { return strides_; } | ||
|
|
||
| TypePtr withSizesStrides(const std::vector<std::int64_t>& sizes, const std::vector<std::int64_t>& strides) const { | ||
| TypePtr withSizesStrides(at::IntList sizes, at::IntList strides) const { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
Implement MM fusion (MM with add reduction tree) A tree where leaves are matrix multiplies and inner vertices are adds can be computed as a single mm. Such subgraph often appear in backward if a single weight is reused multiple times (e.g. in RNNs). NOTE: this seems to be slightly slower on the GPU than the naive implementation, but it's a huge win on the CPU (think 100x lower overhead)
A tree where leaves are matrix multiplies and inner
vertices are adds can be computed as a single mm.
Such subgraph often appear in backward if a single weight
is reused multiple times (e.g. in RNNs).
NOTE: this seems to be slightly slower on the GPU than the
naive implementation, but it's a huge win on the CPU
(think 100x lower overhead)