Jared AI Hub
Published on

Test-Time Augmentation (TTA): Boosting Model Performance at Inference

Authors
  • avatar
    Name
    Jared Chung
    Twitter

Test-Time Augmentation (TTA) is a powerful technique that improves model predictions by applying data augmentation not just during training, but also at inference time. By creating multiple augmented versions of a test image and aggregating their predictions, TTA can significantly boost model performance with minimal computational overhead.

In this comprehensive guide, we'll explore various TTA strategies, implementation techniques, and advanced methods for maximizing their effectiveness.

What is Test-Time Augmentation?

Test-Time Augmentation involves:

  1. Creating multiple versions of the input image using various transformations
  2. Running inference on each augmented version
  3. Aggregating predictions (typically averaging) to get the final result
  4. Improving robustness by reducing prediction variance

The key insight is that while individual predictions may vary, the averaged prediction across multiple augmentations tends to be more reliable and accurate.

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from typing import List, Callable, Union
import cv2

# Basic TTA implementation
class TestTimeAugmentation:
    def __init__(self, model, device='cuda'):
        self.model = model
        self.device = device
        self.model.eval()
    
    def predict_with_tta(self, image, augmentations, aggregation='mean'):
        """
        Perform TTA on a single image
        
        Args:
            image: Input image (PIL Image or tensor)
            augmentations: List of augmentation functions
            aggregation: Method to aggregate predictions ('mean', 'geometric_mean', 'voting')
        """
        predictions = []
        
        with torch.no_grad():
            # Get prediction for original image
            original_pred = self._get_prediction(image)
            predictions.append(original_pred)
            
            # Get predictions for augmented versions
            for aug_fn in augmentations:
                aug_image = aug_fn(image)
                aug_pred = self._get_prediction(aug_image)
                predictions.append(aug_pred)
        
        # Aggregate predictions
        predictions = torch.stack(predictions)
        
        if aggregation == 'mean':
            final_pred = torch.mean(predictions, dim=0)
        elif aggregation == 'geometric_mean':
            final_pred = torch.exp(torch.mean(torch.log(predictions + 1e-8), dim=0))
        elif aggregation == 'voting':
            # Hard voting (for classification)
            votes = torch.argmax(predictions, dim=-1)
            final_pred = torch.mode(votes, dim=0)[0]
        else:
            raise ValueError(f"Unknown aggregation method: {aggregation}")
        
        return final_pred, predictions
    
    def _get_prediction(self, image):
        """Get model prediction for a single image"""
        if isinstance(image, Image.Image):
            # Convert PIL to tensor
            transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                   std=[0.229, 0.224, 0.225])
            ])
            image = transform(image).unsqueeze(0)
        
        image = image.to(self.device)
        output = self.model(image)
        
        # Apply softmax for classification
        if output.dim() == 2 and output.size(1) over 1:
            output = F.softmax(output, dim=1)
        
        return output.squeeze(0)

Standard TTA Transformations

Geometric Augmentations

class GeometricTTA:
    """Standard geometric transformations for TTA"""
    
    @staticmethod
    def horizontal_flip(image):
        """Horizontal flip"""
        if isinstance(image, Image.Image):
            return transforms.functional.hflip(image)
        return torch.flip(image, [-1])
    
    @staticmethod
    def vertical_flip(image):
        """Vertical flip"""
        if isinstance(image, Image.Image):
            return transforms.functional.vflip(image)
        return torch.flip(image, [-2])
    
    @staticmethod
    def rotation_90(image):
        """90-degree rotation"""
        if isinstance(image, Image.Image):
            return image.rotate(90, expand=True)
        return torch.rot90(image, k=1, dims=[-2, -1])
    
    @staticmethod
    def rotation_180(image):
        """180-degree rotation"""
        if isinstance(image, Image.Image):
            return image.rotate(180)
        return torch.rot90(image, k=2, dims=[-2, -1])
    
    @staticmethod
    def rotation_270(image):
        """270-degree rotation"""
        if isinstance(image, Image.Image):
            return image.rotate(270, expand=True)
        return torch.rot90(image, k=3, dims=[-2, -1])
    
    @staticmethod
    def transpose(image):
        """Transpose (diagonal flip)"""
        if isinstance(image, Image.Image):
            return image.transpose(Image.TRANSPOSE)
        return torch.transpose(image, -2, -1)
    
    @staticmethod
    def transverse(image):
        """Transverse (anti-diagonal flip)"""
        if isinstance(image, Image.Image):
            return image.transpose(Image.TRANSVERSE)
        return torch.flip(torch.transpose(image, -2, -1), [-1])

# Eight-fold TTA (D4 group transformations)
def get_d4_transformations():
    """Get all 8 transformations of the D4 dihedral group"""
    return [
        lambda x: x,  # Identity
        GeometricTTA.horizontal_flip,
        GeometricTTA.vertical_flip,
        GeometricTTA.rotation_90,
        GeometricTTA.rotation_180,
        GeometricTTA.rotation_270,
        GeometricTTA.transpose,
        GeometricTTA.transverse
    ]

Multi-Scale TTA

class MultiScaleTTA:
    """Multi-scale test-time augmentation"""
    
    def __init__(self, scales=[0.8, 0.9, 1.0, 1.1, 1.2], base_size=224):
        self.scales = scales
        self.base_size = base_size
    
    def generate_scale_transforms(self):
        """Generate transforms for different scales"""
        transforms_list = []
        
        for scale in self.scales:
            size = int(self.base_size * scale)
            transform = transforms.Compose([
                transforms.Resize((size, size)),
                transforms.CenterCrop(self.base_size),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                   std=[0.229, 0.224, 0.225])
            ])
            transforms_list.append(transform)
        
        return transforms_list
    
    def predict_multi_scale(self, model, image, device='cuda'):
        """Predict with multi-scale TTA"""
        model.eval()
        predictions = []
        
        with torch.no_grad():
            for transform in self.generate_scale_transforms():
                scaled_image = transform(image).unsqueeze(0).to(device)
                output = model(scaled_image)
                
                if output.dim() == 2 and output.size(1) over 1:
                    output = F.softmax(output, dim=1)
                
                predictions.append(output.squeeze(0))
        
        # Average predictions
        final_prediction = torch.mean(torch.stack(predictions), dim=0)
        return final_prediction

Crop-based TTA

class CropBasedTTA:
    """TTA using different crop positions"""
    
    def __init__(self, crop_size=224, stride=32):
        self.crop_size = crop_size
        self.stride = stride
    
    def generate_crops(self, image):
        """Generate multiple crops from different positions"""
        if isinstance(image, Image.Image):
            w, h = image.size
            image_np = np.array(image)
        else:
            h, w = image.shape[-2:]
            image_np = image.permute(1, 2, 0).numpy()
        
        crops = []
        positions = []
        
        # Center crop
        center_x, center_y = w // 2, h // 2
        x1 = center_x - self.crop_size // 2
        y1 = center_y - self.crop_size // 2
        x2 = x1 + self.crop_size
        y2 = y1 + self.crop_size
        
        if x1 >= 0 and y1 >= 0 and x2 <= w and y2 <= h:
            crop = image_np[y1:y2, x1:x2]
            crops.append(Image.fromarray(crop))
            positions.append((x1, y1, x2, y2))
        
        # Corner crops
        corners = [
            (0, 0),  # Top-left
            (w - self.crop_size, 0),  # Top-right
            (0, h - self.crop_size),  # Bottom-left
            (w - self.crop_size, h - self.crop_size)  # Bottom-right
        ]
        
        for x1, y1 in corners:
            if x1 >= 0 and y1 >= 0:
                x2, y2 = x1 + self.crop_size, y1 + self.crop_size
                crop = image_np[y1:y2, x1:x2]
                crops.append(Image.fromarray(crop))
                positions.append((x1, y1, x2, y2))
        
        # Random crops
        for _ in range(5):
            x1 = np.random.randint(0, max(1, w - self.crop_size))
            y1 = np.random.randint(0, max(1, h - self.crop_size))
            x2, y2 = x1 + self.crop_size, y1 + self.crop_size
            
            crop = image_np[y1:y2, x1:x2]
            crops.append(Image.fromarray(crop))
            positions.append((x1, y1, x2, y2))
        
        return crops, positions
    
    def predict_with_crops(self, model, image, device='cuda'):
        """Predict using multiple crops"""
        crops, positions = self.generate_crops(image)
        
        model.eval()
        predictions = []
        
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                               std=[0.229, 0.224, 0.225])
        ])
        
        with torch.no_grad():
            for crop in crops:
                crop_tensor = transform(crop).unsqueeze(0).to(device)
                output = model(crop_tensor)
                
                if output.dim() == 2 and output.size(1) over 1:
                    output = F.softmax(output, dim=1)
                
                predictions.append(output.squeeze(0))
        
        # Average predictions
        final_prediction = torch.mean(torch.stack(predictions), dim=0)
        return final_prediction, predictions

Advanced TTA Techniques

Noise-based TTA

class NoiseTTA:
    """TTA using different types of noise"""
    
    def __init__(self, noise_levels=[0.01, 0.02, 0.03]):
        self.noise_levels = noise_levels
    
    def add_gaussian_noise(self, image, std):
        """Add Gaussian noise to image"""
        if isinstance(image, Image.Image):
            image = transforms.functional.to_tensor(image)
        
        noise = torch.randn_like(image) * std
        noisy_image = torch.clamp(image + noise, 0, 1)
        
        return transforms.functional.to_pil_image(noisy_image)
    
    def add_uniform_noise(self, image, magnitude):
        """Add uniform noise to image"""
        if isinstance(image, Image.Image):
            image = transforms.functional.to_tensor(image)
        
        noise = (torch.rand_like(image) - 0.5) * 2 * magnitude
        noisy_image = torch.clamp(image + noise, 0, 1)
        
        return transforms.functional.to_pil_image(noisy_image)
    
    def generate_noise_transforms(self):
        """Generate noise-based transformations"""
        transforms_list = []
        
        for std in self.noise_levels:
            transforms_list.append(lambda x, s=std: self.add_gaussian_noise(x, s))
            transforms_list.append(lambda x, m=std: self.add_uniform_noise(x, m))
        
        return transforms_list

# Color augmentation TTA
class ColorTTA:
    """TTA using color space manipulations"""
    
    def __init__(self):
        pass
    
    def brightness_variants(self, image, factors=[0.9, 1.0, 1.1]):
        """Generate brightness variants"""
        variants = []
        for factor in factors:
            enhancer = transforms.ColorJitter(brightness=factor)
            variants.append(enhancer(image))
        return variants
    
    def contrast_variants(self, image, factors=[0.9, 1.0, 1.1]):
        """Generate contrast variants"""
        variants = []
        for factor in factors:
            enhancer = transforms.ColorJitter(contrast=factor)
            variants.append(enhancer(image))
        return variants
    
    def saturation_variants(self, image, factors=[0.9, 1.0, 1.1]):
        """Generate saturation variants"""
        variants = []
        for factor in factors:
            enhancer = transforms.ColorJitter(saturation=factor)
            variants.append(enhancer(image))
        return variants

Ensemble TTA

class EnsembleTTA:
    """TTA with model ensembles"""
    
    def __init__(self, models, device='cuda'):
        self.models = models
        self.device = device
        
        # Set all models to evaluation mode
        for model in self.models:
            model.eval()
    
    def predict_ensemble_tta(self, image, augmentations, 
                           model_weights=None, aug_weights=None):
        """
        Perform TTA with multiple models
        
        Args:
            image: Input image
            augmentations: List of augmentation functions
            model_weights: Weights for each model (optional)
            aug_weights: Weights for each augmentation (optional)
        """
        if model_weights is None:
            model_weights = [1.0] * len(self.models)
        
        if aug_weights is None:
            aug_weights = [1.0] * (len(augmentations) + 1)  # +1 for original
        
        all_predictions = []
        
        # Get predictions from all models and augmentations
        with torch.no_grad():
            for model_idx, model in enumerate(self.models):
                model_preds = []
                
                # Original image
                original_pred = self._get_prediction(model, image)
                model_preds.append(original_pred * aug_weights[0])
                
                # Augmented images
                for aug_idx, aug_fn in enumerate(augmentations):
                    aug_image = aug_fn(image)
                    aug_pred = self._get_prediction(model, aug_image)
                    model_preds.append(aug_pred * aug_weights[aug_idx + 1])
                
                # Weight by model importance
                weighted_model_pred = torch.stack(model_preds).sum(dim=0) * model_weights[model_idx]
                all_predictions.append(weighted_model_pred)
        
        # Combine all predictions
        final_prediction = torch.stack(all_predictions).sum(dim=0)
        final_prediction = final_prediction / final_prediction.sum()  # Normalize
        
        return final_prediction
    
    def _get_prediction(self, model, image):
        """Get prediction from a single model"""
        if isinstance(image, Image.Image):
            transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                   std=[0.229, 0.224, 0.225])
            ])
            image = transform(image).unsqueeze(0)
        
        image = image.to(self.device)
        output = model(image)
        
        if output.dim() == 2 and output.size(1) over 1:
            output = F.softmax(output, dim=1)
        
        return output.squeeze(0)

Segmentation TTA

class SegmentationTTA:
    """TTA for segmentation tasks"""
    
    def __init__(self, model, device='cuda'):
        self.model = model
        self.device = device
        self.model.eval()
    
    def predict_with_tta(self, image, return_all=False):
        """
        Perform TTA for segmentation
        
        Args:
            image: Input image tensor [C, H, W]
            return_all: Whether to return all predictions
        """
        predictions = []
        
        with torch.no_grad():
            # Original
            pred = self._predict_single(image)
            predictions.append(pred)
            
            # Horizontal flip
            flipped_h = torch.flip(image, [-1])
            pred_h = self._predict_single(flipped_h)
            pred_h = torch.flip(pred_h, [-1])  # Flip back
            predictions.append(pred_h)
            
            # Vertical flip
            flipped_v = torch.flip(image, [-2])
            pred_v = self._predict_single(flipped_v)
            pred_v = torch.flip(pred_v, [-2])  # Flip back
            predictions.append(pred_v)
            
            # Both flips
            flipped_hv = torch.flip(image, [-2, -1])
            pred_hv = self._predict_single(flipped_hv)
            pred_hv = torch.flip(pred_hv, [-2, -1])  # Flip back
            predictions.append(pred_hv)
            
            # Rotations
            for k in [1, 2, 3]:  # 90, 180, 270 degrees
                rotated = torch.rot90(image, k=k, dims=[-2, -1])
                pred_rot = self._predict_single(rotated)
                pred_rot = torch.rot90(pred_rot, k=-k, dims=[-2, -1])  # Rotate back
                predictions.append(pred_rot)
        
        # Average predictions
        predictions = torch.stack(predictions)
        avg_prediction = torch.mean(predictions, dim=0)
        
        if return_all:
            return avg_prediction, predictions
        return avg_prediction
    
    def _predict_single(self, image):
        """Get prediction for a single image"""
        if image.dim() == 3:
            image = image.unsqueeze(0)
        
        image = image.to(self.device)
        with torch.no_grad():
            output = self.model(image)
            
            # Apply softmax if multi-class
            if output.size(1) over 1:
                output = F.softmax(output, dim=1)
        
        return output.squeeze(0)
    
    def multi_scale_tta(self, image, scales=[0.75, 1.0, 1.25]):
        """Multi-scale TTA for segmentation"""
        original_size = image.shape[-2:]
        predictions = []
        
        with torch.no_grad():
            for scale in scales:
                # Resize image
                new_size = [int(s * scale) for s in original_size]
                scaled_image = F.interpolate(
                    image.unsqueeze(0), 
                    size=new_size, 
                    mode='bilinear', 
                    align_corners=False
                ).squeeze(0)
                
                # Get prediction
                pred = self._predict_single(scaled_image)
                
                # Resize prediction back to original size
                pred = F.interpolate(
                    pred.unsqueeze(0),
                    size=original_size,
                    mode='bilinear',
                    align_corners=False
                ).squeeze(0)
                
                predictions.append(pred)
        
        # Average predictions
        avg_prediction = torch.mean(torch.stack(predictions), dim=0)
        return avg_prediction

Adaptive TTA

class AdaptiveTTA:
    """Adaptive TTA that selects augmentations based on uncertainty"""
    
    def __init__(self, model, device='cuda', uncertainty_threshold=0.1):
        self.model = model
        self.device = device
        self.uncertainty_threshold = uncertainty_threshold
        self.model.eval()
    
    def calculate_uncertainty(self, predictions):
        """Calculate prediction uncertainty"""
        if predictions.dim() == 1:
            # For single prediction
            entropy = -torch.sum(predictions * torch.log(predictions + 1e-8))
        else:
            # For multiple predictions
            mean_pred = torch.mean(predictions, dim=0)
            entropy = -torch.sum(mean_pred * torch.log(mean_pred + 1e-8))
        
        return entropy.item()
    
    def predict_adaptive_tta(self, image, max_augmentations=8):
        """
        Perform adaptive TTA - add augmentations until uncertainty is low
        """
        augmentations = [
            GeometricTTA.horizontal_flip,
            GeometricTTA.vertical_flip,
            GeometricTTA.rotation_90,
            GeometricTTA.rotation_180,
            lambda x: transforms.ColorJitter(brightness=0.1)(x),
            lambda x: transforms.ColorJitter(contrast=0.1)(x),
            lambda x: self._add_noise(x, 0.01),
            lambda x: self._scale_image(x, 1.1)
        ]
        
        predictions = []
        
        with torch.no_grad():
            # Start with original image
            original_pred = self._get_prediction(image)
            predictions.append(original_pred)
            
            uncertainty = self.calculate_uncertainty(torch.stack(predictions))
            
            # Add augmentations until uncertainty is low or max reached
            for i, aug_fn in enumerate(augmentations):
                if uncertainty < self.uncertainty_threshold or i >= max_augmentations:
                    break
                
                aug_image = aug_fn(image)
                aug_pred = self._get_prediction(aug_image)
                predictions.append(aug_pred)
                
                uncertainty = self.calculate_uncertainty(torch.stack(predictions))
        
        # Return average prediction and number of augmentations used
        final_pred = torch.mean(torch.stack(predictions), dim=0)
        return final_pred, len(predictions), uncertainty
    
    def _get_prediction(self, image):
        """Get model prediction"""
        if isinstance(image, Image.Image):
            transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                   std=[0.229, 0.224, 0.225])
            ])
            image = transform(image).unsqueeze(0)
        
        image = image.to(self.device)
        output = self.model(image)
        
        if output.dim() == 2 and output.size(1) over 1:
            output = F.softmax(output, dim=1)
        
        return output.squeeze(0)
    
    def _add_noise(self, image, std):
        """Add Gaussian noise"""
        if isinstance(image, Image.Image):
            image = transforms.functional.to_tensor(image)
        
        noise = torch.randn_like(image) * std
        noisy_image = torch.clamp(image + noise, 0, 1)
        return transforms.functional.to_pil_image(noisy_image)
    
    def _scale_image(self, image, scale):
        """Scale image"""
        if isinstance(image, Image.Image):
            w, h = image.size
            new_size = (int(w * scale), int(h * scale))
            return image.resize(new_size, Image.BILINEAR)
        return image

TTA Evaluation and Analysis

class TTAAnalyzer:
    """Analyze TTA performance and effectiveness"""
    
    def __init__(self, model, device='cuda'):
        self.model = model
        self.device = device
    
    def analyze_tta_impact(self, test_loader, augmentations, num_samples=100):
        """Analyze the impact of TTA on model performance"""
        self.model.eval()
        
        results = {
            'without_tta': [],
            'with_tta': [],
            'confidence_without': [],
            'confidence_with': [],
            'uncertainty_reduction': []
        }
        
        tta = TestTimeAugmentation(self.model, self.device)
        
        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(test_loader):
                if batch_idx >= num_samples:
                    break
                
                for i in range(data.size(0)):
                    image = data[i]
                    true_label = target[i].item()
                    
                    # Prediction without TTA
                    pred_without = self.model(image.unsqueeze(0).to(self.device))
                    pred_without = F.softmax(pred_without, dim=1).squeeze(0)
                    
                    # Prediction with TTA
                    pil_image = transforms.ToPILImage()(image)
                    pred_with, all_preds = tta.predict_with_tta(pil_image, augmentations)
                    
                    # Calculate accuracies
                    acc_without = int(torch.argmax(pred_without).item() == true_label)
                    acc_with = int(torch.argmax(pred_with).item() == true_label)
                    
                    # Calculate confidences
                    conf_without = torch.max(pred_without).item()
                    conf_with = torch.max(pred_with).item()
                    
                    # Calculate uncertainty reduction
                    uncertainty_without = self._calculate_entropy(pred_without)
                    uncertainty_with = self._calculate_entropy(pred_with)
                    uncertainty_reduction = uncertainty_without - uncertainty_with
                    
                    # Store results
                    results['without_tta'].append(acc_without)
                    results['with_tta'].append(acc_with)
                    results['confidence_without'].append(conf_without)
                    results['confidence_with'].append(conf_with)
                    results['uncertainty_reduction'].append(uncertainty_reduction)
        
        # Calculate summary statistics
        summary = {
            'accuracy_without_tta': np.mean(results['without_tta']),
            'accuracy_with_tta': np.mean(results['with_tta']),
            'accuracy_improvement': np.mean(results['with_tta']) - np.mean(results['without_tta']),
            'avg_confidence_without': np.mean(results['confidence_without']),
            'avg_confidence_with': np.mean(results['confidence_with']),
            'avg_uncertainty_reduction': np.mean(results['uncertainty_reduction'])
        }
        
        return summary, results
    
    def _calculate_entropy(self, probs):
        """Calculate entropy of probability distribution"""
        return -torch.sum(probs * torch.log(probs + 1e-8)).item()
    
    def visualize_tta_analysis(self, summary, results):
        """Visualize TTA analysis results"""
        fig, axes = plt.subplots(2, 2, figsize=(15, 10))
        
        # Accuracy comparison
        axes[0, 0].bar(['Without TTA', 'With TTA'], 
                      [summary['accuracy_without_tta'], summary['accuracy_with_tta']])
        axes[0, 0].set_ylabel('Accuracy')
        axes[0, 0].set_title('Accuracy Comparison')
        axes[0, 0].set_ylim(0, 1)
        
        # Confidence comparison
        axes[0, 1].hist(results['confidence_without'], alpha=0.7, label='Without TTA', bins=20)
        axes[0, 1].hist(results['confidence_with'], alpha=0.7, label='With TTA', bins=20)
        axes[0, 1].set_xlabel('Confidence')
        axes[0, 1].set_ylabel('Frequency')
        axes[0, 1].set_title('Confidence Distribution')
        axes[0, 1].legend()
        
        # Uncertainty reduction
        axes[1, 0].hist(results['uncertainty_reduction'], bins=20)
        axes[1, 0].set_xlabel('Uncertainty Reduction')
        axes[1, 0].set_ylabel('Frequency')
        axes[1, 0].set_title('Uncertainty Reduction Distribution')
        
        # Improvement vs original confidence
        axes[1, 1].scatter(results['confidence_without'], 
                          np.array(results['with_tta']) - np.array(results['without_tta']))
        axes[1, 1].set_xlabel('Original Confidence')
        axes[1, 1].set_ylabel('Accuracy Improvement')
        axes[1, 1].set_title('TTA Improvement vs Original Confidence')
        axes[1, 1].axhline(y=0, color='r', linestyle='--', alpha=0.5)
        
        plt.tight_layout()
        plt.show()

Efficient TTA Implementation

class EfficientTTA:
    """Memory and time efficient TTA implementation"""
    
    def __init__(self, model, device='cuda', batch_size=8):
        self.model = model
        self.device = device
        self.batch_size = batch_size
        self.model.eval()
    
    def predict_batch_tta(self, images, augmentations):
        """Perform TTA on a batch of images efficiently"""
        batch_predictions = []
        
        # Process in batches to manage memory
        for i in range(0, len(images), self.batch_size):
            batch_images = images[i:i + self.batch_size]
            batch_pred = self._process_batch(batch_images, augmentations)
            batch_predictions.extend(batch_pred)
        
        return batch_predictions
    
    def _process_batch(self, images, augmentations):
        """Process a single batch with TTA"""
        all_augmented = []
        
        # Prepare all augmented versions
        for image in images:
            # Original image
            all_augmented.append(image)
            
            # Augmented versions
            for aug_fn in augmentations:
                aug_image = aug_fn(image)
                all_augmented.append(aug_image)
        
        # Convert to tensor batch
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                               std=[0.229, 0.224, 0.225])
        ])
        
        batch_tensor = torch.stack([transform(img) for img in all_augmented])
        batch_tensor = batch_tensor.to(self.device)
        
        # Get predictions for entire batch
        with torch.no_grad():
            outputs = self.model(batch_tensor)
            if outputs.dim() == 2 and outputs.size(1) over 1:
                outputs = F.softmax(outputs, dim=1)
        
        # Reshape and average predictions
        num_augs = len(augmentations) + 1  # +1 for original
        outputs = outputs.view(len(images), num_augs, -1)
        averaged_outputs = torch.mean(outputs, dim=1)
        
        return averaged_outputs.cpu()
    
    def predict_with_memory_limit(self, image, augmentations, memory_limit_mb=1000):
        """Perform TTA with memory constraints"""
        import psutil
        import gc
        
        predictions = []
        
        # Monitor memory usage
        process = psutil.Process()
        initial_memory = process.memory_info().rss / 1024 / 1024  # MB
        
        with torch.no_grad():
            # Original prediction
            pred = self._get_prediction(image)
            predictions.append(pred)
            
            for aug_fn in augmentations:
                current_memory = process.memory_info().rss / 1024 / 1024
                
                if current_memory - initial_memory > memory_limit_mb:
                    # Clear cache and garbage collect
                    torch.cuda.empty_cache()
                    gc.collect()
                    break
                
                aug_image = aug_fn(image)
                aug_pred = self._get_prediction(aug_image)
                predictions.append(aug_pred)
        
        # Average predictions
        final_pred = torch.mean(torch.stack(predictions), dim=0)
        return final_pred, len(predictions)
    
    def _get_prediction(self, image):
        """Get prediction for single image"""
        if isinstance(image, Image.Image):
            transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                   std=[0.229, 0.224, 0.225])
            ])
            image = transform(image).unsqueeze(0)
        
        image = image.to(self.device)
        output = self.model(image)
        
        if output.dim() == 2 and output.size(1) over 1:
            output = F.softmax(output, dim=1)
        
        return output.squeeze(0)

Complete TTA Pipeline

class CompleteTTAPipeline:
    """Complete TTA pipeline with configurable strategies"""
    
    def __init__(self, model, device='cuda', config=None):
        self.model = model
        self.device = device
        self.model.eval()
        
        # Default configuration
        self.config = config or {
            'geometric_transforms': True,
            'color_transforms': False,
            'noise_transforms': False,
            'multi_scale': False,
            'crop_based': False,
            'adaptive': False,
            'scales': [0.9, 1.0, 1.1],
            'noise_levels': [0.01, 0.02],
            'uncertainty_threshold': 0.1,
            'max_augmentations': 8,
            'aggregation': 'mean'  # 'mean', 'geometric_mean', 'voting'
        }
    
    def get_augmentations(self):
        """Get augmentations based on configuration"""
        augmentations = []
        
        if self.config['geometric_transforms']:
            augmentations.extend([
                GeometricTTA.horizontal_flip,
                GeometricTTA.vertical_flip,
                GeometricTTA.rotation_90,
                GeometricTTA.rotation_180
            ])
        
        if self.config['color_transforms']:
            augmentations.extend([
                lambda x: transforms.ColorJitter(brightness=0.1)(x),
                lambda x: transforms.ColorJitter(contrast=0.1)(x),
                lambda x: transforms.ColorJitter(saturation=0.1)(x)
            ])
        
        if self.config['noise_transforms']:
            noise_tta = NoiseTTA(self.config['noise_levels'])
            augmentations.extend(noise_tta.generate_noise_transforms())
        
        return augmentations[:self.config['max_augmentations']]
    
    def predict(self, image, return_details=False):
        """Main prediction method"""
        augmentations = self.get_augmentations()
        
        if self.config['adaptive']:
            adaptive_tta = AdaptiveTTA(self.model, self.device, 
                                     self.config['uncertainty_threshold'])
            final_pred, num_augs, uncertainty = adaptive_tta.predict_adaptive_tta(
                image, self.config['max_augmentations']
            )
            
            if return_details:
                return final_pred, {'num_augmentations': num_augs, 'uncertainty': uncertainty}
            return final_pred
        
        elif self.config['multi_scale']:
            multi_scale_tta = MultiScaleTTA(self.config['scales'])
            final_pred = multi_scale_tta.predict_multi_scale(self.model, image, self.device)
            
            if return_details:
                return final_pred, {'num_augmentations': len(self.config['scales'])}
            return final_pred
        
        elif self.config['crop_based']:
            crop_tta = CropBasedTTA()
            final_pred, all_preds = crop_tta.predict_with_crops(self.model, image, self.device)
            
            if return_details:
                return final_pred, {'num_augmentations': len(all_preds)}
            return final_pred
        
        else:
            # Standard TTA
            tta = TestTimeAugmentation(self.model, self.device)
            final_pred, all_preds = tta.predict_with_tta(
                image, augmentations, self.config['aggregation']
            )
            
            if return_details:
                return final_pred, {'num_augmentations': len(all_preds)}
            return final_pred

# Usage example
def main():
    # Load model
    model = torchvision.models.resnet50(pretrained=True)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    
    # Configure TTA
    tta_config = {
        'geometric_transforms': True,
        'color_transforms': True,
        'multi_scale': True,
        'scales': [0.8, 0.9, 1.0, 1.1, 1.2],
        'aggregation': 'mean',
        'max_augmentations': 10
    }
    
    # Create TTA pipeline
    tta_pipeline = CompleteTTAPipeline(model, device, tta_config)
    
    # Load test image
    image = Image.open('test_image.jpg')
    
    # Get prediction with TTA
    prediction, details = tta_pipeline.predict(image, return_details=True)
    
    print(f"Prediction: {torch.argmax(prediction).item()}")
    print(f"Confidence: {torch.max(prediction).item():.4f}")
    print(f"Number of augmentations used: {details['num_augmentations']}")

if __name__ == "__main__":
    main()

Best Practices and Guidelines

When to Use TTA

  1. Competition settings where small improvements matter
  2. Critical applications requiring high reliability
  3. Limited training data scenarios
  4. When computational budget allows at inference time

TTA Selection Guidelines

def select_tta_strategy(task_type, model_type, computational_budget):
    """Guide for selecting appropriate TTA strategy"""
    
    strategies = {
        'classification': {
            'lightweight': ['horizontal_flip'],
            'standard': ['horizontal_flip', 'vertical_flip', 'rotation_90', 'rotation_180'],
            'heavy': ['d4_transforms', 'multi_scale', 'color_jitter']
        },
        'segmentation': {
            'lightweight': ['horizontal_flip', 'vertical_flip'],
            'standard': ['d4_transforms'],
            'heavy': ['d4_transforms', 'multi_scale']
        },
        'object_detection': {
            'lightweight': ['horizontal_flip'],
            'standard': ['horizontal_flip', 'multi_scale'],
            'heavy': ['horizontal_flip', 'multi_scale', 'crop_based']
        }
    }
    
    if computational_budget == 'low':
        budget_level = 'lightweight'
    elif computational_budget == 'medium':
        budget_level = 'standard'
    else:
        budget_level = 'heavy'
    
    return strategies.get(task_type, {}).get(budget_level, ['horizontal_flip'])

Conclusion

Test-Time Augmentation is a powerful technique that can provide significant performance improvements with relatively simple implementation. Key takeaways:

Benefits

  • Improved accuracy with minimal code changes
  • Better uncertainty estimation through prediction averaging
  • Increased robustness to input variations
  • Model-agnostic approach

Considerations

  • Computational cost increases linearly with augmentations
  • Diminishing returns beyond certain number of augmentations
  • Memory requirements can be significant
  • Task-specific augmentation selection is important

Best Practices

  • Start with simple geometric transforms
  • Use domain knowledge for augmentation selection
  • Monitor computational vs. accuracy trade-offs
  • Consider adaptive strategies for efficiency

TTA remains one of the most effective techniques for improving model performance at inference time, making it an essential tool in the computer vision practitioner's toolkit.

References

  • Krizhevsky, A., et al. (2012). "ImageNet Classification with Deep Convolutional Neural Networks."
  • Wang, G., et al. (2019). "Test-time augmentation for deep learning-based cell segmentation on microscopy images."
  • Shanmugam, D., et al. (2021). "Better Aggregation in Test-Time Augmentation."
  • Lyzhov, A., et al. (2020). "Greedy Policy Search: A Simple Baseline for Learnable Test-Time Augmentation."