Численные методы

Автоматическое дифференцирование в ML

Цели урока

  • Понять разницу между прямым и обратным режимами AD и когда каждый выгоден
  • Освоить backpropagation как применение правила цепочки к вычислительному DAG
  • Разобраться в производных высших порядков, гессиан-векторных произведениях и adjoint method для Neural ODE

Предварительные знания

  • Параллельные вычисления в линейной алгебре
  • Параллельные вычисления в линейной алгебре

JAX (Google): градиенты функций с 540B параметров - основа обучения Gemini

  • JAX (Google): градиенты функций с 540B параметров - основа обучения Gemini
  • PyTorch: каждый шаг Adam в GPT-4 - обратный AD через граф из миллионов узлов
  • Neural ODE: normalizing flows и time-series модели через дифференцируемые ODE-решатели
  • Differentiable physics: оптимизация траектории робота через симуляцию физики с AD

История backpropagation

Алгоритм переоткрывался четырежды: Kelley (1960, теория управления), Bryson (1961, аэронавтика), Dreyfus (1962, динамическое программирование), Linnainmaa (1970, численный анализ). Rumelhart, Hinton и Williams (1986, Nature) сделали его известным в нейросетях - и забыли процитировать предшественников. Seppo Linnainmaa в 1970 году в дипломной работе описал полный алгоритм reverse-mode AD. JAX (2018) сделал его программной нормой.

Прямой и обратный режимы: когда что применять

В 2023 году JAX (Google) обрабатывал градиенты функций с 540 миллиардами параметров. Обратный режим AD вычисляет весь градиент за 2 прохода через граф. Прямой режим потребовал бы 540 миллиардов проходов. Это разница между обучением GPT-4 за несколько месяцев и несколькими тысячелетиями.

JAX реализует как forward-mode (jax.jvp), так и reverse-mode (jax.vjp) AD. Функция jax.jacobian выбирает оптимальный режим автоматически: reverse при n > m, forward при n < m.

Для функции f: R^{10^9} -> R (нейросеть), какой режим AD вычисляет градиент эффективнее?

При n=10^9, m=1: reverse AD = 2 прохода. Forward AD = 10^9 прогонов. Конечные разности = 10^9+1 вычислений f (плюс ошибка O(h)). Символьное дифференцирование - exponential expression swell.

Вычислительный граф и правило цепочки

Backpropagation - не магия нейросетей. Это применение правила цепочки к ациклическому вычислительному графу в обратном топологическом порядке. Алгоритм переоткрывали 4 раза: Kelley (1960), Bryson (1961), Dreyfus (1962) и Rumelhart-Hinton-Williams (1986). Последний стал знаменитым.

PyTorch строит динамический граф при каждом forward pass. TensorFlow 2 и JAX - тоже. Это позволяет условные ветвления (if/while) - граф меняется в зависимости от данных. Статические графы (TensorFlow 1, XLA) компилируются один раз и работают быстрее на повторяющихся операциях.

Reverse mode AD требует хранить все промежуточные значения v_k для backprop. Для сети из L слоёв с активациями размера n это O(L*n) памяти. Gradient checkpointing (torch.utils.checkpoint) сохраняет только checkpoints и пересчитывает остальные - компромисс: 2x время, sqrt(L) памяти.

Зачем reverse mode AD требует хранить все промежуточные значения forward pass?

Пример: для y = x^2, dy/dx = 2x. При backprop нужно знать x (значение из forward pass). Для сети каждый слой имеет локальную производную, зависящую от активаций. Gradient checkpointing решает это: пересчитывает нужные активации по требованию.

Производные высших порядков и AD в численных методах

JAX поддерживает производные произвольного порядка: jax.grad(jax.grad(f)) - вторая производная. Метод Ньютона требует гессиана. Оптимальная транспортировка требует производной Якобиана. PINN-уравнения - смешанные производные второго порядка в реальном времени.

Neural ODEs: ODESolve как слой нейросети

Chen et al. (NeurIPS 2018) предложили Neural ODE: dx/dt = f_theta(x,t). Решение через ODEsolve - дифференцируемый слой через adjoint method. Градиент вычисляется решением сопряжённого ОДУ назад - без хранения всей траектории. Используется в normalizing flows, time-series modeling, continuous depth networks.

Как вычислить произведение Гессиан-вектор H*v за O(T_f) без явного построения гессиана O(n^2)?

H*v = d/dt[grad f(x + tv)]|_{t=0} - это jvp применённый к grad(f). JAX: jax.jvp(jax.grad(f), (x,), (v,))[1]. Сложность O(T_f) - один forward + один backward проход. Используется в методе сопряжённых градиентов для Ньютона.

AD в численных методах: дифференцируемые решатели

JAX и PyTorch позволяют дифференцировать через любой код. Это открывает дифференцируемое программирование: градиентная оптимизация через итерационные решатели, ОДУ, FFT. differentiable physics - активная область исследований.

Дифференцируемые решатели СЛАУ: если нейросеть зависит от решения x* линейной системы Ax = b(theta), градиент по theta вычисляется через неявное дифференцирование без раскрутки итераций CG. Это используется в задачах оптимальной транспортировки (Wasserstein loss), и задачах физического моделирования (differentiable simulation для robotic control).

custom_vjp в JAX и register_hook в PyTorch позволяют задать аналитически правило backprop для конкретных операций - обойти автоматическое раскрытие вычислительного графа. Используется в FlashAttention (custom CUDA kernel с аналитически заданным backward pass).

Почему adjoint method для Neural ODE использует постоянный объём памяти, независимо от числа шагов ODE-решателя?

Классический backprop: хранить z(t_0),...,z(t_K) - O(K*n) памяти. Adjoint: решить ODE для a(t) назад, одновременно интегрируя grad_theta. Использует O(n) памяти (текущее состояние), независимо от K.

Связь с другими темами

AD - основа всего modern ML: обучение нейросетей (backward pass), оптимизация (Adam, LBFGS), Physics-Informed Networks (производные ДУ), Differentiable Rendering, Neural ODE. Каждый PINN (nm-29) строится на AD для вычисления производных физических уравнений.

  • Neural ODE Adjoint Training — Связанная тема
  • Higher-Order Optimization — Связанная тема
  • Differentiable Simulation — Связанная тема

Итоги

  • Reverse AD: 1 forward + 1 backward дают полный градиент за O(T_f), независимо от n; выгоден при n >> m (все нейросети). Forward AD: O(T_f) для одного JVP; выгоден при m >> n
  • Backpropagation = reverse AD на DAG нейросети: правило цепочки в обратном топологическом порядке, накопление сопряжённых переменных
  • Gradient checkpointing: компромисс 2x время vs sqrt(L) память; custom_vjp/register_hook для аналитически заданных backward операций (FlashAttention)
  • Neural ODE adjoint method: постоянная память через сопряжённое ОДУ; implicit differentiation для дифференцирования через итерационные решатели

Вопросы для размышления

  • Backpropagation переоткрывали 4 раза за 25 лет. Почему понадобилось так долго, чтобы оценить его значимость для нейросетей?
  • Gradient checkpointing торгует вычислениями за память: 2x время, sqrt(L) память. Когда это выгодно в производственном обучении GPT-класса?
  • Implicit differentiation позволяет дифференцировать через итерационные решатели (CG, ADMM) без раскрутки итераций. В чём ограничения этого подхода?

Связанные уроки

  • nm-24-parallel-linalg — параллельное GEMM - основная операция при backprop
  • nm-26-quantum-algorithms — квантовые симуляторы используют AD для вариационных алгоритмов
  • nm-29-scientific-ml-pinns — PINNs строятся на AD для вычисления производных физических уравнений
Автоматическое дифференцирование в ML

0

1

Войти