[jit] Add Custom graph fusion#18588
Conversation
zdevito
left a comment
There was a problem hiding this comment.
I commented on the other WIP TVM IR with my fuser comments. I think they still apply here.
| TypePtr typ); // value of None with type Optional[typ] | ||
| TORCH_API Node* createAutogradZero(); | ||
| TORCH_API Node* createFusionGroup(); | ||
| TORCH_API Node* createFusionGroup(Symbol kind = prim::FusionGroup); |
There was a problem hiding this comment.
Since this is no longer specific to FusionGroup, maybe it would be better to rename it to createWithSubgraph(Symbol kind)?
| Block* block_; | ||
| std::unique_ptr<AliasDb> aliasDb_; | ||
| std::shared_ptr<Graph> graph_; | ||
| using FusionCallback = std::function<bool(Node*)>; |
There was a problem hiding this comment.
nit: can we move this to not appear inside a sequence of member decalarations? It is quite confusing at the moment
| std::unique_ptr<AliasDb> aliasDb_; | ||
| std::shared_ptr<Graph> graph_; | ||
| using FusionCallback = std::function<bool(Node*)>; | ||
| FusionCallback callback_ = [&](Node* n) { return isFusableDefault(n); }; |
There was a problem hiding this comment.
Hmm can you really capture this in a lambda that's in a default initializer of a member? It's a surprising syntax.
| } | ||
|
|
||
| void CustomFuseGraph(std::shared_ptr<Graph>& graph, std::function<bool(Node*)> fn, Symbol kind) { | ||
| if (canFuseOnCPU() || canFuseOnGPU()) { |
There was a problem hiding this comment.
My previous comment still applies: why do we check the capabilities of PyTorch fuser, when you're really applying this pass to obtain a fused node that will be passed through a completely different backend.
[jit] Add Custom graph fusion gh-metadata: pytorch pytorch 18588 gh/bwasti/2/head
[jit] Add Custom graph fusion gh-metadata: pytorch pytorch 18588 gh/bwasti/2/head
|
clang-tidy is complaining |
[jit] Add Custom graph fusion gh-metadata: pytorch pytorch 18588 gh/bwasti/2/head
[jit] Add Custom graph fusion gh-metadata: pytorch pytorch 18588 gh/bwasti/2/head
[jit] Add Custom graph fusion gh-metadata: pytorch pytorch 18588 gh/bwasti/2/head
|
|
||
| TORCH_API void CustomFuseGraph( | ||
| std::shared_ptr<Graph>& graph, | ||
| std::function<bool(Node*)> fn, |
There was a problem hiding this comment.
It would be better to call this is_fusable than fn because that name carries no information.
| TORCH_API void CustomFuseGraph( | ||
| std::shared_ptr<Graph>& graph, | ||
| std::function<bool(Node*)> fn, | ||
| Symbol tag); |
| bool isFusableDefault(Node* node) { | ||
| bool fusableDevice = true; | ||
| for (const auto& output : node->outputs()) { | ||
| fusableDevice &= isFusableDevice(output); |
There was a problem hiding this comment.
Please check this only for outputs which have uses.
| producer->node()->matches( | ||
| "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor")) { | ||
|
|
||
| if (kind_ == prim::FusionGroup && |
There was a problem hiding this comment.
This place is a bit meh, but I guess it's ok. It's a good sign that we might want to rethink the API you're adding because right now there's no way to do decompose ops lazily with custom fusions.
|
@pytorchbot retest this please |
[jit] Add Custom graph fusion gh-metadata: pytorch pytorch 18588 gh/bwasti/2/head
[jit] Add Custom graph fusion gh-metadata: pytorch pytorch 18588 gh/bwasti/2/head
[jit] Add Custom graph fusion gh-metadata: pytorch pytorch 18588 gh/bwasti/2/head
[jit] Add Custom graph fusion gh-metadata: pytorch pytorch 18588 gh/bwasti/2/head
[jit] Add Custom graph fusion gh-metadata: pytorch pytorch 18588 gh/bwasti/2/head
Summary: Pull Request resolved: pytorch#18588 ghimport-source-id: f40df17 Differential Revision: D14901297 Pulled By: bwasti fbshipit-source-id: 1b6371a5175b3d63dad542b7cc22cb82e8c6cfd0
Stack from ghstack:
Differential Revision: D14901297