avatar
gonzo-обзоры ML статей
@gonzo_ML
02.06.2021 12:06
VICReg: Variance-Invariance-Covariance Regularization for Self-Supervised Learning
Adrien Bardes, Jean Ponce, Yann LeCun

Статья: https://arxiv.org/abs/2105.04906
Неофициальный код: https://github.com/vturrisi/solo-learn (с разными методами SSL, включая VICReg)
Библиотека для SSL от FAIR (где можно ожидать появления VICReg, но пока нет): https://github.com/facebookresearch/vissl + https://vissl.ai/

Май 2021, продолжение развития Barlow Twins от Яна Лекуна и ко.

Проблема всё та же — хочется избежать коллапса репрезентаций, но желательно не разными [работающими, заметим] хаками и имплементациями, а каким-нибудь красивым методом, который был бы обоснован и интерпретируем.

Архитектура сиамская симметричная, обе ветви одинаковые и с расшаренными весами, ResNet-50 энкодер + MLP проектор, в каждую ветвь подаётся своя аугментация входного изображения. Вся инновация снова в лосс-функции.

В данной работе предлагается новая трёхкомпонентная целевая функция под названием VICReg (Variance-Invariance-Covariance Regularization), самое новое из которого — это Variance.

Разберём по частям.

Invariance regularization — в целом общее место во многих методах, отвечающее за одинаковость репрезентаций аугментаций одного и того же объекта. Здесь реализуется через mean-squared euclidean distance.

Covariance regularization — идея аналогичная Barlow Twins для того, чтобы получить независимые фичи, декоррелируя их. Реализована в виде суммы квадратов внедиагональных элементов матрицы ковариации, которые данный лосс стремится занулить. Но разница в том, что в BT была кросс-корреляция векторов из двух ветвей сиамской сети, а здесь это ковариация между фичами внутри каждой ветви сети, делающая фичи независимыми друг от друга локально в пределах ветви. Заодно избавились от стандартизации, которая была в BT (и приводила к численной нестабильности, когда эмбеддинг уезжал в константу), здесь вместо неё работает Variance regularization.

Variance regularization — метод регуляризации дисперсии фич, защищает от коллапса. Технически это hinge loss на стандартное отклонение фич проекции по измерению батча. В лоссе фигурирует параметр гамма, который является целевым для стандартного отклонения по каждой фиче, в работе был установлен в 1.

То есть по сути в работе декомпозировали различные цели SSL в отдельные компоненты лосса. Они замешиваются в итоговый лосс с гиперпараметрами, дающими вес каждому компоненту.

Из новых свойств метода то, что он (вроде как первый) не требует какой-либо отдельной нормализации фич.

Эксперименты снова на ImageNet в режимах linear classification (обучают линейный классификатор поверх эмбеддингов полученных в self-supervised режиме) и semi-supervised (когда дополнительно обучают с учителем на 1% или 10% разметки). Метод примерно в одной группе с BYOL, SwAV и Barlow Twins. Аналогично на других картиночных датасетах и задаче детекции объектов.

Есть довольно подробные абляции чтобы понять что именно и как работает. Заодно показывают, что регуляризацию дисперсии можно использовать и с другими архитектурами (например, BYOL или SimSiam) и это даёт некоторое улучшение.

Метод снова не особо чувствителен к размеру батча и нормально работает на небольших батчах.

Потенциальный недостаток метода — необходимость вычисления матрицы ковариации, а это квадратичная сложность по размеру фич (но всё равно не настолько ужасно как в методах, требующих вычисления обратных матриц, что используется в методе W-MSE на основе whitening). Собираются в будущем поискать более эффективные решения, не требующие вычисления полной матрицы.
GitHub
GitHub - vturrisi/solo-learn: solo-learn: a library of self-supervised methods for visual representation learning powered by Pytorch Lightning
solo-learn: a library of self-supervised methods for visual representation learning powered by Pytorch Lightning - vturrisi/solo-learn
16 2.3K

Обсуждение 0

Обсуждение не доступно в веб-версии. Чтобы написать комментарий, перейдите в приложение Telegram.

Обсудить в Telegram