Skip to main content

Architecting an AI Inference Stack

· 6 min read
Sanjeev Sarda
High Performance Developer

Notes on architecting AI inference stacks and TPUs from Google's learning path, "Inference on TPUs".

TPU

Architecting an AI Inference Stack

The state of the art is rapidly changing and there is a high cost of change.

What is Inference?

The way we use a model to do something useful. Using a trained model. Needs to be fast, scalable and cost-effective.

AI Serving - infra and process required to deploy a trained model and available at scale.

alt text

Frameworks for AI/ML

alt text

Keras

An API for building models, neural networks. Multi-backend. Works with different execution engines.

Pytorch and Jax

Numerical computing librarys. Used for model definitions.

Can also use all 3 to train models.

Inference

alt text

Need dedicated inference engines and frameworks.

NVIDIA Triton Framework

alt text

TGI Huggingface

alt text

vLLM

General LLM serving alt text

LLMD

For large scale deployments

alt text

Fine tuning

Adapting an existing model for a domain or task

alt text

PEFT and LoRA

alt text

Defining and Training, Inference and Fine Tuning

alt text

Model types and Performance Bottlenecks

Compute, memory, memory bandwidth and network.

Types of models:

alt text

LLMS

alt text

Diffusion - Text to Image

alt text

Visual Language Models - multi modal

alt text

Mixture of Experts

alt text

Common Bottlenecks

Compute Bound

alt text

During training of dense models or diffusion inference. Add more chips.

Memory Bound

alt text

Provision high memory machines.

Quantize the model down to 8 or 4 bits.

Memory Bandwidth Bound

alt text

The TPU is waiting for data. You need to move data from memory to the TPU. MOE models can suffer from this.

Network Bound

alt text

When you need to split the model across multiple machines. Distributed training and large scale MOE models - can require a low latency interconnect.

Orchestration

alt text

alt text

alt text

Vertx AI is a managed Google solution.

GKE Approach

Hardware and job orchestration

alt text

Can also use leader-worker sets, or a framework like RAY to distribute tasks.

alt text

Slurm/HPC Approach

Use Google's cluster director.

alt text

vLLM for Throughout and Lower Latency

Memory Inefficiency

alt text

Standard practices are memory inefficient.

Latency from Queing

alt text

Multi host serving due to large memory requirements

alt text

vLLM

alt text

  • Paged attention - manages model memory in non-contingous blocks

  • Prefix caching - caches computation for shared prefixes

  • Multi-host serving - distribute the model across multiple GPUs and TPUs.

alt text

Has lots of tunable params.

Prefill

Prefill: this stage processes the input prompt and
generates an intermediate representation (like a
key-value cache). It s often compute intensive.

Decode

Decode: this stage generates the output tokens,
one by one, using the prefill representation.
It is typically memory-bandwidth bound.

Disaggregated Serving

Seperating the prefill from the decode, allowing them to run in parallel to improve throughput and latency.

Scalable Inference

alt text

Multi-region setup:

Reduce end user latency.

alt text

Observability metrics:

alt text

You can also increase availability using a disaggregated serving approach:

alt text

Storage

For cached data or at startup.

alt text

For low changing data use buckets: alt text

For rapidly changing data

alt text

Reference Architecture

alt text

GKE Inference Gateway

alt text

GIQ - GKE Inference Quickstart

Analyzes benchmark data to identify the most cost-effective hardware configurations for specific performance needs.

Using TPUs for Inference

alt text

CPU - Central processing unit

GPU - Graphics processing unit

TPU - Tensor processing unit

Tensors

A data structure, a multi-dimensional array of numbers.

CPUs

CPUs offer flexibility. They load values from memory, do something with them and store them back in memory. Memory access is slow vs calculation speed and can limit the CPU's throughput (von Neumann bottleneck).

GPUs

GPUs were originally created for rendering and graphics workloads. Has thousands of smaller cores or ALUs allowing for a high level of paralellism. This is important for things like matrix operations. GPUs must also access registers and shared memory to read operands and store intermediate results.

TPUs

A TPU is a custom ASIC made by Google. They are designed specifically for tensor operations. Their primary task is matrix processing, multiplication and accumulation.

TPUs use a sytolic array architecture - a matrix of physically connected multiply-accumulators.

HBM = High bandwidth memory

alt text

TPU Process

  • TPU host streams data into an inbound queue

  • TPU loads data from the queue and stores it in HBM memory

  • Post computation the results get dropped onto an outbound queue

  • TPU host reads the results from the outbound queue and puts them into the host's memory

  • To do matrix operations, params are loaded from HBM memory into a matrix multiplication unit

  • TPU loads data from HBM memory

Cloud TPU Architecture

Hardware utilization is important.

alt text

Data flows through the systolic array

alt text

alt text

Sparse cores - used for sparse data.

MXUs - for dense calculations

Physical organisation

alt text

64x64 cubes

alt text

Slices

alt text

High speed slice interconnect (ICI or inter-chip interconneect) - neighbour connectivity:

alt text

Useful for distributed training.

Built in resilience and routing.

Multi-slice training across data centres

alt text

There are a variety of TPUs that balance memory, cost etc making them more or less useful for training vs inference.

alt text

TPU Cloud Architecture

Made available on Google cloud as compute resources, as TPU VMs.

alt text

Or also on K8s:

alt text

Consumption and Usage

  • Speed - how quickly do you need the TPU compute capacity

  • Duration - how long do you need it for

  • Batch or online inference?

  • Can your workload tolerate pre-emption

  • Pricing/budget

alt text

CUDS - Commited use discounts

DWS - Dynamic workload scheduler

Calendar - Reserved capacity for a specific amount of time

Flex-start - Start time flexibility, batch jobs, not pre-emptible

vLLM GPUs and TPUs

Use a dual container approach within a single pod to support both GPUs and TPUs as the base image is different. Only one vLLM server will be active based on the underlying hardware.

This also allows us to use GKE for scaling and have new nodes come online which could be GPUs or TPUs.

vLLM Performance Tuning

Fast and efficient AI inference with new NVIDIA Dynamo recipe

Reference GKE Architecture

Inference Quickstart

Scalable and Distributed LLM Inference on GKE with vLLM

Model fine tuning pipeline

RAG Pipeline

Set up a Google Cloud project for TPUs

TPU Inference Recipes

Scaling an LLM model Book

JAX and Keras for LLM Development using TPU for Distributed Fine-Tuning