Skip to content

[jit] Add Custom graph fusion#18588

Closed
bwasti wants to merge 10 commits into
gh/bwasti/2/basefrom
gh/bwasti/2/head
Closed

[jit] Add Custom graph fusion#18588
bwasti wants to merge 10 commits into
gh/bwasti/2/basefrom
gh/bwasti/2/head

Conversation

@bwasti

@bwasti bwasti commented Mar 28, 2019

Copy link
Copy Markdown
Contributor

Stack from ghstack:

Differential Revision: D14901297

@zdevito zdevito left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I commented on the other WIP TVM IR with my fuser comments. I think they still apply here.

Comment thread torch/csrc/jit/ir.h Outdated
TypePtr typ); // value of None with type Optional[typ]
TORCH_API Node* createAutogradZero();
TORCH_API Node* createFusionGroup();
TORCH_API Node* createFusionGroup(Symbol kind = prim::FusionGroup);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is no longer specific to FusionGroup, maybe it would be better to rename it to createWithSubgraph(Symbol kind)?

Comment thread torch/csrc/jit/passes/graph_fuser.cpp Outdated
Block* block_;
std::unique_ptr<AliasDb> aliasDb_;
std::shared_ptr<Graph> graph_;
using FusionCallback = std::function<bool(Node*)>;

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we move this to not appear inside a sequence of member decalarations? It is quite confusing at the moment

std::unique_ptr<AliasDb> aliasDb_;
std::shared_ptr<Graph> graph_;
using FusionCallback = std::function<bool(Node*)>;
FusionCallback callback_ = [&](Node* n) { return isFusableDefault(n); };

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm can you really capture this in a lambda that's in a default initializer of a member? It's a surprising syntax.

Comment thread torch/csrc/jit/passes/graph_fuser.cpp Outdated
}

void CustomFuseGraph(std::shared_ptr<Graph>& graph, std::function<bool(Node*)> fn, Symbol kind) {
if (canFuseOnCPU() || canFuseOnGPU()) {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My previous comment still applies: why do we check the capabilities of PyTorch fuser, when you're really applying this pass to obtain a fused node that will be passed through a completely different backend.

[jit] Add Custom graph fusion

gh-metadata: pytorch pytorch 18588 gh/bwasti/2/head
[jit] Add Custom graph fusion

gh-metadata: pytorch pytorch 18588 gh/bwasti/2/head
@wanchaol

Copy link
Copy Markdown
Collaborator

clang-tidy is complaining

bwasti added 3 commits April 22, 2019 13:21
[jit] Add Custom graph fusion

gh-metadata: pytorch pytorch 18588 gh/bwasti/2/head
[jit] Add Custom graph fusion

gh-metadata: pytorch pytorch 18588 gh/bwasti/2/head
[jit] Add Custom graph fusion

gh-metadata: pytorch pytorch 18588 gh/bwasti/2/head
Comment thread torch/csrc/jit/passes/graph_fuser.h Outdated

TORCH_API void CustomFuseGraph(
std::shared_ptr<Graph>& graph,
std::function<bool(Node*)> fn,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better to call this is_fusable than fn because that name carries no information.

Comment thread torch/csrc/jit/passes/graph_fuser.h Outdated
TORCH_API void CustomFuseGraph(
std::shared_ptr<Graph>& graph,
std::function<bool(Node*)> fn,
Symbol tag);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: kind

Comment thread torch/csrc/jit/passes/graph_fuser.cpp Outdated
bool isFusableDefault(Node* node) {
bool fusableDevice = true;
for (const auto& output : node->outputs()) {
fusableDevice &= isFusableDevice(output);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check this only for outputs which have uses.

producer->node()->matches(
"aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor")) {

if (kind_ == prim::FusionGroup &&

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This place is a bit meh, but I guess it's ok. It's a good sign that we might want to rethink the API you're adding because right now there's no way to do decompose ops lazily with custom fusions.

@bddppq

bddppq commented May 5, 2019

Copy link
Copy Markdown
Contributor

@pytorchbot retest this please

bwasti added 5 commits May 6, 2019 11:15
[jit] Add Custom graph fusion

gh-metadata: pytorch pytorch 18588 gh/bwasti/2/head
[jit] Add Custom graph fusion

gh-metadata: pytorch pytorch 18588 gh/bwasti/2/head
[jit] Add Custom graph fusion

gh-metadata: pytorch pytorch 18588 gh/bwasti/2/head
[jit] Add Custom graph fusion

gh-metadata: pytorch pytorch 18588 gh/bwasti/2/head
[jit] Add Custom graph fusion

gh-metadata: pytorch pytorch 18588 gh/bwasti/2/head
@zou3519 zou3519 deleted the gh/bwasti/2/head branch May 7, 2019 06:17
@facebook-github-bot

Copy link
Copy Markdown
Contributor

@bwasti merged this pull request in 4ca325d.

laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
Summary:
Pull Request resolved: pytorch#18588
ghimport-source-id: f40df17

Differential Revision: D14901297

Pulled By: bwasti

fbshipit-source-id: 1b6371a5175b3d63dad542b7cc22cb82e8c6cfd0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

oncall: jit Add this issue/PR to JIT oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants