[JIT] AutoBatching - IR transformation(basic operators)#9198
[JIT] AutoBatching - IR transformation(basic operators)#9198ChunliF wants to merge 21 commits intopytorch:masterfrom
Conversation
facebook-github-bot
left a comment
There was a problem hiding this comment.
@ChunliF has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
torch/csrc/jit/passes/to_batch.cpp
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/passes/to_batch.cpp
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/passes/to_batch.cpp
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/passes/to_batch.cpp
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/jit/__init__.py
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/jit/__init__.py
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/jit/__init__.py
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/jit/__init__.py
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/passes/to_batch.cpp
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/jit/__init__.py
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
facebook-github-bot
left a comment
There was a problem hiding this comment.
@ChunliF has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
@ChunliF has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
@ChunliF has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
zdevito
left a comment
There was a problem hiding this comment.
This looks pretty good! I mostly have code organizational comments and a few nits.
torch/csrc/jit/passes/to_batch.cpp
Outdated
| } | ||
|
|
||
| void to_batch_graph(std::shared_ptr<Graph>& graph, int64_t batch_size){ | ||
| // batch_size: not used yet, will be used to deal with scalarType |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/passes/to_batch.cpp
Outdated
| namespace torch { namespace jit { | ||
|
|
||
| // map from batchTensor to {data, mask, dims} | ||
| static std::unordered_map<Value*, std::vector<Value*>> batch_map; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| auto new_input = batch_map.at(input); | ||
| new_inputs.insert(new_inputs.end(), new_input.begin(), new_input.end()); | ||
| } | ||
| else{ |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/passes/to_batch.cpp
Outdated
| batch_map[output] = std::vector<Value*>(outputs.begin() + i * 3, outputs.begin() + i * 3 + 3); | ||
| } | ||
| } | ||
| // control flow: not supported yet, will be added further |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/script/init.cpp
Outdated
| // There is a thin wrapper on top of this method in the C++ version of | ||
| // ScriptModule. | ||
| return runMethodFromPython(self.get_method("forward"), args); | ||
| }) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/jit/__init__.py
Outdated
| def batch(batch_size=1, optimize=True, _frames_up=0): | ||
| def decorator(fn): | ||
| mod = script(fn, optimize, _frames_up) | ||
| res_graph = mod.to_batch(batch_size) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
facebook-github-bot
left a comment
There was a problem hiding this comment.
@ChunliF has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
@ChunliF has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
@ChunliF is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: Use decorator `torch.jit.batch` to implement auto-batching (call `to_batch` pass to do IR tranformation). - `to_batch` pass: "to_batch.h/cpp" in csrc/jit/passess to transform a graph to a new batched graph. - Write several basic operators for BatchTensor (add, mul, sigmoid, tanh, mm, matmul, select). - Register the operators in a lookup table `<std::string, std::shared_ptr<Graph>>`. (use the Graph to replace the original node in IR graph) Move BatchTensor in python from torch.BatchTensor to torch.jit.BatchTensor Pull Request resolved: pytorch#9198 Reviewed By: zdevito Differential Revision: D8744466 Pulled By: ChunliF fbshipit-source-id: 9ea56a30f55cb870f13a2069a47cc635419763ff
Use decorator
torch.jit.batchto implement auto-batching (callto_batchpass to do IR tranformation).to_batchpass: "to_batch.h/cpp" in csrc/jit/passess to transform a graph to a new batched graph.<std::string, std::shared_ptr<Graph>>. (use the Graph to replace the original node in IR graph)Move BatchTensor in python from torch.BatchTensor to torch.jit.BatchTensor