Skip to content

saarst/transfer-mpl-roleplaying

Repository files navigation

TransferMPL-Roleplaying

Final project for the Technion's EE Deep Learning course (046211)

Implementation of Meta Pseudo Labels with "Role-Playing" and Transfer Learning.

Shira Lifshitz: LinkedIn , GitHub
Saar Stern: LinkedIn , GitHub

Based on paper: Hieu Pham, Zihang Dai, Qizhe Xie, Minh-Thang Luong, Quoc V. Le Meta Pseudo Labels

Background

In semi-supervised setting we are given a dataset with low percent of labels (e.g. 3%) . MPL is an algorithm for such case: The idea of this algorithm is to train a teacher model using "supervised loss" (with labels) , "self-supervised loss" (uda, which uses weak and strong augmentations) , and "semi-supervised loss" derived from a student model, that learns from Psuedo Labels (teacher's predictions)

alt text

Figure taken from original paper

This design is asymetric, so we introduce Role-Playing, basically switch positions between the student and the teacher in the training phase. Eventually using ensemble learning to use both models. We also wanted to examine:

  1. The 16-classes flowers dataset, with 3% percent labels.
  2. Negative cosine similarity criterion instead of CE.
  3. Optuna hyper-parameter tuning.
  4. Different augmentations from the original paper.

Results

On flowers dataset with 3% labels, using RolePlaying and Ensemble:

Test set accuracy : 68.22% , improvement of 13.35% from baseline (only labels)

Installation

Clone the repository and run:

$ conda env create --name TransferMPL --file env.yml
$ conda activate TransferMPL
$ python MPL.py

Files in the repository

File name Purpsoe
MPL.py general purpose main application
models.py create the models.
optuna.py optuna optimizatoin script
data.py loaders, splits and augmentations.
args.py arguments parser
utils.py utils
train.py training functions
visualizatoin.py plot graphs, confusion matrix and etc.

API (MPL.py --help)

You should use the MPL.py file with the following arguments:

Argument Description
-h, --help show this help message and exit
--name experiment name
--data_dir data path (must start with /datasets , for e.g. /datasets/flowers)
--load_path folder in /checkpoints/ folder to load "best_student" from
--num_labels_percent percent of labeled data
--num_epochs number of epochs to run
--warmup_epoch_num warmup steps
--model_name model name for feature extracting (e.g vgg16)
--unsupervised loss for unsupervised, can be "CE" or "cos"
--seed seed for initializing training
--threshold pseudo label threshold
--temperature pseudo label temperature
--lambda_u coefficient of unlabeled loss
--uda_steps warmup steps of lambda-u
--show_images show samples from dataset and sample from augmented dataset
--load_best load best model to that dataset
--print_model print the model we are training
--optuna_mode for running optuna
--test_mode only test
--switch_mode switch models every epoch
--finetune_mode only finetune model on labeled dataset
--n_trials n_trials for optuna
--timeout timeout [sec] for optuna

Usage

  1. Download flowers dataset and put in /datasets/flowers
  2. Run with MPL.py --name NameOfExperiment --data_dir /datasets/flowers
  3. Finetune with MPL.py --name finetune --finetune_mode --load_best --data_dir /datasets/flowers --load_path /checkpoints/flowers/NameOfExperiment
  4. Results are in /results/NameOfExperiment

References

About

Final project for the Technion Deep Learning course. Includes implementation, training code, and experiments for the selected project topic.

Resources

Stars

Watchers

Forks

Packages

 
 
 

Contributors

Languages