Оптимальный транспорт
Computational OT
2021: команда Meta оценивает качество генеративных моделей на 10 миллионах изображений через Wasserstein дистанцию. Точный OT - O(n^3) - невозможен физически: это 10^21 операций. GeomLoss на двух A100 решает ту же задачу за 3 секунды. Выбор правильной библиотеки - разница между working research и production system.
- **Domain adaptation в NLP**: перенос эмбеддингов между языками через OT (Wasserstein Procrustes). GPT-эмбеддинги - 50k токенов. POT справится за минуты, GeomLoss - за секунды на GPU.
- **Оценка качества генеративных моделей**: FID заменяется на Wasserstein дистанцию для более интерпретируемой метрики. Sliced Wasserstein - стандарт для image generation benchmarks: дёшево и коррелирует с визуальным качеством.
- **Логистическая оптимизация**: распределение товаров между складами и магазинами - классический транспортный LP. Для 500 складов и 500 магазинов ot.emd даёт точное решение за секунды; при 10k объектов нужен уже Sinkhorn или специализированный солвер.
Предварительные знания
- Sinkhorn алгоритм и энтропийная регуляризация (ot-04-sinkhorn)
- Wasserstein дистанция: постановка и свойства (ot-03-wasserstein)
- Базовый Python и NumPy для работы с примерами кода
Библиотека POT: Swiss army knife для OT
**POT (Python Optimal Transport)** - главная библиотека для OT в Python. Публикация: Flamary et al. 2021, JMLR. Документация: pythonot.github.io. Покрывает весь стек - от точного Wasserstein до частичного OT и барицентров.
Основных функции три. `ot.emd(a, b, M)` - точный Wasserstein через **network simplex**: решает транспортный LP оптимально за $O(n^3)$. Детерминирован и точен, но медленен для $n > 1000$. `ot.sinkhorn(a, b, M, reg)` - энтропийный OT с параметром $\varepsilon$ (`reg`): быстрые матричные итерации, GPU-friendly, результат сглаженный. `ot.sliced_wasserstein_distance(X, Y)` - sliced OT за $O(n \log n)$: проектирует облака точек на случайные направления, вычисляет 1D Wasserstein аналитически.
**Partial OT** (`ot.partial_wasserstein`) переносит только долю `m` полной массы. Незаменим когда данные содержат выбросы - outliers просто не попадают в транспортный план. Например, сравниваем два эмбеддинга, где 30% точек - шум.
`ot.emd2` возвращает скаляр (стоимость). `ot.emd` возвращает матрицу транспортного плана. Легко перепутать при первом использовании - потеряешь час на отладку.
Какой алгоритм в библиотеке POT решает задачу Wasserstein точно (без аппроксимации)?
GeomLoss и OTT-JAX: миллионы точек на GPU
**GeomLoss** (Feydy et al. 2019) решает главную проблему Sinkhorn: матрица ядра $K \in \mathbb{R}^{n \times m}$ для $n = 10^6$ требует 4 TB памяти. GeomLoss использует **KeOps** - символьные kernel операции без материализации матрицы. Сложность памяти $O(1)$ вместо $O(n^2)$: каждая строка вычисляется on-the-fly при суммировании.
**OTT-JAX** (Cuturi et al. 2022) - библиотека для OT на JAX. JIT-компиляция через XLA, автодифференцирование через всю Sinkhorn-итерацию. Ключевое: поддерживает два режима дифференцирования - **unrolled** (через все итерации, точно но памятеемко) и **implicit differentiation** (через теорему о неявной функции, стабильно и быстро).
**Benchmark**: GeomLoss на GPU vs POT на CPU для $n = 50000$: GeomLoss быстрее в 100-1000 раз. При $n < 1000$ - разница незначительна, накладные расходы GPU нивелируют выигрыш. Порог переключения на GPU-библиотеки: $n \approx 5000-10000$ точек.
Почему GeomLoss (через KeOps) может обрабатывать миллионы точек, а стандартный Sinkhorn - нет?
Выбор алгоритма в продакшне
В реальных системах выбор алгоритма OT зависит от трёх факторов: **размер данных** $n$, **нужен ли градиент** (обучение vs инференс), **допустима ли аппроксимация**. Чёткая матрица принятия решений спасает от часов перебора.
**Epsilon и точность Sinkhorn**: при малом $\varepsilon$ Sinkhorn медленно сходится (нужно много итераций) и численно нестабилен - log-domain Sinkhorn (`ot.sinkhorn(method='sinkhorn_log')`) решает проблему нестабильности. При большом $\varepsilon$ сходимость быстрая, но результат смазывает транспортные пары - пары точек в плане $\gamma$ не четкие.
**Mini-batch OT** - важное предупреждение: усреднение Wasserstein по мини-батчам **не аппроксимирует** истинный $W_2$ между полными распределениями. Это другой функционал со своим bias. В контексте обучения генеративных моделей это часто приемлемо, но нельзя сравнивать числа mini-batch OT с полным OT - они измеряют разное.
**Implicit vs unrolled differentiation**: при обратном проходе через Sinkhorn есть два подхода. Unrolled - раскрутить все итерации как вычислительный граф: точно, но $O(T \cdot n^2)$ памяти. Implicit - применить теорему о неявной функции к KKT условиям: одна матрица $O(n^2)$, численно стабильнее. OTT-JAX реализует оба; для production используй implicit (по умолчанию).
Sinkhorn всегда быстрее network simplex - зачем вообще использовать ot.emd?
Для малых n (менее 500 точек) network simplex быстрее и точнее Sinkhorn. Sinkhorn выигрывает только при больших n благодаря GPU-параллелизму.
Sinkhorn имеет накладные расходы: инициализация, много итераций до сходимости, чувствительность к epsilon. Для n=100 network simplex решает за микросекунды и даёт точный ответ без подбора гиперпараметров.
Дано 200,000 точек в R^10, нужен градиент по координатам для обучения нейросети. GPU доступен. Что выбрать?
Ключевые идеи
- **POT для прототипирования**: ot.emd (точный, O(n^3)) для малых данных; ot.sinkhorn (быстрый, аппроксимация) для средних; sliced и partial OT для специальных случаев. Один import - вся экосистема.
- **GeomLoss и OTT-JAX для scale**: KeOps убирает O(n^2) узкое место памяти через символьные операции; JAX даёт JIT и implicit differentiation. При n > 10k на GPU - единственный разумный выбор.
- **Матрица трейдоффов**: n, GPU, нужен ли градиент - три вопроса, которые определяют выбор алгоритма. Mini-batch OT не равен истинному Wasserstein. Epsilon в Sinkhorn - ключевой гиперпараметр: маленький = точнее, большой = быстрее.
Связанные темы
Computational OT - мост между теорией (уроки 1-11) и реальными ML-системами. Вот как эта тема вписывается в более широкий контекст:
- Sinkhorn алгоритм — Основной алгоритм, который реализуют POT, GeomLoss и OTT-JAX под капотом
- Wasserstein дистанция — То, что мы вычисляем на практике - понимание свойств W2 критично для интерпретации результатов
- Flow Matching — Использует scalable OT (через GeomLoss/OTT-JAX) для построения путей переноса в генеративных моделях
Вопросы для размышления
- Задача: сравнить два датасета эмбеддингов по 100k векторов размерности 768 (BERT). GPU нет. Какой алгоритм OT выбрать и почему? Как изменится ответ если GPU появится?
- Почему слишком маленький epsilon в Sinkhorn - это проблема? Что происходит с планом транспорта при epsilon -> 0, и почему это вычислительно опасно?
- Коллега говорит: 'я использую mini-batch OT с batch_size=256, получил W2=0.34'. Другой вычислил полный Sinkhorn на всём датасете и получил W2=0.51. Почему числа разные? Можно ли их сравнивать напрямую?
Связанные уроки
- ot-04-sinkhorn — Sinkhorn - основной алгоритм, который реализуют все библиотеки
- ot-03-wasserstein — Wasserstein дистанция - то, что мы вычисляем на практике
- ot-11-flow-matching — Flow matching использует scalable OT в реальных ML-пайплайнах
- calc-01-sequences