Pytorch可视化特征图(代码 亲测可用)

2023-05-16

2013年Zeiler和Fergus发表的《Visualizing and Understanding Convolutional Networks》

 

 

早期LeCun 1998年的文章《Gradient-Based Learning Applied to Document Recognition》中的一张图也非常精彩,个人觉得比Zeiler 2013年的文章更能给人以启发。从下图的F6特征,我们可以清楚地看到原始书写差异非常大的图片如何在深层特征中体现出其不变性。
 

pytorch 有专门的接口提取特征图

拿到特征图的方法有多种,有人可以从输入开始,一个一个算子地让网络做前向运算,直到想要的特征图处将其返回,这种方法尽管也可行,但略有些麻烦。实际上PyTorch给了一个专用接口可以在前向过程中获取到特征图,这个接口是torch.nn.Module.register_forward_hook。当我们拿到特征图后,PyTorch又有专门的画图和保存图片的接口:torchvision.utils.make_grid和torchvision.utils.save_image,非常方便。
 

展示代码,复制粘贴就可以使用

# -*- coding: utf-8 -*-
import os
import shutil
import time
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision import datasets
import torchvision.utils as vutil
from torch.utils.data import DataLoader
import torchsummary

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EPOCH = 1
LR = 0.001
TRAIN_BATCH_SIZE = 64
TEST_BATCH_SIZE = 32
BASE_CHANNEL = 32
INPUT_CHANNEL = 1
INPUT_SIZE = 28
MODEL_FOLDER = './save_model'
IMAGE_FOLDER = './save_image'
INSTANCE_FOLDER = None


class Model(nn.Module):
    def __init__(self, input_ch, num_classes, base_ch):
        super(Model, self).__init__()

        self.num_classes = num_classes
        self.base_ch = base_ch
        self.feature_length = base_ch * 4

        self.net = nn.Sequential(
            nn.Conv2d(input_ch, base_ch, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(base_ch, base_ch * 2, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(base_ch * 2, self.feature_length, kernel_size=3,
                      padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(output_size=(1, 1))
        )
        self.fc = nn.Linear(in_features=self.feature_length,
                            out_features=num_classes)

    def forward(self, input):
        output = self.net(input)
        output = output.view(-1, self.feature_length)
        output = self.fc(output)
        return output


def load_dataset():
    train_dataset = datasets.MNIST(root='./data',
                                   train=True,
                                   transform=transforms.ToTensor(),
                                   download=True)
    test_dataset = datasets.MNIST(root='./data',
                                  train=False,
                                  transform=transforms.ToTensor(),
                                  download=True)
    return train_dataset, test_dataset


def hook_func(module, input, output):
    """
    Hook function of register_forward_hook

    Parameters:
    -----------
    module: module of neural network
    input: input of module
    output: output of module
    """
    image_name = get_image_name_for_hook(module)
    data = output.clone().detach()
    data = data.permute(1, 0, 2, 3)
    vutil.save_image(data, image_name, pad_value=0.5)


def get_image_name_for_hook(module):
    """
    Generate image filename for hook function

    Parameters:
    -----------
    module: module of neural network
    """
    os.makedirs(INSTANCE_FOLDER, exist_ok=True)
    base_name = str(module).split('(')[0]
    index = 0
    image_name = '.'  # '.' is surely exist, to make first loop condition True
    while os.path.exists(image_name):
        index += 1
        image_name = os.path.join(
            INSTANCE_FOLDER, '%s_%d.png' % (base_name, index))
    return image_name


if __name__ == '__main__':
    time_beg = time.time()

    train_dataset, test_dataset = load_dataset()
    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=TRAIN_BATCH_SIZE,
                              shuffle=True)
    test_loader = DataLoader(dataset=test_dataset,
                             batch_size=TEST_BATCH_SIZE,
                             shuffle=False)

    model = Model(input_ch=1, num_classes=10, base_ch=BASE_CHANNEL).cuda()
    torchsummary.summary(
        model, input_size=(INPUT_CHANNEL, INPUT_SIZE, INPUT_SIZE))
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=LR)

    train_loss = []
    for ep in range(EPOCH):
        # ----------------- train -----------------
        model.train()
        time_beg_epoch = time.time()
        loss_recorder = []
        for data, classes in train_loader:
            data, classes = data.cuda(), classes.cuda()
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, classes)
            loss.backward()
            optimizer.step()

            loss_recorder.append(loss.item())
            time_cost = time.time() - time_beg_epoch
            print('\rEpoch: %d, Loss: %0.4f, Time cost (s): %0.2f' % (
                ep, loss_recorder[-1], time_cost), end='')

        # print train info after one epoch
        train_loss.append(loss_recorder)
        mean_loss_epoch = torch.mean(torch.Tensor(loss_recorder))
        time_cost_epoch = time.time() - time_beg_epoch
        print('\rEpoch: %d, Mean loss: %0.4f, Epoch time cost (s): %0.2f' % (
            ep, mean_loss_epoch.item(), time_cost_epoch), end='')

        # save model
        os.makedirs(MODEL_FOLDER, exist_ok=True)
        model_filename = os.path.join(MODEL_FOLDER, 'epoch_%d.pth' % ep)
        torch.save(model.state_dict(), model_filename)

        # ----------------- test -----------------
        model.eval()
        correct = 0
        total = 0
        for data, classes in test_loader:
            data, classes = data.cuda(), classes.cuda()
            output = model(data)
            _, predicted = torch.max(output.data, 1)
            total += classes.size(0)
            correct += (predicted == classes).sum().item()
        print(', Test accuracy: %0.4f' % (correct / total))

    print('Total time cost: ', time.time() - time_beg)

    # ----------------- visualization -----------------
    # clear output folder
    if os.path.exists(IMAGE_FOLDER):
        shutil.rmtree(IMAGE_FOLDER)

    model.eval()
    modules_for_plot = (torch.nn.ReLU, torch.nn.Conv2d,
                        torch.nn.MaxPool2d, torch.nn.AdaptiveAvgPool2d)
    for name, module in model.named_modules():
        if isinstance(module, modules_for_plot):
            module.register_forward_hook(hook_func)

    test_loader = DataLoader(dataset=test_dataset,
                             batch_size=1,
                             shuffle=False)
    index = 1
    for data, classes in test_loader:
        INSTANCE_FOLDER = os.path.join(
            IMAGE_FOLDER, '%d-%d' % (index, classes.item()))
        data, classes = data.cuda(), classes.cuda()
        outputs = model(data)

        index += 1
        if index > 20:
            break

运行结果;

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 32, 28, 28]             320
              ReLU-2           [-1, 32, 28, 28]               0
         MaxPool2d-3           [-1, 32, 14, 14]               0
            Conv2d-4           [-1, 64, 14, 14]          18,496
              ReLU-5           [-1, 64, 14, 14]               0
         MaxPool2d-6             [-1, 64, 7, 7]               0
            Conv2d-7            [-1, 128, 7, 7]          73,856
              ReLU-8            [-1, 128, 7, 7]               0
 AdaptiveAvgPool2d-9            [-1, 128, 1, 1]               0
           Linear-10                   [-1, 10]           1,290
================================================================
Total params: 93,962
Trainable params: 93,962
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.74
Params size (MB): 0.36
Estimated Total Size (MB): 1.10
----------------------------------------------------------------
Epoch: 0, Mean loss: 0.6502, Epoch time cost (s): 5.53, Test accuracy: 0.9360
Total time cost:  13.36158013343811

Process finished with exit code 0

特征图被保存在当前文件的文件夹下:

MODEL_FOLDER = './save_model'
IMAGE_FOLDER = './save_image'

这两个文件夹要提前创建好

展示一下  在我电脑上运行的结果

打开一个文件夹  看一下

 

展示上面的一张图

 

 

代码中的重要的内容在visualization部分:

  1. test_loader的batch_size设置为1。
  2. 下面一段代码用来设置register_forward_hook,只要我们在跑前向之前将register_forward_hook设置好,那么在它就会在前向的时候被调用。

 

modules_for_plot = (torch.nn.ReLU, torch.nn.Conv2d,
                    torch.nn.MaxPool2d, torch.nn.AdaptiveAvgPool2d)
for name, module in model.named_modules():
    if isinstance(module, modules_for_plot):
        module.register_forward_hook(hook_func)

 3.hook_func(module, input, output)需要自己实现。该函数的参数是固定的,不能自己随意加减,这就带来了一个比较麻烦的事情,我们没法把要保存的文件名等信息作为参数传给hook_func,所以这里我使用了一种比较粗暴的方式。文件名的前缀是module的类型名,后面会跟一个序号,序号是同类module在网络中出现的顺序。然而这个序号也没法传给hook_func,所以我直接从硬盘上从序号1开始,通过判断文件的存在性来决定下一个序号,由get_image_name_for_hook函数实现。比较粗暴和无奈的一种方法,如果哪位老兄有更好的方式欢迎留言指教。
4. hook_func中,在保存图片之前,需要将input或output的维度从[1, C, H, W]调整为[C, 1, H, W]
save_image()每一行的特征图数量默认是8,可以通过nrow参数进行修改。pad_value是特征图中间填充的间隙的颜色,尽管PyTorch将其定义为整数,但是由于Tensor一般是float型,所以还是输入个[0, 1]之间的浮点数才能真正生效。
 

感谢  点赞 收藏  + 关注

(25条消息) Pytorch可视化特征图_拜阳的博客-CSDN博客_pytorch特征图可视化

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Pytorch可视化特征图(代码 亲测可用) 的相关文章

随机推荐

  • 物理专线与虚拟专线的比较

    租用专用线路是连接两个或多个站点的专用通信渠道 它作为一个点到另一个点的专用隧道 xff0c 业务是固定的月租金 租赁线路用于互联网 数据甚至电话服务 他们通常在光缆上运行 xff0c 以提供更大的带宽和速度 物理学专线是指高速通道提供速安
  • 使用的是什么JDK和JAVA虚拟机?

    Oracle JDK之前被称为SUN JDK 2009年Oracle收购SUN公司之后命名为Oracle JDK Oracle JDK是基于OpenJDK源代码构建的 使用 java version 查看JDK的版本 OracleJDK 8
  • 云计算有哪些应用领域?

    云计算是基础设施 xff0c 基础设施是日常生活的一部分 xff0c 与人们的生活密切相关 现在云计算作为服务和生活的紧密结合 云计算应用之一 金融云 金融云是利用云计算的模型组成原理 xff0c 将金融产品 信息和服务分散到由大型分支机构
  • bash命令的使用方法

    小编给大家分享一下bash命令的使用方法 xff0c 相信大部分人都还不怎么了解 xff0c 因此分享这篇文章给大家参考一下 xff0c 希望大家阅读完这篇文章后大有收获 xff0c 下面让我们一起去了解一下吧 xff01 Bash xff
  • chmod命令详解

    chmod用于改变文件或目录的访问权限 用户用它控制文件或目录的访问权限 该命令有两种用法 一种是包含字母和操作符表达式的文字设定法 xff1b 另一种是包含数字的数字设定法 1 文字设定法 语法 xff1a chmod who 43 61
  • CDN视频存储解决方案

    一 方案背景 高清 超高清视频的蓬勃发展 xff0c 用户对高品质视频体验的渴望 xff0c 对网络的并发处理和内容平台的存储能力提出了更高的要求 作为产业链的重要一环 xff0c CDN xff08 内容分发网络 xff09 进入规范发展
  • vim中替换字符串的方法有哪些

    这篇文章为大家带来有关vim中替换字符串的方法介绍 xff0c 如果在日常学习或工作遇到这个问题 xff0c 希望大家通过这篇文章的几种方法解决替换字符串的问题 s str1 str2 g 替换每一行中所有str1为str2 常用 xff0
  • SSL连接中握手协议及握手过程

    SSL的主要目的是在两个通信应用程序之间提供私密信和可靠性 这个过程通过3个元素来完成 xff1a 1 握手协议 握手协议负责协商被用于客户机和服务器之间会话的加密参数 当一个SSL客户机和服务器第一次开始通信时 xff0c 它们在一个协议
  • SSL证书是什么?SSL运作方式?

    SSL证书创建加密连接并建立信任 在线业务最重要的组成部分之一是创建一个值得信赖的环境 xff0c 潜在客户对此充满信心 SSL证书通过建立安全连接来建立信任的基础 为了确保访问者的连接安全 xff0c 浏览器提供了特殊的视觉提示 xff0
  • 带宽叠加是什么意思?

    视频会议的清晰度 流畅性 xff0c 往往是用户最为看重的体验感 xff0c 而网络带宽速度如何 xff0c 直接影响到了视频会议的呈现效果 如何让企业级 政务级视频会议常用的局域网带宽更快 通常我们在企业网络或实际项目中 xff0c 随着
  • SQL语法整理(五)-视图

    视图 含义 xff1a 从数据库一个或多个表中导出的虚拟表 作用 xff1a 方便用户操作 要求所见即所需 xff0c 无需添加额外的查询条件 xff0c 直接查看 增加数据的安全性 xff1a 通过视图 xff0c 用户只能查看或修改指定
  • 【Dart快速入门】安装与运行程序

    Dart is a client optimized language for fast apps on any platform 下载安装 Dart SDK Dart SDK 可以用来开发 WEB 命令行程序 服务端程序等 如果是开发移动
  • MATLAB 曲线形状,粗细,颜色使用大全

    颜色的改变 可以 通过改变R G B 的值改变线条的颜色 xff1a matlab命令 xff1a plot x y Color R G B RGB颜色表如下 xff1a 二 改变曲线的粗细 xff1b 通过改变c 1 43 c 1 43
  • torch.zeros() 函数详解

    torch zeros 函数 返回一个形状为为size 类型为torch dtype xff0c 里面的每一个值都是0的tensor torch zeros size out 61 None dtype 61 None layout 61
  • Anaconda 删除虚拟环境

    删除虚拟环境 xff1a 删除环境 xff1a 使用命令conda remove n your env name 虚拟环境名称 all xff0c 即可删除 删除虚拟环境中的包 xff1a 使用命令conda remove name you
  • Python实现流星雨效果的代码

    绘制一颗流星 import numpy as np import matplotlib pyplot as plt from matplotlib collections import LineCollection x0 y0 61 1 1
  • python绘制散点图,非常全,非常详细(已验证)

    少废话 xff0c 直接上代码 import matplotlib pyplot as plt import numpy as np 1 首先是导入包 xff0c 创建数据 n 61 10 x 61 np random rand n 2 随
  • python读写 doc文件和docx文件

    背景 xff1a Python 中可以读取 word 文件的库有 python docx 和 pywin32 优点缺点python docx跨平台只能处理 docx 格式 xff0c 不能处理 doc格式pywin32仅限 windows
  • (清华源)ERROR: Could not find a version that satisfies the requirement pycocotools (unavailable)

    安装 pycocotools 的新方法 xff0c 清华源 ERROR Could not find a version that satisfies the requirement pycocotools unavailable from
  • Pytorch可视化特征图(代码 亲测可用)

    2013年Zeiler和Fergus发表的 Visualizing and Understanding Convolutional Networks 早期LeCun 1998年的文章 Gradient Based Learning Appl