- Published on
FastAPI for ML Engineers: Building Production APIs
- Authors

- Name
- Jared Chung
Introduction
You've trained a great model. Now what? Most ML value comes from deploying models where they can serve predictions. FastAPI has become the go-to framework for ML APIs fast, modern, and designed with async support that's perfect for I/O-heavy ML workloads.
This guide covers everything you need to build production-ready ML APIs with FastAPI.
Why FastAPI for ML?
| Feature | Benefit for ML |
|---|---|
| Async native | Handle concurrent requests while model processes |
| Automatic docs | OpenAPI docs generated from type hints |
| Pydantic validation | Validate inputs before hitting your model |
| Streaming support | Stream LLM tokens in real-time |
| Background tasks | Long-running inference without blocking |
| Easy testing | Built-in test client |
Project Structure
A well-organized ML API project:
ml_api/
├── app/
│ ├── __init__.py
│ ├── main.py # FastAPI app
│ ├── config.py # Settings
│ ├── models/
│ │ ├── __init__.py
│ │ ├── schemas.py # Pydantic models
│ │ └── ml_models.py # ML model loading
│ ├── routers/
│ │ ├── __init__.py
│ │ ├── predictions.py # Prediction endpoints
│ │ └── health.py # Health checks
│ ├── services/
│ │ ├── __init__.py
│ │ └── inference.py # Business logic
│ └── middleware/
│ ├── __init__.py
│ └── logging.py # Request logging
├── tests/
├── Dockerfile
├── requirements.txt
└── pyproject.toml
Basic ML API
Let's start with a simple image classification API:
# app/main.py
from fastapi import FastAPI
from contextlib import asynccontextmanager
from app.models.ml_models import load_model
ml_models = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
# Load model on startup
ml_models["classifier"] = load_model("resnet50")
yield
# Cleanup on shutdown
ml_models.clear()
app = FastAPI(
title="ML Prediction API",
version="1.0.0",
lifespan=lifespan
)
# app/models/schemas.py
from pydantic import BaseModel, Field
from typing import List
class PredictionRequest(BaseModel):
image_url: str = Field(..., description="URL of image to classify")
class Prediction(BaseModel):
label: str
confidence: float = Field(..., ge=0, le=1)
class PredictionResponse(BaseModel):
predictions: List[Prediction]
model_version: str
inference_time_ms: float
# app/routers/predictions.py
from fastapi import APIRouter, HTTPException
from app.models.schemas import PredictionRequest, PredictionResponse
from app.services.inference import classify_image
import time
router = APIRouter(prefix="/api/v1", tags=["predictions"])
@router.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest):
start = time.perf_counter()
try:
predictions = await classify_image(request.image_url)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
inference_time = (time.perf_counter() - start) * 1000
return PredictionResponse(
predictions=predictions,
model_version="1.0.0",
inference_time_ms=round(inference_time, 2)
)
Handling Long-Running Inference
ML inference can be slow. Handle it properly.
Background Tasks
For non-blocking inference with callbacks:
from fastapi import BackgroundTasks
from uuid import uuid4
# In-memory store (use Redis in production)
job_store = {}
class JobStatus(BaseModel):
job_id: str
status: str # pending, processing, completed, failed
result: Optional[dict] = None
@router.post("/predict/async")
async def predict_async(
request: PredictionRequest,
background_tasks: BackgroundTasks
):
job_id = str(uuid4())
job_store[job_id] = {"status": "pending", "result": None}
background_tasks.add_task(
run_inference_job,
job_id,
request
)
return {"job_id": job_id, "status_url": f"/api/v1/jobs/{job_id}"}
async def run_inference_job(job_id: str, request: PredictionRequest):
job_store[job_id]["status"] = "processing"
try:
result = await classify_image(request.image_url)
job_store[job_id] = {"status": "completed", "result": result}
except Exception as e:
job_store[job_id] = {"status": "failed", "result": str(e)}
@router.get("/jobs/{job_id}")
async def get_job_status(job_id: str):
if job_id not in job_store:
raise HTTPException(status_code=404, detail="Job not found")
return job_store[job_id]
Streaming Responses (LLMs)
For LLM token streaming:
from fastapi.responses import StreamingResponse
import asyncio
@router.post("/chat/stream")
async def chat_stream(request: ChatRequest):
async def generate():
async for token in llm.stream(request.message):
yield f"data: {token}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(
generate(),
media_type="text/event-stream"
)
For real-world LLM streaming with proper error handling:
from openai import AsyncOpenAI
import json
client = AsyncOpenAI()
@router.post("/chat")
async def chat(request: ChatRequest):
async def event_generator():
try:
stream = await client.chat.completions.create(
model="gpt-4",
messages=[{"role": "user", "content": request.message}],
stream=True
)
async for chunk in stream:
if chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
yield f"data: {json.dumps({'content': content})}\n\n"
yield f"data: {json.dumps({'done': True})}\n\n"
except Exception as e:
yield f"data: {json.dumps({'error': str(e)})}\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
}
)
Model Loading Patterns
Lazy Loading
Load models only when first requested:
from functools import lru_cache
class ModelManager:
def __init__(self):
self._models = {}
def get_model(self, model_name: str):
if model_name not in self._models:
self._models[model_name] = self._load_model(model_name)
return self._models[model_name]
def _load_model(self, model_name: str):
# Load based on name
if model_name == "classifier":
return load_classifier()
elif model_name == "embedder":
return load_embedder()
raise ValueError(f"Unknown model: {model_name}")
model_manager = ModelManager()
# In endpoint
def get_model_manager():
return model_manager
@router.post("/embed")
async def embed(
request: EmbedRequest,
manager: ModelManager = Depends(get_model_manager)
):
model = manager.get_model("embedder")
return model.embed(request.text)
Multiple Model Versions
Support A/B testing with model versions:
class ModelRegistry:
def __init__(self):
self.models = {
"v1": load_model_v1(),
"v2": load_model_v2(),
}
self.default_version = "v2"
def predict(self, input_data, version: str = None):
version = version or self.default_version
model = self.models.get(version)
if not model:
raise ValueError(f"Model version {version} not found")
return model.predict(input_data)
@router.post("/predict")
async def predict(
request: PredictionRequest,
model_version: str = Query(default=None)
):
return registry.predict(request.data, model_version)
Request Validation
Pydantic makes input validation easy:
from pydantic import BaseModel, Field, validator
from typing import List, Optional
class TextGenerationRequest(BaseModel):
prompt: str = Field(..., min_length=1, max_length=4096)
max_tokens: int = Field(default=256, ge=1, le=4096)
temperature: float = Field(default=0.7, ge=0, le=2)
stop_sequences: Optional[List[str]] = Field(default=None, max_items=4)
@validator("prompt")
def prompt_not_empty(cls, v):
if not v.strip():
raise ValueError("Prompt cannot be empty or whitespace")
return v.strip()
@validator("stop_sequences", each_item=True)
def stop_sequence_valid(cls, v):
if len(v) > 20:
raise ValueError("Stop sequence too long")
return v
Error Handling
Consistent error responses:
from fastapi import Request
from fastapi.responses import JSONResponse
class MLException(Exception):
def __init__(self, message: str, error_code: str, status_code: int = 500):
self.message = message
self.error_code = error_code
self.status_code = status_code
class ModelNotLoadedError(MLException):
def __init__(self, model_name: str):
super().__init__(
message=f"Model {model_name} is not loaded",
error_code="MODEL_NOT_LOADED",
status_code=503
)
class InvalidInputError(MLException):
def __init__(self, detail: str):
super().__init__(
message=detail,
error_code="INVALID_INPUT",
status_code=400
)
@app.exception_handler(MLException)
async def ml_exception_handler(request: Request, exc: MLException):
return JSONResponse(
status_code=exc.status_code,
content={
"error": exc.error_code,
"message": exc.message,
"path": str(request.url)
}
)
Health Checks
Essential for production deployments:
from enum import Enum
class HealthStatus(str, Enum):
HEALTHY = "healthy"
DEGRADED = "degraded"
UNHEALTHY = "unhealthy"
class HealthResponse(BaseModel):
status: HealthStatus
model_loaded: bool
gpu_available: bool
version: str
@router.get("/health", response_model=HealthResponse)
async def health_check():
model_loaded = "classifier" in ml_models
gpu_available = torch.cuda.is_available()
if model_loaded and gpu_available:
status = HealthStatus.HEALTHY
elif model_loaded:
status = HealthStatus.DEGRADED
else:
status = HealthStatus.UNHEALTHY
return HealthResponse(
status=status,
model_loaded=model_loaded,
gpu_available=gpu_available,
version=app.version
)
# Kubernetes-style probes
@router.get("/health/live")
async def liveness():
return {"status": "alive"}
@router.get("/health/ready")
async def readiness():
if "classifier" not in ml_models:
raise HTTPException(status_code=503, detail="Model not ready")
return {"status": "ready"}
Rate Limiting
Protect your API from abuse:
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
@router.post("/predict")
@limiter.limit("10/minute")
async def predict(request: Request, data: PredictionRequest):
# Rate limited to 10 requests per minute per IP
return await run_prediction(data)
Middleware
Add observability:
import time
import logging
from uuid import uuid4
logger = logging.getLogger(__name__)
@app.middleware("http")
async def logging_middleware(request: Request, call_next):
request_id = str(uuid4())
start_time = time.perf_counter()
# Add request ID to state
request.state.request_id = request_id
response = await call_next(request)
duration = time.perf_counter() - start_time
logger.info(
"Request completed",
extra={
"request_id": request_id,
"method": request.method,
"path": request.url.path,
"status_code": response.status_code,
"duration_ms": round(duration * 1000, 2)
}
)
response.headers["X-Request-ID"] = request_id
return response
Deployment
Docker
FROM python:3.11-slim
WORKDIR /app
# Install dependencies
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# Copy application
COPY app/ app/
# Non-root user
RUN useradd -m appuser && chown -R appuser:appuser /app
USER appuser
# Run with uvicorn
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
Production Configuration
# app/config.py
from pydantic_settings import BaseSettings
class Settings(BaseSettings):
app_name: str = "ML API"
debug: bool = False
model_path: str = "/models/classifier.pt"
max_batch_size: int = 32
request_timeout: int = 30
workers: int = 4
class Config:
env_file = ".env"
settings = Settings()
Run with Gunicorn for production:
gunicorn app.main:app \
--workers 4 \
--worker-class uvicorn.workers.UvicornWorker \
--bind 0.0.0.0:8000 \
--timeout 120 \
--keep-alive 5
Testing
from fastapi.testclient import TestClient
from app.main import app
import pytest
client = TestClient(app)
def test_health_check():
response = client.get("/health")
assert response.status_code == 200
assert response.json()["status"] == "healthy"
def test_prediction():
response = client.post(
"/api/v1/predict",
json={"image_url": "https://example.com/cat.jpg"}
)
assert response.status_code == 200
assert "predictions" in response.json()
def test_invalid_input():
response = client.post(
"/api/v1/predict",
json={"image_url": ""}
)
assert response.status_code == 422 # Validation error
Conclusion
FastAPI provides everything you need for production ML APIs. Start simple, add complexity as needed, and always prioritize reliability and observability. Your users don't care how clever your API is they care that it works, every time.