Skip to content

[Inductor changes] Invoke Quant#139102

Closed
eellison wants to merge 17 commits intogh/eellison/711/basefrom
gh/eellison/711/head
Closed

[Inductor changes] Invoke Quant#139102
eellison wants to merge 17 commits intogh/eellison/711/basefrom
gh/eellison/711/head

Conversation

@eellison
Copy link
Contributor

@eellison eellison commented Oct 28, 2024

Stack from ghstack (oldest at bottom):

Adds a invoke_quant higher order operator as proposed here.

The primary motivations are

  • Unifying scattered reasoning for quant operators throughout the code base

  • Easy of pattern matching - see this very large pattern match expression [here](

    @register_lowering_pattern(
    CallFunction(
    aten.mm.default,
    KeywordArg("mat1"),
    CallFunction(
    aten.sub.Tensor,
    CallFunction(
    prims.convert_element_type.default,
    CallFunction(
    aten.reshape.default,
    CallFunction(
    aten.cat.default,
    ListOf(
    CallFunction(
    aten.bitwise_and.Scalar,
    KeywordArg("mat2"),
    0xF,
    ),
    # CallFunction(
    # aten.__rshift__.Scalar,
    # KeywordArg("mat2"),
    # 4,
    # ),
    True,
    ),
    1,
    ),
    KeywordArg("mat2_mm_shape"),
    ),
    KeywordArg("mat2_dtype"),
    ),
    8,
    ),
    ),
    extra_check=cuda_and_enabled_mixed_mm_and_not_int8,
    )
    def uint4x2_mixed_mm(match: Match, mat1, mat2, mat2_mm_shape, mat2_dtype):
    . Compared to the pattern I have in the tests:

        @register_graph_pattern(
            CallFunction(
                torch.ops.aten.mm,
                CallFunction(
                    torch.ops.higher_order.invoke_quant,
                    Ignored(),
                    Ignored(),
                    Ignored(),
                    scheme="nf4",
                ),
                Arg(),
            ),
            pass_dict=test_pass,
        )
  • Ability to specify inductor specific logic, like codegen'ing the operators in lower precision, or forcing fusion to a matmul.

Example graph:

 ===== AFTER POST GRAD =====
 /data/users/eellison/pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module):
    def forward(self, arg0_1: "f32[8][1]cpu", arg1_1: "f32[8][1]cpu"):
         # File: /data/users/eellison/pytorch/torch/_higher_order_ops/invoke_quant.py:87 in __call__, code: return invoke_quant_tracer(*args, **kwargs, quant_options=self)  # type: ignore[call-arg]
        repeated_subgraph0 = self.repeated_subgraph0
        invoke_quant: "f32[8][1]cpu" = torch.ops.higher_order.invoke_quant(repeated_subgraph0, arg0_1, arg1_1, scheme = 'nf4');  repeated_subgraph0 = arg0_1 = arg1_1 = None
        return (invoke_quant,)
        
    class repeated_subgraph0(torch.nn.Module):
        def forward(self, arg0_1: "f32[8][1]cpu", arg1_1: "f32[8][1]cpu"):
             # File: /data/users/eellison/pytorch/torch/_higher_order_ops/invoke_quant.py:87 in __call__, code: return invoke_quant_tracer(*args, **kwargs, quant_options=self)  # type: ignore[call-arg]
            mul: "f32[8][1]cpu" = torch.ops.aten.mul.Tensor(arg0_1, arg1_1);  arg0_1 = None
            add: "f32[8][1]cpu" = torch.ops.aten.add.Tensor(mul, arg1_1);  mul = arg1_1 = None
            return add

The schema for invoke_quant is torch.ops.higher_order.invoke_quant(subgraph, *args, scheme=None) where the scheme will not always be present.

I wasn't sure exactly how the inductor specific configurations like codgen_in_low_precision should be passed through. I didnt want to stuff them all in as kwargs, and I didn't want to have them affect pattern matching. So they will be stored as meta of the node itself. And, following that, I wanted the invocation of the hop to match how it will show up in the graph. So I decided to have it be an object that is then invoked for the tracing.

invoke_quant = InvokeQuant(codegen_low_precision=True)
invoke_quant(gn, (x, y), scheme="nf4") 

Todo - not require the packing of args in a tuple, will do following #139162.

Feedback welcome.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov @rec

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 28, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/139102

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 5 Pending

As of commit 0353760 with merge base 49082f9 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@eellison eellison changed the title [Inductor changes] add invoke quant [Inductor changes] Invoke Quant Oct 28, 2024
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov

[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov rec

[ghstack-poisoned]
eellison added a commit that referenced this pull request Oct 28, 2024
ghstack-source-id: 0632fb5
Pull Request resolved: #139102
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov rec

[ghstack-poisoned]
eellison added a commit that referenced this pull request Oct 29, 2024
ghstack-source-id: 2685546
Pull Request resolved: #139102
@dataclasses.dataclass(frozen=True)
class InvokeQuant:
"""
Invoke a quantization function that will be preserved as a single operator. Preservation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you give some examples of quantization function here? are you referring exclusively to dequantize ops? or does it mean any quantization related functions like quantize_affine or quantized kernel as well?

I'm wondering if the higher order op has to mention quant in the name or it can be more general

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uh, maybe @drisspg can help with some fp8 prologue scaling functions we want fused. The uint4x2_mixed_mm added here is another example that @HDCharles added.

Yea, it does not specifically have to be quant. The general thing here is about :

a) preservation as a top level op and scheme tagging for special cased lowerings/pattern matching
b) specific inductor behaviors.

Some of the patterns seem pretty specific to quant/dequant.

Specifically I was envisioning:

  • codegen_low_precision
  • forcing fusion to mm (both as prologue and epilogue) / autotuning when not max-autotune

maybe there others that come up, not sure.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, if it's not specific to quantization, would it be more descriptive to use a different name that doesn't contain "quant" in it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would a different name be ? And what are the other use cases you're envisioning ?

Copy link
Contributor

@jerryzh168 jerryzh168 Oct 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have new use cases.

For naming, I just feel mentioning "quant" in a higher order op is a bit weird, if this is the best name we have now that's fine too, something for consideration:
invoke_undecomposed_op / invoke_high_level_op

remove_redundant_views(gm)


def canonicalize_quant_mapping(gm: torch.fx.GraphModule):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels weird

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am going to update this - to reviewers - let's skip this part because going to revise it in base commit. I am more looking for feedback on the API to users (test files) and API to developers (after this has occurred)

@drisspg
Copy link
Contributor

drisspg commented Oct 31, 2024

it would be nice if we didnt have to hardcode the set of inductor configs on the class since there might be others that effect the subgraph (you are the expert here, but there could be new ones).

Also I am curious what the expected flow is for a user registering their own dequant scheme?

Copy link
Collaborator

@Chillee Chillee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this API makes sense to me. I do think force_fuse_mm should probably just be "force fuse", but minor nits.

I also think we'll need a story on how to register dequant schemes for out-of-tree schemes.

max-autotune enabled.
"""

codegen_low_precision: bool = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These don't do anything yet, right?

[ghstack-poisoned]
eellison added a commit that referenced this pull request Dec 10, 2024
ghstack-source-id: f716efd
Pull Request resolved: #139102
[ghstack-poisoned]
Adds a `invoke_quant` higher order operator as proposed [here](https://docs.google.com/document/d/1s2PfJlq6Q1F8l11CkTIC69BW1rEnGEgs6YmBC7hu8rA/edit?tab=t.0).

The primary motivations are 

- Unifying scattered reasoning for quant operators throughout the code base

- Easy of pattern matching - see this very large pattern match expression [here](https://github.com/pytorch/pytorch/blob/949fdd299764d4fbefe1db093717786d946aaa60/torch/_inductor/fx_passes/post_grad.py#L390-L426. Compared to the pattern I have in the tests:

```
        register_graph_pattern(
            CallFunction(
                torch.ops.aten.mm,
                CallFunction(
                    torch.ops.higher_order.invoke_quant,
                    Ignored(),
                    Ignored(),
                    Ignored(),
                    scheme="nf4",
                ),
                Arg(),
            ),
            pass_dict=test_pass,
        )
```

- Ability to specify inductor specific logic, like codegen'ing the operators in lower precision, or forcing fusion to a matmul.


Example graph:

``` Python
 ===== AFTER POST GRAD =====
 /data/users/eellison/pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module):
    def forward(self, arg0_1: "f32[8][1]cpu", arg1_1: "f32[8][1]cpu"):
         # File: /data/users/eellison/pytorch/torch/_higher_order_ops/invoke_quant.py:87 in __call__, code: return invoke_quant_tracer(*args, **kwargs, quant_options=self)  # type: ignore[call-arg]
        repeated_subgraph0 = self.repeated_subgraph0
        invoke_quant: "f32[8][1]cpu" = torch.ops.higher_order.invoke_quant(repeated_subgraph0, arg0_1, arg1_1, scheme = 'nf4');  repeated_subgraph0 = arg0_1 = arg1_1 = None
        return (invoke_quant,)
        
    class repeated_subgraph0(torch.nn.Module):
        def forward(self, arg0_1: "f32[8][1]cpu", arg1_1: "f32[8][1]cpu"):
             # File: /data/users/eellison/pytorch/torch/_higher_order_ops/invoke_quant.py:87 in __call__, code: return invoke_quant_tracer(*args, **kwargs, quant_options=self)  # type: ignore[call-arg]
            mul: "f32[8][1]cpu" = torch.ops.aten.mul.Tensor(arg0_1, arg1_1);  arg0_1 = None
            add: "f32[8][1]cpu" = torch.ops.aten.add.Tensor(mul, arg1_1);  mul = arg1_1 = None
            return add
```

The schema for `invoke_quant` is `torch.ops.higher_order.invoke_quant(subgraph, *args, scheme=None)` where the scheme will not always be present. 

I wasn't sure exactly how the inductor specific configurations like `codgen_in_low_precision` should be passed through. I didnt want to stuff them all in as kwargs, and I didn't want to have them affect pattern matching. So they will be stored as meta of the node itself. And, following that, I wanted the invocation of the hop to match how it will show up in the graph. So I decided to have it be an object that is then invoked for the tracing.

```
invoke_quant = InvokeQuant(codegen_low_precision=True)
invoke_quant(gn, (x, y), scheme="nf4") 
```
Todo - not require the packing of args in a tuple, will do following #139162. 

Feedback welcome.



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov rec

[ghstack-poisoned]
@eellison
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

[ghstack-poisoned]
[ghstack-poisoned]
@eellison
Copy link
Contributor Author

eellison commented Feb 7, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: inductor / unit-test / linux-jammy-cpu-py3.9-gcc11-inductor / test (inductor_amx, 2, 2, linux.8xlarge.amx)

Details for Dev Infra team Raised by workflow job

[ghstack-poisoned]
@eellison
Copy link
Contributor Author

eellison commented Feb 8, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command
For more information see pytorch-bot wiki.

@eellison
Copy link
Contributor Author

eellison commented Feb 8, 2025

@pytorchbot merge -f "rocm hanging"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Feb 10, 2025
…ly) more aggressive fusion (#145104)

Respect invoke_quant low precision options, also, be more aggressive in attepmting fusion.

Pull Request resolved: #145104
Approved by: https://github.com/shunting314, https://github.com/jansel
ghstack dependencies: #139102
@github-actions github-actions bot deleted the gh/eellison/711/head branch March 11, 2025 02:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants