Нормализация по батчу

Нормализация по батчу (англ. batch normalization, также встречается как «батч-норм») — это метод нормализации, применяемый в искусственных нейронных сетях для повышения скорости и устойчивости обучения за счёт выравнивания («рецентрирования») и масштабирования входных данных каждого слоя. Метод был предложен Сергеем Иоффе и Кристианом Сеге́ди в 2015 году[1].

Внутренний ковариативный сдвиг

В каждом слое нейронной сети входные данные имеют определённое распределение, меняющееся в процессе обучения по двум причинам: случайная инициализация параметров (инициализация параметров) и естественная изменчивость обучающих данных. Такое изменение входных данных каждого внутреннего слоя называется внутренним ковариативным сдвигом. Строгого общепринятого определения этого эффекта нет, однако эксперименты показывают, что речь идёт о смещениях средних и изменениях дисперсий входных значений в процессе обучения.

Нормализация по батчу была изначально предложена для решения данной проблемы[1]. При обучении изменение параметров предшествующих слоёв вызывает смещение распределения входов текущего слоя, из-за чего он вынужден приспосабливаться к новым условиям. Это особенно критично в глубоких сетях, где небольшие изменения на ранних уровнях усиливаются на пути к глубоким слоям. Введение нормализации по батчу снижает такие нежелательные сдвиги, ускоряет обучение и повышает надёжность моделей.

Кроме воздействия на внутренний ковариативный сдвиг, нормализация по батчу даёт и другие преимущества. Она позволяет использовать более высокую скорость обучения (learning rate) без риска возникновения проблемы исчезающих или взрывающихся градиентов, ведёт к регуляризации (улучшает обобщающие способности сети), снижает потребность в методе Dropout для борьбы с переобучением, а также делает модель менее чувствительной к начальному выбору параметров или скорости обучения.

Принципы работы

Трансформация

В нейронной сети нормализация по батчу реализуется посредством нормализации — вычитания среднего значения и деления на стандартное отклонение — для входов каждого слоя. В идеале, нормировку следовало бы проводить по всему обучающему множеству, но для совместимости с стохастическими методами оптимизации используют лишь значения, вычисленные по мини-батчу.

Пусть B — мини-батч размера m, взятый из обучающей выборки. Эмпирическое среднее и дисперсия батча определяются как

Невозможно разобрать выражение (SVG с запасным PNG (MathML можно включить с помощью плагина для браузера): Недопустимый ответ («Math extension cannot connect to Restbase.») от сервера «http://restbase-svc.restbase.svc.production22.local:7231/ru-mediawiki.ruwiki.svc.production22.local/v1/»:): {\displaystyle \mu_B = \frac 1 m \sum_{i=1}^m x_i} , Невозможно разобрать выражение (SVG с запасным PNG (MathML можно включить с помощью плагина для браузера): Недопустимый ответ («Math extension cannot connect to Restbase.») от сервера «http://restbase-svc.restbase.svc.production22.local:7231/ru-mediawiki.ruwiki.svc.production22.local/v1/»:): {\displaystyle \sigma_B^2 = \frac 1 m \sum_{i=1}^m (x_i-\mu_B)^2} .

Для слоя с входом размерности d (Невозможно разобрать выражение (SVG с запасным PNG (MathML можно включить с помощью плагина для браузера): Недопустимый ответ («Math extension cannot connect to Restbase.») от сервера «http://restbase-svc.restbase.svc.production22.local:7231/ru-mediawiki.ruwiki.svc.production22.local/v1/»:): {\displaystyle x = (x^{(1)},...,x^{(d)})} ), каждый компонент нормализуется по формуле:

Невозможно разобрать выражение (SVG с запасным PNG (MathML можно включить с помощью плагина для браузера): Недопустимый ответ («Math extension cannot connect to Restbase.») от сервера «http://restbase-svc.restbase.svc.production22.local:7231/ru-mediawiki.ruwiki.svc.production22.local/v1/»:): {\displaystyle \hat{x}_{i}^{(k)} = \frac {x_i^{(k)}-\mu_B^{(k)}} {\sqrt{\left(\sigma_B^{(k)}\right)^2+\epsilon }}} , где Невозможно разобрать выражение (SVG с запасным PNG (MathML можно включить с помощью плагина для браузера): Недопустимый ответ («Math extension cannot connect to Restbase.») от сервера «http://restbase-svc.restbase.svc.production22.local:7231/ru-mediawiki.ruwiki.svc.production22.local/v1/»:): {\displaystyle k = 1,...,d} , Невозможно разобрать выражение (SVG с запасным PNG (MathML можно включить с помощью плагина для браузера): Недопустимый ответ («Math extension cannot connect to Restbase.») от сервера «http://restbase-svc.restbase.svc.production22.local:7231/ru-mediawiki.ruwiki.svc.production22.local/v1/»:): {\displaystyle i = 1,...,m} , а Невозможно разобрать выражение (SVG с запасным PNG (MathML можно включить с помощью плагина для браузера): Недопустимый ответ («Math extension cannot connect to Restbase.») от сервера «http://restbase-svc.restbase.svc.production22.local:7231/ru-mediawiki.ruwiki.svc.production22.local/v1/»:): {\displaystyle \mu_B^{(k)}} и Невозможно разобрать выражение (SVG с запасным PNG (MathML можно включить с помощью плагина для браузера): Недопустимый ответ («Math extension cannot connect to Restbase.») от сервера «http://restbase-svc.restbase.svc.production22.local:7231/ru-mediawiki.ruwiki.svc.production22.local/v1/»:): {\displaystyle \sigma_B^{(k)}} — среднее и стандартное отклонение по k-му признаку.

Величина (малая положительная константа) добавляется для численной устойчивости. После нормализации среднее каждого признака становится равным нулю, а дисперсия — единице (без учёта ). Для сохранения представляемой мощности сети далее применяется аффинное преобразование:

Невозможно разобрать выражение (SVG с запасным PNG (MathML можно включить с помощью плагина для браузера): Недопустимый ответ («Math extension cannot connect to Restbase.») от сервера «http://restbase-svc.restbase.svc.production22.local:7231/ru-mediawiki.ruwiki.svc.production22.local/v1/»:): {\displaystyle y_i^{(k)} = \gamma^{(k)} \hat{x}_{i}^{(k)} + \beta^{(k)}} ,

где параметры Невозможно разобрать выражение (SVG с запасным PNG (MathML можно включить с помощью плагина для браузера): Недопустимый ответ («Math extension cannot connect to Restbase.») от сервера «http://restbase-svc.restbase.svc.production22.local:7231/ru-mediawiki.ruwiki.svc.production22.local/v1/»:): {\displaystyle \gamma^{(k)}} и Невозможно разобрать выражение (SVG с запасным PNG (MathML можно включить с помощью плагина для браузера): Недопустимый ответ («Math extension cannot connect to Restbase.») от сервера «http://restbase-svc.restbase.svc.production22.local:7231/ru-mediawiki.ruwiki.svc.production22.local/v1/»:): {\displaystyle \beta^{(k)}} обучаются вместе с основными параметрами сети.

Операция нормализации по батчу, формально, это отображение Невозможно разобрать выражение (SVG с запасным PNG (MathML можно включить с помощью плагина для браузера): Недопустимый ответ («Math extension cannot connect to Restbase.») от сервера «http://restbase-svc.restbase.svc.production22.local:7231/ru-mediawiki.ruwiki.svc.production22.local/v1/»:): {\displaystyle BN_{\gamma^{(k)},\beta^{(k)}}: x^{(k)}_{1...m} \rightarrow y^{(k)}_{1...m}} (Batch Normalizing transform). На следующих слоях используется выход Невозможно разобрать выражение (SVG с запасным PNG (MathML можно включить с помощью плагина для браузера): Недопустимый ответ («Math extension cannot connect to Restbase.») от сервера «http://restbase-svc.restbase.svc.production22.local:7231/ru-mediawiki.ruwiki.svc.production22.local/v1/»:): {\displaystyle y^{(k)}} после преобразования, тогда как нормированные значения Невозможно разобрать выражение (SVG с запасным PNG (MathML можно включить с помощью плагина для браузера): Недопустимый ответ («Math extension cannot connect to Restbase.») от сервера «http://restbase-svc.restbase.svc.production22.local:7231/ru-mediawiki.ruwiki.svc.production22.local/v1/»:): {\displaystyle \hat{x}_{i}^{(k)}} остаются внутренними для текущего слоя.

Обратное распространение ошибки

Описанная операция BN является дифференцируемой, что позволяет вычислять производные функции потерь (l) по её параметрам с помощью правила цепочки.

В частности:

Невозможно разобрать выражение (SVG с запасным PNG (MathML можно включить с помощью плагина для браузера): Недопустимый ответ («Math extension cannot connect to Restbase.») от сервера «http://restbase-svc.restbase.svc.production22.local:7231/ru-mediawiki.ruwiki.svc.production22.local/v1/»:): {\displaystyle \frac{\partial l}{\partial \hat{x}_i^{(k)}} = \frac{\partial l}{\partial y_i^{(k)}}\gamma^{(k)} } , Невозможно разобрать выражение (SVG с запасным PNG (MathML можно включить с помощью плагина для браузера): Недопустимый ответ («Math extension cannot connect to Restbase.») от сервера «http://restbase-svc.restbase.svc.production22.local:7231/ru-mediawiki.ruwiki.svc.production22.local/v1/»:): {\displaystyle \frac{\partial l}{\partial \gamma^{(k)}} = \sum_{i=1}^m \frac{\partial l}{\partial y_i^{(k)}}\hat{x}_i^{(k)} } , Невозможно разобрать выражение (SVG с запасным PNG (MathML можно включить с помощью плагина для браузера): Недопустимый ответ («Math extension cannot connect to Restbase.») от сервера «http://restbase-svc.restbase.svc.production22.local:7231/ru-mediawiki.ruwiki.svc.production22.local/v1/»:): {\displaystyle \frac{\partial l}{\partial \beta^{(k)}} = \sum_{i=1}^m \frac{\partial l}{\partial y_i^{(k)}} } , Невозможно разобрать выражение (SVG с запасным PNG (MathML можно включить с помощью плагина для браузера): Недопустимый ответ («Math extension cannot connect to Restbase.») от сервера «http://restbase-svc.restbase.svc.production22.local:7231/ru-mediawiki.ruwiki.svc.production22.local/v1/»:): {\displaystyle \frac{\partial l}{\partial \sigma_B^{(k)^2}} = \sum_{i=1}^m \frac{\partial l}{\partial y_i^{(k)}} (x_i^{(k)}-\mu_B^{(k)})\left(-\frac {\gamma^{(k)}} 2 (\sigma_B^{(k)^2}+\epsilon)^{-3/2}\right) } , Невозможно разобрать выражение (SVG с запасным PNG (MathML можно включить с помощью плагина для браузера): Недопустимый ответ («Math extension cannot connect to Restbase.») от сервера «http://restbase-svc.restbase.svc.production22.local:7231/ru-mediawiki.ruwiki.svc.production22.local/v1/»:): {\displaystyle \frac{\partial l}{\partial \mu_B^{(k)}} = \sum_{i=1}^m \frac{\partial l}{\partial y_i^{(k)}}\frac{-\gamma^{(k)}}{\sqrt{\sigma_B^{(k)^2}+\epsilon}}+\frac{\partial l}{\partial \sigma_B^{(k)^2}}\frac{1}{m}\sum_{i=1}^m (-2)\cdot (x_i^{(k)}-\mu^{(k)}_B) } , Невозможно разобрать выражение (SVG с запасным PNG (MathML можно включить с помощью плагина для браузера): Недопустимый ответ («Math extension cannot connect to Restbase.») от сервера «http://restbase-svc.restbase.svc.production22.local:7231/ru-mediawiki.ruwiki.svc.production22.local/v1/»:): {\displaystyle \frac{\partial l}{\partial x_i^{(k)}} = \frac{\partial l}{\partial \hat{x}^{(k)}_i}\frac{1}{\sqrt{\sigma_B^{(k)^2}+\epsilon}}+\frac{\partial l}{\partial \sigma_B^{(k)^2}}\frac{2(x_i^{(k)}-\mu_B^{(k)})}{m}+\frac{\partial l}{\partial \mu_B^{(k)}}\frac{1}{m} } .

Применение при инференсе

На этапе обучения нормализация опирается на значения в мини-батчах. Однако при инференсе (использовании обучённой модели) используют оценки по всей обучающей выборке — средние Невозможно разобрать выражение (SVG с запасным PNG (MathML можно включить с помощью плагина для браузера): Недопустимый ответ («Math extension cannot connect to Restbase.») от сервера «http://restbase-svc.restbase.svc.production22.local:7231/ru-mediawiki.ruwiki.svc.production22.local/v1/»:): {\displaystyle E[x^{(k)}]} и дисперсии Невозможно разобрать выражение (SVG с запасным PNG (MathML можно включить с помощью плагина для браузера): Недопустимый ответ («Math extension cannot connect to Restbase.») от сервера «http://restbase-svc.restbase.svc.production22.local:7231/ru-mediawiki.ruwiki.svc.production22.local/v1/»:): {\displaystyle \operatorname{Var}[x^{(k)}]} , вычисляемые как

Невозможно разобрать выражение (SVG с запасным PNG (MathML можно включить с помощью плагина для браузера): Недопустимый ответ («Math extension cannot connect to Restbase.») от сервера «http://restbase-svc.restbase.svc.production22.local:7231/ru-mediawiki.ruwiki.svc.production22.local/v1/»:): {\displaystyle E[x^{(k)}] = E_{B}[\mu^{(k)}_B]} , Невозможно разобрать выражение (SVG с запасным PNG (MathML можно включить с помощью плагина для браузера): Недопустимый ответ («Math extension cannot connect to Restbase.») от сервера «http://restbase-svc.restbase.svc.production22.local:7231/ru-mediawiki.ruwiki.svc.production22.local/v1/»:): {\displaystyle \operatorname{Var}[x^{(k)}] = \frac{m}{m-1}E_{B}[(\sigma^{(k)}_B)^2]} .

Таким образом, итоговые оценки по всей выборке представляют усреднённую по батчам статистику.

В режиме инференса BN-преобразование выглядит так:

Невозможно разобрать выражение (SVG с запасным PNG (MathML можно включить с помощью плагина для браузера): Недопустимый ответ («Math extension cannot connect to Restbase.») от сервера «http://restbase-svc.restbase.svc.production22.local:7231/ru-mediawiki.ruwiki.svc.production22.local/v1/»:): {\displaystyle y^{(k)} = BN^{\text{inf}}_{\gamma^{(k)},\beta^{(k)}}(x^{(k)})=\gamma^{(k)}\frac{x^{(k)} - E[x^{(k)}]}{\sqrt{\operatorname{Var}[x^{(k)}]+\epsilon}} + \beta^{(k)}} ,

и переход к следующим слоям сети идёт от Невозможно разобрать выражение (SVG с запасным PNG (MathML можно включить с помощью плагина для браузера): Недопустимый ответ («Math extension cannot connect to Restbase.») от сервера «http://restbase-svc.restbase.svc.production22.local:7231/ru-mediawiki.ruwiki.svc.production22.local/v1/»:): {\displaystyle y^{(k)}} . Здесь параметры BN фиксированы, и преобразование сводится к линейному преобразованию над активацией.

Теоретические аспекты

Хотя нормализация по батчу приобрела широкое распространение благодаря высокой практической эффективности, до сих пор не существует строгого и общепринятого объяснения её действия. В оригинальной работе[1] предполагалось, что BN помогает, устраняя внутренний ковариативный сдвиг, но последующие исследования это поставили под сомнение. Один из опытов[2] проводился с сетью VGG-16[3] в трёх режимах: без BN, с BN и с BN плюс шум (для явного создания ковариативного сдвига). Несмотря на шум, второй и третий варианты показали одинаковую точность, оба превосходя первый, что говорит о том, что не устранение ковариативного сдвига определяет улучшение.

Применение нормализации по батчу приводит к тому, что элементы внутри одного батча перестают быть независимо и одинаково распределёнными, что иногда ухудшает оценку градиента и, соответственно, замедляет обучение[4].

Сглаживание функций потерь

Одна из популярных альтернативных гипотез[2] объясняет эффективность BN тем, что её применение приводит к сглаживанию пространства параметров и градиентов (уменьшение Липшицевой константы).

Для двух идентичных сетей — с BN и без BN — можно показать, что величина градиента для BN ограничена сверху выражением:

Невозможно разобрать выражение (SVG с запасным PNG (MathML можно включить с помощью плагина для браузера): Недопустимый ответ («Math extension cannot connect to Restbase.») от сервера «http://restbase-svc.restbase.svc.production22.local:7231/ru-mediawiki.ruwiki.svc.production22.local/v1/»:): {\displaystyle ||\triangledown_{y_i}\hat{L}||^2 \le \frac{\gamma^2}{\sigma_j^2}\Bigg(||\triangledown_{y_i}L||^2-\frac{1}{m}\langle 1,\triangledown_{y_i}L\rangle^2-\frac{1}{m}\langle\triangledown_{y_i}L,\hat{y}_j\rangle^2\bigg) } ,

где — параметр масштаба, — дисперсия по батчу. Чем больше дисперсия и сильнее корреляция между градиентом и активацией, тем выраженнее сглаживающий эффект.

Аналогичные оценки можно получить и для гессиана функции потерь и других характеристик, что свидетельствует о повышении устойчивости градиентов при использовании BN.

Однако часть исследователей полагает, что для полного анализа эффективности BN необходимо учитывать весь спектр собственных значений гессиана, а не только крайнее[2][5].

Измерение эффекта ковариативного сдвига

Для проверки влияния BN на ковариативный сдвиг экспериментально измеряют корреляцию градиентов потерь до и после изменения параметров предыдущих слоёв, что даёт численную оценку величины сдвига. Четыре сравниваемые модели: стандартная сеть VGG, VGG с BN, глубокая линейная сеть (DLN) и DLN с BN, показали, что дополнительные слои BN не уменьшают степень внутреннего ковариативного сдвига, что ставит под сомнение изначальное объяснение эффективности метода.

Взрыв и исчезновение градиентов

Хотя BN изначально вводилась как средство борьбы с проблемой затухания или взрыва градиентов, на практике в очень глубоких BN-сетях наблюдается выраженный взрыв градиента при инициализации, независимо от используемой функции активации. Если сеть имеет слоёв, норма градиента первого слоя растёт экспоненциально с L: Невозможно разобрать выражение (SVG с запасным PNG (MathML можно включить с помощью плагина для браузера): Недопустимый ответ («Math extension cannot connect to Restbase.») от сервера «http://restbase-svc.restbase.svc.production22.local:7231/ru-mediawiki.ruwiki.svc.production22.local/v1/»:): {\displaystyle > c\lambda^L} , где Невозможно разобрать выражение (SVG с запасным PNG (MathML можно включить с помощью плагина для браузера): Недопустимый ответ («Math extension cannot connect to Restbase.») от сервера «http://restbase-svc.restbase.svc.production22.local:7231/ru-mediawiki.ruwiki.svc.production22.local/v1/»:): {\displaystyle \lambda > 1, c > 0} . Для ReLU-подобных функций уменьшается с ростом батча, но остаётся больше единицы. Это делает глубокие BN-сети практически необучаемыми — ситуация нормализуется лишь в архитектурах с остаточными (skip-) связями[6].

Декорреляция длины и направления

Считается также, что важный вклад BN даёт благодаря декорреляции длины и направления весовых векторов: обучение разлагается на задачи по модулю и по направлению, что ускоряет сходимость.

В частности, с помощью BN параметризация весов приводится к форме Невозможно разобрать выражение (SVG с запасным PNG (MathML можно включить с помощью плагина для браузера): Недопустимый ответ («Math extension cannot connect to Restbase.») от сервера «http://restbase-svc.restbase.svc.production22.local:7231/ru-mediawiki.ruwiki.svc.production22.local/v1/»:): {\displaystyle \tilde{w} = \gamma \frac{w}{||w||_s}} , где и играют роли масштаба и направления раздельно; это приводит к существенно более быстрой сходимости при оптимизации[5].

Линейная сходимость

Задача наименьших квадратов

Анализ показывает, что применение BN к задаче обыкновенных наименьших квадратов (OLS) переводит её из сублинейного в линейный режим сходимости по градиентному спуску. В частности, задача сведения к обобщённому частному Рэлея позволяет получить явную оценку на скорость приближения к оптимуму.

Обучение разделяющей поверхности

В простейшем случае линейного персептрона задача обучения — это минимизация функции вида Невозможно разобрать выражение (SVG с запасным PNG (MathML можно включить с помощью плагина для браузера): Недопустимый ответ («Math extension cannot connect to Restbase.») от сервера «http://restbase-svc.restbase.svc.production22.local:7231/ru-mediawiki.ruwiki.svc.production22.local/v1/»:): {\displaystyle E_{y,x}[\phi(z^T\tilde{w})]} по Невозможно разобрать выражение (SVG с запасным PNG (MathML можно включить с помощью плагина для браузера): Недопустимый ответ («Math extension cannot connect to Restbase.») от сервера «http://restbase-svc.restbase.svc.production22.local:7231/ru-mediawiki.ruwiki.svc.production22.local/v1/»:): {\displaystyle \tilde{w}} , где Невозможно разобрать выражение (SVG с запасным PNG (MathML можно включить с помощью плагина для браузера): Недопустимый ответ («Math extension cannot connect to Restbase.») от сервера «http://restbase-svc.restbase.svc.production22.local:7231/ru-mediawiki.ruwiki.svc.production22.local/v1/»:): {\displaystyle z = -yx} , — произвольная выпуклая функция потерь. Для нормального распределения признаков всё множество критических точек (оптимумов) располагается на одной прямой; декомпозиция весов BN ускоряет сходимость и здесь.

Описанный алгоритм GDNP (Gradient Descent in Normalized Parameterization, Градиентный спуск в нормализованных параметрах), основанный на BN, гарантирует линейную сходимость решения задачи — для длины и направления весов раздельно, что подтверждается теоретически и экспериментально.

Многослойные нейронные сети

В случае многослойного персептрона (MLP) с одним скрытым слоем и функцией активации tanh, оптимизация по входным и выходным весам каждого скрытого нейрона вновь приводит к тому, что точки оптимума всех скрытых слоёв лежат на единой прямой — что, в сочетании с BN, гарантирует линейную сходимость по соответствующим параметрам[5].

Примечания

Литература

  • Ioffe, Sergey; Szegedy, Christian (2015). "Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift", ICML'15: Proceedings of the 32nd International Conference on International Conference on Machine Learning — Volume 37, июль 2015, стр. 448–456.
  • Simonyan, Karen; Zisserman, Andrew Very Deep Convolutional Networks for Large-Scale Image Recognition (англ.). arXiv. arXiv (2014). Дата обращения: 9 июня 2024. Архивировано 31 марта 2024 года.

Категории