Skip to content

Latest commit

 

History

History
61 lines (38 loc) · 4.43 KB

File metadata and controls

61 lines (38 loc) · 4.43 KB

Project Name: A case study on Weakly Supervised Learning.

Team:

Jacques Thibodeau, Arian Pasquali, and Kevin Koehncke

Project Goal:

Explore weak supervision approaches to train text classification models. The goal is to compare weak supervision with traditional supervised learning. We want to research how good weak supervision can get compared to SOTA supervised text classification that can count on manually built gold-standard. We will use a public and well established dataset with ground-truth annotations.

What baseline will you use?

Traditional supervised learning model for text classification using pre-trained BERT. Weak supervision approaches: We will emulate a non-labeled dataset by ignoring the labels in the dataset and define label functions from scratch and apply weak labeling using Snorkel. Then we will train a model using labels provided by the label functions using BERT. Weak labeling using model distillation approach. The idea is to use a NLI model as a teacher using zero-shot classification to provide the labels. The student model will be trained as the previous datasets using the same architecture but with labels provided by the teacher model.

Research questions:

  • Which weak supervision approach is the best?
  • Is weak supervision good enough to reduce the number and necessity of manually annotated data?

What dataset will you use, or how do you plan to collect data?

To encourage reproducibility we plan to use a well established public dataset designed for a text classification task. Our first option is https://huggingface.co/datasets/dbpedia_14 due to the number of labels and size of the dataset.

We will be using the DBpedia14 dataset and importing it directly from HuggingFace Datasets. The dataset has wikipedia content related to 14 different labels. We are using this labeled dataset because the total size of the training dataset is 560,000 and testing dataset 70,000. We can then use this dataset to emulate the real-world scenario where we need to choose between manually labelling a subsample of the dataset or using weak supervision for the entire dataset. We can then compare both approaches.

What model architecture / loss function do you propose?

We will be using bert-base-cased (https://huggingface.co/bert-base-cased) to train our baseline model and the model trained with the Snorkel labels. For the distilled model, we’ll work with a roberta-large-mnli as the source model to get the labels using zero-shot classification, and the new model will be based on BERT (bert-base-cased).

Since the dataset sets out to solve a multi-class classification problem, we’ll be using Categorical Cross-Entropy Loss as our loss function.

What will the end result look like?

A Streamlit App where the user can paste text of interest in a dialogue box, then submit the text to have every paragraph labelled by one of the models through inference. The output will be a table showing the text contents with the corresponding label and a button to download a CSV of the new dataset.

The system will apply each classification model and show results and to compare the confidence scores from each approach.

What’s your stretch goal?

We would like to employ active learning into our system, where the user is able to correct predictions generated by the model, save the examples and then use those examples to re-train the model and iteratively improve it. The system should be able to rank the predictions with the lowest confidence scores and ask the user to confirm or change the prediction.

Tooling

We would also like to include the following:

GPU Training environment: TBD, local GPU, AWS SageMaker, Colab, ... Experimentation tracking using Weights and Biases; Web app using Streamlit; Source control: Github Model deployment using Lambda; Define dataset management tool TBD Manage the Dataset with a Dataset Management Tool Label management with Label Studio Model Service monitoring via telemetry written to cloud provider and visualized via a dashboard TBD Data processing and pipeline management using https://github.com/prefecthq/prefect, https://github.com/dagster-io/dagster or https://github.com/apache/airflow.
We are also interested in publishing the model to some shared hub like Huggingface Model Hub. This way the model is not only deployed but also public for everyone.