Skip to content

Implement MM fusion (MM with add reduction tree)#4615

Merged
apaszke merged 5 commits intomasterfrom
batch_mm
Jan 17, 2018
Merged

Implement MM fusion (MM with add reduction tree)#4615
apaszke merged 5 commits intomasterfrom
batch_mm

Conversation

@apaszke
Copy link
Copy Markdown
Contributor

@apaszke apaszke commented Jan 11, 2018

A tree where leaves are matrix multiplies and inner
vertices are adds can be computed as a single mm.
Such subgraph often appear in backward if a single weight
is reused multiple times (e.g. in RNNs).

NOTE: this seems to be slightly slower on the GPU than the
naive implementation, but it's a huge win on the CPU
(think 100x lower overhead)

A tree where leaves are matrix multiplies and inner
vertices are adds can be computed as a single mm.
Such subgraph often appear in backward if a single weight
is reused multiple times (e.g. in RNNs).

NOTE: this seems to be slightly slower on the GPU than the
naive implementation, but it's a huge win on the CPU
(think 100x lower overhead)
@apaszke apaszke requested review from ezyang and zdevito January 11, 2018 21:46
@pytorchbot
Copy link
Copy Markdown
Collaborator

@apaszke, thanks for your PR! We identified @zdevito to be a potential reviewer.

// This pass looks for trees in the graph, where leaves are mm ops, and the inner
// vertices are add nodes. Once we have such a tree they can be reduced to two
// concats and a single mm (basically into a single multiply of a wide matrix, with
// a tall matrix).

This comment was marked as off-topic.

This comment was marked as off-topic.

// a tall matrix).
// Such patterns show up mostly in backward of RNNs, since the derivative of many
// uses of matrix multiplies with same weights forms exactly such a tree
// (note that it's usually also highly imbalanced i.e. has O(n) depth).

This comment was marked as off-topic.

This comment was marked as off-topic.

return arr;
}

struct TreeToken {

This comment was marked as off-topic.

This comment was marked as off-topic.

Comment thread torch/csrc/jit/passes/batch_mm.cpp Outdated
std::array<int64_t, 2> lhs_sizes;
std::array<int64_t, 2> rhs_sizes;
Node *node = nullptr;
bool valid = false;

This comment was marked as off-topic.

This comment was marked as off-topic.

Comment thread torch/csrc/jit/passes/batch_mm.cpp Outdated
Node *node = nullptr;
bool valid = false;

static TreeToken from_mm(Node *mm) {

This comment was marked as off-topic.

This comment was marked as off-topic.

// TreeTokens will be used to label nodes of the graph, if the nodes will fit
// our mm/add tree pattern. Basically we do dynamic programming on DAGs, where
// when we reach node N with inputs A and B, then A and B have already been
// procesed, and we can try to unify their TreeTokens (if they have them)

This comment was marked as off-topic.

Copy link
Copy Markdown
Contributor

@zdevito zdevito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good! Clean and understandable.

// | R2 |
// | |
// +------+
// +------+------+ +------+

This comment was marked as off-topic.

// If we ever get around implementing this, the right solution is probably to fuse
// MMs for the common part, and assume it's an input leaf for the outer two parts
// (I don't think it's beneficial to recompute, unless the subtree is super small,
// but let's not get into such details).

This comment was marked as off-topic.

// See Note [Overlapping trees]
if (&l == &r || !l.is_root || !r.is_root)
return token;
// We can batch the tree only if all sizes match, because we need to

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

// See Note [Overlapping trees] (regarding the uses().size() == 1 check)
// We could treat a subtree with multiple uses as if it was overlapping.
if (lhs_it != tokens.end() && rhs_it != tokens.end() &&
lhs->output()->uses().size() == 1 && rhs->output()->uses().size() == 1) {

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

if (lhs_it != tokens.end() && rhs_it != tokens.end() &&
lhs->output()->uses().size() == 1 && rhs->output()->uses().size() == 1) {
if (auto token = TreeToken::unify(node, lhs_it->second, rhs_it->second))
tokens[node] = token;

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

Comment thread torch/csrc/jit/passes/batch_mm.cpp Outdated
// topological order and labeling nodes with TreeTokens. Then, we look for roots of
// the trees we formed and fuse them.

enum class Side {

This comment was marked as off-topic.

@apaszke
Copy link
Copy Markdown
Contributor Author

apaszke commented Jan 13, 2018

Hmm the Windows contbuilds seem to be failing at test stage... Any ideas for non-standard things I could have used? cc: @peterjc123

Comment thread torch/csrc/jit/type.h
const std::vector<std::int64_t>& strides() const { return strides_; }

TypePtr withSizesStrides(const std::vector<std::int64_t>& sizes, const std::vector<std::int64_t>& strides) const {
TypePtr withSizesStrides(at::IntList sizes, at::IntList strides) const {

This comment was marked as off-topic.

@peterjc123
Copy link
Copy Markdown
Collaborator

@apaszke Sorry, I don't have too much idea on how to debug this stuff. If this one can't pass, then how about skipping the tests, listing it in #4092 and waiting for future fixs?

@yf225 yf225 mentioned this pull request Jan 15, 2018
13 tasks
@apaszke apaszke merged commit 1a02d3a into master Jan 17, 2018
@apaszke apaszke deleted the batch_mm branch January 17, 2018 20:36
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
Implement MM fusion (MM with add reduction tree)

A tree where leaves are matrix multiplies and inner
vertices are adds can be computed as a single mm.
Such subgraph often appear in backward if a single weight
is reused multiple times (e.g. in RNNs).

NOTE: this seems to be slightly slower on the GPU than the
naive implementation, but it's a huge win on the CPU
(think 100x lower overhead)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants