Skip to content

[JIT] Auto-batching IR transformation for control flow#9392

Closed
ChunliF wants to merge 20 commits intopytorch:masterfrom
ChunliF:control-flow
Closed

[JIT] Auto-batching IR transformation for control flow#9392
ChunliF wants to merge 20 commits intopytorch:masterfrom
ChunliF:control-flow

Conversation

@ChunliF
Copy link
Contributor

@ChunliF ChunliF commented Jul 12, 2018

Implement IR transformation for control flow

  • prim::Constant: clone to new graph directly
  • prim::NumToTensor: create a BatchTensor from output tensor with batch_size = 1
  • prim::TensorToNum: clone to new graph
  • prim::ListConstruct: clone to new graph
  • prim::If: execute both if_block and else_block and combine results from them using cond
  • prim::Loop:
    • for loop
    • while loop: change while cond to cond_any, use cond to update outputs

test case: hand-written LSTM, greedy search, beam search

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ChunliF has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Collaborator

@jamesr66a jamesr66a left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good! A couple overall comments:

  1. A lot of places in the code you have a literal 3 around to represent the size of the expanded Values. It might be better to factor that out as a constant so if the number of values used to represent a batch changes we can just change that in 1 place. In general I think it's safe to assume this is going to be a 1 -> N value transform for all values, so factoring it out as a constant might help
  2. It would probably also be good to add some calls to self.assertExpected and pass in the batched graphs in the test cases. This way, we can manually inspect the outputs of the graph in review. That also has the nice property of failing the test if we pass it "on accident" (right answer with wrong method)

}
}

// clone prim::Constant to new graph

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

void ToBatch::toBatch(Block* block, Block* res_block) {
// change inputs of a graph - expand tensor to {data, mask, dims}
// eg: "a.1" -> {"a", "1"}; "a" -> {"a"}
std::vector<std::string> ToBatch::get_name(std::string name) {

This comment was marked as off-topic.

// replace aten operator node with BatchTensor operator graph
void ToBatch::visitAten(Node* n, Block* block, Block* res_block, std::unordered_map<std::string, Value*>& var_map){
if(n->outputs().size() > 1){
throw std::runtime_error("Cannot process multiple assignment");

This comment was marked as off-topic.

// do update on assignment
auto name_base = get_name(n->output()->uniqueName())[0];
if(var_map.find(name_base) != var_map.end()){
std::vector<Value*> inputs(batch_map.at(var_map.at(name_base)));

This comment was marked as off-topic.

// elif is not supported
// assume every variable assigned in an if statement is already defined before
void ToBatch::visitIf(Node* n, Block* block, Block* res_block, std::unordered_map<std::string, Value*>& var_map){
auto res_graph = res_block->owningGraph();

This comment was marked as off-topic.


// elif is not supported
// assume every variable assigned in an if statement is already defined before
void ToBatch::visitIf(Node* n, Block* block, Block* res_block, std::unordered_map<std::string, Value*>& var_map){

This comment was marked as off-topic.

auto res_graph = res_block->owningGraph();

// create prim::If node for res_block
auto add_if_node = [&](Block* block, std::shared_ptr<Graph> cond_graph, std::vector<Value*> cond, std::vector<Value*> unchanged_outputs){

This comment was marked as off-topic.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ChunliF has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@zdevito zdevito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice progress! I see some issues involving update statements that I put in the comments. Let me know if you have any questions about it.

}
}

void ToBatch::toBatch(Block* block, Block* res_block, std::unordered_map<std::string, Value*>& upper_var_map) {

This comment was marked as off-topic.

visitAten(n, block, res_block, var_map);
}
else if(n->kind().is_prim()){
if(n->kind() == prim::Constant){

This comment was marked as off-topic.

res_block->registerOutput(r_output[1]);
res_block->registerOutput(r_output[2]);
// change outputs of block - expand tensor to batchtensor(data, mask, dims)
// for block in prim::Loop, register outputs separately

This comment was marked as off-topic.

}
auto outputs = script::inlineCallTo(*res_block->owningGraph(), *batch_graph, new_inputs);

// do update on assignment

This comment was marked as off-topic.

graph(%a.1_data : Dynamic
%a.1_mask : Dynamic
%a.1_dims : Dynamic) {
%3 : int = prim::Constant[value={1}]()

This comment was marked as off-topic.

%8 : Dynamic = aten::__or__(%a.1_dims, %b_dims)
%9 : Dynamic = aten::mul(%6, %7)
%10 : Dynamic = aten::sum(%9)
%11 : Dynamic = aten::gt[other={0}](%10)

This comment was marked as off-topic.

test/test_jit.py Outdated

@torch.jit.script
def batch_for(x, y):
for _i in range(10):

This comment was marked as off-topic.


@torch.jit.script
def batch_argmax(data, mask, dims, dim, keepdim):
# if dim == 0:

This comment was marked as off-topic.

@zou3519 zou3519 added the oncall: jit Add this issue/PR to JIT oncall triage queue label Jul 23, 2018
@weiyangfb weiyangfb added the ready for review (this tag is deprecated) All PRs are ready for review unless they are draft, WIP, or have undismissed requested changes label Jul 24, 2018
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ChunliF has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Collaborator

@jamesr66a jamesr66a left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good! @zdevito can you please have a look?

graph = torch.to_batch_graph(batch_for.graph)
self.assertExpected(str(graph))

def test_lstm(self):

This comment was marked as off-topic.

test/test_jit.py Outdated
def greedy(x, h, c, embed, w_xi, w_xf, w_xo, w_xc, w_hi, w_hf, w_ho, w_hc,
b_i, b_f, b_o, b_c, w_hs, b_s, iter_num):
iter_count = torch.zeros_like(iter_num)
while(iter_count < iter_num):

This comment was marked as off-topic.

This comment was marked as off-topic.

auto size = block->inputs().size();
for(size_t i = 0; i < size; i++){
auto input = block->inputs()[i];
std::shared_ptr<Graph> ToBatch::getBatchOperator(std::string name, int64_t input_num){

This comment was marked as off-topic.

test/test_jit.py Outdated
def batch_if(a, b):
if a > b:
a += b
a = a + b

This comment was marked as off-topic.

This comment was marked as off-topic.

test/test_jit.py Outdated
s_t = s_t.view([1, -1])
p_t = torch.softmax(s_t, 1)
# print(p_t)
prob_t, i_t = torch.topk(p_t, k, 1)

This comment was marked as off-topic.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ChunliF has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ChunliF
Copy link
Contributor Author

ChunliF commented Jul 27, 2018

@pytorchbot retest this please

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ChunliF has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ChunliF
Copy link
Contributor Author

ChunliF commented Jul 31, 2018

@pytorchbot retest this please

1 similar comment
@ChunliF
Copy link
Contributor Author

ChunliF commented Jul 31, 2018

@pytorchbot retest this please

Copy link
Contributor

@zdevito zdevito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't do a detailed review. But I think we should merge this, assuming its tests pass so that it doesn't get broken by future jit changes.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ChunliF has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ChunliF
Copy link
Contributor Author

ChunliF commented Aug 1, 2018

@pytorchbot retest this please

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ChunliF has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

goodlux pushed a commit to goodlux/pytorch that referenced this pull request Aug 15, 2018
Summary:
Implement IR transformation for control flow

- `prim::Constant`: clone to new graph directly
- `prim::NumToTensor`: create a `BatchTensor` from output tensor with `batch_size = 1`
- `prim::TensorToNum`: clone to new graph
- `prim::ListConstruct`: clone to new graph
- `prim::If`: execute both `if_block` and `else_block` and combine results from them using `cond`
- `prim::Loop`:
  - for loop
  - while loop: change while `cond` to `cond_any`, use `cond` to update outputs

test case: hand-written LSTM, greedy search, beam search
Pull Request resolved: pytorch#9392

Differential Revision: D8822369

Pulled By: ChunliF

fbshipit-source-id: 8f03c95757d32e8c4580eeab3974fd1bc429a1e5
@ezyang ezyang added the merged label Jun 26, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

oncall: jit Add this issue/PR to JIT oncall triage queue ready for review (this tag is deprecated) All PRs are ready for review unless they are draft, WIP, or have undismissed requested changes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants