Оптимизация

Распределённая оптимизация

GPT-4 не помещается на один GPU. Датасет МРТ-снимков не может покинуть больницу. Миллиарды смартфонов с пользовательскими данными хотят улучшить общую модель клавиатуры. Каждая из этих ситуаций требует распределённой оптимизации. И каждая - со своими ограничениями: privacy, bandwidth, fault tolerance. За последние десять лет три алгоритма решили большую часть индустрии: FedAvg для privacy-preserving обучения, ADMM для cluster-вычислений с гарантиями, и Ring All-Reduce для масштабирования глубокого обучения на тысячах GPU. Этот урок - о том, как они работают и когда какой выбирать.

  • **Google Gboard** - федеративное обучение модели клавиатуры на миллиардах смартфонов; данные не покидают устройство
  • **PyTorch DDP + NCCL** - стандарт индустрии для multi-GPU обучения; использует Ring All-Reduce под капотом
  • **DeepSpeed ZeRO-3** - оптимизация для моделей 100B+ параметров; sharded all-reduce плюс gradient compression

Федеративное обучение: FedAvg

Google в 2017 году решила задачу, которая казалась невозможной: обучить общую модель клавиатуры Gboard на миллиардах смартфонов без сбора пользовательских сообщений в дата-центре. Решение - **федеративное обучение**: модель отправляется на устройство, локально обучается на пользовательских данных, потом возвращается на сервер, где усредняется с другими устройствами. Данные не покидают телефон. Это - алгоритм FedAvg (McMahan et al., 2017), и он лёг в основу всех современных privacy-preserving ML-систем.

Формализация: есть K клиентов (устройств), каждый со своим датасетом D_k размера n_k. Глобальная цель - минимизировать F(w) = sum_k (n_k/n) * F_k(w), где F_k - локальная функция потерь, n = sum n_k. На каждом раунде сервер выбирает подмножество клиентов S_t, рассылает текущие веса w_{t-1}, каждый клиент локально обучает E эпох на своих данных, возвращает обновлённые веса, сервер усредняет: w_t = sum (n_k/n) * w_k. E=1 даёт FedSGD (просто усреднение градиентов), E>1 экономит коммуникационные раунды ценой client drift.

В FedAvg после каждого communication round сервер выполняет:

ADMM для распределённых задач

FedAvg - простой эвристический алгоритм, и его сходимость для non-IID данных не доказана. ADMM (Alternating Direction Method of Multipliers) - более глубокий метод с гарантированной сходимостью для выпуклых задач. Идея: задача распадается на локальные подзадачи плюс глобальное consensus-ограничение. Формально: минимизируем sum_k f_k(x_k) при условии x_k = z (все локальные переменные равны общему z). ADMM чередует три шага: x-шаг (локально, параллельно), z-шаг (на сервере), u-шаг (обновление двойственных множителей).

Сравнение FedAvg vs ADMM: FedAvg - эвристика, ADMM имеет гарантии сходимости. FedAvg передаёт только веса, ADMM ещё и двойственные переменные. FedAvg устойчив к отказам клиентов (можно просто игнорировать), ADMM чувствителен. FedAvg прост в реализации, ADMM требует настройки rho (параметр аугментированного лагранжиана). На практике FedAvg доминирует в edge-сценариях (смартфоны), ADMM - в HPC-кластерах с надёжной сетью.

В распределённом ADMM z-шаг выполняется:

Сжатие градиентов и Ring All-Reduce

Обучение GPT-3 (175B параметров): один градиентный шаг требует передачи 175B * 4 байта = 700GB данных между узлами. На InfiniBand 400Gb/s это 14 секунд на один шаг - неприемлемо. Два класса решений: **сжатие градиентов** (квантизация до 8/4/1 бит, Top-K sparsification, error feedback) и **архитектура коммуникации** (Ring All-Reduce вместо Parameter Server). Ring All-Reduce - стандарт индустрии: каждый узел передаёт ровно 2N * (K-1)/K байт, что почти не растёт с числом узлов K.

Parameter Server (PS) - простая архитектура: все workers отправляют градиенты на центральный сервер, он усредняет и рассылает обновлённые веса. Минус: PS становится узким горлышком при большом K. Ring All-Reduce: узлы образуют кольцо, scatter-reduce (K-1 шагов) суммирует частичные градиенты, all-gather (K-1 шагов) распространяет результат. Суммарно каждый узел отправляет ~2N байт независимо от K - идеальное масштабирование. PyTorch DDP, NCCL, Horovod используют именно Ring All-Reduce.

All-Reduce масштабируется плохо с числом узлов

Ring All-Reduce передаёт ровно 2*(K-1)/K*N байт с каждого узла, что асимптотически не растёт с K. Это - оптимальная нижняя граница для consensus в синхронной модели

Распространённое заблуждение основано на том, что K-1 шагов кольца кажется длинным. Но каждый шаг передаёт только N/K байт, в сумме - 2N. Реальное узкое место - latency и straggler problem, но не bandwidth

Почему Ring All-Reduce предпочтительнее Parameter Server для синхронного обучения на большом числе узлов?

Ключевые идеи

  • **FedAvg** - клиенты обучаются локально E эпох, сервер усредняет веса с весом n_k/n; основа privacy-preserving ML
  • **Client drift** при non-IID данных лечится FedProx (проксимальный член) или FedNova (нормализация по числу шагов)
  • **ADMM** - консенсусная формулировка с гарантированной сходимостью; x-шаг параллелен, z-шаг централизован, u-шаг обновляет двойственные
  • **Ring All-Reduce** - оптимальная архитектура коммуникации: ~2N байт на узел независимо от K; стандарт PyTorch DDP, Horovod, NCCL
  • **Сжатие градиентов** - квантизация (8/4/1 бит) + Top-K sparsification + error feedback; критично при триллионах параметров

Связанные темы

Возврат к мотивации: распределённая оптимизация - не отдельная дисциплина, а развитие классических методов под ограничения сети и privacy. Связь с предыдущими уроками:

  • Multi-Objective Optimization — ADMM возник как метод декомпозиции multi-objective задач; распределённый ADMM - его естественное продолжение
  • Stochastic Gradient Descent — FedSGD (E=1) - это просто SGD с агрегацией градиентов; FedAvg обобщает идею до E локальных эпох
  • Subgradient methods — Многие практические задачи распределённой оптимизации недифференцируемы (L1 регуляризация); ADMM элегантно справляется через soft-thresholding

Вопросы для размышления

  • Если FedAvg настолько прост и работает, зачем вообще существует ADMM? Какие практические сценарии требуют гарантий сходимости, оправдывающих сложность?
  • Ring All-Reduce оптимален по объёму, но straggler problem (медленный узел тормозит всех) ограничивает реальную пропускную способность. Какие подходы существуют для борьбы со straggler'ами?
  • Возврат к мотивации: представь, что нужно обучить модель на медицинских данных трёх больниц без обмена сырыми данными. Какие три-четыре алгоритмических решения должны быть приняты, и в каком порядке?

Связанные уроки

  • opt-13 — Распределённая оптимизация обобщает одноузловые методы
  • par-06 — MPI как инфраструктура для распределённого градиентного спуска
  • dl-12 — Distributed training - ключевое применение распределённой оптимизации
  • ml-09-gradient-descent — SGD - основа AllReduce и параллельного обучения
  • ds-01-intro — CAP-теорема в распределённых системах актуальна и для distributed opt
  • calc-01-sequences
Распределённая оптимизация

0

1

Войти