深度学习论文笔记(可解释性)——CAM与Grad-CAM

2023-11-03

主要工作

CAM与Grad-CAM用于解释CNN模型,这两个算法均可得出 c l a s s   a c t i v a t i o n   m a p p i n g class\ activation\ mapping class activation mapping(类似于热力图),可用于定位图像中与类别相关的区域(类似于目标检测),如下图所示:

在这里插入图片描述
五颜六色的区域即为类别相关的区域,表明了CNN为什么如此分类,比如CNN注意到了图中存在牙齿,因此将该图分为Brushing teeth。

阅读了三篇论文,总体来说收获有:

  1. 明白全局池化(Global Average Pooling)为什么有效
  2. 明白CAM与Grad-CAM可视化的原理

需注意,CAM与Grad-CAM的可视化只可以解释为什么CNN如此分类,但是不能解释CNN为什么可以定位到类别相关的区域


Global Average Pooling的工作机制

在这里插入图片描述

设类别数为 n n n,最后一层含有 n n n个特征图,求每张特征图所有像素的平均值,后接入一个有 n n n个神经元的全连接层,这里有两个疑问

为什么要有 n n n个特征图
论文的解释为“the feature maps can be easily interpreted as categories confidence maps.”。
这么做效果好是前提,对此的解释便是,每个特征图主要提取了某一类别相关的某些特征,例如第 i i i张特征图主要提取图中与飞机相关的部分,第 i + 1 i+1 i+1张特征图主要提取图中与汽车相关的部分。
论文在CIFAR10上训练完模型后,最后一层特征图可视化的结果如下:
在这里插入图片描述
从图来看,基本满足论文的解释

求完平均后接入全连接,这么做的理由亦或是好处是什么
下一节的“为什么如此计算可以得出类别相关区域”部分解释


CAM

CNN一般有特征提取器与分类器组成,特征提取器负责提取图像特征,分类器依据特征提取器提取的特征进行分类,目前常用的分类器为MLP,目前主流的做法是特征提取器后接一个GAP+类别数目大小的全连阶层。

CNN最后一层特征图富含有最为丰富类别语意信息(可以理解为高度抽象的类别特征),因此,CAM基于最后一层特征图进行可视化。

CAM将CNN的分类器替换为GAP+类别数目大小的全连接层(以下称为分类层)后重新训练模型,设最后一层有 n n n张特征图,记为 A 1 , A 2 , . . . A n A^1,A^2,...A^n A1,A2,...An,分类层中一个神经元有 n n n个权重,一个神经元对应一类,设第 i i i个神经元的权重为 w 1 i , w 2 i , . . . , w n i w_1^i,w_2^i,...,w_n^i w1i,w2i,...,wni,则第 c c c类的 c l a s s   a c t i v a t i o n   m a p p i n g class\ activation\ mapping class activation mapping(记为 L C A M c L_{CAM}^c LCAMc)的生成方式为:

L C A M c = ∑ i = 1 n w i c A i (式1.0) L_{CAM}^c=\sum_{i=1}^{n}w_i^cA^i\tag{式1.0} LCAMc=i=1nwicAi(1.0)

图示如下:
在这里插入图片描述
生成的Class Activation Mapping大小与最后一层特征图的大小一致,接着进行上采样即可得到与原图大小一致的Class Activation Mapping


为什么如此计算可以得出类别相关区域

用GAP表示全局平均池化函数,沿用上述符号,第 c c c类的分类得分为 S c S_c Sc,GAP的权重为 w i c w_{i}^c wic,特征图大小为 c 1 ∗ c 2 c_1*c_2 c1c2 Z = c 1 ∗ c 2 Z=c_1*c_2 Z=c1c2,第 i i i个特征图第 k k k行第 j j j列的像素值为 A k j i A^i_{kj} Akji,则有
S c = ∑ i = 1 n w i c G A P ( A i ) = ∑ i = 1 n w i c 1 Z ∑ k = 1 c 1 ∑ j = 1 c 2 A k j i = 1 Z ∑ i = 1 n ∑ k = 1 c 1 ∑ j = 1 c 2 w i c A k j i \begin{aligned} S_c&=\sum_{i=1}^{n}w_i^cGAP(A_i)\\ &=\sum_{i=1}^nw_i^c\frac{1}{Z}\sum_{k=1}^{c_1}\sum_{j=1}^{c_2}A_{kj}^i\\ &=\frac{1}{Z}\sum_{i=1}^n\sum_{k=1}^{c_1}\sum_{j=1}^{c_2}w_i^cA_{kj}^i \end{aligned} Sc=i=1nwicGAP(Ai)=i=1nwicZ1k=1c1j=1c2Akji=Z1i=1nk=1c1j=1c2wicAkji

特征图中的一个像素对应原图中的一个区域,而像素值表示该区域提取到的特征,由上式可知 S c S_c Sc的大小由特征图中像素值与权重决定,特征图中像素值与权重的乘积大于0,有利于将样本分到该类,即CNN认为原图中的该区域具有类别相关特征。式1.0就是计算特征图中的每个像素值是否具有类别相关特征,如果有,我们可以通过上采样,康康这个这个像素对应的是原图中的哪一部分

GAP的出发点也是如此,即在训练过程中让网络学会判断原图中哪个区域具有类别相关特征,由于GAP去除了多余的全连接层,并且没有引入参数,因此GAP可以降低过拟合的风险

可视化的结果也表明,CNN正确分类的确是因为注意到了原图中正确的类别相关特征


Grad-CAM

CAM的缺点很明显,为了得出GAP中的权重,需要替换最后的分类器后重新训练模型,Grad-CAM克服了上述缺点。

设第 c c c类的分类得分为 S c S_c Sc,GAP的权重为 w i c w_{i}^c wic,特征图大小为 c 1 ∗ c 2 c_1*c_2 c1c2 Z = c 1 ∗ c 2 Z=c_1*c_2 Z=c1c2,第 i i i个特征图第 k k k行第 j j j列的像素值为 A k j i A^i_{kj} Akji

计算 α i c = 1 Z ∑ k = 1 c 1 ∑ j = 1 c 2 ∂ S c ∂ A k j i \alpha_i^c=\frac{1}{Z}\sum_{k=1}^{c_1}\sum_{j=1}^{c_2}\frac{\partial S_c}{\partial A^i_{kj}} αic=Z1k=1c1j=1c2AkjiSc

Grad-CAM的Class Activation Mapping计算方式如下:
L G r a d − C A M c = R e L U ( ∑ i α i c A i ) L_{Grad-CAM}^c=ReLU(\sum_{i}\alpha_i^cA^i) LGradCAMc=ReLU(iαicAi)

之所以使用ReLU激活函数,是因为我们只关注对于类别有关的区域,即特征图取值大于0的部分

Grad-CAM为什么这么做呢?具体的推导位于快点我,我等不及了,推导比较简单,这里就不敲了,直接贴图
在这里插入图片描述
在这里插入图片描述

最后三个式子漏了符号 ∂ \partial ,总的来说还是非常惊喜的,如果CAM在可视化的过程中,将特征图进行了归一化,则有
L C A M c = 1 Z ∑ i = 1 n w i c A i = 1 Z ∑ i = 1 n ∑ k = 1 c 1 ∑ j = 1 c 2 ∂ S c ∂ A k j i A i L_{CAM}^c=\frac{1}{Z}\sum_{i=1}^{n}w_i^cA^i=\frac{1}{Z}\sum_{i=1}^n\sum_{k=1}^{c_1}\sum_{j=1}^{c_2}\frac{\partial S_c}{\partial A^i_{kj}}A^i LCAMc=Z1i=1nwicAi=Z1i=1nk=1c1j=1c2AkjiScAi

Grad-CAM是CAM的一般化。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

深度学习论文笔记(可解释性)——CAM与Grad-CAM 的相关文章

  • 标签平滑(label smoothing)

    目录 1 标签平滑主要解决什么问题 2 标签平滑是怎么操作的 3 标签平滑公式 4 代码实现 标签平滑 label smoothing 出自GoogleNet v3 关于one hot编码的详细知识请见 One hot编码 1 标签平滑主要
  • HDMI与TMDS接口

    目录 0 Xilinx的HDMI 1 4 2 0 Transmitter Subsystem Product Guide 1 HDMI是新一代的多媒体接口标准 2 HDMI向下兼容DVI 3 TMDS 最小化传输差分信号 4 TMDS编码算
  • DB2多行转一行【XML方式】

    分组然后合并 然后去除XML标签 SELECT replace replace replace xml2clob xmlagg xmlelement name A 字段 a 分隔符 a a from 表 group by 分组字段 如 SE
  • python通信仿真_通信协议TLV的介绍及在python下的代码实现及仿真

    TLV协议是一种通讯协议 一般将数据封装成TLV的形式 即Tag Length Value 协议就是指通信双方对数据传输控制的一种规定 规定了数据格式 同步方式 传送速度 传送步骤的问题作出统一的规定 可以理解为两个节点之间为了协同工作 协

随机推荐

  • Elasticsearch实战(十五)---查询query,filter过滤,结合aggs 进行局部/全局聚合统计

    Elasticsearch实战 查询query filter过滤 结合aggs 进行局部 全局聚合统计 文章目录 Elasticsearch实战 查询query filter过滤 结合aggs 进行局部 全局聚合统计 1 准备数据 2 ES
  • C++调用Python Win10 Miniconda虚拟环境配置

    目录 前言 1 Win10 安装 Miniconda 2 创建虚拟环境 3 配置C 调用python环境 4 C 调用Python带参函数 5 遇到的问题 6 总结 前言 本文记录了Win10 系统下Qt 应用程序调用Python时配置Mi
  • (译) 如何使用 React hooks 获取 api 接口数据

    点击上方 蓝字 带你每天阅读全栈前端精选好文 原文地址 robinwieruch 在本教程中 我想向你展示如何使用 state 和 effect 钩子在React中获取数据 你还将实现自定义的 hooks 来获取数据 可以在应用程序的任何位
  • t1服务器显示演示版,T1与T3经常找不到加密狗,及提示演示版本到期,同一个加密狗,WIN2008R2 64位系统。T1工贸版V11.5,T3是普及版本10.8,都已注册。服务器与用户都是同一网络。请...

    经济责任审计结果运用中存在的问题及对策经济责任审计结果运用中存在的问题及对策 近年来 各级高度重视经济责任审计结果运用工作 采取完善机制 健全制度 加强配合等有效措施不断加大结果运用力度 通过将审计结果进行科学合理的转化利用 在加强干部监管
  • mnist

    mnist是什么 它是在机器学习和计算机视觉领域相当于hello world的一个最基础的数据集 内容是手写的数字 从0到9 我们想通过这个数据集来让计算机进行图像识别和手写识别 from matplotlib import pyplot
  • 超详细

    本教程讲述在论文编写中使用ChatGPT进行辅助 提供思路 提升效率 祝看到本教程的小伙伴们都完成论文 顺利毕业 可以加QQ群交流 一群 123589938 第一章 论文框架搭建 1 1 明确论文题目 1 1 1 适合的研究方向 首先赋予它
  • shell-read读取控制台输入

    基本语法 read 选项 参数 选项 p 指定读取值时的提示符 t 指定读取值时等待的时间 秒 如果没有在指定的时间内输入 就不再等待了 参数 变量 指定读取值的变量名 编写一个shell 1 读取控制台输入一个NUM1值 2 读取控制台输
  • android设备外接多个usb摄像头

    转自 https youshaohua com post android device external multiple USB camera 代码访问 OTG USB camera https github com quantum6 A
  • DGL学习笔记03-消息传递机制

    DGL学习笔记03 消息传递机制 1 什么是消息传递 举个简单的例子 1 什么是消息传递 什么是消息传递机制 首先来看下官方的解释 也可以去看论文 对于这一节的话 我感觉如果没接触过Message Passing的人可能看了官方文档也不太容
  • 文件系统FATFS使用 总结

    最近在使用FATFS 现将使用的方法记录下来 f open 函数 此函数用来打开或创建文件 重点 是这个函数的最后一个参数所代表的访问的模式 例子 state f open mfileinfo bmp name FA WRITE FA RE
  • Socket 关于设置Socket连接超时时间

    做网络编程的人对setSoTimeout方法一定很熟悉 都知道是设置连接的超时时间 但是我在网上找资料时发现很多人把这个超时时间理解成了链路的超时时间 我看了一下JDK 关于这个方法的说明 其实根本不是链路的超时时间 Java代码 setS
  • Android ImageView使用详解(系列教程三)

    目录 一 ImageView简介 二 ImageView基本使用 三 ImageView常用属性 四 几种图片的加载方法 五 ImageView的缩放类型 一 ImageView简介 ImageView是Android开发中最常用的组件之一
  • 华为OD机试 Python 报数问题

    描述 你和你的朋友们围成一个圈玩游戏 从第一个人开始 依次报数 1 2 3 当数到3的时候 那个人就得退出游戏 然后从他的下一个朋友继续开始 1 2 3 同样的 数到3的人又得退出 这样一直进行下去 直到圈里只剩下一个人 人会是谁 任务 给
  • 华为OD机试真题 Java 实现【Linux 发行版的数量】【2023Q1 100分】

    目录 一 题目描述 二 输入描述 三 输出描述 四 解题思路 五 Java算法源码 六 效果展示 1 输入 2 输出 3 说明 一 题目描述 Linux 操作系统有多个发行版 distrowatch com 提供了各个发行版的资料 这些发行
  • 集合转换为Jsoin存入redis

    重温复习redis 要将对象集合存入redis作为缓存 上网找了个json串转集合的工具类 这里记录一下 import java io IOException import java util ArrayList import java u
  • js获取input上传文件的文件名和扩展名的方法

  • WLAN配置

    SW1 sysname SW1 修改名称 undo info center enable 关闭提示 vlan batch 100 to 102 批量创建vlan 100 101 102 interface GigabitEthernet0
  • Ethereum geth 同步区块的三种模式

    Ethereum 以太坊 当前交易多 截止当前 2018 02 04 已经有5029238个区块 区块大小在150G左右 如果全部同步 并且严格逐个验证 需要太多的时间和计算 作者曾经用一台实体机 8核 16GB内存 2TB机械硬盘的del
  • leetcode1921.消灭怪物的最大数量(中等)

    解法 排序 贪心 具体 计算出每个怪物到达城市的时间 然后排序 class Solution public int eliminateMaximum vector
  • 深度学习论文笔记(可解释性)——CAM与Grad-CAM

    文章目录 主要工作 Global Average Pooling的工作机制 CAM Grad CAM 主要工作 CAM与Grad CAM用于解释CNN模型 这两个算法均可得出 c l a s s