Skip to content

Add pruning-aware training in torchao.prototype.pat#3429

Merged
lisjin merged 3 commits into
mainfrom
lvj/pat
Feb 23, 2026
Merged

Add pruning-aware training in torchao.prototype.pat#3429
lisjin merged 3 commits into
mainfrom
lvj/pat

Conversation

@lisjin

@lisjin lisjin commented Dec 3, 2025

Copy link
Copy Markdown
Contributor

Adding our pruning-aware training (PAT) library as a prototype. The original library is under fairinternal/qpat but we would like to surface it in torchao for broader adoption.

The interface is almost identical to torchao.prototype.parq, but we use (group) Lasso instead of piecewise-affine regularization. More details on code organization and usage can be found in the README.

@lisjin lisjin requested a review from andrewor14 December 3, 2025 21:54
@pytorch-bot

pytorch-bot Bot commented Dec 3, 2025

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3429

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 1c753dc with merge base d988122 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 3, 2025
@lisjin lisjin added the topic: new feature Use this tag if this PR adds a new feature label Dec 3, 2025
@lisjin lisjin force-pushed the lvj/pat branch 2 times, most recently from ffa338e to 4f78b65 Compare December 8, 2025 14:08
@lisjin

lisjin commented Dec 8, 2025

Copy link
Copy Markdown
Contributor Author

@andrewor14 Let me know if anything needs to be cleared up in this diff. I'm hoping to update D88501706 so that it imports from torchao.prototype.pat instead of copying code.

@meta-codesync

meta-codesync Bot commented Dec 8, 2025

Copy link
Copy Markdown

@lisjin has imported this pull request. If you are a Meta employee, you can view this in D88638093.

a base optimizer (e.g., SGD or AdamW)
- update the latent variables for QAT
Other parameters:
warmup_steps: int >= 0

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the central API right, can we add an example usage in this docstring?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call—I updated the README example to include keyword args like warmup_steps and reg_lambda

return out


class MaskedLayerNorm(nn.LayerNorm):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like this is not used anywhere other than in tests. Can we delete this? Am I missing something?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm hoping to keep this class since it's important for converting pruned models to their compressed inference-ready forms. This functionality can be added to PAT in the future.

Comment thread torchao/prototype/pat/utils/__init__.py Outdated
Comment thread torchao/prototype/pat/utils/distributed.py Outdated
Comment thread torchao/prototype/pat/optim/pruneopt.py Outdated
Comment thread torchao/prototype/pat/optim/nm_sgd.py Outdated
from .pruneopt import PruneOptimizer


class NMSGDOptimizer(PruneOptimizer):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

General question: I notice a lot of APIs in this PR that are not used or referenced anywhere. Are these all user-facing APIs? If so can we document them somewhere (e.g. main README) and explain how they're related to the main PruneOptimizer API? If they're not user-facing APIs and they're not used, do we still need them?

Some examples:

  • NMSGDOptimizer
  • ProxNuclearNorm
  • all the groupers like QKSVDGrouper

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • The NMSGDOptimizer was written by a summer intern last year and has shown promising results. Since it's an experimental feature, we don't have unit tests for it yet.
  • ProxNuclearNorm is important for applying low-rank pruning to embeddings. Here's an example config.
  • The other groupers are more experimental. It would be great to keep them around so that we can stay in sync with the original repo, but I can also remove them if you'd like.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, we can keep them if we document them somewhere. If they're experimental we can mark them as such in the README. In general public APIs should have associated documentation somewhere, otherwise users won't be able to find them

@andrewor14

Copy link
Copy Markdown
Contributor

Hi @lisjin looks good overall. My main comment is just my confusion about how the APIs are used, seems like the code snippet in the main README only references 1 or 2 of these, so it's unclear to me how the rest are related. Would be great if you can clarify this in documentation.

Separately do you have any initial results? If so, would be great to include these in the README too.

@lisjin lisjin force-pushed the lvj/pat branch 3 times, most recently from 73a572c to 71c2270 Compare February 12, 2026 15:58
@lisjin

lisjin commented Feb 12, 2026

Copy link
Copy Markdown
Contributor Author

@andrewor14 Thanks for taking the time to review this back in Dec! I found out in January that the team I was collaborating with no longer needed to use PAT in torchao. However, now @Ninja91 and his team are planning to experiment with PAT. Could you please check that my fixes addressed all your comments? I've also added some initial results on unstructured pruning to the README.

Comment thread torchao/prototype/pat/distributed_utils.py Outdated
Comment thread torchao/prototype/pat/distributed_utils.py Outdated
Comment thread torchao/prototype/pat/README.md Outdated
{
"params": weights",
"group_type": "pat.group.Dim0Grouper",
"prox_type": "pat.prox.ProxGroupLasso",

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should these take in actual classes instead of strings of classes? Seems like it'll be more robust

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah this usage is actually outdated. I updated it a while back to accept strings like "Dim0Grouper" and "ProxGroupLasso" so that there's no dependency on import structure. The README is fixed to reflect this.

Comment thread torchao/prototype/pat/__init__.py
Comment thread torchao/prototype/pat/optim/nm_sgd.py Outdated
from .pruneopt import PruneOptimizer


class NMSGDOptimizer(PruneOptimizer):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, we can keep them if we document them somewhere. If they're experimental we can mark them as such in the README. In general public APIs should have associated documentation somewhere, otherwise users won't be able to find them

@lisjin

lisjin commented Feb 13, 2026

Copy link
Copy Markdown
Contributor Author

@andrewor14 Thanks for the suggestions again. Here's what I've updated in the latest commit:

  • Removed experimental classes like NMSGDOptimizer, QKGrouper, QKSVDGrouper
  • Documented all remaining grouper and proximal mapping classes in a new table of the README
  • Added underscores to non user-facing methods in distributed_utils.py

Let me know if anything's missing—this is very much a research prototype :)

@andrewor14 andrewor14 left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks!

@lisjin lisjin enabled auto-merge (squash) February 23, 2026 14:44
@lisjin lisjin merged commit 2a37912 into main Feb 23, 2026
21 of 22 checks passed
@lisjin lisjin deleted the lvj/pat branch February 23, 2026 15:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: new feature Use this tag if this PR adds a new feature

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants