[quant][graph] Add valueQuantizable function#36548
[quant][graph] Add valueQuantizable function#36548supriyar wants to merge 7 commits intogh/supriyar/84/basefrom
Conversation
Summary: Refactor to be able to observe based on inputs to ops Test Plan: test_quantize_script.py Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
💊 Build failures summary and remediationsAs of commit 2077608 (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker. This comment has been revised 17 times. |
| if (n->kind() == prim::CallFunction && | ||
| getFuncName(n->inputs()[0]) == "batch_norm") { | ||
| return use.offset == 1; | ||
| } |
There was a problem hiding this comment.
could you refactor this further to take a list of tuple of (name, offset) and check if the use is matching the name and offset?
There was a problem hiding this comment.
some similar functions we already implemented: https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/passes/quantization.cpp#L707, https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/passes/quantization.cpp#L713, please feel free to refactor these as well as you see fit
There was a problem hiding this comment.
I've added the list of tuple in the function itself. I didn't want to spend too much time refactoring the other code, so I used it in the best way I saw fit, given that the conditions we were returning on in both cases was different.
Summary: Refactor to be able to observe based on inputs to ops Test Plan: test_quantize_script.py Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
| for (const auto& func_arg : aten_func_args) { | ||
| if (n->kind() == Symbol::aten(func_arg.func_name)) { | ||
| return v == n->inputs().at(func_arg.arg_index); | ||
| } | ||
| } | ||
|
|
||
| for (const auto& func_arg : call_func_args) { | ||
| if (n->kind() == prim::CallFunction && | ||
| getFuncName(n->inputs()[0]) == func_arg.func_name) { | ||
| return v == n->inputs().at(func_arg.arg_index); | ||
| } | ||
| } |
There was a problem hiding this comment.
I think you can use isAtenFuncNthArg and isCallFunctionNthArg here, also you can refactor these functions to not take value as input, and just compare use.offset with the arg_index
There was a problem hiding this comment.
the logic seems a bit different here actually, and this is due to the convoluted check logic that we fall back to nodeQuantizable when the check fails. Can we add a default index for aten and call function in nodes in nodeQuantizable?
Summary: Refactor to be able to observe based on inputs to ops Test Plan: test_quantize_script.py Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
| // For each operator in this list observers are inserted for the input based | ||
| // on the index specified. | ||
| const AtenFuncArgs& aten_func_args = AtenFuncArgs({{"lstm", 2}}); | ||
| const CallFuncArgs& call_func_args = CallFuncArgs({{"batch_norm", 1}}); |
There was a problem hiding this comment.
this change belongs to next PR
Summary: Refactor to be able to observe based on inputs to ops Test Plan: test_quantize_script.py Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: Refactor to be able to observe based on inputs to ops Test Plan: test_quantize_script.py Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
| // Special checks for ops that do not require observers for all input tensors. | ||
| // For each operator in this list observers are inserted for the input based | ||
| // on the index specified. | ||
| const AtenFuncArgs& aten_func_args = AtenFuncArgs({{"lstm", 2}}); |
There was a problem hiding this comment.
we also need to check is_dynamic for "lstm" node..
There was a problem hiding this comment.
I think that won't be necessary because for dynamic quant we insert observers at the input. Which should be same as static quant. We have a check for is_dynamic to insert observer at the output of node in the parent function.
There was a problem hiding this comment.
OK, then do you need is_dynamic as argument here?
There was a problem hiding this comment.
we don't support static quant for lstm right now though
Summary: Refactor to be able to observe based on inputs to ops Test Plan: test_quantize_script.py Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
| const AtenFuncArgs& aten_func_args = AtenFuncArgs({}); | ||
| const CallFuncArgs& call_func_args = CallFuncArgs({}); |
There was a problem hiding this comment.
could you expose these two lists in the beginning of the file so that it's easier to find and change?
There was a problem hiding this comment.
Also, a follow up here is, can we remove the isBias check for conv and linear and put the special case in the list?
There was a problem hiding this comment.
Also, a follow up here is, can we remove the isBias check for conv and linear and put the special case in the list?
I think we can do this separately, I also see isWeight check that uses these lists.
| Node* node = use.user; | ||
| return node->kind() == prim::CallFunction && | ||
| getFuncName(node->inputs()[0]) == func_name && | ||
| (n.has_value() ? (n.value() == use.offset) : true); |
There was a problem hiding this comment.
actually this can be changed to (!n.has_value() || n.value() == use.offset)
Summary: Refactor to be able to observe based on inputs to ops Test Plan: test_quantize_script.py Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
| AtenFuncArgs _aten_func_args = {}; | ||
| CallFuncArgs _call_func_args = {}; |
There was a problem hiding this comment.
can you use better names for these?
Stack from ghstack:
Summary:
Refactor to be able to observe based on inputs to ops
Test Plan:
test_quantize_script.py
Reviewers:
Subscribers:
Tasks:
Tags:
Differential Revision: D21048963