【Pytorch】六行代码实现:特征图提取与特征图可视化

2023-11-18

前言

之前记录过特征图的可视化:Pytorch实现特征图可视化,当时是利用IntermediateLayerGetter 实现的,但是有很大缺陷,只能获取到一级的子模块的特征图输出,无法获取内部二级子模块的输出。今天补充另一种Pytorch官方实现好的特征提取方式,非常好用!



一、Torch FX

首先是Torch FX的介绍:FX Blog(具体可参考Reference)

FX based feature extraction is a new TorchVision utility that lets us access intermediate transformations of an input during the forward pass of a PyTorch Module. It does so by symbolically tracing the forward method to produce a graph where each node represents a single operation. Nodes are named in a human-readable manner such that one may easily specify which nodes they want to access.
Did that all sound a little complicated? Not to worry as there’s a little in this article for everyone. Whether you’re a beginner or an advanced deep-vision practitioner, chances are you will want to know about FX feature extraction. If you still want more background on feature extraction in general, read on. If you’re already comfortable with that and want to know how to do it in PyTorch, skim ahead to Existing Methods in PyTorch: Pros and Cons. And if you already know about the challenges of doing feature extraction in PyTorch, feel free to skim forward to FX to The Rescue.


也就是我们后面调用的特征提取函数是基于Torch FX实现的。总之一句话:基于FX的特征提取是一种新的TorchVision实用程序,它允许我们在PyTorch模块的前向传递过程中访问输入的中间值。


二、特征提取

1.使用get_graph_node_names提取各个节点

首先依然是查看各个网络的子层

#首先定义一个模型,这里直接加载models里的预训练模型
model = torchvision.models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
#查看模型的各个层,
for name in model.named_children():
    print(name[0])
#输出,相当于把ResNet的分成了10个层
"""
conv1
bn1
relu
maxpool
layer1
layer2
layer3
layer4
avgpool
fc"""

在这里插入图片描述


之前是利用IntermediateLayerGetter 实现的,但是有很大缺陷,只能获取到一级的子模块的特征图输出,无法获取内部二级子模块的输出。比如不能获取layer2内部第一个BasicBlock的特征图输出。现在可以利用 get_graph_node_names获取任意前向传播的子节点。

import torchvision
import torch
from torchvision.models.feature_extraction import get_graph_node_names

model = torchvision.models.resnet18(
    weights=torchvision.models.ResNet18_Weights.DEFAULT)
nodes, _ = get_graph_node_names(model)
nodes
# 输出如下
"""
['x',
 'conv1',
 'bn1',
 'relu',
 'maxpool',
 'layer1.0.conv1',
 'layer1.0.bn1',
 'layer1.0.relu',
 'layer1.0.conv2',
 'layer1.0.bn2',
 'layer1.0.add',
 'layer1.0.relu_1',
 'layer1.1.conv1',
 'layer1.1.bn1',
 'layer1.1.relu',
 'layer1.1.conv2',
 'layer1.1.bn2',
 'layer1.1.add',
 'layer1.1.relu_1',
 'layer2.0.conv1',
 'layer2.0.bn1',
 'layer2.0.relu',
 'layer2.0.conv2',
 'layer2.0.bn2',
 'layer2.0.downsample.0',
 'layer2.0.downsample.1',
 'layer2.0.add',
 'layer2.0.relu_1',
 'layer2.1.conv1',
 'layer2.1.bn1',
 'layer2.1.relu',
 'layer2.1.conv2',
 'layer2.1.bn2',
 'layer2.1.add',
 'layer2.1.relu_1',
 'layer3.0.conv1',
 'layer3.0.bn1',
 'layer3.0.relu',
 'layer3.0.conv2',
 'layer3.0.bn2',
 'layer3.0.downsample.0',
 'layer3.0.downsample.1',
 'layer3.0.add',
 'layer3.0.relu_1',
 'layer3.1.conv1',
 'layer3.1.bn1',
 'layer3.1.relu',
 'layer3.1.conv2',
 'layer3.1.bn2',
 'layer3.1.add',
 'layer3.1.relu_1',
 'layer4.0.conv1',
 'layer4.0.bn1',
 'layer4.0.relu',
 'layer4.0.conv2',
 'layer4.0.bn2',
 'layer4.0.downsample.0',
 'layer4.0.downsample.1',
 'layer4.0.add',
 'layer4.0.relu_1',
 'layer4.1.conv1',
 'layer4.1.bn1',
 'layer4.1.relu',
 'layer4.1.conv2',
 'layer4.1.bn2',
 'layer4.1.add',
 'layer4.1.relu_1',
 'avgpool',
 'flatten',
 'fc']
"""

get_graph_node_names把前向传播的各个节点都列出来了形成了一个列表。比如列表中的x表示我们的输入;layer1.0.conv2表示layer1的第1个BasicBlock的conv2节点;layer3.1.conv2表示layer3的第2个BasicBlock的conv2节点;这些节点和我们上图方框中圈出来的是一一对应的,可以结合自己的网络结构具体分析。

class AlexNet(nn.Module):
    def __init__(self):
        super(AlexNet, self).__init__()

        self.conv1 = nn.Sequential(nn.Conv2d(3, 96, 11, 4, 2),
                                   nn.ReLU(),
                                   nn.MaxPool2d(3, 2),
                                   )

        self.conv2 = nn.Sequential(nn.Conv2d(96, 256, 5, 1, 2),
                                   nn.ReLU(),
                                   nn.MaxPool2d(3, 2),
                                   )

        self.conv3 = nn.Sequential(nn.Conv2d(256, 384, 3, 1, 1),
                                   nn.ReLU(),
                                   nn.Conv2d(384, 384, 3, 1, 1),
                                   nn.ReLU(),
                                   nn.Conv2d(384, 256, 3, 1, 1),
                                   nn.ReLU(),
                                   nn.MaxPool2d(3, 2))


        self.fc=nn.Sequential(nn.Linear(256*6*6, 4096),
                                nn.ReLU(),
                                nn.Dropout(0.5),
                                nn.Linear(4096, 4096),
                                nn.ReLU(),
                                nn.Dropout(0.5),
                                nn.Linear(4096, 100),
                                )

    def forward(self, x):
        x=self.conv1(x)
        x=self.conv2(x)
        x=self.conv3(x)
        output=self.fc(x.view(-1, 256*6*6))
        return output
    
model=AlexNet()
nodes, _ = get_graph_node_names(model)
nodes
# 输出如下
['x',
 'conv1.0',
 'conv1.1',
 'conv1.2',
 'conv2.0',
 'conv2.1',
 'conv2.2',
 'conv3.0',
 'conv3.1',
 'conv3.2',
 'conv3.3',
 'conv3.4',
 'conv3.5',
 'conv3.6',
 'view',
 'fc.0',
 'fc.1',
 'fc.2',
 'fc.3',
 'fc.4',
 'fc.5',
 'fc.6']

如果是自定义网络结构,在__init__中初始化了self.conv1self.conv2self.conv3self.fc与输出列表相对应。
conv3为例:

 self.conv3 = nn.Sequential(nn.Conv2d(256, 384, 3, 1, 1),
                                   nn.ReLU(),
                                   nn.Conv2d(384, 384, 3, 1, 1),
                                   nn.ReLU(),
                                   nn.Conv2d(384, 256, 3, 1, 1),
                                   nn.ReLU(),
                                   nn.MaxPool2d(3, 2))

总共定义了7层,3个卷积层、3个激活层、1个池化层。 输出节点列表中的conv3.0就表示conv3的第一个节点即第一个卷积层nn.Conv2d(256, 384, 3, 1, 1),同理, conv3.1表示conv3的第二个节点即nn.ReLU()

2.使用create_feature_extractor提取输出

在获取节点信息之后,我么可以利用create_feature_extractor来获取对应节点层的输出。所以get_graph_node_names只是帮助我们获取节点层的信息。

比如,我只想获取layer3layer4内部的第一个卷积层的输出即layer3.0.conv1, layer4.0.conv1

import torch
import torchvision
from torchvision.models.feature_extraction import create_feature_extractor

# 根据get_graph_node_names得到的节点层信息
# 定义想要得到的输出层
features = ['layer3.0.conv1', "layer4.0.conv1"]

model = torchvision.models.resnet18(
					weights=torchvision.models.ResNet18_Weights.DEFAULT)
					
# return_nodes参数就是返回对应的输出
feature_extractor = create_feature_extractor(model, return_nodes=features)
# 定义输入
x=torch.ones(1, 3, 224, 224)
# 得到一个我们想要的输出层的字典
out = feature_extractor(x)
out

# tensor即对应的输出
"""
{'layer3.0.conv1': tensor(...),
 'layer4.0.conv1': tensor(...) }
"""

当然,并不是一定要完全按照get_graph_node_names得到的节点层信息来定义输出层。比如,我只想获取layer3整个层的输出特征图,我并不关心layer3内部子层的输出:

import torch
import torchvision
from torchvision.models.feature_extraction import create_feature_extractor

# 定义layer3即可
# 其他层同理
features = ['layer3']
model = torchvision.models.resnet18(
    weights=torchvision.models.ResNet18_Weights.DEFAULT)
feature_extractor = create_feature_extractor(model, return_nodes=features)
# 定义输入
x=torch.ones(1, 3, 224, 224)
# 得到一个我们想要的输出层的字典
out = feature_extractor(x)
out
"""
{'layer3': tensor(...)}
"""


return_nodes参数也可以传入一个字典,字典的键是节点层,值是自定义别名。比如{"layer3":"output1","layer4":"output2"}

features = {"layer3":"output1","layer4":"output2"}
model = torchvision.models.resnet18(
    weights=torchvision.models.ResNet18_Weights.DEFAULT)
feature_extractor = create_feature_extractor(model, return_nodes=features)
x=torch.ones(1, 3, 224, 224)
out = feature_extractor(x)
out
# 输出如下
"""
{'output1': tensor(...),
 'output2': tensor(...)}

"""

3.六行代码可视化特征图

import torch
import torchvision
from PIL import Image
import torchvision.transforms as transforms
from matplotlib import pyplot as plt
from torchvision.models.feature_extraction import create_feature_extractor


transform = transforms.Compose([transforms.Resize(256),
                                transforms.CenterCrop(224),
                                transforms.ToTensor(),
                                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 

model = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)

feature_extractor = create_feature_extractor(model, return_nodes={"conv1":"output"})

original_img = Image.open("dog.jpg")

img=transform(original_img).unsqueeze(0)

out = feature_extractor(img) 

# 这里没有分通道可视化
plt.imshow(out["output"][0].transpose(0,1).sum(1).detach().numpy())

在这里插入图片描述

在这里插入图片描述

三、Reference

Torch FX官方文档:Torch FX官方文档介绍
Torch FX Blog:Feature Extraction in TorchVision using Torch FX
在这里插入图片描述
官方对四种获取特征输出的方式进行了对比,这篇Blog写的比较详细,可以仔细看看。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【Pytorch】六行代码实现:特征图提取与特征图可视化 的相关文章

随机推荐

  • futureTask RunnableFuture Future 三者关系认知

    对于这三者首先我们看下源码 之后在分别写几个demo讲解下用法 public interface RunnableFuture
  • HTML5中的引用标记是什么元素?

    HTML5提供了多种元素用于引用文本内容 其中最常用的是 blockquote 元素和 blockquote
  • (二)TestNG 基础概念和执行时机注解

    入门的篇幅会写的比较长 毕竟基础要理解好 在学习TestNG注解前 我们先了解基本的名词 留个印象 TestNG名词解释 1 TestNG方法 method 是一个在代码内使用 Test注解标注的方法 下面代码中的isDuckMeal 就是
  • 机器学习常识 14: 半监督学习

    摘要 半监督学习强调的是一种学习场景 在该场景下 无标签数据可以协助带标签数据提升预测质量 1 基本概念 监督学习 训练数据都有标签 相应的任务为分类 回归等 无监督学习 训练数据都没有标签 相应的任务为聚类 特征提取 如 PCA 等 半监
  • MySQL索引篇

    目录 MySQL索引 一 怎么知道一条SQL语句有没有使用索引 二 如何排查慢查询 三 索引失效以及为什么失效 四 索引为什么能提高查询性能 五 为何选择B 树而不是B树 六 索引分类 七 什么时候创建以及什么时候不需要索引 八 索引优化
  • Python PyQt5(三)添加控件,绑定简单事件处理函数

    coding utf 8 Author BlueSand Email slxxfl000 163 com Web www lzmath cn Blog https blog csdn net weixin 41810846 Date 201
  • Leetcode 160. 相交链表 解题思路及C++实现

    解题思路 先将两个链表构建成一个环 定义两个快慢指针 当它们相遇时 将fast指针从头结点往后遍历 每次走一步 当这两个指针再次相遇时 该节点就是相交节点 Definition for singly linked list struct L
  • Verilog中forever、repeat、while、for四类循环语句(含Verilog实例)

    当搭建FPGA逻辑时 使用循环语句可以使语句更加简洁易懂 Verilog中存在四类循环语句 如标题 几种循环语句的具体介绍和用法如下 1 forever 连续的执行语句 语法格式 forever
  • 【算法入门】什么是时间复杂度和空间复杂度,最优解

    如何评价算法复杂度 时间复杂度 额外空间复杂度 常数操作 常数操作 常数操作 执行时间固定和数据量没有关系的运算操作 如果和数据量有关就不是常数操作 运算 数组寻址 数组里获取3位置和3000w位置数据时间相等 1 1 和100w 100w
  • unity3D期末作业捕鱼游戏,适合初学者学习使用,包含源程序所有文件

    虚拟现实期末作业捕鱼游戏 免积分下载 点我下载资源 有按钮 背景音乐 可以发射炮弹捕鱼 可以选择难度 可以调节音乐声音大小 有游戏加载进度条 详细情况请看如下动态图 点我下载资源
  • DataFrame添加列名,查看均值等,seaborn

    查看数据 seaborn画图简单好看 看两两特征的关系 对角线是自己和自己 dropna 处理缺失值
  • 设计模式 之 状态模式

    2019独角兽企业重金招聘Python工程师标准 gt gt gt 设计模式 之 状态模式 概念 类的行为基于它的状态而改变 主体思想是将各种具体的状态类抽象出来 也就是会有很多状态类 使用场景 代码中包含大量与对象状态有关的条件语句 行为
  • C#多线程基础(一) PS:阅读C#多线程编程实战第一章总结

    一 基本概念 进程 Process 在操作系统中正在运行的应用程序被视为一个进程 包含着一个运行程序所需要的资源 进程可以包括一个或多个线程 线程 Thread 进程的基本执行单元 是操作系统分配CPU时间的基本单位 在进程入口执行的第一个
  • Git(1)

    步骤1 使用Git Bash 方法1 使用命令行进入Git安装目录的bin文件下 cd Program Files x86 Git bin 这样就可以使用Git Bash了 方法2 相比方法1更简便 步骤2 设置Git 配置email gi
  • ts类型体操 43 - Exclude

    43 Exclude by Zheeeng zheeeng easy built in union Question Implement the built in Exclude
  • Unity之Animation动画

    Unity之Animation动画 Unity之Animation绘制动画 这篇文章做最简单的动画 让一个立方体从左边移动到右边 1 创建一个Unity的新工程 名为TestAnimation 点击Create And Open按键 打开工
  • 机器学习之支持向量机: Support Vector Machines (SVM)

    机器学习之支持向量机 Support Vector Machines SVM 欢迎访问人工智能研究网 课程中心 网址是 http i youku com studyai 本篇博客介绍机器学习算法之一的支持向量机算法 理解支持向量机 Unde
  • 蓝桥杯:优秀的拆分

    蓝桥杯 优秀的拆分https www lanqiao cn problems 801 learning 目录 题目描述 输入描述 输出描述 输入输出样例 输入 输出 输入 输出 题目分析 位运算 AC代码 Java 题目描述 一般来说 一个
  • CSS font-family 中的苹方字体

    苹方提供了六个字重 font family 定义如下 苹方 简 常规体 font family PingFangSC Regular sans serif 苹方 简 极细体 font family PingFangSC Ultralight
  • 【Pytorch】六行代码实现:特征图提取与特征图可视化

    前言 之前记录过特征图的可视化 Pytorch实现特征图可视化 当时是利用IntermediateLayerGetter 实现的 但是有很大缺陷 只能获取到一级的子模块的特征图输出 无法获取内部二级子模块的输出 今天补充另一种Pytorch