[JIT] Auto-batching IR transformation for control flow#9392
[JIT] Auto-batching IR transformation for control flow#9392ChunliF wants to merge 20 commits intopytorch:masterfrom
Conversation
facebook-github-bot
left a comment
There was a problem hiding this comment.
@ChunliF has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
jamesr66a
left a comment
There was a problem hiding this comment.
Looking good! A couple overall comments:
- A lot of places in the code you have a literal
3around to represent the size of the expanded Values. It might be better to factor that out as a constant so if the number of values used to represent a batch changes we can just change that in 1 place. In general I think it's safe to assume this is going to be a 1 -> N value transform for all values, so factoring it out as a constant might help - It would probably also be good to add some calls to
self.assertExpectedand pass in the batched graphs in the test cases. This way, we can manually inspect the outputs of the graph in review. That also has the nice property of failing the test if we pass it "on accident" (right answer with wrong method)
| } | ||
| } | ||
|
|
||
| // clone prim::Constant to new graph |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/passes/to_batch.cpp
Outdated
| void ToBatch::toBatch(Block* block, Block* res_block) { | ||
| // change inputs of a graph - expand tensor to {data, mask, dims} | ||
| // eg: "a.1" -> {"a", "1"}; "a" -> {"a"} | ||
| std::vector<std::string> ToBatch::get_name(std::string name) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/passes/to_batch.cpp
Outdated
| // replace aten operator node with BatchTensor operator graph | ||
| void ToBatch::visitAten(Node* n, Block* block, Block* res_block, std::unordered_map<std::string, Value*>& var_map){ | ||
| if(n->outputs().size() > 1){ | ||
| throw std::runtime_error("Cannot process multiple assignment"); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/passes/to_batch.cpp
Outdated
| // do update on assignment | ||
| auto name_base = get_name(n->output()->uniqueName())[0]; | ||
| if(var_map.find(name_base) != var_map.end()){ | ||
| std::vector<Value*> inputs(batch_map.at(var_map.at(name_base))); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| // elif is not supported | ||
| // assume every variable assigned in an if statement is already defined before | ||
| void ToBatch::visitIf(Node* n, Block* block, Block* res_block, std::unordered_map<std::string, Value*>& var_map){ | ||
| auto res_graph = res_block->owningGraph(); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/passes/to_batch.cpp
Outdated
|
|
||
| // elif is not supported | ||
| // assume every variable assigned in an if statement is already defined before | ||
| void ToBatch::visitIf(Node* n, Block* block, Block* res_block, std::unordered_map<std::string, Value*>& var_map){ |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/passes/to_batch.cpp
Outdated
| auto res_graph = res_block->owningGraph(); | ||
|
|
||
| // create prim::If node for res_block | ||
| auto add_if_node = [&](Block* block, std::shared_ptr<Graph> cond_graph, std::vector<Value*> cond, std::vector<Value*> unchanged_outputs){ |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
facebook-github-bot
left a comment
There was a problem hiding this comment.
@ChunliF has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
zdevito
left a comment
There was a problem hiding this comment.
Nice progress! I see some issues involving update statements that I put in the comments. Let me know if you have any questions about it.
torch/csrc/jit/passes/to_batch.cpp
Outdated
| } | ||
| } | ||
|
|
||
| void ToBatch::toBatch(Block* block, Block* res_block, std::unordered_map<std::string, Value*>& upper_var_map) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/passes/to_batch.cpp
Outdated
| visitAten(n, block, res_block, var_map); | ||
| } | ||
| else if(n->kind().is_prim()){ | ||
| if(n->kind() == prim::Constant){ |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/passes/to_batch.cpp
Outdated
| res_block->registerOutput(r_output[1]); | ||
| res_block->registerOutput(r_output[2]); | ||
| // change outputs of block - expand tensor to batchtensor(data, mask, dims) | ||
| // for block in prim::Loop, register outputs separately |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/passes/to_batch.cpp
Outdated
| } | ||
| auto outputs = script::inlineCallTo(*res_block->owningGraph(), *batch_graph, new_inputs); | ||
|
|
||
| // do update on assignment |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| graph(%a.1_data : Dynamic | ||
| %a.1_mask : Dynamic | ||
| %a.1_dims : Dynamic) { | ||
| %3 : int = prim::Constant[value={1}]() |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| %8 : Dynamic = aten::__or__(%a.1_dims, %b_dims) | ||
| %9 : Dynamic = aten::mul(%6, %7) | ||
| %10 : Dynamic = aten::sum(%9) | ||
| %11 : Dynamic = aten::gt[other={0}](%10) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
test/test_jit.py
Outdated
|
|
||
| @torch.jit.script | ||
| def batch_for(x, y): | ||
| for _i in range(10): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
|
||
| @torch.jit.script | ||
| def batch_argmax(data, mask, dims, dim, keepdim): | ||
| # if dim == 0: |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
facebook-github-bot
left a comment
There was a problem hiding this comment.
@ChunliF has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
| graph = torch.to_batch_graph(batch_for.graph) | ||
| self.assertExpected(str(graph)) | ||
|
|
||
| def test_lstm(self): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
test/test_jit.py
Outdated
| def greedy(x, h, c, embed, w_xi, w_xf, w_xo, w_xc, w_hi, w_hf, w_ho, w_hc, | ||
| b_i, b_f, b_o, b_c, w_hs, b_s, iter_num): | ||
| iter_count = torch.zeros_like(iter_num) | ||
| while(iter_count < iter_num): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/passes/to_batch.cpp
Outdated
| auto size = block->inputs().size(); | ||
| for(size_t i = 0; i < size; i++){ | ||
| auto input = block->inputs()[i]; | ||
| std::shared_ptr<Graph> ToBatch::getBatchOperator(std::string name, int64_t input_num){ |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
test/test_jit.py
Outdated
| def batch_if(a, b): | ||
| if a > b: | ||
| a += b | ||
| a = a + b |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
test/test_jit.py
Outdated
| s_t = s_t.view([1, -1]) | ||
| p_t = torch.softmax(s_t, 1) | ||
| # print(p_t) | ||
| prob_t, i_t = torch.topk(p_t, k, 1) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
facebook-github-bot
left a comment
There was a problem hiding this comment.
@ChunliF has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
@pytorchbot retest this please |
facebook-github-bot
left a comment
There was a problem hiding this comment.
ChunliF has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
@pytorchbot retest this please |
1 similar comment
|
@pytorchbot retest this please |
zdevito
left a comment
There was a problem hiding this comment.
I didn't do a detailed review. But I think we should merge this, assuming its tests pass so that it doesn't get broken by future jit changes.
facebook-github-bot
left a comment
There was a problem hiding this comment.
ChunliF has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
@pytorchbot retest this please |
facebook-github-bot
left a comment
There was a problem hiding this comment.
ChunliF has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: Implement IR transformation for control flow - `prim::Constant`: clone to new graph directly - `prim::NumToTensor`: create a `BatchTensor` from output tensor with `batch_size = 1` - `prim::TensorToNum`: clone to new graph - `prim::ListConstruct`: clone to new graph - `prim::If`: execute both `if_block` and `else_block` and combine results from them using `cond` - `prim::Loop`: - for loop - while loop: change while `cond` to `cond_any`, use `cond` to update outputs test case: hand-written LSTM, greedy search, beam search Pull Request resolved: pytorch#9392 Differential Revision: D8822369 Pulled By: ChunliF fbshipit-source-id: 8f03c95757d32e8c4580eeab3974fd1bc429a1e5
Implement IR transformation for control flow
prim::Constant: clone to new graph directlyprim::NumToTensor: create aBatchTensorfrom output tensor withbatch_size = 1prim::TensorToNum: clone to new graphprim::ListConstruct: clone to new graphprim::If: execute bothif_blockandelse_blockand combine results from them usingcondprim::Loop:condtocond_any, usecondto update outputstest case: hand-written LSTM, greedy search, beam search