Transformers Inference Optimizations ⏰🚀
Here’s a quick (and definitely not all-inclusive) write-up of the optimizations I stumbled upon in the last two days. It’s a bit rushed, so bear with me on this one! 💨
Install Flash Attention 2 🚀
Flash Attention 2 is the way to go—Flash Attention 3 exists, but as far as I know, it’s exclusive to H100 machines. Installing it was a fucking headache 😅, but the trick was to use a -devel
Docker container. Also, don’t forget to install ninja
; it speeds up the build process like a champ.
Here’s what worked for me:
FROM pytorch/pytorch:2.5.1-cuda12.4-cudnn9-devel
RUN python --version
WORKDIR /app
# Install build dependencies
RUN apt-get update && apt-get install -y \
build-essential \
cmake \
git \
&& rm -rf /var/lib/apt/lists/*
# Copy the requirements.txt file into the container at /app
COPY requirements.txt /app/
# Install any needed packages specified in requirements.txt
RUN pip install ninja
RUN pip install --no-cache-dir flash-attn --no-build-isolation
RUN pip install --no-cache-dir -r requirements.txt
# # Install CUDA extensions for fused dense
# RUN pip install git+https://github.com/Dao-AILab/flash-attention@v2.6.3#subdirectory=csrc/fused_dense_lib
# Copy the current directory contents into the container at /app
COPY ./src/ /my/working/directory/src/
RUN addgroup --system somebody && \
adduser --system --home /app --ingroup somebody somebody && \
chown -R somebody:somebody /app
USER somebody
# Set environment variables
ENV PYTHONPATH=/app
ENV PYTHONUNBUFFERED=1
Torch tricks 🔧✨
Here’s a neat trick: always compile your model first! This applies to any model, whether you’re using transformers or something else. I was surprised to learn it’s not just about running model = torch.compile(model)
. Torch uses lazy compilation, meaning the first data pass actually kicks off the compile process. So, you’ll want to include a few warm-up iterations.
Here’s what that looked like for me. Note how I used a dummy input of "Hello, how can I help you?"
.
self.model = torch.compile(self.model)
for i in range(5):
= time.time()
starting_time = _generate_response(
_ self.model, self.tokenizer, "Hello, how can I help you?", self.generation_kwargs
)
logger.info(f"Model response in {time.time() - starting_time:.2f} seconds for iteration {i}"
)
And don’t forget to use torch.inference_mode()
! If you skip it, Torch will try to compute gradients, which adds unnecessary overhead. I personally prefer slapping it on as a decorator rather than using with torch.inference_mode():
. Saves a bit of boilerplate and looks cleaner, too! 🧹
Loading the Model ⚡️🤖
When loading your model, two big things to remember: use Flash Attention and set torch_dtype
to bfloat16. This combo works wonders for performance. However, I’m still on the fence about quantization. On an L4 GPU (similar to AWS’s A10), loading in 8-bit mode actually slowed things down for me. Turns out, dequantizing weights during inference eats up compute time. Ditching 8-bit mode cut my inference time in half. 🕒✂️
Here’s my typical model-loading setup:
= transformers.AutoModelForCausalLM.from_pretrained(
model
llm_model_name,# quantization_config=bnb_config,
="auto",
device_map=True,
trust_remote_code=torch.bfloat16,
torch_dtype="flash_attention_2",
attn_implementation )
Also, size matters! Smaller models generally run faster, but it’s not always a sure thing. For example, the ibm-granite-1b-instruct
model was slower than qwen2.5-1.5b-instruct
, even though the former is smaller. Model architecture plays a role here, so don’t assume smaller is always quicker 🤔. It is worth noting that ibm-granite
scored higher in the Open LLM Leaderboard.
Oh, and if you’re serving models, go for *-instruct
versions over *-chat
. While chat models shine for chatbot applications, most use cases benefit more from instruction-tuned models.
Future Directions 🔮💡
Here’s a quick list of things I didn’t fully explore (or barely scratched the surface of) but think are worth diving into:
Use Unsloth: I’ve heard great things about this library—it’s supposed to save loads of time on research like this. Definitely on my to-try list! 🦥❌
Try Sampling for Inference: In my limited testing, sampling gave me some…let’s just say interesting outputs (a.k.a. gibberish 😂). I suspect there’s a better set of hyperparameters that could give more cohesive results. Here’s what I tried:
python CopyEditself.generation_kwargs = { "do_sample": True, # Enables sampling instead of beam search "top_k": 50, # Limits sampling to the top 50 tokens (controls diversity) "top_p": 0.95, # Uses nucleus sampling (cumulative probability threshold) "temperature": 0.3, # Scales logits before sampling (higher = more randomness) "no_repeat_ngram_size": 2, # Avoids repetitive sequences "early_stopping": True, # Stops generation once max_length is reached "use_cache": True, "max_new_tokens": 192, # Maximum number of tokens to generate }
That’s it for now! I’ll keep experimenting, but these tweaks already gave me some solid improvements. If you’ve got any tips or questions, drop me a line! ✌️