Back to Blog
MLCost OptimizationCloudInfrastructure

Cost Optimization Strategies for ML Workloads

Practical techniques to reduce your machine learning infrastructure costs without sacrificing performance or reliability.

David Kim
8 min read

Cost Optimization Strategies for ML Workloads

Machine learning workloads can quickly become your largest cloud expense. Between GPU-intensive training, large-scale data storage, and high-throughput inference serving, costs can spiral out of control. Here's how to optimize.

Understanding ML Cost Structure

Training Costs (60-70% of ML spend)

  • Compute: GPUs/TPUs for model training
  • Storage: Training datasets and checkpoints
  • Data transfer: Moving data between storage and compute
  • Experimentation: Multiple training runs and hyperparameter tuning

Inference Costs (25-35% of ML spend)

  • Serving infrastructure: Always-on prediction endpoints
  • Compute: CPU/GPU for model inference
  • Caching: Storing frequent predictions
  • Scaling: Auto-scaling based on traffic

Data Costs (5-10% of ML spend)

  • Storage: Raw data, processed features, model artifacts
  • Processing: ETL pipelines, feature engineering
  • Transfer: Moving data between regions/clouds

Training Optimization Strategies

1. Use Spot/Preemptible Instances

Save 60-90% on training costs by using interruptible compute.

# Example: Training with checkpointing for spot instances
import tensorflow as tf
from datetime import datetime

def train_with_checkpoints(model, dataset, checkpoint_dir):
    # Configure checkpointing
    checkpoint = tf.train.Checkpoint(model=model)
    manager = tf.train.CheckpointManager(
        checkpoint, directory=checkpoint_dir, max_to_keep=3
    )
    
    # Resume from latest checkpoint if exists
    if manager.latest_checkpoint:
        checkpoint.restore(manager.latest_checkpoint)
        print(f"Restored from {manager.latest_checkpoint}")
    
    # Training loop with frequent checkpoints
    for epoch in range(num_epochs):
        for batch in dataset:
            train_step(model, batch)
        
        # Save checkpoint every epoch
        save_path = manager.save()
        print(f"Saved checkpoint: {save_path}")

Best practices:

  • Save checkpoints frequently (every epoch or every N minutes)
  • Use auto-restart scripts when instances terminate
  • Acceptable for training, not for inference

2. Optimize Batch Sizes

Larger batches = better GPU utilization = faster training = lower costs.

def find_optimal_batch_size(model, dataset, max_memory=16):
    """
    Binary search to find maximum batch size that fits in memory
    """
    min_batch = 1
    max_batch = 1024
    optimal_batch = min_batch
    
    while min_batch <= max_batch:
        batch_size = (min_batch + max_batch) // 2
        try:
            # Test if batch size fits
            test_batch = next(iter(dataset.batch(batch_size)))
            model(test_batch)  # Forward pass
            optimal_batch = batch_size
            min_batch = batch_size + 1
        except tf.errors.ResourceExhaustedError:
            max_batch = batch_size - 1
    
    return optimal_batch

# Usage
optimal_batch = find_optimal_batch_size(model, train_dataset)
print(f"Optimal batch size: {optimal_batch}")

3. Mixed Precision Training

Use FP16 instead of FP32 to reduce memory usage and increase speed.

import tensorflow as tf

# Enable mixed precision
policy = tf.keras.mixed_precision.Policy('mixed_float16')
tf.keras.mixed_precision.set_global_policy(policy)

# Build model (automatically uses mixed precision)
model = create_model()

# Benefits:
# - 2x faster training
# - 50% less memory usage
# - Minimal accuracy impact

Savings: 30-50% reduction in training costs

4. Distributed Training

Scale horizontally to reduce wall-clock time.

import tensorflow as tf

# Multi-GPU training strategy
strategy = tf.distribute.MirroredStrategy()

with strategy.scope():
    model = create_model()
    model.compile(...)

# Train on all available GPUs
model.fit(train_dataset, epochs=10)

# 4 GPUs = 3.5x speedup (not quite 4x due to overhead)
# Result: Same cost, 3.5x faster iteration

5. Transfer Learning

Start with pre-trained models instead of training from scratch.

import tensorflow as tf
from tensorflow.keras.applications import ResNet50

# Load pre-trained model
base_model = ResNet50(weights='imagenet', include_top=False)
base_model.trainable = False  # Freeze base layers

# Add custom top layers
model = tf.keras.Sequential([
    base_model,
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(num_classes, activation='softmax')
])

# Fine-tune only top layers (much faster and cheaper)
model.compile(...)
model.fit(train_dataset, epochs=10)

Savings: 70-90% reduction in training time and cost

Inference Optimization Strategies

1. Model Quantization

Reduce model size and inference cost with minimal accuracy loss.

import tensorflow as tf

# Post-training quantization
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.optimizations = [tf.lite.Optimize.DEFAULT]

# Quantize to INT8
converter.target_spec.supported_types = [tf.int8]
tflite_model = converter.convert()

# Benefits:
# - 4x smaller model size
# - 2-3x faster inference
# - Can run on CPU instead of GPU

2. Model Caching

Cache predictions for repeated inputs.

from functools import lru_cache
import hashlib

class CachedModel:
    def __init__(self, model, cache_size=10000):
        self.model = model
        self.cache = {}
        self.cache_size = cache_size
        self.hits = 0
        self.misses = 0
    
    def predict(self, input_data):
        # Create cache key
        key = hashlib.md5(str(input_data).encode()).hexdigest()
        
        # Check cache
        if key in self.cache:
            self.hits += 1
            return self.cache[key]
        
        # Compute prediction
        self.misses += 1
        prediction = self.model.predict(input_data)
        
        # Update cache (with LRU eviction)
        if len(self.cache) >= self.cache_size:
            # Remove oldest entry
            self.cache.pop(next(iter(self.cache)))
        self.cache[key] = prediction
        
        return prediction
    
    @property
    def cache_hit_rate(self):
        total = self.hits + self.misses
        return self.hits / total if total > 0 else 0

# Usage
cached_model = CachedModel(model)
# 20% cache hit rate = 20% cost reduction

3. Batch Inference

Process multiple requests together for better efficiency.

import asyncio
from collections import deque

class BatchedInference:
    def __init__(self, model, max_batch_size=32, max_wait_time=0.1):
        self.model = model
        self.max_batch_size = max_batch_size
        self.max_wait_time = max_wait_time
        self.queue = deque()
    
    async def predict(self, input_data):
        # Add to queue
        future = asyncio.Future()
        self.queue.append((input_data, future))
        
        # Process batch if full or timeout
        if len(self.queue) >= self.max_batch_size:
            await self._process_batch()
        else:
            asyncio.create_task(self._wait_and_process())
        
        return await future
    
    async def _process_batch(self):
        if not self.queue:
            return
        
        # Collect batch
        batch_data = []
        batch_futures = []
        while self.queue and len(batch_data) < self.max_batch_size:
            data, future = self.queue.popleft()
            batch_data.append(data)
            batch_futures.append(future)
        
        # Process batch
        predictions = self.model.predict(batch_data)
        
        # Return results
        for future, prediction in zip(batch_futures, predictions):
            future.set_result(prediction)

Savings: 40-60% reduction in inference costs

4. Auto-Scaling

Scale inference capacity based on demand.

# Kubernetes HPA configuration
apiVersion: autoscaling/v2
kind: HorizontalPodAutoscaler
metadata:
  name: ml-model-server
spec:
  scaleTargetRef:
    apiVersion: apps/v1
    kind: Deployment
    name: ml-model-server
  minReplicas: 2
  maxReplicas: 20
  metrics:
  - type: Resource
    resource:
      name: cpu
      target:
        type: Utilization
        averageUtilization: 70
  behavior:
    scaleDown:
      stabilizationWindowSeconds: 300
      policies:
      - type: Percent
        value: 50
        periodSeconds: 60

Data Cost Optimization

1. Lifecycle Policies

Move cold data to cheaper storage tiers automatically.

# AWS S3 lifecycle policy example
lifecycle_policy = {
    'Rules': [
        {
            'Id': 'MoveToGlacier',
            'Status': 'Enabled',
            'Transitions': [
                {
                    'Days': 90,
                    'StorageClass': 'STANDARD_IA'  # Infrequent Access
                },
                {
                    'Days': 180,
                    'StorageClass': 'GLACIER'  # Long-term archive
                }
            ]
        },
        {
            'Id': 'DeleteOldData',
            'Status': 'Enabled',
            'Expiration': {
                'Days': 365
            },
            'Filter': {
                'Prefix': 'temporary/'
            }
        }
    ]
}

Savings: 50-90% on long-term data storage

2. Data Compression

Compress datasets to reduce storage and transfer costs.

import pandas as pd
import pickle
import gzip

# Compress training data
def save_compressed(data, filename):
    with gzip.open(f"{filename}.gz", 'wb') as f:
        pickle.dump(data, f)

def load_compressed(filename):
    with gzip.open(f"{filename}.gz", 'rb') as f:
        return pickle.load(f)

# Usage
save_compressed(train_data, 'train_data')
# Result: 70-90% size reduction for typical datasets

3. Feature Store

Avoid recomputing features across projects.

class FeatureStore:
    def __init__(self, cache_backend):
        self.cache = cache_backend
        self.compute_stats = {'hits': 0, 'misses': 0}
    
    def get_features(self, entity_id, feature_names):
        cache_key = f"{entity_id}:{':'.join(feature_names)}"
        
        # Try cache first
        cached = self.cache.get(cache_key)
        if cached:
            self.compute_stats['hits'] += 1
            return cached
        
        # Compute features
        self.compute_stats['misses'] += 1
        features = self._compute_features(entity_id, feature_names)
        
        # Cache for reuse
        self.cache.set(cache_key, features, ttl=3600)
        return features

Monitoring and Optimization

Track Key Metrics

import logging
from datetime import datetime

class CostMetrics:
    def log_training_run(self, duration_hours, instance_type, cost_per_hour):
        total_cost = duration_hours * cost_per_hour
        logging.info({
            'event': 'training_complete',
            'timestamp': datetime.now(),
            'duration_hours': duration_hours,
            'instance_type': instance_type,
            'total_cost': total_cost,
            'cost_per_epoch': total_cost / num_epochs
        })
    
    def log_inference_batch(self, batch_size, latency_ms, cost):
        logging.info({
            'event': 'inference_batch',
            'timestamp': datetime.now(),
            'batch_size': batch_size,
            'latency_ms': latency_ms,
            'cost_per_prediction': cost / batch_size
        })

Set Up Alerts

def check_cost_anomalies(current_spend, historical_avg, threshold=1.5):
    if current_spend > historical_avg * threshold:
        alert_team(f"Cost spike detected: ${current_spend} vs avg ${historical_avg}")
        return True
    return False

Implementation Roadmap

Week 1-2: Baseline and analyze current costs

  • Audit current spending
  • Identify top cost drivers
  • Document current architecture

Week 3-4: Quick wins

  • Implement spot instances for training
  • Add basic caching
  • Set up monitoring

Week 5-8: Medium-term optimizations

  • Deploy model quantization
  • Implement batch inference
  • Optimize data pipelines

Week 9-12: Advanced optimizations

  • Fine-tune auto-scaling
  • Implement feature store
  • Optimize model architecture

Expected Savings

| Optimization | Typical Savings | Difficulty | Priority | |---|---|---|---| | Spot instances | 60-90% | Low | High | | Mixed precision | 30-50% | Low | High | | Model quantization | 40-60% | Medium | High | | Batch inference | 40-60% | Medium | High | | Transfer learning | 70-90% | Low | Medium | | Data compression | 70-90% | Low | Medium | | Auto-scaling | 20-40% | Medium | Medium | | Feature caching | 10-30% | Medium | Low |

Total potential savings: 50-70% reduction in ML infrastructure costs

Conclusion

ML cost optimization isn't a one-time effort—it's an ongoing practice. Start with the high-impact, low-effort optimizations (spot instances, mixed precision), then work your way through more complex improvements.

Remember: the goal isn't to minimize costs at all costs. It's to maximize value per dollar spent. Sometimes spending more on infrastructure enables faster iteration and better models, which drives more business value.

Need help optimizing your ML infrastructure costs? Contact us for a free cost assessment.


Part of our ML Operations series

Share this article