VICReg: Variance-Invariance-Covariance Regularization for Self-Supervised Learning
Adrien Bardes, Jean Ponce, Yann LeCun
Статья:
https://arxiv.org/abs/2105.04906
Неофициальный код:
https://github.com/vturrisi/solo-learn (с разными методами SSL, включая VICReg)
Библиотека для SSL от FAIR (где можно ожидать появления VICReg, но пока нет):
https://github.com/facebookresearch/vissl +
https://vissl.ai/
Май 2021, продолжение развития Barlow Twins от Яна Лекуна и ко.
Проблема всё та же — хочется избежать коллапса репрезентаций, но желательно не разными [работающими, заметим] хаками и имплементациями, а каким-нибудь красивым методом, который был бы обоснован и интерпретируем.
Архитектура сиамская симметричная, обе ветви одинаковые и с расшаренными весами, ResNet-50 энкодер + MLP проектор, в каждую ветвь подаётся своя аугментация входного изображения. Вся инновация снова в лосс-функции.
В данной работе предлагается новая трёхкомпонентная целевая функция под названием VICReg (Variance-Invariance-Covariance Regularization), самое новое из которого — это Variance.
Разберём по частям.
Invariance regularization — в целом общее место во многих методах, отвечающее за одинаковость репрезентаций аугментаций одного и того же объекта. Здесь реализуется через mean-squared euclidean distance.
Covariance regularization — идея аналогичная Barlow Twins для того, чтобы получить независимые фичи, декоррелируя их. Реализована в виде суммы квадратов внедиагональных элементов матрицы ковариации, которые данный лосс стремится занулить. Но разница в том, что в BT была кросс-корреляция векторов из двух ветвей сиамской сети, а здесь это ковариация между фичами внутри каждой ветви сети, делающая фичи независимыми друг от друга локально в пределах ветви. Заодно избавились от стандартизации, которая была в BT (и приводила к численной нестабильности, когда эмбеддинг уезжал в константу), здесь вместо неё работает Variance regularization.
Variance regularization — метод регуляризации дисперсии фич, защищает от коллапса. Технически это hinge loss на стандартное отклонение фич проекции по измерению батча. В лоссе фигурирует параметр гамма, который является целевым для стандартного отклонения по каждой фиче, в работе был установлен в 1.
То есть по сути в работе декомпозировали различные цели SSL в отдельные компоненты лосса. Они замешиваются в итоговый лосс с гиперпараметрами, дающими вес каждому компоненту.
Из новых свойств метода то, что он (вроде как первый) не требует какой-либо отдельной нормализации фич.
Эксперименты снова на ImageNet в режимах linear classification (обучают линейный классификатор поверх эмбеддингов полученных в self-supervised режиме) и semi-supervised (когда дополнительно обучают с учителем на 1% или 10% разметки). Метод примерно в одной группе с BYOL, SwAV и Barlow Twins. Аналогично на других картиночных датасетах и задаче детекции объектов.
Есть довольно подробные абляции чтобы понять что именно и как работает. Заодно показывают, что регуляризацию дисперсии можно использовать и с другими архитектурами (например, BYOL или SimSiam) и это даёт некоторое улучшение.
Метод снова не особо чувствителен к размеру батча и нормально работает на небольших батчах.
Потенциальный недостаток метода — необходимость вычисления матрицы ковариации, а это квадратичная сложность по размеру фич (но всё равно не настолько ужасно как в методах, требующих вычисления обратных матриц, что используется в методе W-MSE на основе whitening). Собираются в будущем поискать более эффективные решения, не требующие вычисления полной матрицы.
Обсуждение 0
Обсуждение не доступно в веб-версии. Чтобы написать комментарий, перейдите в приложение Telegram.
Обсудить в Telegram