Оптимизация
Optimization в ML и Production
В пространстве 175 миллиардов параметров вероятность случайно попасть в настоящий минимум - $2^{-175 \cdot 10^9}$. SGD никогда туда не попадает. И именно поэтому он работает.
- ResNet-50 inference: PyTorch eager 5.2 ms, TensorRT INT8 - 0.9 ms на A100. 5.8x за счёт fusion + квантизации.
- LLaMA-7B в FP16 требует 14 GB VRAM. INT4 через GPTQ - 3.5 GB, помещается на RTX 4060. Inference падает на 1.2x, потеря качества ниже 0.5 perplexity.
- Foret et al. (2021) показали: SAM поднимает ImageNet top-1 на 1.3 пункта без изменения архитектуры. Цена - 2 forward pass на шаг.
- torch.compile, добавленный одной строкой, ускоряет GPT-2 generation в 2.3x. Затраты на разработку - 0 человеко-часов.
Предварительные знания
## Оптимизация в реальном ML Теоретические гарантии сходимости работают на выпуклых функциях. Реальные нейросети - невыпуклые, с миллиардами параметров и странной геометрией loss landscape. В этом уроке: почему SGD находит хорошие решения (flat minima теория), как SAM явно ищет обобщаемые минимумы, как квантизация работает во время обучения, и как превратить PyTorch модель в production-ready inference pipeline с 5-10x ускорением.
## Итог: Optimization в ML Production **Loss landscape**: flat minima обобщают лучше, чем sharp. SAM находит их явно, выполняя 2 forward pass в итерации. Large batch → sharp minima → generalization gap (лечится linear scaling + warmup). **QAT vs PTQ**: Quantization-Aware Training симулирует INT8 во время обучения через FakeQuantize + STE. Результат: < 0.5% потери точности vs 1-3% у PTQ. **Inference pipeline**: PyTorch → ONNX → TensorRT даёт 4-6x ускорение. Operator fusion устраняет лишние round-trip в HBM. Flash Attention - 3x для трансформеров. torch.compile - одна строка для 2x ускорения. **Правило thumb**: начинай с torch.compile + FP16. Если недостаточно - TensorRT INT8 + Flash Attention. Для edge - ONNX Runtime + QAT.
Геометрия loss landscape
## Seddle points vs. flat minima Пространство параметров нейросети имеет особую геометрию: ``` Типичный loss landscape для нейросети: loss │ sharp minimum flat minimum │ ↓ ↓ │ /\ /‾‾‾‾‾‾\ │ / \ / \ │ / \_______________/ \_____ │ └──────────────────────────────────── params Sharp minimum: маленький basin, высокий Hessian → плохая генерализация Flat minimum: большой basin → хорошая генерализация (Hochreiter & Schmidhuber, 1997) ``` ## Saddle points в высоких измерениях Для сети с n параметрами Hessian имеет размер n×n. - В 2D: критическая точка - либо минимум, либо максимум - В nD: критическая точка, где k из n собственных значений отрицательны - **seddle point порядка k** ``` При случайном Hessian вероятность того, что все собственные значения положительны (минимум) ≈ 2^{-n} Для n=1000: P(настоящий минимум) ≈ 10^{-301} Вывод: SGD почти никогда не попадает в настоящие минимумы! Он эскейпит seddle points за счёт шума градиентов. ``` ## Sharpness-Aware Minimization (SAM) SAM (Foret et al., 2021) явно оптимизирует flat minima: ``` Обычный SGD минимизирует: L(w) SAM минимизирует: max_{||eps||<=rho} L(w + eps) ^^^^^^^^^^^^^^^^^^^^^^^^ worst-case loss в окрестности w Это заставляет оптимизатор искать точки, где ОКРЕСТНОСТЬ имеет низкий loss - т.е. flat regions. ``` ## Реализация SAM ```python import torch class SAM(torch.optim.Optimizer): def __init__(self, params, base_optimizer, rho=0.05, **kwargs): defaults = dict(rho=rho, **kwargs) super().__init__(params, defaults) self.base_optimizer = base_optimizer(self.param_groups, **kwargs) self.param_groups = self.base_optimizer.param_groups @torch.no_grad() def first_step(self, zero_grad=False): """Шаг 1: найти adversarial perturbation eps""" grad_norm = self._grad_norm() for group in self.param_groups: scale = group['rho'] / (grad_norm + 1e-12) for p in group['params']: if p.grad is None: continue # eps = rho * grad / ||grad|| (нормированный) e_w = p.grad * scale p.add_(e_w) # w_adv = w + eps self.state[p]['e_w'] = e_w if zero_grad: self.zero_grad() @torch.no_grad() def second_step(self, zero_grad=False): """Шаг 2: градиент в w_adv, шаг базового оптимизатора из w""" for group in self.param_groups: for p in group['params']: if p.grad is None: continue p.sub_(self.state[p]['e_w']) # вернуться в w self.base_optimizer.step() # SGD/Adam шаг if zero_grad: self.zero_grad() def _grad_norm(self): shared_device = self.param_groups[0]['params'][0].device norm = torch.norm( torch.stack([ p.grad.norm(p=2).to(shared_device) for group in self.param_groups for p in group['params'] if p.grad is not None ]), p=2 ) return norm # Использование (2 forward pass на шаг): # optimizer = SAM(model.parameters(), torch.optim.SGD, lr=0.1, momentum=0.9) # for x, y in loader: # loss = criterion(model(x), y) # loss.backward() # optimizer.first_step(zero_grad=True) # criterion(model(x), y).backward() # второй forward pass # optimizer.second_step(zero_grad=True) ``` ## Эффекты batch size на генерализацию ``` Касаясь sharp/flat minima: Маленький batch (32-256): → больше шума в градиентах → escapes sharp minima → лучшая генерализация (flat minima) Большой batch (4096+): → детерминированный градиент → застревает в sharp minima → generalization gap Linear scaling rule (Goyal et al., 2017): При batch_size * k → lr * k, warmup на первые несколько эпох Позволяет обучать ImageNet за 1 час на 256 GPU! ```
Quantization-Aware Training
## Post-Training Quantization (PTQ) vs QAT ``` FP32 модель: Параметры: 4 байта × N = 4N байт Inference: медленно на CPU/mobile INT8 квантизация: Параметры: 1 байт × N = N байт (4x меньше памяти) Inference: x2-4x быстрее (SIMD/VNNI инструкции) Но: потеря точности! PTQ (Post-Training Quantization): Взять готовую FP32 модель → применить квантизацию Быстро, но accuracy drop 1-3% для сложных моделей QAT (Quantization-Aware Training): Симулировать квантизацию ВО ВРЕМЯ обучения Модель адаптируется к ошибке квантизации Accuracy drop < 0.5% ``` ## QAT как задача оптимизации Квантизация - это дискретная операция, не дифференцируемая. QAT обходит это через **Straight-Through Estimator (STE)**: ```python class FakeQuantize(torch.autograd.Function): @staticmethod def forward(ctx, x, scale, zero_point, bits=8): """ Forward: реальная квантизация + декантизация (симулирует потерю точности) """ q_min = -(2**(bits-1)) q_max = 2**(bits-1) - 1 # Квантизировать x_q = torch.clamp(torch.round(x / scale + zero_point), q_min, q_max) # Декантизировать (остаёмся в float для следующих слоёв) x_dq = (x_q - zero_point) * scale return x_dq @staticmethod def backward(ctx, grad_output): """ Backward: Straight-Through Estimator Градиент сквозь round() = identity function (как будто round не было) """ return grad_output, None, None, None # dL/dx = dL/d(x_dq) ``` ## QAT с PyTorch ```python import torch from torch.quantization import QuantStub, DeQuantStub, prepare_qat, convert class QuantizedModel(torch.nn.Module): def __init__(self, model): super().__init__() self.quant = QuantStub() # FP32 → INT8 на входе self.model = model self.dequant = DeQuantStub() # INT8 → FP32 на выходе def forward(self, x): x = self.quant(x) x = self.model(x) return self.dequant(x) # QAT pipeline: model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm') prepare_qat(model, inplace=True) # вставить FakeQuantize nodes # Дообучить 5-10 эпох (fine-tuning под квантизацию) for epoch in range(10): train_epoch(model, optimizer, loader) model.eval() convert(model, inplace=True) # FakeQuantize → настоящий INT8 # Теперь модель работает на INT8! ``` ## Mixed Precision Training ```python from torch.cuda.amp import autocast, GradScaler scaler = GradScaler() # динамическое масштабирование для предотвращения underflow for x, y in loader: with autocast(): # forward pass в FP16 logits = model(x) loss = criterion(logits, y) scaler.scale(loss).backward() # backward в FP16 scaler.step(optimizer) # optimizer.step() в FP32 scaler.update() # адаптировать scale factor # BF16 (Google Brain Float): лучше чем FP16 для обучения # Wider exponent range → нет проблемы underflow → не нужен GradScaler! # torch.autocast(device_type='cuda', dtype=torch.bfloat16) ``` ## Quantization formats сравнение | Формат | Бит | Memory | Speed | Use case | |--------|-----|--------|-------|----------| | FP32 | 32 | 1x | 1x | Обучение (baseline) | | BF16 | 16 | 2x | 2x | Обучение LLM (A100+) | | FP16 | 16 | 2x | 2x | Inference GPU | | INT8 | 8 | 4x | 4x | Inference CPU/GPU | | INT4 | 4 | 8x | 6x | LLM inference (GPTQ) | | INT1 | 1 | 32x | 10x | Extreme edge (XNor-Net) |
Inference Optimization: TensorRT, ONNX, Operator Fusion
## Путь модели от PyTorch к Production ``` Training Export Optimization Deploy ───────── ────── ──────────── ────── PyTorch model ───► ONNX graph ───► TensorRT engine ───► C++/Python API (FP32, eager) (FP32/FP16) (INT8, fused ops) (latency < 1ms) или Alternative path: PyTorch ───► torch.compile() ───► triton kernels ───► fastpath inference ``` ## ONNX: стандартный формат обмена ```python import torch import torch.onnx # Экспорт в ONNX dummy_input = torch.randn(1, 3, 224, 224) torch.onnx.export( model, dummy_input, 'model.onnx', opset_version=17, input_names=['image'], output_names=['logits'], dynamic_axes={'image': {0: 'batch_size'}} # динамический batch ) # Проверить граф import onnx model_onnx = onnx.load('model.onnx') onnx.checker.check_model(model_onnx) # валидация # Запустить через ONNX Runtime (оптимальнее PyTorch для inference) import onnxruntime as ort sess = ort.InferenceSession('model.onnx', providers=['CUDAExecutionProvider']) output = sess.run(['logits'], {'image': image_np}) ``` ## Operator Fusion GPU плохо использует свою пропускную способность при множестве мелких операций из-за overhead каждого kernel launch: ``` Без fusion (наивно): 1. MatMul kernel launch ← memory read/write 2. Add bias kernel launch ← memory read/write 3. LayerNorm kernel launch ← memory read/write 4. GELU kernel launch ← memory read/write Итого: 4 round-trip в global memory С operator fusion (fused kernel): 1. Fused MatMul+Add+LayerNorm+GELU kernel ← одно чтение/запись Результат: 3-5x speedup на Transformer layers! Flash Attention (Dao et al., 2022): Классический attention: O(n^2) memory (хранить всю матрицу внимания) Flash Attention: IO-aware, tile-based → O(n) memory, 3x faster! Полностью на GPU SRAM, избегает HBM round-trips. ``` ## TensorRT оптимизация ```python # TensorRT через torch2trt или tensorrt-llm import tensorrt as trt def build_trt_engine(onnx_path, fp16=True, int8=False): logger = trt.Logger(trt.Logger.WARNING) builder = trt.Builder(logger) network = builder.create_network( 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) ) parser = trt.OnnxParser(network, logger) with open(onnx_path, 'rb') as f: parser.parse(f.read()) config = builder.create_builder_config() config.max_workspace_size = 4 * (1024**3) # 4 GB if fp16: config.set_flag(trt.BuilderFlag.FP16) # FP16 kernels if int8: config.set_flag(trt.BuilderFlag.INT8) # INT8 kernels # Нужен calibrator для INT8! engine = builder.build_serialized_network(network, config) return engine # TensorRT автоматически: # 1. Fuses compatible layers (Conv+BN+ReLU → один kernel) # 2. Eliminates redundant ops # 3. Selects optimal kernel for each layer + GPU # 4. Manages memory pool ``` ## torch.compile() - новый подход ```python import torch # PyTorch 2.0+: JIT компиляция через Triton/TorchInductor model = torch.compile(model, mode='max-autotune', # самый агрессивный режим fullgraph=True # весь граф как один kernel ) # Режимы: # 'default' - баланс скорость/качество (рекомендуется) # 'reduce-overhead' - меньше kernel launches overhead # 'max-autotune' - долгая компиляция, максимальный speedup # Реальные ускорения на A100: # ResNet-50: 1.4x # BERT-large: 2.1x # GPT-2: 2.3x # LLM decode: до 4x с torch.compile + FlashAttention2 ``` ## Latency breakdown типичного inference ``` ResNet-50 inference (batch=1, A100 GPU): PyTorch eager: 5.2 ms ONNX Runtime: 3.1 ms (1.7x) TensorRT FP16: 1.8 ms (2.9x) TensorRT INT8: 0.9 ms (5.8x) torch.compile: 2.3 ms (2.3x) GPT-2 generation (256 tokens, A100): Vanilla PyTorch: 1240 ms FlashAttention2: 380 ms (3.3x) + torch.compile: 240 ms (5.2x) + INT8 (bitsandbytes): 180 ms (6.9x) ```
Связь с предыдущим
Distributed Optimization дал инструменты для масштабирования градиентного спуска на кластеры. Этот урок переходит от 'как обучать быстрее' к 'как развернуть быстрее': геометрия loss landscape определяет, какие минимумы достижимы, квантизация определяет память и latency, fusion определяет throughput.
- Distributed Optimization — Дал scale-out обучения, теперь нужен scale-out inference
- Gradient Descent — SGD-шум объясняет, почему обучение находит обобщаемые решения
Итоги
- Геометрия high-dim landscape: настоящих минимумов нет, есть seddle points и flat regions. SGD ищет именно flat.
- SAM формализует поиск flat minima как min-max задачу: 2 forward pass, +1-2% accuracy на ImageNet/CIFAR.
- QAT через FakeQuantize + STE: simulate INT8 во время обучения, чтобы модель адаптировалась к ошибке квантизации.
- BF16 vs FP16: одинаковая память, но BF16 имеет диапазон FP32 - не нужен GradScaler, поэтому стандарт для LLM с A100+.
- Operator fusion устраняет memory bandwidth bottleneck: Conv+BN+ReLU - один kernel вместо трёх, 3-5x speedup.
- Flash Attention заменяет O(n²) attention matrix на O(n) tiled compute - открыл дорогу к 128K context.
Вопросы для размышления
- Если flat minima обобщают лучше, можно ли явно добавить регуляризатор на curvature в loss? Что мешает это сделать в современных моделях?
- Почему BF16 победил FP16 для обучения LLM, но FP16 всё ещё используется для inference на GPU без bfloat16 support?
- Operator fusion даёт 3-5x на Conv+BN+ReLU, но почти ничего не даёт на attention. Почему - и что именно решает Flash Attention?
Связанные уроки
- opt-14 — ML-оптимизация строится на распределённых методах
- ml-09-gradient-descent — Градиентный спуск - основа всей ML-оптимизации
- dl-09 — AdamW и SGD в контексте обучения нейронных сетей
- aie-13-advanced-rag — Оптимизация поиска в RAG - практический кейс
- alg-01-big-o — Сложность алгоритма определяет выбор оптимизатора