Оптимальный транспорт
OT для генеративных моделей (WGAN)
Цели урока
- Вывести двойственность Канторовича-Рубинштейна и понять, как она позволяет оценивать W₁ нейросетью
- Объяснить проблему исчезающих градиентов в GAN и почему W₁ её решает
- Сравнить методы реализации ограничения Липшица: clipping, GP, spectral norm
Предварительные знания
- Расстояние Вассерштейна W₁ и W₂
- Теория двойственности в LP
- Архитектура GAN: генератор и дискриминатор
Представьте архив из 1000 пар 'реальное изображение - синтетическое'. Как измерить, насколько синтетика похожа на реальность? KL-дивергенция сказала бы 'бесконечно' - носители не пересекаются. JS - 'log 2', константа. Только Вассерштейн даёт осмысленное число. WGAN превратил это в градиент.
- Синтез медицинских изображений (MRI, КТ): WGAN обучается на 10-50K снимков и генерирует реалистичные патологии для аугментации обучающих данных
- Молекулярный дизайн: MolGAN и аналоги используют W₁ для обучения генераторов молекул с заданными биохимическими свойствами
- Аномальное обнаружение в NLP: WGAN-тренированный критик используется как детектор out-of-distribution текстов
- EvolutionaryScale ESMFold: WGAN компонент для генерации белковых последовательностей с нужными 3D-структурами
От земляной метрики к революции GAN
Леон Боту заметил связь между GAN и транспортными расстояниями ещё в 2014, но вычислительные трудности казались непреодолимыми. Мартин Арьовский в 2017 году, ещё будучи PhD-студентом, написал теоретический анализ, почему GAN нестабильны - и как W₁ это исправляет. Первая версия WGAN была простой: weight clipping. Gulrajani и соавторы через несколько месяцев предложили WGAN-GP с gradient penalty - этот вариант стал стандартом. FID (Frechet Inception Distance), ставший метрикой качества генерации, тоже основан на W₂.
WGAN и двойственность Канторовича-Рубинштейна
2017 год. GAN обучаются нестабильно: mode collapse, исчезающие градиенты, training curves без смысла. Мартин Арьовский и соавторы нашли диагноз - Jensen-Shannon дивергенция даёт нулевые градиенты когда распределения не пересекаются. Рецепт: заменить JS на расстояние Вассерштейна. WGAN перевернул всё.
Ключевое преимущество W₁ перед KL и JS: при непересекающихся носителях (типичная ситуация в начале обучения) W₁ конечно и непрерывно зависит от параметров генератора. JS = log 2 (константа, нулевые градиенты), KL = ∞.
Почему расстояние Вассерштейна W₁ лучше JS-дивергенции для обучения GAN?
Когда носители P_r и P_g не пересекаются (начало обучения GAN), JS(P_r||P_g) = log 2 - константа, градиент нулевой. W₁ при этом конечно и непрерывно зависит от параметров генератора, давая полезный градиент.
Обучение WGAN: практика и ловушки
Теория красивая, практика жёстче. Ограничение 1-Липшиц на критика - главная трудность. Три подхода: weight clipping (оригинальный WGAN, работает плохо), gradient penalty (WGAN-GP, стандарт), spectral normalization (отдельная история). У каждого цена - скорость, сложность, гиперпараметры.
В WGAN-GP критика обучают n_critic=5 шагов на каждый шаг генератора. Критику нужно сойтись к хорошей оценке W₁ прежде чем давать градиент генератору. При n_critic=1 оценка нестабильна.
WGAN в синтезе молекул и белков
DeepMind применяет WGAN для молекулярного дизайна: критик обучается на пространстве молекулярных графов, W₁ оценивает расстояние между распределением реальных и синтетических молекул. MolGAN (2018) генерирует молекулы с нужными свойствами через оптимизацию W₁ + reward signal. В белковой инженерии: EvolutionaryScale использует WGAN для генерации последовательностей - 35 000 задач в сутки на A100 кластере.
Почему в оригинальном WGAN с weight clipping критик теряет выразительность?
Weight clipping загоняет критика в пространство функций с очень ограниченной сложностью - фактически бинарные паттерны ±c во всех слоях. WGAN-GP позволяет любые веса при сохранении нормы градиента.
После WGAN: Sinkhorn-дивергенция и дифференцируемые потери
WGAN дал расстояние, но вычисление точного W₁ остаётся LP. Для батчей из 64 изображений 256×256 - нереально. Решение: Sinkhorn-дивергенция. Это регуляризованный Вассерштейн, дифференцируемый по параметрам генератора, вычисляемый за O(n² · n_iter) на GPU.
GeomLoss (Feydy et al.) реализует Sinkhorn-дивергенцию для PyTorch с поддержкой GPU. Используется в Stable Diffusion 3 и Flow Matching архитектурах. API: SamplesLoss('sinkhorn', p=2, blur=0.05).
Почему W_ε(μ,μ) ≠ 0 и как это исправляет Sinkhorn-дивергенция?
Штраф -εH(P) при P = идентичный план (μ против μ) даёт ненулевой вклад. Debiased Sinkhorn убирает его симметричным вычитанием, получая истинную дивергенцию.
Куда ведёт тема
WGAN - первое массовое применение OT в deep learning. Следующие шаги: дискретный OT и приложения (ot-23) для задач без градиентов, Flow Matching (ot-25) как эволюция идеи - обучать транспорт напрямую, а не через adversarial игру.
- Optimal Transport — Связанная тема
Итоги
- W₁ = sup по 1-Липшицевым f разности ожиданий: это позволяет оценивать Вассерштейн нейросетью-критиком
- WGAN заменяет JS-дивергенцию на W₁: конечные градиенты при непересекающихся носителях
- Gradient penalty WGAN-GP: штраф (||∇f||₂ - 1)² на интерполяциях - лучше weight clipping
- Sinkhorn-дивергенция = деблюрованный W_ε, дифференцируемый по параметрам генератора
Вопросы для размышления
- Почему критика в WGAN обучают больше шагов на каждый шаг генератора - что происходит если этот баланс нарушен?
- Чем Sinkhorn-дивергенция отличается от просто регуляризованного OT в качестве функции потерь?
- Как spectral normalization гарантирует ограничение Липшица для всей нейросети через нормировку отдельных слоёв?
Связанные уроки
- ot-07-wgan — WGAN - прямое применение OT к обучению GAN
- ot-03-wasserstein — расстояние Вассерштейна - основа WGAN
- ot-21 — регуляризованный OT - вычислительный движок WGAN-GP
- ot-25-flow-matching — Flow Matching - следующее поколение OT-генеративных моделей