Оптимальный транспорт
OT в генеративных моделях и NLP
Stable Diffusion 3 и Flux генерируют изображения через Rectified Flow - это OT в непрерывном времени. WGAN стабилизировал обучение GAN через Wasserstein loss. Model merging (Mistral, Llama derivatives) использует OT barycenters. Optimal Transport перестал быть академической теорией - это инструмент, стоящий за генерацией изображений, text в production.
- **Stable Diffusion 3 и FLUX.1:** Rectified Flow / Flow Matching - прямые OT пути от шума к изображению. 25 шагов вместо 1000. Открытый код, billions downloads.
- **WGAN-GP:** стандарт для высококачественной генерации до диффузионных моделей. StyleGAN2 использует R1 regularization - вариант gradient penalty из WGAN-GP.
- **MBR decoding в Google Translate и NLLB Meta:** Minimum Bayes Risk с Wasserstein/chrF метриками улучшает translation quality без изменения модели.
Flow Matching: Stable Diffusion через OT в непрерывном времени
**Stable Diffusion 3 и Flux используют Rectified Flow - это optimal transport flow matching.** Идея: соединить шум (N(0,I)) с данными (изображение) прямой линией в пространстве. Нейросеть учится предсказывать направление движения по этой прямой. Это проще и быстрее чем DDPM диффузия.
Flow matching обучает vector field v_θ(x,t) такой, что ODE dx/dt = v_θ(x,t) транспортирует p₀ (шум) в p₁ (данные). OT flow matching: v*(x,t) = (x₁ - x₀) - прямая линия от шума к данным. Это минимальный путь по Wasserstein метрике.
**Flow Matching (Lipman 2022, Liu 2022):** Обучение: min_θ E_{t,x₀,x₁} ||v_θ(tx₁ + (1-t)x₀, t) - (x₁-x₀)||² где t ~ Uniform(0,1), x₀ ~ p₀ (шум), x₁ ~ p₁ (данные) **Rectified Flow (Liu 2022):** ODE: dx/dt = v_θ(x,t) Стартовать из x₀ ~ N(0,I), решить ODE к t=1 → x₁ ~ p_data **Преимущество vs DDPM:** - DDPM: 1000 шагов Euler solver - Rectified Flow: 10-25 шагов (прямые пути) - Stable Diffusion 3 и Flux используют именно это **OT connection:** OT coupling (x₀,x₁) = прямые пути. Это минимизирует E[||x₀-x₁||²] - transport cost. "Reflow" итерация: переобучить на прямых парах, ещё больше выпрямить пути.
Почему Rectified Flow требует ~40x меньше шагов, чем DDPM (25 vs 1000)?
DDPM имеет криволинейные пути в пространстве данных - нужно много шагов для точной аппроксимации SDE. RF обучает поле скоростей, которое транспортирует по прямым линиям. Прямые пути отлично аппроксимируются даже за 1 шаг Эйлера.
WGAN: Wasserstein как adversarial loss
Оригинальный GAN минимизирует Jensen-Shannon divergence. При несовпадении support (что часто в начале обучения) gradient = 0 - обучение останавливается. **WGAN** (Arjovsky 2017) использует Wasserstein-1 как loss: всегда даёт информативный gradient, не требует точного matching supports.
**Wasserstein-1 через двойственность Канторовича-Рубинштейна:** W₁(p_r, p_g) = max_{||f||_L ≤ 1} E_{x~p_r}[f(x)] - E_{x~p_g}[f(x)] где f - 1-Lipschitz функция. **WGAN алгоритм:** - Critic f_w (вместо дискриминатора): max_w E[f_w(real)] - E[f_w(fake)] - Generator G_θ: min_θ -E[f_w(G_θ(z))] - Lipschitz constraint: weight clipping [-c, c] (оригинал) или gradient penalty ||∇f||=1 (WGAN-GP) **WGAN-GP (Gulrajani 2017):** штраф за ||∇_x̂ f(x̂)||₂ ≠ 1 на интерполяции между real и fake. Стабильнее weight clipping.
Почему critic в WGAN не использует sigmoid и binary cross-entropy?
Классический GAN: min JS-div через BCE - при несовпадающих носителях JS=log2 и градиенты исчезают. WGAN critic выдаёт f(x)∈ℝ без ограничений, аппроксимируя W₁ = sup_{‖f‖_L≤1} E[f(real)] - E[f(fake)].
OT в LLM: beam search, decoding и распределённое обучение
OT появляется в LLM в трёх местах: (1) **Minimum Bayes Risk decoding** - выбор предсказания, минимизирующего ожидаемую Wasserstein distance к другим гипотезам; (2) **Knowledge distillation** через Wasserstein loss между token distributions; (3) **OT Barycenters** для federated learning - усреднение моделей как Wasserstein barycenter.
**OT Barycenter для federated learning:** Федеративное обучение: K клиентов, каждый обучает свою модель θₖ. Стандартное усреднение: θ_avg = (1/K)Σθₖ (FedAvg). Проблема: усреднение весов работает плохо при heterogeneous data. Wasserstein Barycenter нейронных сетей: θ* = argmin_θ Σₖ wₖ · W₂²(ν_θ, ν_θₖ) где ν_θ - распределение активаций модели θ. Решение: согласовать permutation нейронов (activation matching) перед усреднением. Model Merging (2023): Git Re-basin, SLERP interpolation - конкретные алгоритмы для LLM merging без дообучения.
Model merging через Wasserstein barycenter используется в production: Mistral-7B-Instruct-v0.3 - это merger нескольких checkpoint через SLERP/TIES методы. Hugging Face mergekit библиотека реализует Git Re-basin (permutation matching), SLERP interpolation и TIES merging.
Почему FedAvg плохо работает при heterogeneous data, и как Wasserstein barycenter решает это?
Нейросети эквивалентны с точностью до перестановки нейронов. При FedAvg нейрон 1 клиента A усредняется с нейроном 1 клиента B, хотя они могут делать разное. WB находит оптимальную перестановку через OT перед усреднением активаций.
Ключевые идеи
- **Flow Matching:** обучить v_θ(x,t) ≈ x₁-x₀ на прямых путях. ODE dx/dt=v_θ транспортирует шум в данные за 25 шагов. Stable Diffusion 3, Flux.
- **WGAN:** critic f_w максимизирует E[f(real)] - E[f(fake)] при ||f||_L ≤ 1. W₁ ≈ critic score. WGAN-GP: gradient penalty вместо weight clipping.
- **MBR decoding:** выбрать hypothesis с min E_{h'}[dist(h,h')]. Wasserstein между token distributions - семантически осмысленная метрика.
- **OT Barycenters:** θ* = argmin Σwₖ·W₂(ν_θ, ν_θₖ). Federated learning и model merging без дообучения.
- **Rectified Flow vs DDPM:** прямые OT пути → меньше curvature → меньше шагов Euler solver. Reflow итерации делают пути ещё прямее.
Связанные темы
OT в современном ML:
- Flow Matching (базовый) — Теория flow matching и CFM
- Wasserstein Gradient Flows — Диффузия как gradient flow - математический фундамент diffusion models
Связанные уроки
- ot-11-flow-matching — Flow matching = OT в непрерывном времени
- ot-07-wgan — WGAN использует W1 как adversarial loss
- ot-14-gradient-flows — Gradient flows объясняют динамику диффузионных моделей