[invoke_subgraph] User facing API to support arbitrary args and kwargs#139162
[invoke_subgraph] User facing API to support arbitrary args and kwargs#139162anijain2305 wants to merge 16 commits intogh/anijain2305/568/basefrom
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/139162
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 60da56e with merge base bf1b8ad ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…s and kwargs" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames rec [ghstack-poisoned]
| invoke_subgraph_placeholder = InvokeSubgraphPlaceholderHOP() | ||
|
|
||
|
|
||
| def wrap_with_invoke_subgraph(fn=None): |
There was a problem hiding this comment.
Since this is the user-facing API: can we bikeshed the name? I imagine this API will end up in the torch.compiler namespace.
- torch.compiler.mark_subregion ?
- torch.compiler.set_inline(False) ?
There was a problem hiding this comment.
Yes, but not right now. I want to see some wins, before introducing it in the compiler namespace.
I am imagining something like torch.compile to torch._dynamo.optimizer type of mapping, when we have established that this HOP is indeed useful in practice.
There was a problem hiding this comment.
Can we name the function the name we want to see in the future? If we want to call it mark_subregion or set_inline in the future, then let's name it that now instead of wrap_with_invoke_subgraph. (Names tend to stick and "wrap_with_invoke_subgraph" sounds too much like an implementation detail)
There was a problem hiding this comment.
Alright, calling it mark_compile_region.
| return super().__call__(subgraph, identifier, operands) | ||
|
|
||
|
|
||
| class InvokeSubgraphPlaceholderHOP(HigherOrderOperator): |
There was a problem hiding this comment.
Does this need to be a HOP? For torch.utils.checkpoint, we transformed the torch.utils.checkpoint function into the checkpoint HOP. So we can probably get away with this just being an invoke_subgraph_placeholder function
| def test_multiple(self): | ||
| n_layers = 2 | ||
|
|
||
| @wrap_with_invoke_subgraph |
There was a problem hiding this comment.
Can we test applying this to an nn.Module?
There was a problem hiding this comment.
Good idea! I added a test - test_module
| def wrap(func): | ||
| def inner(*args, **kwargs): | ||
| if torch._dynamo.is_compiling(): | ||
| return invoke_subgraph_placeholder(func, *args, **kwargs) |
There was a problem hiding this comment.
Are there any constraints to the input types of (args, kwargs) at all?
There was a problem hiding this comment.
I don't think so. Dynamo will traverse through the user defined objects here and pull up the tensors from the underlying datastructures to feed to the invoke_subgraph HOP
| This is a user facing API to wrap the function with invoke_subgraph HOP. For | ||
| PyTorch eager, this is a no-op. For torch.compile, we wrap the given | ||
| function into invoke_subgraph_placeholder, which is parsed by Dynamo and | ||
| replaced by invoke_subgraph. |
There was a problem hiding this comment.
this is a developer-facing docstring, we should make it more user-facing. "Use mark_subregion (or whatever we call it) to decorate a region so that torch.compile attempts to only compile it once (instead of inlining it multiple times).
Under the hood, this creates an invoke_subgraph HOP..."
There was a problem hiding this comment.
Sounds good. But, I am thinking of this as private API. The real-public API will be in compiler namespace. I can add more docs there.
zou3519
left a comment
There was a problem hiding this comment.
left some questions but overall seems good
…s and kwargs" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames rec [ghstack-poisoned]
…s and kwargs" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames rec [ghstack-poisoned]
…s and kwargs" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames rec [ghstack-poisoned]
…s and kwargs" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames rec [ghstack-poisoned]
|
@zou3519 This is ready for another round of review. |
…s and kwargs" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames rec [ghstack-poisoned]
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
pytorch#139162) Pull Request resolved: pytorch#139162 Approved by: https://github.com/zou3519
Pull Request resolved: #140058 Approved by: https://github.com/ydwu4, https://github.com/eellison ghstack dependencies: #139162
Summary: X-link: pytorch/pytorch#139162 Approved by: https://github.com/zou3519 Reviewed By: ZainRizvi Differential Revision: D65661188 Pulled By: anijain2305 fbshipit-source-id: 5a090d84e22d020552a4901aacab60aabdc94d24
pytorch#139162) Pull Request resolved: pytorch#139162 Approved by: https://github.com/zou3519
Pull Request resolved: pytorch#140058 Approved by: https://github.com/ydwu4, https://github.com/eellison ghstack dependencies: pytorch#139162
Adds a `invoke_quant` higher order operator as proposed [here](https://docs.google.com/document/d/1s2PfJlq6Q1F8l11CkTIC69BW1rEnGEgs6YmBC7hu8rA/edit?tab=t.0). The primary motivations are - Unifying scattered reasoning for quant operators throughout the code base - Easy of pattern matching - see this very large pattern match expression [here](https://github.com/pytorch/pytorch/blob/949fdd299764d4fbefe1db093717786d946aaa60/torch/_inductor/fx_passes/post_grad.py#L390-L426. Compared to the pattern I have in the tests: ``` register_graph_pattern( CallFunction( torch.ops.aten.mm, CallFunction( torch.ops.higher_order.invoke_quant, Ignored(), Ignored(), Ignored(), scheme="nf4", ), Arg(), ), pass_dict=test_pass, ) ``` - Ability to specify inductor specific logic, like codegen'ing the operators in lower precision, or forcing fusion to a matmul. Example graph: ``` Python ===== AFTER POST GRAD ===== /data/users/eellison/pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module): def forward(self, arg0_1: "f32[8][1]cpu", arg1_1: "f32[8][1]cpu"): # File: /data/users/eellison/pytorch/torch/_higher_order_ops/invoke_quant.py:87 in __call__, code: return invoke_quant_tracer(*args, **kwargs, quant_options=self) # type: ignore[call-arg] repeated_subgraph0 = self.repeated_subgraph0 invoke_quant: "f32[8][1]cpu" = torch.ops.higher_order.invoke_quant(repeated_subgraph0, arg0_1, arg1_1, scheme = 'nf4'); repeated_subgraph0 = arg0_1 = arg1_1 = None return (invoke_quant,) class repeated_subgraph0(torch.nn.Module): def forward(self, arg0_1: "f32[8][1]cpu", arg1_1: "f32[8][1]cpu"): # File: /data/users/eellison/pytorch/torch/_higher_order_ops/invoke_quant.py:87 in __call__, code: return invoke_quant_tracer(*args, **kwargs, quant_options=self) # type: ignore[call-arg] mul: "f32[8][1]cpu" = torch.ops.aten.mul.Tensor(arg0_1, arg1_1); arg0_1 = None add: "f32[8][1]cpu" = torch.ops.aten.add.Tensor(mul, arg1_1); mul = arg1_1 = None return add ``` The schema for `invoke_quant` is `torch.ops.higher_order.invoke_quant(subgraph, *args, scheme=None)` where the scheme will not always be present. I wasn't sure exactly how the inductor specific configurations like `codgen_in_low_precision` should be passed through. I didnt want to stuff them all in as kwargs, and I didn't want to have them affect pattern matching. So they will be stored as meta of the node itself. And, following that, I wanted the invocation of the hop to match how it will show up in the graph. So I decided to have it be an object that is then invoked for the tracing. ``` invoke_quant = InvokeQuant(codegen_low_precision=True) invoke_quant(gn, (x, y), scheme="nf4") ``` Todo - not require the packing of args in a tuple, will do following #139162. Feedback welcome. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov rec [ghstack-poisoned]
Adds a `invoke_quant` higher order operator as proposed [here](https://docs.google.com/document/d/1s2PfJlq6Q1F8l11CkTIC69BW1rEnGEgs6YmBC7hu8rA/edit?tab=t.0). The primary motivations are - Unifying scattered reasoning for quant operators throughout the code base - Easy of pattern matching - see this very large pattern match expression [here](https://github.com/pytorch/pytorch/blob/949fdd299764d4fbefe1db093717786d946aaa60/torch/_inductor/fx_passes/post_grad.py#L390-L426. Compared to the pattern I have in the tests: ``` register_graph_pattern( CallFunction( torch.ops.aten.mm, CallFunction( torch.ops.higher_order.invoke_quant, Ignored(), Ignored(), Ignored(), scheme="nf4", ), Arg(), ), pass_dict=test_pass, ) ``` - Ability to specify inductor specific logic, like codegen'ing the operators in lower precision, or forcing fusion to a matmul. Example graph: ``` Python ===== AFTER POST GRAD ===== /data/users/eellison/pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module): def forward(self, arg0_1: "f32[8][1]cpu", arg1_1: "f32[8][1]cpu"): # File: /data/users/eellison/pytorch/torch/_higher_order_ops/invoke_quant.py:87 in __call__, code: return invoke_quant_tracer(*args, **kwargs, quant_options=self) # type: ignore[call-arg] repeated_subgraph0 = self.repeated_subgraph0 invoke_quant: "f32[8][1]cpu" = torch.ops.higher_order.invoke_quant(repeated_subgraph0, arg0_1, arg1_1, scheme = 'nf4'); repeated_subgraph0 = arg0_1 = arg1_1 = None return (invoke_quant,) class repeated_subgraph0(torch.nn.Module): def forward(self, arg0_1: "f32[8][1]cpu", arg1_1: "f32[8][1]cpu"): # File: /data/users/eellison/pytorch/torch/_higher_order_ops/invoke_quant.py:87 in __call__, code: return invoke_quant_tracer(*args, **kwargs, quant_options=self) # type: ignore[call-arg] mul: "f32[8][1]cpu" = torch.ops.aten.mul.Tensor(arg0_1, arg1_1); arg0_1 = None add: "f32[8][1]cpu" = torch.ops.aten.add.Tensor(mul, arg1_1); mul = arg1_1 = None return add ``` The schema for `invoke_quant` is `torch.ops.higher_order.invoke_quant(subgraph, *args, scheme=None)` where the scheme will not always be present. I wasn't sure exactly how the inductor specific configurations like `codgen_in_low_precision` should be passed through. I didnt want to stuff them all in as kwargs, and I didn't want to have them affect pattern matching. So they will be stored as meta of the node itself. And, following that, I wanted the invocation of the hop to match how it will show up in the graph. So I decided to have it be an object that is then invoked for the tracing. ``` invoke_quant = InvokeQuant(codegen_low_precision=True) invoke_quant(gn, (x, y), scheme="nf4") ``` Todo - not require the packing of args in a tuple, will do following #139162. Feedback welcome. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov rec [ghstack-poisoned]
Adds a `invoke_quant` higher order operator as proposed [here](https://docs.google.com/document/d/1s2PfJlq6Q1F8l11CkTIC69BW1rEnGEgs6YmBC7hu8rA/edit?tab=t.0). The primary motivations are - Unifying scattered reasoning for quant operators throughout the code base - Easy of pattern matching - see this very large pattern match expression [here](https://github.com/pytorch/pytorch/blob/949fdd299764d4fbefe1db093717786d946aaa60/torch/_inductor/fx_passes/post_grad.py#L390-L426. Compared to the pattern I have in the tests: ``` register_graph_pattern( CallFunction( torch.ops.aten.mm, CallFunction( torch.ops.higher_order.invoke_quant, Ignored(), Ignored(), Ignored(), scheme="nf4", ), Arg(), ), pass_dict=test_pass, ) ``` - Ability to specify inductor specific logic, like codegen'ing the operators in lower precision, or forcing fusion to a matmul. Example graph: ``` Python ===== AFTER POST GRAD ===== /data/users/eellison/pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module): def forward(self, arg0_1: "f32[8][1]cpu", arg1_1: "f32[8][1]cpu"): # File: /data/users/eellison/pytorch/torch/_higher_order_ops/invoke_quant.py:87 in __call__, code: return invoke_quant_tracer(*args, **kwargs, quant_options=self) # type: ignore[call-arg] repeated_subgraph0 = self.repeated_subgraph0 invoke_quant: "f32[8][1]cpu" = torch.ops.higher_order.invoke_quant(repeated_subgraph0, arg0_1, arg1_1, scheme = 'nf4'); repeated_subgraph0 = arg0_1 = arg1_1 = None return (invoke_quant,) class repeated_subgraph0(torch.nn.Module): def forward(self, arg0_1: "f32[8][1]cpu", arg1_1: "f32[8][1]cpu"): # File: /data/users/eellison/pytorch/torch/_higher_order_ops/invoke_quant.py:87 in __call__, code: return invoke_quant_tracer(*args, **kwargs, quant_options=self) # type: ignore[call-arg] mul: "f32[8][1]cpu" = torch.ops.aten.mul.Tensor(arg0_1, arg1_1); arg0_1 = None add: "f32[8][1]cpu" = torch.ops.aten.add.Tensor(mul, arg1_1); mul = arg1_1 = None return add ``` The schema for `invoke_quant` is `torch.ops.higher_order.invoke_quant(subgraph, *args, scheme=None)` where the scheme will not always be present. I wasn't sure exactly how the inductor specific configurations like `codgen_in_low_precision` should be passed through. I didnt want to stuff them all in as kwargs, and I didn't want to have them affect pattern matching. So they will be stored as meta of the node itself. And, following that, I wanted the invocation of the hop to match how it will show up in the graph. So I decided to have it be an object that is then invoked for the tracing. ``` invoke_quant = InvokeQuant(codegen_low_precision=True) invoke_quant(gn, (x, y), scheme="nf4") ``` Todo - not require the packing of args in a tuple, will do following #139162. Feedback welcome. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov rec [ghstack-poisoned]
Adds a `invoke_quant` higher order operator as proposed [here](https://docs.google.com/document/d/1s2PfJlq6Q1F8l11CkTIC69BW1rEnGEgs6YmBC7hu8rA/edit?tab=t.0). The primary motivations are - Unifying scattered reasoning for quant operators throughout the code base - Easy of pattern matching - see this very large pattern match expression [here](https://github.com/pytorch/pytorch/blob/949fdd299764d4fbefe1db093717786d946aaa60/torch/_inductor/fx_passes/post_grad.py#L390-L426. Compared to the pattern I have in the tests: ``` register_graph_pattern( CallFunction( torch.ops.aten.mm, CallFunction( torch.ops.higher_order.invoke_quant, Ignored(), Ignored(), Ignored(), scheme="nf4", ), Arg(), ), pass_dict=test_pass, ) ``` - Ability to specify inductor specific logic, like codegen'ing the operators in lower precision, or forcing fusion to a matmul. Example graph: ``` Python ===== AFTER POST GRAD ===== /data/users/eellison/pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module): def forward(self, arg0_1: "f32[8][1]cpu", arg1_1: "f32[8][1]cpu"): # File: /data/users/eellison/pytorch/torch/_higher_order_ops/invoke_quant.py:87 in __call__, code: return invoke_quant_tracer(*args, **kwargs, quant_options=self) # type: ignore[call-arg] repeated_subgraph0 = self.repeated_subgraph0 invoke_quant: "f32[8][1]cpu" = torch.ops.higher_order.invoke_quant(repeated_subgraph0, arg0_1, arg1_1, scheme = 'nf4'); repeated_subgraph0 = arg0_1 = arg1_1 = None return (invoke_quant,) class repeated_subgraph0(torch.nn.Module): def forward(self, arg0_1: "f32[8][1]cpu", arg1_1: "f32[8][1]cpu"): # File: /data/users/eellison/pytorch/torch/_higher_order_ops/invoke_quant.py:87 in __call__, code: return invoke_quant_tracer(*args, **kwargs, quant_options=self) # type: ignore[call-arg] mul: "f32[8][1]cpu" = torch.ops.aten.mul.Tensor(arg0_1, arg1_1); arg0_1 = None add: "f32[8][1]cpu" = torch.ops.aten.add.Tensor(mul, arg1_1); mul = arg1_1 = None return add ``` The schema for `invoke_quant` is `torch.ops.higher_order.invoke_quant(subgraph, *args, scheme=None)` where the scheme will not always be present. I wasn't sure exactly how the inductor specific configurations like `codgen_in_low_precision` should be passed through. I didnt want to stuff them all in as kwargs, and I didn't want to have them affect pattern matching. So they will be stored as meta of the node itself. And, following that, I wanted the invocation of the hop to match how it will show up in the graph. So I decided to have it be an object that is then invoked for the tracing. ``` invoke_quant = InvokeQuant(codegen_low_precision=True) invoke_quant(gn, (x, y), scheme="nf4") ``` Todo - not require the packing of args in a tuple, will do following #139162. Feedback welcome. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov rec [ghstack-poisoned]
Adds a `invoke_quant` higher order operator as proposed [here](https://docs.google.com/document/d/1s2PfJlq6Q1F8l11CkTIC69BW1rEnGEgs6YmBC7hu8rA/edit?tab=t.0). The primary motivations are - Unifying scattered reasoning for quant operators throughout the code base - Easy of pattern matching - see this very large pattern match expression [here](https://github.com/pytorch/pytorch/blob/949fdd299764d4fbefe1db093717786d946aaa60/torch/_inductor/fx_passes/post_grad.py#L390-L426. Compared to the pattern I have in the tests: ``` register_graph_pattern( CallFunction( torch.ops.aten.mm, CallFunction( torch.ops.higher_order.invoke_quant, Ignored(), Ignored(), Ignored(), scheme="nf4", ), Arg(), ), pass_dict=test_pass, ) ``` - Ability to specify inductor specific logic, like codegen'ing the operators in lower precision, or forcing fusion to a matmul. Example graph: ``` Python ===== AFTER POST GRAD ===== /data/users/eellison/pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module): def forward(self, arg0_1: "f32[8][1]cpu", arg1_1: "f32[8][1]cpu"): # File: /data/users/eellison/pytorch/torch/_higher_order_ops/invoke_quant.py:87 in __call__, code: return invoke_quant_tracer(*args, **kwargs, quant_options=self) # type: ignore[call-arg] repeated_subgraph0 = self.repeated_subgraph0 invoke_quant: "f32[8][1]cpu" = torch.ops.higher_order.invoke_quant(repeated_subgraph0, arg0_1, arg1_1, scheme = 'nf4'); repeated_subgraph0 = arg0_1 = arg1_1 = None return (invoke_quant,) class repeated_subgraph0(torch.nn.Module): def forward(self, arg0_1: "f32[8][1]cpu", arg1_1: "f32[8][1]cpu"): # File: /data/users/eellison/pytorch/torch/_higher_order_ops/invoke_quant.py:87 in __call__, code: return invoke_quant_tracer(*args, **kwargs, quant_options=self) # type: ignore[call-arg] mul: "f32[8][1]cpu" = torch.ops.aten.mul.Tensor(arg0_1, arg1_1); arg0_1 = None add: "f32[8][1]cpu" = torch.ops.aten.add.Tensor(mul, arg1_1); mul = arg1_1 = None return add ``` The schema for `invoke_quant` is `torch.ops.higher_order.invoke_quant(subgraph, *args, scheme=None)` where the scheme will not always be present. I wasn't sure exactly how the inductor specific configurations like `codgen_in_low_precision` should be passed through. I didnt want to stuff them all in as kwargs, and I didn't want to have them affect pattern matching. So they will be stored as meta of the node itself. And, following that, I wanted the invocation of the hop to match how it will show up in the graph. So I decided to have it be an object that is then invoked for the tracing. ``` invoke_quant = InvokeQuant(codegen_low_precision=True) invoke_quant(gn, (x, y), scheme="nf4") ``` Todo - not require the packing of args in a tuple, will do following #139162. Feedback welcome. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov rec [ghstack-poisoned]
Adds a `invoke_quant` higher order operator as proposed [here](https://docs.google.com/document/d/1s2PfJlq6Q1F8l11CkTIC69BW1rEnGEgs6YmBC7hu8rA/edit?tab=t.0). The primary motivations are - Unifying scattered reasoning for quant operators throughout the code base - Easy of pattern matching - see this very large pattern match expression [here](https://github.com/pytorch/pytorch/blob/949fdd299764d4fbefe1db093717786d946aaa60/torch/_inductor/fx_passes/post_grad.py#L390-L426. Compared to the pattern I have in the tests: ``` register_graph_pattern( CallFunction( torch.ops.aten.mm, CallFunction( torch.ops.higher_order.invoke_quant, Ignored(), Ignored(), Ignored(), scheme="nf4", ), Arg(), ), pass_dict=test_pass, ) ``` - Ability to specify inductor specific logic, like codegen'ing the operators in lower precision, or forcing fusion to a matmul. Example graph: ``` Python ===== AFTER POST GRAD ===== /data/users/eellison/pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module): def forward(self, arg0_1: "f32[8][1]cpu", arg1_1: "f32[8][1]cpu"): # File: /data/users/eellison/pytorch/torch/_higher_order_ops/invoke_quant.py:87 in __call__, code: return invoke_quant_tracer(*args, **kwargs, quant_options=self) # type: ignore[call-arg] repeated_subgraph0 = self.repeated_subgraph0 invoke_quant: "f32[8][1]cpu" = torch.ops.higher_order.invoke_quant(repeated_subgraph0, arg0_1, arg1_1, scheme = 'nf4'); repeated_subgraph0 = arg0_1 = arg1_1 = None return (invoke_quant,) class repeated_subgraph0(torch.nn.Module): def forward(self, arg0_1: "f32[8][1]cpu", arg1_1: "f32[8][1]cpu"): # File: /data/users/eellison/pytorch/torch/_higher_order_ops/invoke_quant.py:87 in __call__, code: return invoke_quant_tracer(*args, **kwargs, quant_options=self) # type: ignore[call-arg] mul: "f32[8][1]cpu" = torch.ops.aten.mul.Tensor(arg0_1, arg1_1); arg0_1 = None add: "f32[8][1]cpu" = torch.ops.aten.add.Tensor(mul, arg1_1); mul = arg1_1 = None return add ``` The schema for `invoke_quant` is `torch.ops.higher_order.invoke_quant(subgraph, *args, scheme=None)` where the scheme will not always be present. I wasn't sure exactly how the inductor specific configurations like `codgen_in_low_precision` should be passed through. I didnt want to stuff them all in as kwargs, and I didn't want to have them affect pattern matching. So they will be stored as meta of the node itself. And, following that, I wanted the invocation of the hop to match how it will show up in the graph. So I decided to have it be an object that is then invoked for the tracing. ``` invoke_quant = InvokeQuant(codegen_low_precision=True) invoke_quant(gn, (x, y), scheme="nf4") ``` Todo - not require the packing of args in a tuple, will do following #139162. Feedback welcome. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov rec [ghstack-poisoned]
Adds a `invoke_quant` higher order operator as proposed [here](https://docs.google.com/document/d/1s2PfJlq6Q1F8l11CkTIC69BW1rEnGEgs6YmBC7hu8rA/edit?tab=t.0). The primary motivations are - Unifying scattered reasoning for quant operators throughout the code base - Easy of pattern matching - see this very large pattern match expression [here](https://github.com/pytorch/pytorch/blob/949fdd299764d4fbefe1db093717786d946aaa60/torch/_inductor/fx_passes/post_grad.py#L390-L426. Compared to the pattern I have in the tests: ``` @register_graph_pattern( CallFunction( torch.ops.aten.mm, CallFunction( torch.ops.higher_order.invoke_quant, Ignored(), Ignored(), Ignored(), scheme="nf4", ), Arg(), ), pass_dict=test_pass, ) ``` - Ability to specify inductor specific logic, like codegen'ing the operators in lower precision, or forcing fusion to a matmul. Example graph: ``` Python ===== AFTER POST GRAD ===== /data/users/eellison/pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module): def forward(self, arg0_1: "f32[8][1]cpu", arg1_1: "f32[8][1]cpu"): # File: /data/users/eellison/pytorch/torch/_higher_order_ops/invoke_quant.py:87 in __call__, code: return invoke_quant_tracer(*args, **kwargs, quant_options=self) # type: ignore[call-arg] repeated_subgraph0 = self.repeated_subgraph0 invoke_quant: "f32[8][1]cpu" = torch.ops.higher_order.invoke_quant(repeated_subgraph0, arg0_1, arg1_1, scheme = 'nf4'); repeated_subgraph0 = arg0_1 = arg1_1 = None return (invoke_quant,) class repeated_subgraph0(torch.nn.Module): def forward(self, arg0_1: "f32[8][1]cpu", arg1_1: "f32[8][1]cpu"): # File: /data/users/eellison/pytorch/torch/_higher_order_ops/invoke_quant.py:87 in __call__, code: return invoke_quant_tracer(*args, **kwargs, quant_options=self) # type: ignore[call-arg] mul: "f32[8][1]cpu" = torch.ops.aten.mul.Tensor(arg0_1, arg1_1); arg0_1 = None add: "f32[8][1]cpu" = torch.ops.aten.add.Tensor(mul, arg1_1); mul = arg1_1 = None return add ``` The schema for `invoke_quant` is `torch.ops.higher_order.invoke_quant(subgraph, *args, scheme=None)` where the scheme will not always be present. I wasn't sure exactly how the inductor specific configurations like `codgen_in_low_precision` should be passed through. I didnt want to stuff them all in as kwargs, and I didn't want to have them affect pattern matching. So they will be stored as meta of the node itself. And, following that, I wanted the invocation of the hop to match how it will show up in the graph. So I decided to have it be an object that is then invoked for the tracing. ``` invoke_quant = InvokeQuant(codegen_low_precision=True) invoke_quant(gn, (x, y), scheme="nf4") ``` Todo - not require the packing of args in a tuple, will do following #139162. Feedback welcome. Pull Request resolved: #139102 Approved by: https://github.com/Chillee
Stack from ghstack (oldest at bottom):
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @rec