Prompt Caching: Poor man’s guide to zero shot vision-LLM classification

Deep Learning
LLM
huggingface
Using KV caching and logit ratios to speed up and control LLM/ VLM outputs.
Author

Sachin Abeywardana

Published

June 29, 2024

This tutorial is for the GPU poor. Zero-shotting vision-based large language models (VLMs) offers an affordable way to make predictions without the need for manual labeling or training. While methods like OWLv2 and CLIP are available, this guide will help you utilize more powerful VLMs. A common challenge with these VLMs is the lack of probability estimation for predictions. However, the techniques presented here can be extended to text-based classifiers as well.

This is not a prompt engineering tutorial. Instead, I will demonstrate two key techniques:

  1. Key-Value caching to store common prompts.

  2. Using vocabulary logits to estimate the probability of a given statement being true or false.

These techniques are applicable if you can frame your problem as a yes/no question. For demonstration purposes, the root prompt is: Answer with just yes or no. Does the following image contain the stated attribute? Image: <image> Attribute:. You can substitute the attribute as needed.

As a test, we will label the following image to determine if it depicts fried chicken or a dog. You can fork the full code from this kaggle kernel here (please upvote if useful).

dogs that looks like fried chicken. Most likely cavoodles

KV Caching

One of the main tricks used in the vLLM package is to cache common parts of incoming prompts. In our case we will cache the root prompt shown above.

The reason that this works is due to the following equation:

\[ \begin{align} a_{ij} &= \frac{\exp \left( q_i^\top k_j/\sqrt{d} \right)}{\sum_{t=1}^{i} \exp \left( q_i^\top k_t/\sqrt{d}\right)}\\ o_i &= \sum_{t=1}^{i} a_{tj} v_j \end{align} \]

The query \(q_i\) is the incoming token(s), while \(k_t, v_t\) up to \(i - 1\) are the common key, value pairs. This is true regardless of whichever new query prompt \(x_i\) that we put in. \(a_{ij}\) is the attention that the i-th token pays to the j-th token. \(o_i\) is the output of the i-th token in the transformer network. Thus, it makes sense to cache the previous keys and values.

In order to get these keys and values we use the following snippet:

prompt = "<|user|>\nYou are an expert on dogs and fried chicken. Only answer yes or no. <|image_1|>\n Does this image contain "
prompt_end = "<|end|>\n<|assistant|>\n"
root_inputs = processor(text=prompt, images=[image], padding="longest", return_tensors="pt").to(device)
with torch.inference_mode():
    kv_cache = model(**root_inputs, return_dict=True).past_key_values

Fun Fact: Note that kv_cache is an array of 32 keys and values, and not just two values. This is due to the fact that all LLMs these days are stacked transformer networks. As shown below each of the keys and values are of shape (1, 32, 2540, 96). The 1 is the batch size, which in this case is the simple input prompt. The 32 in this case is due to the number of attention heads and 96 is the dimension of each head (the product resulting in 3072 which is the dimensionality of the network as seen in model.config). The 2540 is the number of tokens (sequence length) that is in the root prompt. We will come back to this.

[[k.shape, v.shape] for k, v in kv_cache]
>>> 
[[torch.Size([1, 32, 2540, 96]), torch.Size([1, 32, 2540, 96])],
 [torch.Size([1, 32, 2540, 96]), torch.Size([1, 32, 2540, 96])],
 [torch.Size([1, 32, 2540, 96]), torch.Size([1, 32, 2540, 96])],
  ...

Logit Ratio

In any transformer model, the outputs of the final layer have the shape (batch_size, sequence_length, vocabulary_size). For our specific application, we are primarily interested in the output of the final token, specifically for the words “yes” and “no.” To address this, we begin by identifying and storing the positions of the “yes” and “no” tokens in the vocabulary. I have included a check to ensure these tokens exist in the vocabulary, as modern tokenizers often break words into sub-words.

tokenizer = processor.tokenizer
yes_id = torch.tensor(tokenizer.encode("yes", add_special_tokens=False)[-1]).unsqueeze(0).unsqueeze(0)
no_id = torch.tensor(tokenizer.encode("no", add_special_tokens=False)[-1]).unsqueeze(0).unsqueeze(0)
if not yes_id.shape == torch.Size([1, 1]):
    raise ValueError("yes id is multiple tokens")
if not no_id.shape == torch.Size([1, 1]):
    raise ValueError("no id is multiple tokens")

Then we take the softmax of the last location (-1) over only the yes and no tokens as opposed to the entire vocabulary as shown below:


with torch.inference_mode():
    probs_iterative = []
    class_names = ["a dog", "fried chicken", "a lion"]
    for class_name in class_names:
        inputs = processor(text = [class_name + prompt_end], padding=True, truncation=True, return_tensors="pt").to(device)
        inputs["attention_mask"] = torch.cat([root_inputs["attention_mask"], inputs["attention_mask"]], dim=-1)
        outputs = model(**inputs, past_key_values=kv_cache, return_dict=True)
        logits = torch.tensor([outputs.logits[:, -1, yes_id], outputs.logits[:, -1, no_id]], device=device)
        probs_iterative.append(F.softmax(logits, dim=-1))
        
        print(f"The probability of seeing {class_name} is {probs_iterative[-1][0].item():.4f}")
>>>
The probability of seeing a dog is 0.7585
The probability of seeing fried chicken is 0.1931
The probability of seeing a lion is 0.0026
CPU times: user 581 ms, sys: 0 ns, total: 581 ms
Wall time: 579 ms

I have included a random class of lion above to see what it does. And it does pass the test assigning a <1% probability of seeing that class. You can see that the fried chicken probability is at 20% which wasn’t low as I’d like, but it’s good enough imo.

Worth noting that I have added prompt_end after class_name. This is due to the fact that LLMs expect a <assistant>: token before starting to generate text.

Fun Fact: You do not need to make the class names a single word. In fact you can make it a complete sentence and it will still work.

Batching

The above method took ~500ms to generate the output. We can optimise further by using a batch instead of the for loop shown above. The steps needed to convert the above process is as follows:

  1. Repeat each of the key value pairs to the size of the batch. I have use torch.expand here instead of torch.repeat in the hopes that there is no copy overhead, but I am not sure if this actually copied tensors or not:
expanded_kv_cache = tuple(
    (
        k.expand(num_classes, -1, -1, -1), 
        v.expand(num_classes, -1, -1, -1)
    ) 
    for k, v in kv_cache
)
  1. Despite being able to pass in the stacked keys and values, for some reason the model requires you to stack the new attention_mask (but not the input_ids).
    inputs["attention_mask"] = torch.cat(
        [
            root_inputs["attention_mask"].expand(num_classes, -1), 
            inputs["attention_mask"]
        ], 
        dim=-1
    )
    outputs = model(**inputs, past_key_values=expanded_kv_cache, return_dict=True)
  1. Identify the last index of each sentence. Because we are stacking sentences of varying lengths, the relevant word may not be at the last position in the sequence dimension. Simply taking the last index might cause you to mistakenly include a padded token. To avoid this, use the attention mask to accurately find the last index of each sentence. last_indices = inputs["attention_mask"][:, -seq_length:].sum(dim=-1) - 1

The final code snippet is as follows. This method takes ~250ms, halving the initial for loop time.

with torch.inference_mode():
    inputs = processor(
        text = [class_name + prompt_end for class_name in class_names], 
        padding=True, 
        truncation=True, 
        return_tensors="pt"
    ).to(device)
    num_classes = len(class_names)
    inputs["attention_mask"] = torch.cat(
        [
            root_inputs["attention_mask"].expand(num_classes, -1), 
            inputs["attention_mask"]
        ], 
        dim=-1
    ) 
    expanded_kv_cache = tuple((k.expand(num_classes, -1, -1, -1), v.expand(num_classes, -1, -1, -1)) for k, v in kv_cache)
    outputs = model(**inputs, past_key_values=expanded_kv_cache, return_dict=True)
    num_elements, seq_length = inputs["input_ids"].shape
    last_indices = inputs["attention_mask"][:, -seq_length:].sum(dim=-1) - 1
    probs = F.softmax(
        torch.cat(
            [
                outputs.logits[torch.arange(num_elements), last_indices, yes_id], 
                outputs.logits[torch.arange(num_elements), last_indices, no_id]
            ], 
            dim=0
        ).squeeze(), dim=0
    )
probs.T

Fun Fact: We can use other VLMs besides phi-3-vision. However, make sure you initialise processor.tokenizer.padding_side = "right" since I spend almost a day with a bug feature due to the padding happening on the left with the Llava models.

Multi-class extension

I attempted a similar approach for a multiclass scenario, but with limited success. You might achieve better results with a larger model, considering that phi-3 has “only” 3 billion parameters. Note that I did not use Key-Value (KV) caching in this case. However, you may choose to use KV caching if you want to swap out images, in which case the root prompt would not include the image.

Long sleeve t-shirt in a white background

The prompt in this scenario was structured as follows for the image shown:

long_sleeve_tshirt = "https://i.ebayimg.com/images/g/-eoAAOSwnHZYRpKL/s-l1200.webp"
image = load_image(long_sleeve_tshirt)

prompt = """<|user|>\nYou are a fashion expert. Only answer with the number of the following options:
1. Shoes
2. Long-sleeve shirt
3. T-shirt
4. Jacket
5. Skirt
6. Sandals
<|image_1|>\n This image contains the option: """
prompt_end = "<|end|>\n<|assistant|>\n"
multiclass_inputs = processor(text=prompt + prompt_end, images=[image], padding="longest", return_tensors="pt").to(device)
with torch.inference_mode():
    output = model(**multiclass_inputs, return_dict=True)

The tokens are extracted as what follows. Note that I used range(1, 7) instead of range(6) since python is zero indexed.

token_ids = []
tokenizer = processor.tokenizer
for i in range(1, 7):
    token_id = torch.tensor(tokenizer.encode(str(i), add_special_tokens=False)[-1]).unsqueeze(0).unsqueeze(0)
    if not token_id.shape == torch.Size([1, 1]):
        raise ValueError("id is multiple tokens")
    token_ids.append(token_id)
all_token_ids = torch.cat(token_ids).squeeze()

Finally, in order to get the probabilities, I can do:

F.softmax(output.logits[:, -1, all_token_ids], dim=-1)
>>>
tensor([[0.3414, 0.2539, 0.0979, 0.1892, 0.0834, 0.0341]], device='cuda:0')

The shoes unfortunately got the highest probability, while the correct answer of a “Long sleeve t-shirt” comes in at second. I tried swapping the order of the items and it seems that the LLM has a tendency to prefer 1 regardless of what the context is.

Gotchas and Final Thoughts

When I tried Llava as opposed to Phi-3 during my experiments, I could not get the probabilities between the batched and the for loop method to match for a while. It turns out, this was due to the fact that padding does not happen on the right by default with all models. So just to be safe, I would put the following snippet to force it to use this method.

processor.tokenizer.padding_side = "right"  

Unfortunately, images do take up a large number of tokens (roughly 2000 in my case). So the benefit of caching everything before the <image> token does not amount to much. Especially given the O(n^2) complexity of transformers. I tried resizing the image, but this didn’t help. However, I am hoping that there is some parameter in the processor which will allow me to take up fewer tokens for the image. It makes very little sense that an image will take more than a 1000 tokens given that even GPT-4 only takes up 255 tokens per image according to their pricing page. I have posed this question on stackoverflow and if you have any insights, I would really appreciate it.

Also as a side note, remember to use torch.inference_mode() to ensure we don’t spend more time accidentally calculating gradients.

There are other tricks such as torch.compile but unfortunately, that trick is not available for anything less than a A10 machine.

Kudos

Kudos to whoever wrote this thread on reddit on which I based my method off of.