Transformers Inference Optimizations ⏰🚀

LLM
pytorch
Optimization tricks used for speeding up transformer inference
Author

Sachin Abeywardana

Published

January 21, 2025

Here’s a quick (and definitely not all-inclusive) write-up of the optimizations I stumbled upon in the last two days. It’s a bit rushed, so bear with me on this one! 💨

Install Flash Attention 2 🚀

Flash Attention 2 is the way to go—Flash Attention 3 exists, but as far as I know, it’s exclusive to H100 machines. Installing it was a fucking headache 😅, but the trick was to use a -devel Docker container. Also, don’t forget to install ninja; it speeds up the build process like a champ.

Here’s what worked for me:

FROM pytorch/pytorch:2.5.1-cuda12.4-cudnn9-devel
RUN python --version
WORKDIR /app

# Install build dependencies
RUN apt-get update && apt-get install -y \
    build-essential \
    cmake \
    git \
    && rm -rf /var/lib/apt/lists/*

# Copy the requirements.txt file into the container at /app
COPY requirements.txt /app/

# Install any needed packages specified in requirements.txt
RUN pip install ninja
RUN pip install --no-cache-dir flash-attn --no-build-isolation
RUN pip install --no-cache-dir -r requirements.txt
# # Install CUDA extensions for fused dense
# RUN pip install git+https://github.com/Dao-AILab/flash-attention@v2.6.3#subdirectory=csrc/fused_dense_lib

# Copy the current directory contents into the container at /app
COPY ./src/ /my/working/directory/src/

RUN addgroup --system somebody && \
    adduser --system --home /app --ingroup somebody somebody && \
    chown -R somebody:somebody /app

USER somebody

# Set environment variables
ENV PYTHONPATH=/app
ENV PYTHONUNBUFFERED=1

Torch tricks 🔧✨

Here’s a neat trick: always compile your model first! This applies to any model, whether you’re using transformers or something else. I was surprised to learn it’s not just about running model = torch.compile(model). Torch uses lazy compilation, meaning the first data pass actually kicks off the compile process. So, you’ll want to include a few warm-up iterations.

Here’s what that looked like for me. Note how I used a dummy input of "Hello, how can I help you?".

self.model = torch.compile(self.model)
for i in range(5):
    starting_time = time.time()
    _ = _generate_response(
        self.model, self.tokenizer, "Hello, how can I help you?", self.generation_kwargs
    )
    logger.info(
        f"Model response in {time.time() - starting_time:.2f} seconds for iteration {i}"
    )

And don’t forget to use torch.inference_mode()! If you skip it, Torch will try to compute gradients, which adds unnecessary overhead. I personally prefer slapping it on as a decorator rather than using with torch.inference_mode():. Saves a bit of boilerplate and looks cleaner, too! 🧹

Loading the Model ⚡️🤖

When loading your model, two big things to remember: use Flash Attention and set torch_dtype to bfloat16. This combo works wonders for performance. However, I’m still on the fence about quantization. On an L4 GPU (similar to AWS’s A10), loading in 8-bit mode actually slowed things down for me. Turns out, dequantizing weights during inference eats up compute time. Ditching 8-bit mode cut my inference time in half. 🕒✂️

Here’s my typical model-loading setup:

model = transformers.AutoModelForCausalLM.from_pretrained(
    llm_model_name,
    # quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
)

Also, size matters! Smaller models generally run faster, but it’s not always a sure thing. For example, the ibm-granite-1b-instruct model was slower than qwen2.5-1.5b-instruct, even though the former is smaller. Model architecture plays a role here, so don’t assume smaller is always quicker 🤔. It is worth noting that ibm-granite scored higher in the Open LLM Leaderboard.

Oh, and if you’re serving models, go for *-instruct versions over *-chat. While chat models shine for chatbot applications, most use cases benefit more from instruction-tuned models.

Future Directions 🔮💡

Here’s a quick list of things I didn’t fully explore (or barely scratched the surface of) but think are worth diving into:

  1. Use Unsloth: I’ve heard great things about this library—it’s supposed to save loads of time on research like this. Definitely on my to-try list! 🦥❌

  2. Try Sampling for Inference: In my limited testing, sampling gave me some…let’s just say interesting outputs (a.k.a. gibberish 😂). I suspect there’s a better set of hyperparameters that could give more cohesive results. Here’s what I tried:

    python
    CopyEdit
    self.generation_kwargs = {
        "do_sample": True,  # Enables sampling instead of beam search
        "top_k": 50,  # Limits sampling to the top 50 tokens (controls diversity)
        "top_p": 0.95,  # Uses nucleus sampling (cumulative probability threshold)
        "temperature": 0.3,  # Scales logits before sampling (higher = more randomness)
        "no_repeat_ngram_size": 2,  # Avoids repetitive sequences
        "early_stopping": True,  # Stops generation once max_length is reached
        "use_cache": True,
        "max_new_tokens": 192,  # Maximum number of tokens to generate
    }

That’s it for now! I’ll keep experimenting, but these tweaks already gave me some solid improvements. If you’ve got any tips or questions, drop me a line! ✌️