Zero Shot Classification with Huggingface + Sentence Transformers πŸ€— πŸ€–

pytorch
huggingface
Fast Zero Shot classification of text
Author

Sachin Abeywardana

Published

October 10, 2021

bert image Photo credit: Bert.

Introduction

When it comes to text classification Bert/ Distilbert is our goto. However, quite often we lack labels to start off our classification process. Huggingface released a tool about a year ago to do exactly this but by using BART. The concept behind zero shot classification is to match the text to a topic word. The words used in a topic sentence contains information that describes the cluster as opposed to a one hot encoded vector.

What’s wrong with BART?

I personally believe that BART is a heavy handed way of doing this as it’s complexity is O(NK) whereas, using a sentence transformer, the complexity is roughly O(N + K) (where N is the number of sentences and K is the number of topics).

When using BART to check if a topic is similar to a word, we must concatenate the sentence along with the potential topic (seperated by a <SEP> token) and pass it through a BART transformer. This needs to be done against all potential topics. BART outputs a probability of the two sentences being neutral (nothing to do with each other), entailing and contradictions. In the HF repo the entailment probabilities are normalised across topics to choose the most likely topic.

Sentence Transformers

Sentence Transformers are used to summarise a sentence into a single vector. Therefore this is ideal to compare a sentence against and works reasonably well as shown below. One other benefit of using Sentence Transformers is that they offer a small model (127 MB!) compared to BART which is 500MB. One other benfit that is given for free is the fact that the sentence transformer is multilingual!

Experiment

To demonstrate zero shot classification I have used the News 20 dataset to classify news articles into one of 20 topics such as politics, religion, baseball etc.

In order to calculate the sentence embedding the mean_pooling function takes all the token embedding transformed outputs and averages them. We go further and normalise these embedding vectors to be of unit length.

Code
#Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

def get_embedding_batch(model, tokenizer, sentences: List[str]) -> torch.FloatTensor:
    x = tokenizer(
        sentences, 
        max_length=MAX_TEXT_LENGTH, 
        truncation=True, 
        padding="max_length", 
        return_tensors="pt"
    )
    x_dev = {k: v.to(device) for k, v in x.items()}
    out = model(**x_dev)
    embeds = mean_pooling(out, x_dev["attention_mask"]).cpu()
    embed_lens = torch.norm(embeds, dim=-1, keepdim=True)
    return embeds / embed_lens

def get_embeddings(model, tokenizer, sentences, batch_size):
    with torch.no_grad():
        embeds = []
        for i in tqdm(range(0, len(sentences), batch_size)):
            embeds.append(get_embedding_batch(model, tokenizer, sentences[i:i+batch_size]))
    
    return torch.cat(embeds)

We pass the topics as well as the candidate sentences through the sentence transformer separately. By taking the product we are able to get a similarity metric. Below we add one and halve it to ensure the number lies between [0, 1]. Strictly speaking this rescaling is not necessary.

similarity = 0.5 * (1 + sentence_embeds @ topic_embeds.T)
confidence, idx = similarity.topk(k=2)

As can be seen below, even when it does get it wrong, the predictions are close. If your topics are quite distinct you might observe better results than what is shown below.

True Topic Predicted Topic Confidence
0 hockey [hockey, baseball] [0.64, 0.58]
1 sys ibm pc hardware [graphics, sys ibm pc hardware] [0.63, 0.62]
2 middle east [middle east, politics guns] [0.64, 0.60]
3 sys ibm pc hardware [sys mac hardware, sys ibm pc hardware] [0.69, 0.68]
4 sys mac hardware [sys ibm pc hardware, sys mac hardware] [0.64, 0.61]

Looking at the top k accuracy we get the following result:

Top 1 Accuracy is: 58.04%, Top 2 Accuracy is: 58.82%

Shameless Self Promotion

If you enjoyed the tutorial buy my course (usually 90% off).