Infrastructure25 minute read

LLM Inference Optimization

From 4s to 450ms: A production journey through transformer acceleration

Zev Uhuru
Engineering Research
March 18, 2025

When we launched our LLM-powered writing assistant, users were waiting 4+ secondsfor responses. Six months later, we've reduced that to 450ms while handling 4x the traffic. This is the story of how we optimized transformer inference from the ground up.

The journey wasn't just about faster hardware or bigger caches. It required rethinking our entire inference pipeline, from memory management to request batching, from model quantization to speculative decoding. Here's what we learned.

The Performance Baseline

Our initial setup was straightforward: a 7B parameter model running on A100 GPUs with standard PyTorch inference. The results were... educational.

Latency Improvements by Optimization Stage

4.2s
Initial P95
450ms
Final P95
9.3x
Improvement
4x
Throughput

Optimization Strategy

We approached optimization systematically, measuring each change's impact before moving to the next. Here's the progression that got us to sub-500ms latency:

1. KV Cache Optimization

The transformer's attention mechanism recomputes the same key-value pairs for every token. By caching these computations, we eliminated redundant work.

python
class OptimizedAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.kv_cache = {}
        self.max_cache_size = config.max_sequence_length
        
    def forward(self, hidden_states, past_key_values=None):
        batch_size, seq_len, hidden_size = hidden_states.shape
        
        # Use cached K,V for previous tokens
        if past_key_values is not None:
            key = torch.cat([past_key_values[0], new_key], dim=2)
            value = torch.cat([past_key_values[1], new_value], dim=2)
        else:
            key, value = new_key, new_value
            
        # Cache for next iteration
        present_kv = (key, value)
        
        attention_output = self.attention(query, key, value)
        return attention_output, present_kv

Result: 25% latency reduction. Memory usage increased by ~30%, but the speed gains were worth it for our use case.

2. Dynamic Quantization

Moving from FP16 to INT8 quantization reduced memory bandwidth requirements significantly. The key was maintaining accuracy through careful calibration.

python
import torch.quantization as quant

# Calibration dataset preparation
def prepare_calibration_data(model, dataloader, device):
    calibration_data = []
    model.eval()
    
    with torch.no_grad():
        for batch in dataloader:
            inputs = batch['input_ids'].to(device)
            # Collect activation statistics
            _ = model(inputs)
            calibration_data.append(inputs)
            if len(calibration_data) >= 100:  # Sufficient samples
                break
    
    return calibration_data

# Dynamic quantization
model_quantized = quant.quantize_dynamic(
    model, 
    {torch.nn.Linear, torch.nn.MultiheadAttention}, 
    dtype=torch.qint8
)

# Quality check
def validate_quantized_model(original, quantized, test_data):
    original_outputs = []
    quantized_outputs = []
    
    for batch in test_data:
        with torch.no_grad():
            orig_out = original(batch)
            quant_out = quantized(batch)
            
            original_outputs.append(orig_out)
            quantized_outputs.append(quant_out)
    
    # Compare perplexity scores
    return calculate_perplexity_diff(original_outputs, quantized_outputs)

Trade-off: 33% faster inference, but required careful monitoring of output quality. We saw a 2% increase in perplexity that was acceptable for our use case.

3. Continuous Batching

Traditional batching waits for all sequences to complete. Continuous batching processes requests as they arrive and finish, maximizing GPU utilization.

python
class ContinuousBatcher:
    def __init__(self, max_batch_size=32, max_wait_time=50):
        self.max_batch_size = max_batch_size
        self.max_wait_time = max_wait_time  # milliseconds
        self.pending_requests = []
        self.active_batches = {}
        
    async def process_request(self, request):
        request.arrival_time = time.time()
        self.pending_requests.append(request)
        
        # Try to form a batch immediately
        if len(self.pending_requests) >= self.max_batch_size:
            return await self.create_batch()
            
        # Wait for more requests or timeout
        await asyncio.sleep(self.max_wait_time / 1000)
        return await self.create_batch()
    
    async def create_batch(self):
        if not self.pending_requests:
            return
            
        # Take up to max_batch_size requests
        batch_requests = self.pending_requests[:self.max_batch_size]
        self.pending_requests = self.pending_requests[self.max_batch_size:]
        
        # Process batch
        batch_id = uuid.uuid4()
        self.active_batches[batch_id] = batch_requests
        
        try:
            results = await self.model.generate_batch(batch_requests)
            return results
        finally:
            del self.active_batches[batch_id]

4. Speculative Decoding

The breakthrough came with speculative decoding: using a smaller, faster model to predict multiple tokens, then verifying with the main model in parallel.

python
class SpeculativeDecoder:
    def __init__(self, main_model, draft_model, gamma=4):
        self.main_model = main_model
        self.draft_model = draft_model  # Smaller, faster model
        self.gamma = gamma  # Speculation window
        
    def decode_step(self, input_ids, past_kv=None):
        # Draft model generates multiple tokens speculatively
        draft_tokens = []
        draft_probs = []
        current_input = input_ids
        
        for _ in range(self.gamma):
            with torch.no_grad():
                draft_output = self.draft_model(current_input)
                draft_logits = draft_output.logits[:, -1, :]
                draft_prob = F.softmax(draft_logits, dim=-1)
                
                # Sample next token
                next_token = torch.multinomial(draft_prob, 1)
                draft_tokens.append(next_token)
                draft_probs.append(draft_prob)
                
                current_input = torch.cat([current_input, next_token], dim=1)
        
        # Verify with main model (parallel processing)
        verification_input = torch.cat([input_ids] + draft_tokens, dim=1)
        main_output = self.main_model(verification_input, past_key_values=past_kv)
        main_logits = main_output.logits
        
        # Accept/reject tokens based on probability ratios
        accepted_tokens = []
        for i, (draft_token, draft_prob) in enumerate(zip(draft_tokens, draft_probs)):
            main_prob = F.softmax(main_logits[:, input_ids.shape[1] + i, :], dim=-1)
            
            # Acceptance probability
            accept_prob = torch.min(
                torch.ones_like(main_prob), 
                main_prob / (draft_prob + 1e-10)
            )
            
            if torch.rand(1) < accept_prob[0, draft_token]:
                accepted_tokens.append(draft_token)
            else:
                # Resample from corrected distribution
                corrected_prob = torch.clamp(main_prob - draft_prob, min=0)
                corrected_prob = corrected_prob / corrected_prob.sum()
                resampled_token = torch.multinomial(corrected_prob, 1)
                accepted_tokens.append(resampled_token)
                break  # Stop speculation on rejection
        
        return accepted_tokens, main_output.past_key_values

Breakthrough: Speculative decoding gave us the final 2.5x speedup. On average, we accept 2.8 out of 4 speculated tokens, dramatically reducing the number of main model forward passes.

Production Results

The optimizations transformed our service. Here's how performance and costs evolved over six months of production deployment:

Daily Throughput Comparison

Monthly Cost Reduction

Lessons Learned

Profile Before Optimizing

We spent weeks optimizing the wrong bottlenecks initially. Comprehensive profiling with tools like NVIDIA Nsight and PyTorch Profiler revealed that memory bandwidth, not compute, was our primary constraint.

Hardware-Software Co-design

The biggest gains came from aligning our software optimizations with hardware capabilities. Understanding GPU memory hierarchy and tensor core utilization patterns was crucial for effective optimization.

Measure Everything

We instrumented every optimization with detailed metrics: latency percentiles, throughput, memory usage, and quality scores. This data-driven approach prevented us from making changes that helped one metric while hurting others.

What's Next

Our optimization journey continues. We're exploring model distillation for even smaller draft models, investigating mixed-precision training for better quantization, and experimenting with custom CUDA kernels for specific operations.

"The future of LLM inference isn't just about bigger models—it's about smarter systems that can deliver human-quality responses at machine speed."

The techniques we've shared here represent just the beginning. As models grow larger and more capable, the optimization challenges will only intensify. But with systematic approaches and careful measurement, sub-second inference for even the largest models is within reach.