When I first heard about weak labelling a little over a year ago, I was initially sceptical. The premise of weak labelling is that you can replace manually annotated data with data created from heuristic rules written by domain experts. This didn't make any sense to me. If you can create a really good rules based system, then why wouldn't you just use that system?! And if the rules aren't good then won't a model trained on the noisy data also be bad? It sounded like a return to the world of feature engineering that deep learning was supposed to replace.

Over the past year, I've had my mind completely changed. I worked on numerous NLP projects that involved data extraction and delved a lot deeper into the literature surrounding weak supervision. I also spoke to ML team leads at companies like Apple, where I heard stories of entire systems being replaced in the space of weeks - by combing weak supervision and machine translation they could create massive data-sets in under-resourced languages that previously were just not served!

Since I now have the zeal of a convert, I want to explain what weak supervision is, what I learned and why I think it complements techniques like active learning for data annotation.

Weak Supervision is a form of programatic labelling

The term "weak supervision" was first introduced by Alex Ratner and Chris Re who then went on to create the Snorkel open source package and eventually the unicorn company Snorkel AI. It's a three-step process that lets you train supervised machine learning models starting from having very little annotated data. The three steps are:

  1. Create a set of "labelling functions" that can be applied across your dataset

    Labelling functions are simple rules that can guess the correct label for a data point.

    For example, if you were trying to classify news headlines as "clickbait" or "legitimate news" you might use a rule that says:

    If the headline starts with a number predict that the label is clickbait and if not don't predict a label.

    or in code:

    def starts_with_digit(headline: Doc) -> int:
    	"""Labels any headlines which start with a digit.""
    	if headline.tokens[0].is_digit():
    	    return 1
      else:
    	    return 0
    

    It's clear that this rule won't be 100% accurate and won't capture all clickbait headlines. That's ok. What's important is that we can come up with a few different rules and that each rule is pretty good. Once we have a collection of rules, we can then learn which ones to trust and when.

  2. Use a Bayesian model to work out the most likely label for each data point

    The good news for a practitioner is that a lot of the hardest math is taken care of by open-source packages.

    It's helpful to keep an example in mind, so let's stick to thinking about classifying news headlines into one of two categories. Imagine that you have 10000 headlines and 5 labelling functions. Each labelling function tries to guess wether the headline is "news" or "clickbait". If you wanted to visualise it, you could put it together into a large table with one row per headline and one column for each labelling function. If you did that you'd get a table like this with 10,000 rows and 5 columns:

    Labelling functions for clickbait detection. Each rule one can vote for "News", "Clickbait" or abstain.

    Labelling functions for clickbait detection. Each rule one can vote for "News", "Clickbait" or abstain.

The goal now is to work out the most likely label for each data-point. The simplest thing you could do would be to just take the majority vote in each row. So if four labelling functions voted for clickbait, we'd assume the label is clickbait.

The problem with this approach is that we know that some of the rules are going to be bad and some are going to be good. We also know that the rules might be correlated. The way we get around this is to first to train a probabilistic model that learns an estimate for the accuracy of each rule. Once we've trained this model, we can calculate the distribution p(y=clickbait| labelling functions) for each of our data-points. This is a more intelligent weighting of all the 5 votes we get for each data-point.

Intuitively a model can tell that a rule is accurate if it consistently votes with the majority. Conversely a rule that is very erratic and only sometimes votes with the majority is less likely to be good. The more approximate rules we have, the better we can clean them up.

There is a lot of freedom in what model you choose to clean up the labels, and how you train it. Most of the research in the early days of weak supervision was in improving the model used and the efficiency of training. The original paper used a naive Bayes model and SGD with Gibbs sampling to learn the accuracy of each rule. Later methods were developed that can learn the correlations between labelling functions too and recent work has used matrix completion methods to efficiently train the model.

  1. Train a normal machine learning model on the dataset produced by steps one and two

    Once you have your estimates for the probability of each label, you can either threshold that to get a hard label for each data-point or you can use the probability distribution as the target for a model. Either way, you now have a labelled dataset and you can use it just like any other labelled dataset!

    You can go ahead and train a machine learning model and get good performance, even though you might have had no annotated data when you started!

A big weakly-supervised dataset is just as good as a small labelled dataset. But is much easier to get!

One of the cool proofs in the original paper shows that the performance of weak labelling gets better the more unlabelled data you have at the same rate that supervised learning gets better the more data you have!

The above sentence is a bit of a mouthful and is easy to miss but is really profound. It's something that I didn't appreciate when I first read about weak supervision. Understanding this properly is one of the reasons I've become a weak labelling advocate.

What it means is that once you have a good set of labelling functions, the more unlabelled data you add the better the performance will be. Compare that to supervised learning where to get better performance you have to add more labelled data which might require expert time.