Add Patch Representation Refinement module#2678
Add Patch Representation Refinement module#2678sinahmr wants to merge 1 commit intohuggingface:mainfrom
Conversation
|
@sinahmr this is fairly redundant when the attention pool is already there as an option isn't it? |
|
Thanks for the feedback @rwightman. Conceptually, PRR is not intended to improve the pooled [CLS] representation itself. Its purpose is to refine the final-layer patch tokens by allowing the classification signal to propagate to spatial tokens in a more diverse, content-dependent way during pretraining. We discuss this in the paper in comparison to GAP as well. While GAP does route supervision to patch tokens, unlike when On the implementation side, I agree that the current integration may not be ideal. If it would be better represented as part of the pooling path, I can rework it that way, for example by treating it as an option for If the concern is that this is too paper-specific to justify a dedicated switch in |
|
Changes in #2685 look good to me, thank you for the implementation! |
This branch adds the Patch Representation Refinement (PRR) module from Locality-Attending Vision Transformer (ICLR 2026), paper, code.
PRR is a parameter-free multi-head self-attention applied before the classification head. In standard ViT, only the [CLS] token receives direct supervision from the classification loss, leaving patch representations at the final layer under-optimized for dense prediction. PRR addresses this by aggregating information from all positions non-uniformly, ensuring diverse gradient flow to spatial tokens.
Changes
timm/layers/prr.py(new): PRR module with support for both fused (scaled_dot_product_attention) and manual attention paths.timm/layers/__init__.py: Export PRR.timm/models/vision_transformer.py: Addprrparameter toVisionTransformer. When enabled, PRR is applied inforward_headbefore pooling. Defaults to off, so no behavioral change for existing models.