Skip to content

DAGroup-PKU/MHLA

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

16 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

MHLA_logo

MHLA: Restoring Expressivity of Linear Attention via Token-Level Multi-Head

Kewei Zhang1*, Ye Huang1*, Yufan Deng1, Jincheng Yu2, Junsong Chen2,
Huan Ling2, Enze Xie2, Daquan Zhou1

1Peking University   2NVIDIA

MHLA Preview

MHLA Overview

MHLA is a universal high-efficiency linear attention operator. MHLA can be applied to image classification, image generation, language modeling, and video generation tasks, maintaining performance consistent with Flash Attention while achieving significant speed advantages over Flash Attention under long-sequence conditions. For more details, please refer to our paper.

This repository is organized into four sub-projects: mhla_dit, mhla_image_classification, mhla_nlp, and mhla_videogen. Each corresponds to the experimental code for the four tasks presented in our paper. Each sub-project contains its own README.md with detailed instructions.

🔥 News

  • [2026.01.12] 🔥 Our paper is available at arxiv.
  • [2026.01.12] 🔥 We release the code of MHLA, including training and inference code for image classification, image generation, language modeling, and video generation.

🎥 Demo

Please note that the following video is a compressed version. You can view the full HD demo by visiting this link.

MHLA_demo.mp4

Todo List

  • Release code of MHLA on Video Generation
  • Release code of MHLA on DiT
  • Release code of MHLA on NLP
  • Release code of MHLA on ImageNet classification
  • Release code of MHLA on Sana
  • Release pretrained weights of Wan-MHLA
  • Release pretrained weights of DiT-MHLA
  • Release pretrained weights of Sana-MHLA
  • Release pretrained weights of Image Classifcation models with MHLA
  • Release pretrained weights of language models with MHLA

Installation & Usage

git clone -b main --single-branch https://github.com/DAGroup-PKU/MHLA

Please refer to the README.md files in the following sub-projects for detailed information:

Performance & Efficiency

On Wan2.1-1.3B

Method Quality score Semantic score Total Latency
Wan2.1 1.3B 85.23 75.65 83.31 139s
Full MHLA 83.93 78.40 82.83 62s
Full Linear 69.96 11.38 58.24 62s
MHLA Hybrid 2/3 84.87 79.59 83.82 84s

Wan-MHLA and Wan-LA replace all layers with MHLA and Linear Attention respectively. Wan-MHLA-H only replace 2/3 layers.

Excellent Convergence

Training loss comparison.

As shown in the figure, linear attention fails to converge during training in ultra-long sequence scenarios for video generation, while MHLA demonstrates excellent convergence.

Acknowledgement

Our project is built on multiple inspiring projects including: timm, DiT, Sana and flash-linear-attention.

Support Us

If you find this work useful, please consider:

  • Starring the repository
  • Citing our paper
  • Contributing to the codebase

Citation

@misc{mhla,
      title={MHLA: Restoring Expressivity of Linear Attention via Token-Level Multi-Head}, 
      author={Kewei Zhang and Ye Huang and Yufan Deng and Jincheng Yu and Junsong Chen and Huan Ling and Enze Xie and Daquan Zhou},
      year={2026},
      eprint={2601.07832},
      archivePrefix={arXiv},
      primaryClass={cs.CV},
      url={https://arxiv.org/abs/2601.07832}, 
}

About

MHLA: Restoring Expressivity of Linear Attention via Token-Level Multi-Head

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 3

  •  
  •  
  •