对比损失的PyTorch实现详解

2023-05-16

对比损失的PyTorch实现详解

本文以SiT代码中对比损失的实现为例作介绍。

论文:https://arxiv.org/abs/2104.03602
代码:https://github.com/Sara-Ahmed/SiT

对比损失简介

作为一种经典的自监督损失,对比损失就是对一张原图像做不同的图像扩增方法,得到来自同一原图的两张输入图像,由于图像扩增不会改变图像本身的语义,因此,认为这两张来自同一原图的输入图像的特征表示应该越相似越好(通常用余弦相似度来进行距离测度),而来自不同原图像的输入图像应该越远离越好。来自同一原图的输入图像可做正样本,同一个batch内的不同输入图像可用作负样本。如下图所示(粗箭头向上表示相似度越高越好,向下表示越低越好)。
在这里插入图片描述

论文中的公式

l c o n t r x i , x j ( W ) = e s i m ( S i T c o n t r ( x i ) , S i T c o n t r ( x j ) ) / τ ∑ k = 1 , k ≠ i 2 N e s i m ( S i T c o n t r ( x i ) , S i T c o n t r ( x k ) ) / τ                    ( 1 ) l^{x_i,x_j}_{contr}(W)=\frac{e^{sim(SiT_{contr}(x_i),SiT_{contr}(x_j))/\tau}}{\sum_{k=1,k\ne i}^{2N}e^{sim(SiT_{contr}(x_i),SiT_{contr}(x_k))/\tau}} \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ (1) lcontrxi,xj(W)=k=1,k=i2Nesim(SiTcontr(xi),SiTcontr(xk))/τesim(SiTcontr(xi),SiTcontr(xj))/τ                  (1)

L = − 1 N ∑ j = 1 N l o g l x j , x j ˉ ( W )                    ( 2 ) \mathcal{L}=-\frac{1}{N}\sum_{j=1}^Nlogl^{x_j,x_{\bar{j}}}(W) \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ (2) L=N1j=1Nloglxj,xjˉ(W)                  (2)

SiT论文中的对比损失公式如上所示。其中 x i x_i xi x j x_j xj分别表示两个不同的输入图像, s i m ( ⋅ , ⋅ ) sim(\cdot,\cdot) sim(,)表示余弦相似度,即归一化之后的点积, τ \tau τ是超参数温度, x j x_j xj x j ˉ x_{\bar{j}} xjˉ是来自同一原图的两种不同数据增强的输入图像, S i T c o n t r ( ⋅ ) SiT_{contr}(\cdot) SiTcontr() 表示从对比头中得到的图像表示,没看过原文的话,就直接理解为输入图像经过一系列神经网络,得到一个 d i m dim dim 维度的特征向量作为图像的特征表示,网络不是本文的重点,重点是怎样根据得到的特征向量计算对比损失

与最近很火的infoNCE对比损失基本一样,只是写法不同。

代码实现

class ContrastiveLoss(nn.Module):
    def __init__(self, batch_size, device='cuda', temperature=0.5):
        super().__init__()
        self.batch_size = batch_size
        self.register_buffer("temperature", torch.tensor(temperature).to(device))			# 超参数 温度
        self.register_buffer("negatives_mask", (~torch.eye(batch_size * 2, batch_size * 2, dtype=bool).to(device)).float())		# 主对角线为0,其余位置全为1的mask矩阵
        
    def forward(self, emb_i, emb_j):		# emb_i, emb_j 是来自同一图像的两种不同的预处理方法得到
        z_i = F.normalize(emb_i, dim=1)     # (bs, dim)  --->  (bs, dim)
        z_j = F.normalize(emb_j, dim=1)     # (bs, dim)  --->  (bs, dim)

        representations = torch.cat([z_i, z_j], dim=0)          # repre: (2*bs, dim)
        similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=2)      # simi_mat: (2*bs, 2*bs)
        
        sim_ij = torch.diag(similarity_matrix, self.batch_size)         # bs
        sim_ji = torch.diag(similarity_matrix, -self.batch_size)        # bs
        positives = torch.cat([sim_ij, sim_ji], dim=0)                  # 2*bs
        
        nominator = torch.exp(positives / self.temperature)             # 2*bs
        denominator = self.negatives_mask * torch.exp(similarity_matrix / self.temperature)             # 2*bs, 2*bs
    
        loss_partial = -torch.log(nominator / torch.sum(denominator, dim=1))        # 2*bs
        loss = torch.sum(loss_partial) / (2 * self.batch_size)
        return loss

以下是SiT论文的对比损失代码实现,笔者已经将debug过程中得到的张量形状在注释中标注了出来,供大家参考,其中dim是得到的特征向量的维度,bs是批尺寸batch size。

笔者简单画了一张similarity_matrix的图示来说明整个过程。本图以bs==4为例, a , b , c , d a,b,c,d a,b,c,d分别代表同一个batch内的不同样本,下表0和1表示两种不同的图像扩增方法。图中每个方格则是对应行列的图像特征(dim维的向量)表示计算相似度的结果值。

在这里插入图片描述

  1. emb_i,emb_j 是来自同一图像的两种不同的预处理方法得到的输入图像的特征表示。首先是通过F.normalize()emb_iemb_j进行归一化。

  2. 然后将二者拼接起来的到维度为2*bs的representations。再将representations分别转换为列向量和行向量计算相似度矩阵similarity_matrix(见图)。

  3. 在通过偏移的对角线(图中蓝线)的到sim_ijsim_ji,并拼接的到positives。请注意蓝线对应的行列坐标,分别是 a 0 , a 1 a_0,a_1 a0,a1 b 0 , b 1 b_0,b_1 b0,b1等,即蓝线对应的网格即是来自同一张原图的不同处理的输入图像。这在损失的设计中即是我们的正样本。

  4. 然后nominator(分子)即可根据公式计算的到。

  5. 而在计算denominator时需注意要乘上self.negatives_mask。该变量在__init__中定义,是对2*bs的方针对角阵取反,即主对角线全是0,其余位置全是1 。这是为了在负样本中屏蔽自己与自己的相似度结果(图中红线),即使得similarity_matrix的主对角钱全为0。因为自己与自己的相似度肯定是1,加入到计算中没有意义。

  6. 再到后面loss_partial的计算(第22行)其实是计算出公式(1),torch.sum()计算的是(1)中分母上的 ∑ \sum 符号。

  7. 第23行就是计算公式(2),其中与公式相比分母上多了除了个2,是因为本实现为了方便将similarity_matrix的维度扩展为2*bs。即相当于将公式(2)中的 l c o n t r x j , x j ˉ l_{contr}^{x_j,x_{\bar{j}}} lcontrxj,xjˉ l c o n t r x j ˉ , x j l_{contr}^{x_{\bar{j}},x_j} lcontrxjˉ,xj 分别计算了一遍。所以要多除个2。

自行验证

大家可以将上面的ContrastiveLoss类复制到自己的测试的文件中,并构造几个输入进行测试,打印中间结果,验证自己是否真正地理解了对比损失的代码实现计算过程。

loss_func = losses.ContrastiveLoss(batch_size=4)
emb_i = torch.rand(4, 512).cuda()
emb_j = torch.rand(4, 512).cuda()

loss_contra = loss_func(emb_i, emb_j)
print(loss_contra)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

对比损失的PyTorch实现详解 的相关文章

  • Golang 计算两个时间相差多少分钟

    stime int64 etime int64 时间戳 starttime 61 time Unix stime 0 endtime 61 time Unix etime 0 costtime 61 decimal NewFromFloat
  • Spring自动配置原理

    文章目录 一 概念二 自动配置原理二 自动配置生效总结 一 概念 spring集成其他框架中 xff0c 需要编写大量的xml配置文件 xff0c 编写这些配置文件十分繁琐 xff0c 常常出行错误 xff0c 导致开发效率低 Spring
  • centos 上容器配置X11

    系统 xff1a centos 7 9 连接工具 xff1a 同一个局域网内win10电脑上安装的MobaXterm Personal 步骤 xff1a 找到对应的包 96 yum whatprovides xhost 安装yum y in
  • Java 实现线程安全的方式

    1 创建线程的三种方式 通过实现 Runnable 接口 xff1b 通过继承 Thread 类本身 xff1b 通过 Callable 和 Future 创建线程 2 线程的生命周期 新建状态 使用 new 关键字和 Thread 类或其
  • 生产者消费者问题c语言_C中的生产者消费者问题

    生产者消费者问题c语言 Here you will learn about producer consumer problem in C 在这里 xff0c 您将了解C语言中的生产者消费者问题 Producer consumer probl
  • 2.15多生产者多消费者问题

    视频链接 xff1a https www bilibili com video BV1YE411D7nH p 61 24 一 xff0c 问题描述 桌上有一只盘子 xff0c 每次只能向其中放入一个水果 爸爸向盘子中只放苹果 xff0c 妈
  • 【已解决】xterm: Xt error: Can‘t open display:

    项目场景 xff1a 在MobaXterm中 xff0c 使用Ubuntu 18 04的gdb来debug MPI并行的C 43 43 代码 问题描述 Debug时 xff0c 输入 mpiexec span class token ope
  • mariadb安装

    1 配置官方的mariadb的yum源 手动创建 mariadb repo仓库文件 touch etc yum repos d mariadb repo 然后写入如下内容 mariadb name 61 MariaDB baseurl 61
  • Java 给某段代码加超时时间

    问题原因 xff1a 使用HuTool 的DbTtil 不能设置数据库连接超时时间 xff0c 可能数据库挂了 xff0c 会导致连接一直卡在那 xff0c 也没有异常抛出 xff0c 导致线程一直占着 所以给该段代码加超时时间处理 spa
  • 使用reserve来避免不必要的内存重新分配

    STL容器的内存分配策略是 xff0c 他们会自动增长以便容纳下你放入其中的数据 xff0c 只要没有超过它的最大限制就可以 xff08 要查看最大限制可调用名为max size的成员函数 xff09 对于vector和string xff
  • zmq发布-订阅模式c++实现

    上一篇讲到zmq的安装及简单的请求 应答模式 xff0c 本篇主要来看一下zmq的pub sub代码如何实现 发布 订阅模式的特点 xff1a 1 一个发布者可以被多个订阅者订阅 xff0c 即发布者和订阅者是1 xff1a n的关系 2
  • git commit之后如何撤销

    git正常提交代码的的操作为 xff1a git add 将本地的所有文件改变添加至暂存区 git commit m 34 fix xx update xx 34 进行commit的提交 git push 推送到远端仓库 如果在git co
  • dockerfile中多个FROM指令的意义(multistage)

    从docker17 05版本开始 xff0c dockerfile中允许使用多个FROM指令 multistage 这是docker17 05版本的release note xff1a https docs docker com engin
  • C++多个头文件中可以定义同名的namespace吗?

    结论 xff1a c 43 43 是支持在多个 h文件中定义同名的namespace的 分两种情况测试 xff1a 1 两个 h文件中namespace名字相同 xff0c 命名空间中成员名称无重复 xff0c 那么他们会合并为一个命名空间
  • k8s pod OOMKilled 错误原因

    k8s oomkilled 错误原因 xff1a 容器使用的内存资源超过了限制 只要节点有足够的内存资源 xff0c 那容器就可以使用超过其申请的内存 xff0c 但是不允许容器使用超过其限制的资源 在yaml文件的resources li
  • 如何在Android上自定义Google窗口小部件

    Google Search is one of the most popular widgets on Android smartphones and tablets It most likely even came preloaded o
  • 解决 OpenCV Error: Insufficient memory (Failed to allocate 3221225472 bytes) in cv::OutOfMemoryError

    现象 xff1a 调用cvLoadImage加载图片时报OpenCV Error Insufficient memory Failed to allocate 3221225472 bytes in cv OutOfMemoryError
  • git:解决git pull/push 每次都要重新输入用户名密码的问题

    使用git pull或者git push默认是每次都要输入用户名密码的 xff0c 使用起来很不方便 xff0c 在Git代码的根目录下执行以下命令实现保存用户名和密码 xff0c 不用每次都输入 xff1a git config glob
  • 解决docker commit报错:invalid reference format: repository name must be lowercase

    现象 xff1a 使用docker commit命令将容器导出为镜像时报错 xff1a invalid reference format repository name must be lowercase docker commit Map
  • c++11智能指针

    智能指针的原理 RAII RAII xff08 Resource Acquisition Is Initialization xff09 是一种利用对象生命周期来控制程序资源 xff08 如内存 文件句柄 网络连接 互斥量等等 xff09

随机推荐