【无监督学习】1、MOCOv1

2023-11-19

在这里插入图片描述

论文:Momentum Contrast for Unsupervised Visual Representation Learning

代码:https://github.com/facebookresearch/moco

出处:FAIR | 何凯明 | CVPR2020

时间:2020.03

MOCO 分类:

  • 在 ImageNet 上使用 linear protocal 和之前的无监督学习方法取得了相当的性能(linear protocal 的意思就是在预训练数据上训练后冻结 backbone,当做特征提取器,只训练分类头的方式)

MOCO 下游任务:可以很好的迁移到下游任务

  • 之所以要做大规模的无监督预训练,就是为了能够学习到一个很好的特征提取器,而且有很好的迁移性,期望能够在没有很多标注数据的下游任务上获得很好的结果
  • MOCO 在 7 个下游任务上取得了超越之前有监督预训练模型的效果
  • 意味着无监督和有监督学习的鸿沟被填补了,因为之前的无监督是在小部分数据上表现好,而 MOCO 在这么多数据上都表现好

MOCO 的主要特点:

  • 将对比学习构建成字典查询的问题,且通过将字典大小和 batch size 解耦开来,使用队列的方式建立了一个足够大的字典
  • 使用动量大方式来更新 key 编码器的权重,保证每次 key 和 query 的对比尽可能的在相同的编码空间里

文章思想分析:

  • 本文的方式可以看做为了建立大的字典来提升编码空间的信息丰富程度,也就是样本能见到足够多的负样本,从而学习到它们是我的对立样本,是和我不一样的。但建立大的字典又非常依赖 batch size(因为负样本是 batch size 内的样本,要建立大的字典就表示当前 batch 内的样本要尽可能的多),大的 batch size 又对显存要求很大,所以就陷入了如何在不依赖 batch size 的情况下来建立大的字典

  • 那么建立大的字典就借用了队列的方式,每个 epoch 提取到的特征送入队列中,同时队列中还保存了之前很多 epoch 保留的特征,但是,不同 epoch 的编码器的权重是更新的,没法保证 query 和 key 是在一个绝对一致的编码空间的,这样的对比其实意义不大,因为标准都不一样了怎么判断相似度呢,所以作者就提出了一个缓解的方法,就是以动量的形式来更新编码器的权重,这样虽然也没法绝对的保证每次对比的编码空间是一致的,但总能拉小这种不一致性吧,所以就有了动量更新 query 编码器的形式

  • 当预训练数据集换成 ins(1亿) 时,提升相比 imagenet 的 100w 数据有点小,所以作者觉得大规模数据集还没有很好的利用起来,这可能是源于没有一个好的代理任务而导致的

  • MOCO 还提出一个讨论,是不是能使用除了 instance discrimination 之外的代理任务,如 masked auto-encoding 的方式,将图像的一块挡起来,来和 MOCO 一起训练,这个想法就是后面的 MAE

一、背景

无监督表达学习在自然语言处理方面已经有了很成功的应用,如 GPT 和 BERT,但当时(2020年左右)在计算机视觉中还是监督学习占主流。

其主要原因在于两者的特征信号不同,语言任务的数据是在离散数字空间(如单词),可以很方便的建立 tokenized 词典,可以很好的把一个单词对应成一个特征,一旦有了这个词典呢,无监督学习可以基于它很好的展开。因为可以简单的把字典里的所有 key 想象成一个类别,就变成了一个有监督学习的范式了,所以还是有一个类似标签一样的东西来帮助学习,所以 NLP 上无监督学习很容易建模。

但视觉任务的数据是在连续的高维空间,不像单词有很强的语义信息,没有那么容易建模,所以并不适合建立一个字典,没有字典的无监督学习很难建模

近期有一些使用对比学习的无监督方法,取得了较好的效果,都可以被归纳成一种方法,就是在构造动态的字典

当时的主流方法是怎么做的:

  • 很多方法使用对比学习 loss 来解决,也就是最小化 contrastive loss
  • 相当于构建一个动态字典,key(token)是从数据(image/patch)中采样的,使用 encoder 提取特征
  • 目标是让编码后的 query 和与其匹配的 key 最为相近,和其他没匹配的 key 都距离很远

如何理解:

  • 假设有一个数据集有 N 张图片,取一张图经过变换获得两个样本 x1 和 x2,这两个样本也是一对正样本,把 x1 叫做 anchor ,那么 x2 就是 positive,就是 x1 的正样本,剩下的其他样本都是负样本
  • 一旦有了正负样本后,就可以进入编码器来提取特征,x1 和 x2 的编码器可以相同也可以不同,但 x2 和其他负样本的编码器一定是要相同的,因为样本的正负是相比于 x1 这个 anchor 的,从而让除了 anchor 外的样本都使用相同的编码器来得到同一个编码空间中的特征
  • 对比学习就是让在编码空间里,正样本对儿尽可能相近,负样本对儿尽可能远离
  • MOCO 为什么认为上面的这些操作都可以被看做动态生成一个字典呢,因为如果把这些除了 anchor 之外的样本的特征都看做字典里的 key,anchor 的特征看做 query,那么对比学习就转换成了从字典里边查询和 query 相近的 key 的问题,从而只需要最小化对比学习 loss 即可。

因为将对比学习转换成了字典查询的问题,所以本文中很少出现正负样本等字样,而是使用 query 表示 anchor (x_key),key 表示其他样本(x_query),对应的特征使用 k 和 q 表示。

那么从字典查询的角度来理解对比学习呢,作者就说期望构建的字典有如下两个特征:

  • 第一个特征:大,因为字典越大(key 越多)则包含的视觉特征就越丰富,当拿 query 去和 key 做对比的时候,就真的可能学习到把物体能区分开的特征,如果字典很小,则可能不能很好的泛化
  • 第二个特征:在训练的时候要保持尽可能的一致性,字典里边的 key 都应该使用相同或相似的编码器抽取得到的,才能保证对比的时候尽可能的一致,如果由不同的编码器的得到的,则可能会出现对比偏差

基于上面的两个期望的特征,作者提出了 MOCO,目的就是构建一个大且连续的字典,如图 1 所示,关键的贡献在于 momentum encoder 和 queue

  • queue:为什么要用队列来表示字典呢,主要还是受限于显卡的内存,因为如果想要字典越大,就要输入很多的图片,如果字典是几千或上万的话,显卡肯定是受不了的,所以作者使用队列,将字典和大小和 batch_size 大小剥离开。因为每次训练时可见的负样本就是同一个 batch 中的样本,如果 batch 很小那么字典不够大,如果 batch 很大的话又需要很大的显存。所以把负样本的数量(key的数量)和 batch 的大小剥离开很重要。本文作者就使用队列的方法,最近的 batch 入队,最老的 batch 出队,所以队列的大小就可以设定的很大,且更新的慢。
  • momentum encoder:但是,query 虽然更新的慢,但也是更新的啊,前面作者说了 key 的特征要尽可能的保持一致,就是说最好使用相同的编码器来提取特征才能保证每次查询的时候特征是在相同的编码空间的。那么 query 的更新每次使用的 encoder 其实是会随着训练过程的进行来变化的,所以提取到的特征也是会变化的。所以作者使用了动量编码器,保证生成 key 特征的编码器(蓝色)不会每次都从绿色编码器更新,而是基于上一时刻的 key encoder 和本时刻的 query encoder 共同决定的。
    在这里插入图片描述

动量对比学习: y t = m ⋅ y t − 1 + ( 1 − m ) ⋅ x t y_t = m \cdot y_{t-1} + (1-m) \cdot x_t yt=myt1+(1m)xt

  • 动量是加权的意思,不想让当前时刻的输出完全依赖于当前时刻的输入,想让前一时刻的输出也影响当前时刻的输出,m 如果越大,则上一时刻的输出对当前时刻影响越大,m 如果越小,则当前时刻的输入对当前时刻输出影响越大
  • MOCO 就是使用动量来缓慢的更新编码器,让学习到的特征尽可能保持一致

二、方法

2.1 对比学习(字典查表)

代理任务 instance discrimination (个体判别)介绍:认为每张图片都是自己的类

  • 给定一堆数据,在一张图中随机裁剪和数据增强,得到两个图,但这两个图都是来自于同一个图中,那这两个图就是正样本,数据集中的其他图片都是负样本
  • 然后通过一些模型,得到特征,然后在特征上使用对比学习 loss 即可
  • 对比学习最厉害的地方就在于灵活性,可以大开脑洞制定正样本和负样本的规则,比如一个物体的 rgb 图像和深度图像可以被认为正样本,然后扩展到了多模态任务,促成了 CLIP 的出现

什么是对比学习:

  • 对比学习的输入是对每张图进行两种不同的变换,即经过不同的数据增强,会得到两种不同的数据 q 和 k,k 不需要使用梯度反传更新参数(是使用 q encoder 参数和历史的 k encoder 参数的加权和来更新的),q 需要使用梯度反传更新参数
  • 对特征 q 来说,总有一个特征 k + k_+ k+ 是其正样本,这个两个特征就是同一张原始图像的两种不同特征而已
  • 对特征 q 来说,同一个 batch 中的其他图像提取到的特征就是负样本
  • 从原理上说,提高对比学习的效果就是提供足够大的 batch 、研究更加有效的预处理方式,使得变换后的两个图像既能保留本质信息,又能尽可能的不一致、增加模型 encoder 的能力

对比学习是无监督学习/自监督学习,一般的研究点落在下面两个地方:

  • 代理任务
  • loss 函数

MOCO 的落脚点是在目标函数上,提出的这个大的字典主要会影响后面计算 loss,对比学习的 loss 不同于生成/判别式的目标都是固定的目标,而对比学习的目标是会变化的,在训练的过程中会变化,也就是目标是由编码器中抽出来的特征决定的,主要是在同一个特征表达空间中衡量样本之间的相似程度。

对比学习可以被看做为了字典查表任务训练一个 encoder 的任务:

  • 假设一个 encoded query q q q 和一系列编码好的样本 { k 0 , k 1 , k 2 , . . . } \{k_0, k_1, k_2, ...\} {k0,k1,k2,...}(即字典的 key)
  • 假设字典中只有一个 key k + k_+ k+ 是和 q q q 匹配的

那么训练这类任务的时候 loss 要有什么属性呢:

  • q q q k + k_+ k+ 非常近似且和其他 key (即 negative key)远离时,contrastive loss 要很小
  • 反之,loss 要很大来惩罚模型

在看对比学习 loss 之前先看一下交叉熵 loss:

在这里插入图片描述

假设分类任务的 gt 是 one-hot 向量,那么一般会对最后一层特征的输出做 softmax 后和 gt 来做交叉熵损失。

假设第 i 个输出为 z i z_i zi,那么经过 softmax 后为: e z i ∑ i = 0 k e z i \frac{e^{z_i}}{\sum_{i=0}^ke^{z_i}} i=0keziezi,交叉熵结果为:

− l o g e z i ∑ i = 0 k e z i -log\frac{e^{z_i}}{\sum_{i=0}^ke^{z_i}} logi=0keziezi

这里的 k 指的是训练的时候有多少类别,比如 1000 类就是 1000,是固定的数字

但对比学习,类别数 k 将会是非常大的数,有多少图片就有多少类,计算复杂度非常高,所以就出来了 NCE loss

因为之前类别太多,没法计算,NCE 就简化成二分类问题,一个是数据样本,一个是噪声样本,将数据样本和噪声样本做对比就可以了

但如何将整个数据集中剩下的所有图片当负样本,NCE 就解决了类别多的问题,但没有解决计算复杂度高的问题

如何让 loss 计算的更快一点呢,就是取近似,与其计算整个数据集中剩下的数据的 loss,不如直接选择一些作为负样本来计算 loss 就可以了,这也就是 estimation 的含义。

如果选取样本少,近似性会越差,如果选的越多,那么近似性会越高,效果会更好,所以 MOCO 希望字典足够大,足够大的字典就能更好的近似整个空间

所以 NCE loss 就把超级多分类的问题,变成了一系列的二分类问题,从而还可以使用 softmax

这里的 InfoNCE 就是一个变体,觉得如果把数据看成二分类,可能对模型学习也不是很友好,毕竟在那么多噪声样本里,大家也不是一个类,还是看成更多的类比较好

InfoNCE[46] 公式如下

在这里插入图片描述

  • 这里的 q . k q.k q.k 其实就是 logits
  • τ \tau τ:温度超参,用于控制分布的形状,温度值越大,分布会越平缓,如果温度设的越大,对比损失对所有负样本都一视同仁,导致模型的学习没有轻重,如果设置过小,会让模型只关注特别困难的负样本,其实那么负样本可能是潜在的正样本,也会导致模型不好泛化
  • 忽略温度系数,就是 cross entropy loss,唯一不同的是,这里的 K 指代的是负样本数
  • K K K:negative samples 的数量
  • 1 1 1:positive samples 的数量
  • 这里的 sum 其实是在 K 个负样本和 1 个正样本上做的,因为 i 是从 0 到 K 的,是 K+1 个数,也就是字典里所有的 key
  • 直观来想,这个 NCE loss 其实就是 K+1 类的分类,就是想把 q 分成 k+ 这个类别

2.2 动量对比函数

作者之所以提出动量对比是因为作者认为使用大的字典能够引入更丰富的负样本,但大的字典中不同 batch 提取特征的模型参数是一直在更新的,这样就难以使用梯度反传的方式来更新 key encoder,之前有些方法使用 query encoder 的参数来当做 key encoder 的参数,但这样就会导致特征不连续,因为 encoder 是随着训练的进行会变化的,所以会导致特征不在相同的特征空间内。

1、Dictionary as a queue:如何把一个字典看成一个队列

本文思想的核心在于将字典当做数据的队列,这样的好处就是能够对前面 batch 的编码特征进行重复使用,可以将字典大小和 batch 大小解耦开来,字典的大小可以远远大于 batch 的大小,且大小可以设置为可调节的超参数。

字典中的样本可以被逐步的替代,当前 batch 的特征入队,最老的 batch 的特征出队

这个字典中的特征其实是整个数据集中的子集,前面说过计算 loss 的时候不是在整个数据集中来计算 loss,而使用这个字典中的特征来计算,其实就可以看做一个近似

而且字典的长度基本不会影响训练时间,且队列的先进先出的特性,能让新的特征加入计算,更有时效性

2、momentum update:如何用动量的方式来更新 key encoder 的参数

使用队列可以使得字典变大,但不同 batch 提取特征的模型参数是一直在更新的,这样就难以使用梯度反传的方式来更新 key encoder,因为需要给队列中的所有 samples 传递梯度。

一个简单的做法是直接从 query encoder 来复制得到 key encoder,但效果不好

作者猜测这不好的效果来源于 encoder 剧烈的变化会降低 key 表达特征的一致性,所以提出了动量更新的方法。

所以,动量 encoder 更新 θ k \theta_k θk 的公式如下:

在这里插入图片描述

  • f k f_k fk:key encoder,参数为 θ k \theta_k θk,首次的时候是使用 query encoder 来更新的
  • f q f_q fq:query encoder,参数为 θ q \theta_q θq
  • m ∈ [ 0 , 1 ) m \in [0, 1) m[0,1) 是动量系数
  • 只对 θ q \theta_q θq 使用梯度反传来更新参数
  • θ k \theta_k θk 使用公式 2 进行更新,会使得其更新的更加平滑,尽管 queue 中的 key 是使用不同的 encoder 来编码的(因为是在不同的 batch 中得到的),这些 encoder 的差异也可以变得很小
  • 在实验中,作者使用大的动量(m=0.999)就比小的动量(m=0.9)表现更好,99.9% 都是原来的参数,只有 0.01% 是来源于更新后的 query encoder,这也说明使用好 queue 的核心就在于 encoder 的变化要缓慢

3、Relations to previous mechanisms:MOCO 和之前的方法到底有什么不同呢,又是如何使用动量的方式解决问题的

MoCo 和之前两种方法的对比见图 2,主要的差别就在于字典的尺寸和参数更新的一致性。

在这里插入图片描述

  • 图 2a 是端到端的梯度反传方法:牺牲字典大小,保证了一致性

    就是每次都通过梯度回传来更新 encoder k 的参数,这里 q 和 k 可以相同也可以不同,之所以可以相同是因为这种训练方式,每次的正负样本都是来源于当前训练的 batch 中,这些样本是高度一致的,因为是使用相同的编码器得到的,这也就意味着字典的大小和 batch 的大小是耦合的,会被 GPU 显存限制,且大的 batch size 也很难收敛

  • 图 2b 是 memory bank 方法:牺牲了一致性,提高了字典大小,但扩展性也不好

    memory bank 包括所有数据集中 sample 的特征表达,这里只有一个 query encoder,key 是不需要 encoder 的,而是从存储好特征的 memory bank 中随机采样的,如果是 ImageNet 就有 128w 个特征,每个特征 128 维,所以只需要 600M 空间来存储这些 key,查询也很高效,也能支持很大尺寸的字典。而且他的做法会更新 memory bank 中的样本特征,做法就是每次选择 m 个样本来作为负样本计算 loss,query encoder 就会更新权重,然后会使用更新后的 encoder 对这 m 个负样本生成新的特征,这样的过程会缺少一致性,每次的 encoder 都会被更新,都是不同时刻得到的,就意味着存储的特征缺乏一致性。

2.3 Pretext Task

pretext task 也叫前置任务或代理任务,也就是该任务不是目标任务,但执行该目标可以更好的执行目标任务,本质就是迁移学习。

参考论文 [61],作者也将来源于同一张图像的 query 和 key 当做一组 positive pair,其他都是 negative pair

query 和 key 都被其各自的 encoder 进行编码 f q f_q fq f k f_k fk,编码器可以是卷积神经网络。

Algorithm 1 展示了 MoCo 的伪代码

  • 代理任务:对 x 做两次数据增强,得到 x_q 和 x_k,作为一对正样本对

  • batch size 为 256,queue 长度为 65536,编码器输出特征为 128 维,也就是每个样本被编码为 128 维的特征

  • 对 x_q 使用 f_q 提取特征得到 [256,128] 特征 q,x_k 使用 f_k 提取特征得到 [256,128] 特征 k,k 不进行梯度回传

  • 计算 query 和正样本的相似度 logits:每个 q 对应一个 k,得到 256x1,即每个 query 样本和它对应的正样本 k+ 的相似度的

  • 计算 query 和负样本的相似度 logits:每个 q 对应 K 个负样本(K=65536),得到 256x65536 的特征,表示每个 query 和负样本相似度

  • 计算完 query 和正样本/负样本的相似度后,concat 起来,得到 256x65537 维特征向量,表示每个 query 和其他样本的相似度,第一列就是正样本,最好的期望就是第一列全为 1,后面所有列都为 0 ,那就学的非常好了

  • 真值:作者使用 256 维全 0 向量表示真值 gt,放到真正的分类任务上来说,就是说每个样本学习的真正类别就是 0 类,也就是说在 0 位置上的 one-hot 值为 1,其他位置(其他 65536 个位置)上的 one-hot 值都为 0。

    为什么这样来定义真值呢,因为 query 和正样本的相似度矩阵 256x1 后面 concat 了负样本相似度矩阵,也就是 query 和正样本的相似度在整个矩阵的第一列,也就是位置 0,所以对于正样本,如果找对了 key,在分类任务上找到的正确类别就是类别 0

    再想想对比学习学的是啥,学的是样本的相似度,而不是真正的类别,再换句话说就是,要在 0 位置上的相似度为 1,因为 0 位置上永远是正样本。

  • 梯度回传更新后,再使用动量方法更新 key encoder

  • 然后更新 queue

在这里插入图片描述

三、效果

3.1 数据集

1、ImageNet-1M(IN-1M)

约有 1.28 million 数据,共 1000 个类别

2、Instagram-1B (IG-1B)

约有 1 billion 数据,来源于 Instagram

3.2 训练细节

优化器:SGD

  • weight decay:0.0001
  • momentum:0.9

IN-1M :

  • batch: IN-1M 使用 256(8 卡训练)
  • 初始学习率:0.03
  • epoch:200,在 120 和 160 时分别乘以 0.1
  • 训练时间:ResNet50 训练时间大约为 53 小时

IG-1B:

  • batch:1024(64卡训练)
  • 学习率:0.12(每 62.5k iter 时降低到 0.9)
  • iter:1.25M
  • 训练时间:ResNet50 训练时间约为 6天

3.3 实验

1、不同 loss 的对比

作者冻结训练好的特征(在 IN-1M 上无监督预训练),在后面接了一个 linear classification,就是一个全连接+average pooling 的操作,训练了 100 个 epoch,只训练这个分类器,也表现的较好。表明 MOCO 可以很好的迁移到下游任务,有效的弥补了有监督和无监督的鸿沟。

但作者说使用 grid search 的方式发现最优的初始化学习率是 30,所以这是很奇怪的现象,暗示无监督学习到的分布和有监督学习的分布是非常不同的。

对比不同对比学习 loss 的结果见图 3,横坐标表示负样本的数量,纵坐标是 top-one 准确率:

  • 黑色是端到端学习,最大 batch size 为 1024
  • memory bank 的方式是蓝色,虽然可以走到很远,但有限
  • 橘色的是 MOCO 的,效果最好,对硬件要求最低,可扩展性最好

在这里插入图片描述

2、Momentum 的效果

动量从小到大带来的影响,使用 0.999 的效果最好,变小的时候下降的明显,尤其是去掉动量后整个模型都无法收敛,loss 一致是震荡。非常有利的证明了作者的论点。

K=4096

在这里插入图片描述

3、ImageNet 分类上的结果

所有的方法都是在 linear protocal 方式下进行的,都没有更新 backbone

上面是不用对比学习的,下面是用对比学习的

而且在无监督里边,模型的大小是很重要的,所以使用的网络结构和参数量对比也很重要。

之前的方法使用 R50 最好的效果是 58.8,MOCO 达到了 60.6,高两个点

不对模型做限制,变宽 channel 后,得到的效果也很好

在这里插入图片描述

4、在 PASCAL VOC 上的目标检测效果

要验证下游任务能否有好的效果才是更重要的

对比使用 ImageNet 有监督训练和 MOCO 无监督训练的模型来初始化模型做不同的下游任务

COCO 文章中有证明说,当使用较大的数据集如 COCO 的时候,随机初始化然后训练 6x~9x (1x = 12 epoch)的模型仍然可以达到很好的效果,那么也就没法证明预训练模型的效果了

但是,如果训练时间短的话呢,预训练模型还是有用的,所以作者只对比 1x~2x 的训练效果

检测器:Faster RCNN

  • 第一行是随机初始化
  • 第二行是有监督模型初始化
  • 第三行是 moco 使用 imagenet 预训练模型初始化,只有这个 81.8 有一点低,其他指标都高于有监督预训练模型
  • 第四行是 moco 使用 ins 预训练模型做初始化

在这里插入图片描述

端到端、memory bank、moco 的对比,效果都是 moco 更好,且超越了有监督训练结果

在这里插入图片描述

在这里插入图片描述

5、在 COCO 上的检测和分割

模型:Mask RCNN(with FPN)

在这里插入图片描述

下游任务:

在这里插入图片描述

总结:

  • moco 在大多数任务上都超越了
  • moco 在实例分割和语义分割上表现不好,所有就思考是不是在 dense 预测上有些不好
  • 使用 ins 预训练普遍比使用 imagenet 预训练的效果都更好,那么是不是就说明 moco 扩展性比较好呢,这也和 NLP 上的结论是一样的

四、代码

训练无监督学习的方式:

# Copyright (c) Meta Platforms, Inc. and affiliates.

# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn as nn


class MoCo(nn.Module):
    """
    Build a MoCo model with: a query encoder, a key encoder, and a queue
    https://arxiv.org/abs/1911.05722
    """

    def __init__(self, base_encoder, dim=128, K=65536, m=0.999, T=0.07, mlp=False):
        """
        dim: feature dimension (default: 128)
        K: queue size; number of negative keys (default: 65536)
        m: moco momentum of updating key encoder (default: 0.999)
        T: softmax temperature (default: 0.07)
        """
        super(MoCo, self).__init__()

        self.K = K
        self.m = m
        self.T = T

        # create the encoders
        # num_classes is the output fc dimension
        self.encoder_q = base_encoder(num_classes=dim) # 构建 query encoder
        self.encoder_k = base_encoder(num_classes=dim) # 构建 key encoder

        if mlp:  # hack: brute-force replacement
            dim_mlp = self.encoder_q.fc.weight.shape[1]
            self.encoder_q.fc = nn.Sequential(
                nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_q.fc
            )
            self.encoder_k.fc = nn.Sequential(
                nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_k.fc
            )

        for param_q, param_k in zip(
            self.encoder_q.parameters(), self.encoder_k.parameters()
        ):
            param_k.data.copy_(param_q.data)  # initialize           # 初始时是将 query 的参数拷贝到 key encoder 作为初始参数
            param_k.requires_grad = False  # not update by gradient

        # create the queue
        self.register_buffer("queue", torch.randn(dim, K))
        self.queue = nn.functional.normalize(self.queue, dim=0)

        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

    @torch.no_grad()
    def _momentum_update_key_encoder(self): # key encoder 的参数不进行梯度反传的更新,而是使用动量法来更新
        """
        Momentum update of the key encoder
        """
        for param_q, param_k in zip(
            self.encoder_q.parameters(), self.encoder_k.parameters()
        ):
            param_k.data = param_k.data * self.m + param_q.data * (1.0 - self.m)

    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys):
        # gather keys before updating queue
        keys = concat_all_gather(keys)

        batch_size = keys.shape[0]

        ptr = int(self.queue_ptr)
        assert self.K % batch_size == 0  # for simplicity

        # replace the keys at ptr (dequeue and enqueue)
        self.queue[:, ptr : ptr + batch_size] = keys.T
        ptr = (ptr + batch_size) % self.K  # move pointer

        self.queue_ptr[0] = ptr

    @torch.no_grad()
    def _batch_shuffle_ddp(self, x):
        """
        Batch shuffle, for making use of BatchNorm.
        *** Only support DistributedDataParallel (DDP) model. ***
        """
        # gather from all gpus
        batch_size_this = x.shape[0]
        x_gather = concat_all_gather(x)
        batch_size_all = x_gather.shape[0]

        num_gpus = batch_size_all // batch_size_this

        # random shuffle index
        idx_shuffle = torch.randperm(batch_size_all).cuda()

        # broadcast to all gpus
        torch.distributed.broadcast(idx_shuffle, src=0)

        # index for restoring
        idx_unshuffle = torch.argsort(idx_shuffle)

        # shuffled index for this gpu
        gpu_idx = torch.distributed.get_rank()
        idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]

        return x_gather[idx_this], idx_unshuffle

    @torch.no_grad()
    def _batch_unshuffle_ddp(self, x, idx_unshuffle):
        """
        Undo batch shuffle.
        *** Only support DistributedDataParallel (DDP) model. ***
        """
        # gather from all gpus
        batch_size_this = x.shape[0]
        x_gather = concat_all_gather(x)
        batch_size_all = x_gather.shape[0]

        num_gpus = batch_size_all // batch_size_this

        # restored index for this gpu
        gpu_idx = torch.distributed.get_rank()
        idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx]

        return x_gather[idx_this]

    def forward(self, im_q, im_k):
        """
        Input:
            im_q: a batch of query images
            im_k: a batch of key images
        Output:
            logits, targets
        """

        # compute query features
        q = self.encoder_q(im_q)  # queries: NxC
        q = nn.functional.normalize(q, dim=1)

        # compute key features
        with torch.no_grad():  # no gradient to keys
            self._momentum_update_key_encoder()  # update the key encoder

            # shuffle for making use of BN
            im_k, idx_unshuffle = self._batch_shuffle_ddp(im_k)

            k = self.encoder_k(im_k)  # keys: NxC
            k = nn.functional.normalize(k, dim=1)

            # undo shuffle
            k = self._batch_unshuffle_ddp(k, idx_unshuffle)

        # compute logits
        # Einstein sum is more intuitive
        # positive logits: Nx1
        # 爱因斯坦求和约定 einsum:第一个参数:equation 中的箭头左边表示输入张量,以逗号分割每个输入张量,箭头右边则表示输出张量
        #                       第二个参数:表示实际的输入张量列表
        l_pos = torch.einsum("nc,nc->n", [q, k]).unsqueeze(-1) # 这个 batch 中有 N 个样本,每个样本有一个正样本 logits 得分
        # negative logits: NxK
        l_neg = torch.einsum("nc,ck->nk", [q, self.queue.clone().detach()]) # 这个 batch 中有 N 个样本,有 K 个负样本,共有 NxK 个相似度得分

        # logits: Nx(1+K)
        logits = torch.cat([l_pos, l_neg], dim=1)

        # apply temperature
        logits /= self.T

        # labels: positive key indicators
        labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda() 
        # 大小为 batch size 的一组全 0 向量,作为 label
        # labels 中的元素实际上意味着在进行 CrossEntropyLoss 计算时,标签为 1 的 ground truth 的索引是多少,而不是 gt 为 0

        # dequeue and enqueue
        self._dequeue_and_enqueue(k)

        return logits, labels


# utils
@torch.no_grad()
def concat_all_gather(tensor):
    """
    Performs all_gather operation on the provided tensors.
    *** Warning ***: torch.distributed.all_gather has no gradient.
    """
    tensors_gather = [
        torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
    ]
    torch.distributed.all_gather(tensors_gather, tensor, async_op=False)

    output = torch.cat(tensors_gather, dim=0)
    return output

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【无监督学习】1、MOCOv1 的相关文章

  • 【程序人生】底层程序员,出局

    底层程序员 出局 不如去送外卖 这是徐亮和同事们常开的一个玩笑 入职两三个月 最初的激情退去 在加完班的夜晚 他疲惫地躺在床上 经常自嘲式地想起这个玩笑 送外卖是搬运食物 自己是搬运代码 都不产出新的东西 在深圳 每个人都走得很快 这是徐亮
  • kafka的安装和使用

    ZooKeeper简介 ZooKeeper 是一个为分布式应用所设计的分布的 开源的 java 协调服务 分布式的应用可以建立在同步配置管理 选举 分布式锁 分组和命名等服务的更高级别的实现的基础之上 ZooKeeper 意欲设计一个易于编
  • C语言(二十一)

    1 查找指定字符 本题要求编写程序 从给定字符串中查找某指定的字符 输入 输入待查找的字符c以及字符串s 输出 找到则输出字符c在字符串s中所对应的最大下标index 否则输出 Not Found 优化目标 无 include
  • TCP/IP详解 卷1:协议 学习笔记 第二十九章 网络文件系统

    NFS 网络文件系统 使客户可以透明地访问服务器上的文件和文件系统 NFS的基础是RPC 两个常用的网络编程API socket和TLI 运输层接口 Transport Layer Interface 通信的双方可使用不同的API RPC可

随机推荐

  • 蚁剑的使用以及用蚁剑做一道ctf题

    一 蚁剑的介绍及下载 1 蚁剑是一款和菜刀相像的shell控制端软件 主要面向于合法授权的渗透测试安全人员以及进行常规操作的网站管理员 2 蚁剑的下载 这是gethub的官方下载地址 供大家下载 3 蚁剑的安装 点击初始化就完成安装 再次点
  • Linux进程管理:deadline调度器

    一 概述 实时系统是这样的一种计算系统 当事件发生后 它必须在确定的时间范围内做出响应 在实时系统中 产生正确的结果不仅依赖于系统正确的逻辑动作 而且依赖于逻辑动作的时序 换句话说 当系统收到某个请求 会做出相应的动作以响应该请求 想要保证
  • Jib使用小结(Maven插件版)

    小结三 多次构建后 积累的无用镜像 如下所示 构建多次后 本地会遗留多个名为 tag也是的镜像 root maven hellojib docker images REPOSITORY TAG IMAGE ID CREATED SIZE b
  • 懒人式迁移服务器深度学习环境(完全不需要重新下载)

    换服务器了 想迁移原来服务器上的深度学习环境 但又觉得麻烦懒得重新安装一遍anaconda pytorch 有没有办法能不费吹灰之力直接迁移 接下来跟着我一起 懒汉式迁移 本方法适用于在同一内网下的两台服务器之间互相迁移 不在同一局域网下的
  • 【华为OD统一考试B卷

    在线OJ 已购买本专栏用户 请私信博主开通账号 在线刷题 运行出现 Runtime Error 0Aborted 请忽略 华为OD统一考试A卷 B卷 新题库说明 2023年5月份 华为官方已经将的 2022 0223Q 1 2 3 4 统一
  • C++ primer智能指针(HasPtr)实现

    智能指针显然是C 吸引人的地方之一 必须掌握 看了 C primer 里面着重讲了智能指针的实现方式 书中说到 HasPtr 注 就是自定义的智能指针 在其它方面的行为与普通指针一致 具体而言 复制对象时 副本和原对象将指向同一基础对象 如
  • linux下libxml库的安装及编译

    linux下libxml库的安装及编译 1 下载和安装LIBXML2 Libxml2是个C语言的XML程式库 能简单方便的提供对XML文件的各种操作 并且支持XPATH查询 及部分的支持XSLT转换等功能 Libxml2的下载地址是 htt
  • Mysql8.0出现this is incompatible with sql_mode=only_full_group_by

    MySQL的sql mode模式说明及设置 sql mode是个很容易被忽视的变量 默认值是空值 在这种设置下是可以允许一些非法操作的 比如允许一些非法数据的插入 在生产环境必须将这个值设置为严格模式 所以开发 测试环境的数据库也必须要设置
  • phabricator mysql_搭建 Phabricator 我遇到的那些坑 - 简书

    一 可能会用到的命令 1 重启phd守护线程 先进入到Fabricator文件夹下面 然后 bin phd log 2 删除一个代码仓库 bin remove destroy rMOBILE 代码库的前缀名字 3 重启mysql数据库 su
  • 数据结构:力扣OJ题

    目录 编辑题一 链表分割 思路一 题二 相交链表 思路一 题三 环形链表 思路一 题四 链表的回文结构 思路一 链表反转 查找中间节点 本人实力有限可能对一些地方解释的不够清晰 可以自己尝试读代码 望海涵 题一 链表分割 现有一链表的头指针
  • Java8 新特性 之 lambda 表达 和 函数式接口

    lambda 表达式 概念 lambda 表达式是一个匿名函数 可以把 lambda 表达式理解为是一段可以传递的代码 更简洁 更灵活 使 Java 的语言表达能力得到了提升 lambda 表达式是作为接口的实现类的对象 万事万物皆对象 使
  • Java取模运算中余数的符号选择问题

    Java取模运算中 余数 的符号和 被除数 符号相同 除号前面的数 即与第一个数的符号相同 public class MyTestProgram public static void main String args 被除数 除数 商 被除
  • idea连接mysql注册登录_idea配置连接数据库的超详细步骤

    学习时 使用IDEA的时候 需要连接Database 连接时遇到了一些小问题 下面记录一下操作流程以及遇到的问题的解决方法 一 连接操作 简介 介绍如何创建连接 具体连接某个数据库的操作流程 1 1 创建连接 打开idea 点击右侧的 Da
  • 并行程序设计作业7/7

    目录 两个线程 一个生产者一个消费者 2k个线程 奇数消费者偶数生产者 2k个线程 每个既可以是生产者又可以是消费者 两个线程 一个生产者一个消费者 include
  • cmake policy

    1 cmake policy是什么 cmake policy可以理解为cmake的语法标准 也就是说 它规定了cmake在解析CMakeLists txt文件时的行为 2 cmake policy的用途是什么 cmake在进化的过程中 需要
  • CAN分析仪 USBCAN USB转CAN CAN转换调试器接口卡使用指导

    USBCAN系列便携式CAN分析仪 通过USB接口快速扩展一路CAN通道 使接入CAN网络非常容易 它具有一体式和小巧紧凑的外形 特别适合于随身携带 第一步 将usbcan卡连接电脑如图 usb灯亮红灯 打开 USBCAN系列便携式CAN总
  • 编程之美2015初赛第二场AB

    题目1 扑克牌 时间限制 2000ms 单点时限 1000ms 内存限制 256MB 描述 一副不含王的扑克牌由52张牌组成 由红桃 黑桃 梅花 方块4组牌组成 每组13张不同的面值 现在给定52张牌中的若干张 请计算将它们排成一列 相邻的
  • 2023.02

    2023 02 01 将mpu写到dxReagion中的数据打印到文件中 调试解决mpu2ipu和ipu2mpu同时跑线程未关掉导致的异常 2023 02 02 学习2102 spec文档和mpu设计文档 将mpuipu测试用例加到回归测试
  • SpringMVC访问静态资源问题

    搭建Spring MVC环境时 如果在Spring MVC的配置文件中DispatcherServlet拦截 则会对 html js jpg等静态文件的访问也会被拦截 想要访问这些静态资源必须要进行相应的配置这里推荐两中比较简单的方法 1
  • 【无监督学习】1、MOCOv1

    文章目录 一 背景 二 方法 2 1 对比学习 字典查表 2 2 动量对比函数 2 3 Pretext Task 三 效果 3 1 数据集 3 2 训练细节 3 3 实验 四 代码 论文 Momentum Contrast for Unsu