Оптимальный транспорт

Computational OT

2021: команда Meta оценивает качество генеративных моделей на 10 миллионах изображений через Wasserstein дистанцию. Точный OT - O(n^3) - невозможен физически: это 10^21 операций. GeomLoss на двух A100 решает ту же задачу за 3 секунды. Выбор правильной библиотеки - разница между working research и production system.

  • **Domain adaptation в NLP**: перенос эмбеддингов между языками через OT (Wasserstein Procrustes). GPT-эмбеддинги - 50k токенов. POT справится за минуты, GeomLoss - за секунды на GPU.
  • **Оценка качества генеративных моделей**: FID заменяется на Wasserstein дистанцию для более интерпретируемой метрики. Sliced Wasserstein - стандарт для image generation benchmarks: дёшево и коррелирует с визуальным качеством.
  • **Логистическая оптимизация**: распределение товаров между складами и магазинами - классический транспортный LP. Для 500 складов и 500 магазинов ot.emd даёт точное решение за секунды; при 10k объектов нужен уже Sinkhorn или специализированный солвер.

Предварительные знания

  • Sinkhorn алгоритм и энтропийная регуляризация (ot-04-sinkhorn)
  • Wasserstein дистанция: постановка и свойства (ot-03-wasserstein)
  • Базовый Python и NumPy для работы с примерами кода
  • Sinkhorn алгоритм
  • Flow Matching и OT

Библиотека POT: Swiss army knife для OT

**POT (Python Optimal Transport)** - главная библиотека для OT в Python. Публикация: Flamary et al. 2021, JMLR. Документация: pythonot.github.io. Покрывает весь стек - от точного Wasserstein до частичного OT и барицентров.

Основных функции три. `ot.emd(a, b, M)` - точный Wasserstein через **network simplex**: решает транспортный LP оптимально за $O(n^3)$. Детерминирован и точен, но медленен для $n > 1000$. `ot.sinkhorn(a, b, M, reg)` - энтропийный OT с параметром $\varepsilon$ (`reg`): быстрые матричные итерации, GPU-friendly, результат сглаженный. `ot.sliced_wasserstein_distance(X, Y)` - sliced OT за $O(n \log n)$: проектирует облака точек на случайные направления, вычисляет 1D Wasserstein аналитически.

**Partial OT** (`ot.partial_wasserstein`) переносит только долю `m` полной массы. Незаменим когда данные содержат выбросы - outliers просто не попадают в транспортный план. Например, сравниваем два эмбеддинга, где 30% точек - шум.

`ot.emd2` возвращает скаляр (стоимость). `ot.emd` возвращает матрицу транспортного плана. Легко перепутать при первом использовании - потеряешь час на отладку.

Какой алгоритм в библиотеке POT решает задачу Wasserstein точно (без аппроксимации)?

GeomLoss и OTT-JAX: миллионы точек на GPU

**GeomLoss** (Feydy et al. 2019) решает главную проблему Sinkhorn: матрица ядра $K \in \mathbb{R}^{n \times m}$ для $n = 10^6$ требует 4 TB памяти. GeomLoss использует **KeOps** - символьные kernel операции без материализации матрицы. Сложность памяти $O(1)$ вместо $O(n^2)$: каждая строка вычисляется on-the-fly при суммировании.

**OTT-JAX** (Cuturi et al. 2022) - библиотека для OT на JAX. JIT-компиляция через XLA, автодифференцирование через всю Sinkhorn-итерацию. Ключевое: поддерживает два режима дифференцирования - **unrolled** (через все итерации, точно но памятеемко) и **implicit differentiation** (через теорему о неявной функции, стабильно и быстро).

**Benchmark**: GeomLoss на GPU vs POT на CPU для $n = 50000$: GeomLoss быстрее в 100-1000 раз. При $n < 1000$ - разница незначительна, накладные расходы GPU нивелируют выигрыш. Порог переключения на GPU-библиотеки: $n \approx 5000-10000$ точек.

Почему GeomLoss (через KeOps) может обрабатывать миллионы точек, а стандартный Sinkhorn - нет?

Выбор алгоритма в продакшне

В реальных системах выбор алгоритма OT зависит от трёх факторов: **размер данных** $n$, **нужен ли градиент** (обучение vs инференс), **допустима ли аппроксимация**. Чёткая матрица принятия решений спасает от часов перебора.

**Epsilon и точность Sinkhorn**: при малом $\varepsilon$ Sinkhorn медленно сходится (нужно много итераций) и численно нестабилен - log-domain Sinkhorn (`ot.sinkhorn(method='sinkhorn_log')`) решает проблему нестабильности. При большом $\varepsilon$ сходимость быстрая, но результат смазывает транспортные пары - пары точек в плане $\gamma$ не четкие.

**Mini-batch OT** - важное предупреждение: усреднение Wasserstein по мини-батчам **не аппроксимирует** истинный $W_2$ между полными распределениями. Это другой функционал со своим bias. В контексте обучения генеративных моделей это часто приемлемо, но нельзя сравнивать числа mini-batch OT с полным OT - они измеряют разное.

**Implicit vs unrolled differentiation**: при обратном проходе через Sinkhorn есть два подхода. Unrolled - раскрутить все итерации как вычислительный граф: точно, но $O(T \cdot n^2)$ памяти. Implicit - применить теорему о неявной функции к KKT условиям: одна матрица $O(n^2)$, численно стабильнее. OTT-JAX реализует оба; для production используй implicit (по умолчанию).

Sinkhorn всегда быстрее network simplex - зачем вообще использовать ot.emd?

Для малых n (менее 500 точек) network simplex быстрее и точнее Sinkhorn. Sinkhorn выигрывает только при больших n благодаря GPU-параллелизму.

Sinkhorn имеет накладные расходы: инициализация, много итераций до сходимости, чувствительность к epsilon. Для n=100 network simplex решает за микросекунды и даёт точный ответ без подбора гиперпараметров.

Дано 200,000 точек в R^10, нужен градиент по координатам для обучения нейросети. GPU доступен. Что выбрать?

Ключевые идеи

  • **POT для прототипирования**: ot.emd (точный, O(n^3)) для малых данных; ot.sinkhorn (быстрый, аппроксимация) для средних; sliced и partial OT для специальных случаев. Один import - вся экосистема.
  • **GeomLoss и OTT-JAX для scale**: KeOps убирает O(n^2) узкое место памяти через символьные операции; JAX даёт JIT и implicit differentiation. При n > 10k на GPU - единственный разумный выбор.
  • **Матрица трейдоффов**: n, GPU, нужен ли градиент - три вопроса, которые определяют выбор алгоритма. Mini-batch OT не равен истинному Wasserstein. Epsilon в Sinkhorn - ключевой гиперпараметр: маленький = точнее, большой = быстрее.

Связанные темы

Computational OT - мост между теорией (уроки 1-11) и реальными ML-системами. Вот как эта тема вписывается в более широкий контекст:

  • Sinkhorn алгоритм — Основной алгоритм, который реализуют POT, GeomLoss и OTT-JAX под капотом
  • Wasserstein дистанция — То, что мы вычисляем на практике - понимание свойств W2 критично для интерпретации результатов
  • Flow Matching — Использует scalable OT (через GeomLoss/OTT-JAX) для построения путей переноса в генеративных моделях

Вопросы для размышления

  • Задача: сравнить два датасета эмбеддингов по 100k векторов размерности 768 (BERT). GPU нет. Какой алгоритм OT выбрать и почему? Как изменится ответ если GPU появится?
  • Почему слишком маленький epsilon в Sinkhorn - это проблема? Что происходит с планом транспорта при epsilon -> 0, и почему это вычислительно опасно?
  • Коллега говорит: 'я использую mini-batch OT с batch_size=256, получил W2=0.34'. Другой вычислил полный Sinkhorn на всём датасете и получил W2=0.51. Почему числа разные? Можно ли их сравнивать напрямую?

Связанные уроки

  • ot-04-sinkhorn — Sinkhorn - основной алгоритм, который реализуют все библиотеки
  • ot-03-wasserstein — Wasserstein дистанция - то, что мы вычисляем на практике
  • ot-11-flow-matching — Flow matching использует scalable OT в реальных ML-пайплайнах
  • calc-01-sequences
Computational OT

0

1

Войти