Skip to content

autumn9999/GO4Align

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

11 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

GO4Align

Welcome to the official repository for "GO4Align: Group Optimization for Multi-Task Alignment," one effective and efficient approach to multi-task optimization.

Project Webpage

Details will be available soon.

Abstract

This paper proposes GO4Align, a multi-task optimization approach that tackles task imbalance by explicitly aligning the optimization across tasks. To achieve this, we design an adaptive group risk minimization strategy, compromising two crucial techniques in implementation:

  • dynamical group assignment, which clusters similar tasks based on task interactions;
  • risk-guided group indicators, which exploit consistent task correlations with risk information from previous iterations.

Comprehensive experimental results on diverse typical benchmarks demonstrate our method's performance superiority with even lower computational costs.

Paper

The preprint of our paper is available on arXiv.

Framework of Adaptive Group Risk Minimization


Setup Environment

We recommend using miniconda to create a virtual environment for running the code:

conda create -n go4align python=3.9.7
conda activate go4align 
conda install pytorch==1.13.1 torchvision==0.14.1 cudatoolkit=12.3 -c pytorch
conda install pyg -c pyg -c conda-forge

Install the package by running the following commands in the terminal:

git clone https://github.com/autumn9999/GO4Align.git
cd GO4Align
pip install -e .

GPU: NVIDIA A100-SXM4-40GB

Download Datasets

This work is evaluated on several multi-task learning benchmarks:

  1. NYUv2 (3 tasks), where the link is provided by the previous MTO work CAGrad.
  2. CityScapes (2 tasks), where the link is provided by CAGrad.
  3. CelebA (40 tasks). Details can be found in the previous MTO work FAMO.
  4. QM9 (11 tasks), which can be downloaded automatically by torch_geometric.datasets. Details can be found in FAMO.

Run Experiments

Here we provide experiments code for NYUv2. To run the experiment with other benchmark, please refer to unified APIs in NashMTL or FAMO.

cd experiment/nyuv2
python trainer.py --method go4align 

We also support the following MTL methods as alternatives.

Method (code name) Paper (notes)
Gradient-oriented methods ---------------------------------------------------------------------------------
MGDA Multi-Task Learning as Multi-Objective Optimization
PCGrad Gradient Surgery for Multi-Task Learning
CAGrad Conflict-Averse Gradient Descent for Multi-task Learning
IMTL-G Towards Impartial Multi-task Learning
NashMTL Multi-Task Learning as a Bargaining Game
Loss-oriented methods ---------------------------------------------------------------------------------
LS - (equal weighting)
SI - (see Nash-MTL paper for details)
RLW A Closer Look at Loss Weighting in Multi-Task Learning
DWA End-to-End Multi-Task Learning with Attention
UW Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics
FAMO FAMO: Fast Adaptive Multitask Optimization

Citation

This repo is built upon NashMTL or FAMO. If our work GO4Align is helpful in your research or projects, please cite the following papers:

@article{shen2024go4align,
  title={GO4Align: Group Optimization for Multi-Task Alignment},
  author={Shen, Jiayi and Wang, Cheems and Xiao, Zehao and Van Noord, Nanne and Worring, Marcel},
  journal={arXiv preprint arXiv:2404.06486},
  year={2024}
}

@article{liu2024famo,
  title={Famo: Fast adaptive multitask optimization},
  author={Liu, Bo and Feng, Yihao and Stone, Peter and Liu, Qiang},
  journal={Advances in Neural Information Processing Systems},
  volume={36},
  year={2024}
}

@article{navon2022multi,
  title={Multi-task learning as a bargaining game},
  author={Navon, Aviv and Shamsian, Aviv and Achituve, Idan and Maron, Haggai and Kawaguchi, Kenji and Chechik, Gal and Fetaya, Ethan},
  journal={arXiv preprint arXiv:2202.01017},
  year={2022}
}

About

Code for GO4Align: Group Optimization for Multi-Task Alignment

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages