Move at::chunk into the graph fuser#10178
Conversation
5ec923d to
bf687cc
Compare
zdevito
left a comment
There was a problem hiding this comment.
The fusion pass changes look good. The fusion compiler stuff looks pretty good too but I have some suggestions to simplify it and to avoid allocations along the fast path that I think we should do. Otherwise it would make it difficult to add more functionality to the fuser later.
torch/csrc/jit/fusion_compiler.cpp
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/fusion_compiler.cpp
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/fusion_compiler.cpp
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/fusion_compiler.cpp
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/fusion_compiler.cpp
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/fusion_compiler.cpp
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/fusion_compiler.cpp
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
ffba17a to
afe7b9d
Compare
|
I've modified the pull request to use ConcatDesc and fixed the extra std::vector allocations. I've measured some rough numbers on how much more time it takes to run Before (master): 11.73 microseconds After changes: 12.14 microseconds so the current changes have not added too much overhead. |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/fusion_compiler.cpp
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/fusion_compiler.cpp
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/fusion_compiler.cpp
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/fusion_compiler.cpp
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/fusion_compiler.cpp
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
@zou3519 Ping :D |
0656d68 to
3de2807
Compare
|
This should be ready for another review now, despite the failing (unrelated) tests |
3de2807 to
60b58f1
Compare
apaszke
left a comment
There was a problem hiding this comment.
LGTM, but I'm not sure if you're handling contiguity information correctly in PartitionDesc
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/fusion_compiler.cpp
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/fusion_compiler.h
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
0e344fe to
598eedf
Compare
facebook-github-bot
left a comment
There was a problem hiding this comment.
zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
This is done through the following: 1) Absorb starting chunks into FusionGroup as a part of the graph fuser pass. 2) When compiling a kernel, move chunks out of the FusionGroup and use the resulting graph. Emit a std::vector<MaybeChunkDesc> that describes if an input (of the original graph) will be chunked. 3) When launching a kernel, use std::vector<MaybeChunkDesc> to chunk an input tensor on the CPU. This chunk takes in an at::Tensor and outputs four TensorInfo structs, bypassing intermediate Tensors. 4) The resulting TensorInfo structs are sent into the compiled kernel. Test Plan - Expect test and correctness test to see if a single chunk is fused by the graph fuser - Correctness test for a variety of chunks (dimension = beginning, middle, end) and tensors (contiguous, non-contiguous, edge case (splitSize = 1) for both CPU/CUDA - Expect test for multiple chunks fused into the same kernel and correctness test. Absorb starting at::chunk into FusionGroups If all outputs to at::chunk are inputs to a FusionGroup and "chunks", "dim" are both constants, then the at::chunk is moved into the beginning of the FusionGroup. Teach fusion compiler about at::chunk inside a FusionGroup When compiling, the compiler emits an extra std::vector<ConcatDesc> that says which inputs are chunked into how many pieces. The compiler scans inputs and produces a list of "flat inputs". When launching, the compiler scans the inputs and the chunk_desc to see which inputs are chunked. It uses this information to prepare a list of flat inputs to send to the compiled kernel. Update expect files Fix nit Windows fix Address most comments, still working on the rest Use prim::FusedChunk for chunks inside FusionGroup. Addressed comments: - add assert - separate PartitionDesc chunk / cat ctors so the logic is clearer
If one has a graph like the following: ``` y1, y2 = chunk(x) z1, z2 = chunk(x) fusiongroup(y1, y2, z1, z2) ``` Only one chunk should become a prim::FusedChunk inside the fusion group because there is an invariant that prim::FusedChunk inside the fusion group may not have the same input. This is because of how the fusion compiler replaces the input to be chunked into its chunked tensors.
598eedf to
b00a76e
Compare
facebook-github-bot
left a comment
There was a problem hiding this comment.
zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
@pytorchbot retest this please |
Summary: ... to avoid slow at::chunk (it is slow due to tensor initialization). Picking up from pytorch#10026 This is done through the following: 1) Absorb starting chunks into FusionGroup as a part of the graph fuser pass. 2) When compiling a kernel, emit a `std::vector<ConcatDesc>` that describes if an input (of the original graph) will be chunked. 3) When launching a kernel, `use std::vector<ConcatDesc>` to chunk an input tensor on the CPU. This chunk directly takes in an at::Tensor and creates four TensorInfo structs in-place in the argument list, bypassing the creation of intermediate Tensors. - Expect test and correctness test to see if a single chunk is fused by the graph fuser - Correctness test for a variety of chunks (dimension = beginning, middle, end) and tensors (contiguous, non-contiguous, edge case (splitSize = 1) for both CPU/CUDA - Expect test for multiple chunks fused into the same kernel and correctness test. cc zdevito apaszke LSTM forward pass, 1 layer, 512 hidden size and input size, 100 seq length, requires_grad=False on all inputs and weights. After changes: ``` thnn cudnn jit 8.8468 6.5797 9.3470 ``` Before changes: ``` thnn cudnn jit 9.9221 6.6539 11.2550 ``` Pull Request resolved: pytorch#10178 Differential Revision: D9382661 Pulled By: zou3519 fbshipit-source-id: 1f8a749208fbdd45559775ce98cf4eb9558448f8
... to avoid slow at::chunk (it is slow due to tensor initialization). Picking up from #10026
This is done through the following:
pass.
std::vector<ConcatDesc>that describes if an input (of the original graph) will be chunked.use std::vector<ConcatDesc>to chunk aninput tensor on the CPU. This chunk directly takes in an at::Tensor and creates
four TensorInfo structs in-place in the argument list, bypassing the creation of intermediate Tensors.
Test Plan
by the graph fuser
middle, end) and tensors (contiguous, non-contiguous, edge case
(splitSize = 1) for both CPU/CUDA
correctness test.
cc @zdevito @apaszke
Perf
LSTM forward pass, 1 layer, 512 hidden size and input size, 100 seq length, requires_grad=False on all inputs and weights.
After changes:
Before changes: