InceptionNext实战:使用InceptionNext实现图像分类任务(一)

2023-10-27

摘要

论文翻译:https://blog.csdn.net/m0_47867638/article/details/131630614
官方源码:https://github.com/sail-sg/inceptionnext

这是一篇来自颜水成团队的论文。作者提出InceptionNext,将大核深度卷积分解为沿通道维度的四个平行分支,即小方形核、两个正交带核和一个单位映射。通过这种新的Inception深度卷积,构建了一系列网络,不仅享有高吞吐量,而且保持有竞争力的性能。例如,InceptionNeXt-T实现了比convnext - t高1.6倍的训练吞吐量,并在ImageNet- 1K上实现了0.2%的top-1精度提高。

在这里插入图片描述

这篇文章使用InceptionNext完成植物分类任务,模型采用inceptionnext_small向大家展示如何使用InceptionNext。inceptionnext_small在这个数据集上实现了96+%的ACC,如下图:

请添加图片描述

请添加图片描述

通过这篇文章能让你学到:

  1. 如何使用数据增强,包括transforms的增强、CutOut、MixUp、CutMix等增强手段?
  2. 如何实现InceptionNext模型实现训练?
  3. 如何使用pytorch自带混合精度?
  4. 如何使用梯度裁剪防止梯度爆炸?
  5. 如何使用DP多显卡训练?
  6. 如何绘制loss和acc曲线?
  7. 如何生成val的测评报告?
  8. 如何编写测试脚本测试测试集?
  9. 如何使用余弦退火策略调整学习率?
  10. 如何使用AverageMeter类统计ACC和loss等自定义变量?
  11. 如何理解和统计ACC1和ACC5?
  12. 如何使用EMA?
  13. 如果使用Grad-CAM 实现热力图可视化?

如果基础薄弱,对上面的这些功能难以理解可以看我的专栏:经典主干网络精讲与实战
这个专栏,从零开始时,一步一步的讲解这些,让大家更容易接受。

安装包

安装timm

使用pip就行,命令:

pip install timm

mixup增强和EMA用到了timm

安装 grad-cam

pip install grad-cam

数据增强Cutout和Mixup

为了提高成绩我在代码中加入Cutout和Mixup这两种增强方式。实现这两种增强需要安装torchtoolbox。安装命令:

pip install torchtoolbox

Cutout实现,在transforms中。

from torchtoolbox.transform import Cutout
# 数据预处理
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    Cutout(),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])

])

需要导入包:from timm.data.mixup import Mixup,

定义Mixup,和SoftTargetCrossEntropy

  mixup_fn = Mixup(
    mixup_alpha=0.8, cutmix_alpha=1.0, cutmix_minmax=None,
    prob=0.1, switch_prob=0.5, mode='batch',
    label_smoothing=0.1, num_classes=12)
 criterion_train = SoftTargetCrossEntropy()

参数详解:

mixup_alpha (float): mixup alpha 值,如果 > 0,则 mixup 处于活动状态。

cutmix_alpha (float):cutmix alpha 值,如果 > 0,cutmix 处于活动状态。

cutmix_minmax (List[float]):cutmix 最小/最大图像比率,cutmix 处于活动状态,如果不是 None,则使用这个 vs alpha。

如果设置了 cutmix_minmax 则cutmix_alpha 默认为1.0

prob (float): 每批次或元素应用 mixup 或 cutmix 的概率。

switch_prob (float): 当两者都处于活动状态时切换cutmix 和mixup 的概率 。

mode (str): 如何应用 mixup/cutmix 参数(每个’batch’,‘pair’(元素对),‘elem’(元素)。

correct_lam (bool): 当 cutmix bbox 被图像边框剪裁时应用。 lambda 校正

label_smoothing (float):将标签平滑应用于混合目标张量。

num_classes (int): 目标的类数。

EMA

EMA(Exponential Moving Average)是指数移动平均值。在深度学习中的做法是保存历史的一份参数,在一定训练阶段后,拿历史的参数给目前学习的参数做一次平滑。具体实现如下:


import logging
from collections import OrderedDict
from copy import deepcopy
import torch
import torch.nn as nn

_logger = logging.getLogger(__name__)

class ModelEma:
    def __init__(self, model, decay=0.9999, device='', resume=''):
        # make a copy of the model for accumulating moving average of weights
        self.ema = deepcopy(model)
        self.ema.eval()
        self.decay = decay
        self.device = device  # perform ema on different device from model if set
        if device:
            self.ema.to(device=device)
        self.ema_has_module = hasattr(self.ema, 'module')
        if resume:
            self._load_checkpoint(resume)
        for p in self.ema.parameters():
            p.requires_grad_(False)

    def _load_checkpoint(self, checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        assert isinstance(checkpoint, dict)
        if 'state_dict_ema' in checkpoint:
            new_state_dict = OrderedDict()
            for k, v in checkpoint['state_dict_ema'].items():
                # ema model may have been wrapped by DataParallel, and need module prefix
                if self.ema_has_module:
                    name = 'module.' + k if not k.startswith('module') else k
                else:
                    name = k
                new_state_dict[name] = v
            self.ema.load_state_dict(new_state_dict)
            _logger.info("Loaded state_dict_ema")
        else:
            _logger.warning("Failed to find state_dict_ema, starting from loaded model weights")

    def update(self, model):
        # correct a mismatch in state dict keys
        needs_module = hasattr(model, 'module') and not self.ema_has_module
        with torch.no_grad():
            msd = model.state_dict()
            for k, ema_v in self.ema.state_dict().items():
                if needs_module:
                    k = 'module.' + k
                model_v = msd[k].detach()
                if self.device:
                    model_v = model_v.to(device=self.device)
                ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v)

加入到模型中。

#初始化
if use_ema:
     model_ema = ModelEma(
            model_ft,
            decay=model_ema_decay,
            device='cpu',
            resume=resume)

# 训练过程中,更新完参数后,同步update shadow weights
def train():
    optimizer.step()
    if model_ema is not None:
        model_ema.update(model)


# 将model_ema传入验证函数中
val(model_ema.ema, DEVICE, test_loader)

针对没有预训练的模型,容易出现EMA不上分的情况,这点大家要注意啊!

项目结构

InceptionNext_Demo
├─data1
│  ├─Black-grass
│  ├─Charlock
│  ├─Cleavers
│  ├─Common Chickweed
│  ├─Common wheat
│  ├─Fat Hen
│  ├─Loose Silky-bent
│  ├─Maize
│  ├─Scentless Mayweed
│  ├─Shepherds Purse
│  ├─Small-flowered Cranesbill
│  └─Sugar beet
├─models
│  └─inceptionnext.py.py
├─mean_std.py
├─makedata.py
├─train.py
├─cam_image.py
└─test.py

models:来源官方代码,对面的代码做了一些适应性修改。
mean_std.py:计算mean和std的值。
makedata.py:生成数据集。
ema.py:EMA脚本
train.py:训练InceptionNext模型
cam_image.py:热力图可视化

计算mean和std

为了使模型更加快速的收敛,我们需要计算出mean和std的值,新建mean_std.py,插入代码:

from torchvision.datasets import ImageFolder
import torch
from torchvision import transforms

def get_mean_and_std(train_data):
    train_loader = torch.utils.data.DataLoader(
        train_data, batch_size=1, shuffle=False, num_workers=0,
        pin_memory=True)
    mean = torch.zeros(3)
    std = torch.zeros(3)
    for X, _ in train_loader:
        for d in range(3):
            mean[d] += X[:, d, :, :].mean()
            std[d] += X[:, d, :, :].std()
    mean.div_(len(train_data))
    std.div_(len(train_data))
    return list(mean.numpy()), list(std.numpy())

if __name__ == '__main__':
    train_dataset = ImageFolder(root=r'data1', transform=transforms.ToTensor())
    print(get_mean_and_std(train_dataset))

数据集结构:

image-20220221153058619

运行结果:

([0.3281186, 0.28937867, 0.20702125], [0.09407319, 0.09732835, 0.106712654])

把这个结果记录下来,后面要用!

生成数据集

我们整理还的图像分类的数据集结构是这样的

data
├─Black-grass
├─Charlock
├─Cleavers
├─Common Chickweed
├─Common wheat
├─Fat Hen
├─Loose Silky-bent
├─Maize
├─Scentless Mayweed
├─Shepherds Purse
├─Small-flowered Cranesbill
└─Sugar beet

pytorch和keras默认加载方式是ImageNet数据集格式,格式是

├─data
│  ├─val
│  │   ├─Black-grass
│  │   ├─Charlock
│  │   ├─Cleavers
│  │   ├─Common Chickweed
│  │   ├─Common wheat
│  │   ├─Fat Hen
│  │   ├─Loose Silky-bent
│  │   ├─Maize
│  │   ├─Scentless Mayweed
│  │   ├─Shepherds Purse
│  │   ├─Small-flowered Cranesbill
│  │   └─Sugar beet
│  └─train
│      ├─Black-grass
│      ├─Charlock
│      ├─Cleavers
│      ├─Common Chickweed
│      ├─Common wheat
│      ├─Fat Hen
│      ├─Loose Silky-bent
│      ├─Maize
│      ├─Scentless Mayweed
│      ├─Shepherds Purse
│      ├─Small-flowered Cranesbill
│      └─Sugar beet

新增格式转化脚本makedata.py,插入代码:

import glob
import os
import shutil

image_list=glob.glob('data1/*/*.png')
print(image_list)
file_dir='data'
if os.path.exists(file_dir):
    print('true')
    #os.rmdir(file_dir)
    shutil.rmtree(file_dir)#删除再建立
    os.makedirs(file_dir)
else:
    os.makedirs(file_dir)

from sklearn.model_selection import train_test_split
trainval_files, val_files = train_test_split(image_list, test_size=0.3, random_state=42)
train_dir='train'
val_dir='val'
train_root=os.path.join(file_dir,train_dir)
val_root=os.path.join(file_dir,val_dir)
for file in trainval_files:
    file_class=file.replace("\\","/").split('/')[-2]
    file_name=file.replace("\\","/").split('/')[-1]
    file_class=os.path.join(train_root,file_class)
    if not os.path.isdir(file_class):
        os.makedirs(file_class)
    shutil.copy(file, file_class + '/' + file_name)

for file in val_files:
    file_class=file.replace("\\","/").split('/')[-2]
    file_name=file.replace("\\","/").split('/')[-1]
    file_class=os.path.join(val_root,file_class)
    if not os.path.isdir(file_class):
        os.makedirs(file_class)
    shutil.copy(file, file_class + '/' + file_name)

完成上面的内容就可以开启训练和测试了。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

InceptionNext实战:使用InceptionNext实现图像分类任务(一) 的相关文章

  • Jenkins流水线怎么做?

    问CHAT Jenkins流水线怎么做 CHAT回复 Jenkins流水线是一种创建 测试和部署应用程序的方法 以下是为Jenkins创建流水线的步骤 1 安装Jenkins 首先你需要在你的服务器上安装Jenkins 这个过程可能会根据你
  • AAAI 2024 一作讲者招募 | 持续报名中

    点击蓝字 关注我们 AI TIME欢迎每一位AI爱好者的加入 我们诚挚地邀请您来AI TIME分享您发表在AAAI 2024的工作 请您扫码填写以下问卷 如内容合适我们将会与您沟通相关分享事宜 AAAI 2024预讲会计划时间 2024年1
  • 【路径规划】基于A*算法路径规划研究(Matlab代码实现)

    欢迎来到本博客 博主优势 博客内容尽量做到思维缜密 逻辑清晰 为了方便读者 座右铭 行百里者 半于九十 本文目录如下 目录 1 概述 2 运行结果 3 参考文献 4 Matlab代码实现
  • 喜报|华测导航荣获“张江之星”领军型企业称号

    近日 2023年度 张江之星 企业培育名单发布 上海华测导航荣获2023年度 张江之星 领军型企业称号 据悉 张江之星 企业培育是上海科创办为落实 关于推进张江高新区改革创新发展建设世界领先科技园区的若干意见 张江高新区加快世界领先科技园区
  • 利用CHAT上传文件的操作

    问CHAT autox js ui 上传框 CHAT回复 上传文件的操作如果是在应用界面中的话 由于Android对于文件权限的限制 你可能不能直接模拟点击选择文件 一般来说有两种常见的解决方案 一种是使用intent来模拟发送一个文件路径
  • 扬帆证券:三只松鼠去年扣非净利预增超1.4倍

    在 高端性价比 战略驱动下 三只松鼠 300783 重拾增势 1月15日晚间 三只松鼠发布成绩预告 预计2023年度净赢利为2亿元至2 2亿元 同比增加54 97 至70 47 扣非后净赢利为1亿元至1 1亿元 同比增速达146 9 至17
  • 活动日程&直播预约|智谱AI技术开放日 Zhipu DevDay

    点击蓝字 关注我们 AI TIME欢迎每一位AI爱好者的加入 直播预约通道 关于AI TIME AI TIME源起于2019年 旨在发扬科学思辨精神 邀请各界人士对人工智能理论 算法和场景应用的本质问题进行探索 加强思想碰撞 链接全球AI学
  • 明日 15:00 | NeurIPS 2023 Spotlight 论文

    点击蓝字 关注我们 AI TIME欢迎每一位AI爱好者的加入 哔哩哔哩直播通道 扫码关注AITIME哔哩哔哩官方账号预约直播 1月17日 15 00 16 00 讲者介绍 黄若孜 腾讯AI LAB游戏AI研究员 2020年复旦大学硕士毕业后
  • 毕业设计:基于卷积神经网络的验证码识别系统 机器视觉 人工智能

    目录 前言 设计思路 一 课题背景与意义 二 算法理论原理 2 1 字符分割算法 2 2 深度学习 三 检测的实现 3 1 数据集 3 2 实验环境搭建 3 3 实验及结果分析 最后 前言 大四是整个大学期间最忙碌的时光 一边要忙着备考或实
  • 作物叶片病害识别系统

    介绍 由于植物疾病的检测在农业领域中起着重要作用 因为植物疾病是相当自然的现象 如果在这个领域不采取适当的护理措施 就会对植物产生严重影响 进而影响相关产品的质量 数量或产量 植物疾病会引起疾病的周期性爆发 导致大规模死亡 这些问题需要在初
  • 强烈推荐收藏!LlamaIndex 官方发布高清大图,纵览高级 RAG技术

    近日 Llamaindex 官方博客重磅发布了一篇博文 A Cheat Sheet and Some Recipes For Building Advanced RAG 通过一张图给开发者总结了当下主流的高级RAG技术 帮助应对复杂的生产场
  • 手把手教你用 Stable Diffusion 写好提示词

    Stable Diffusion 技术把 AI 图像生成提高到了一个全新高度 文生图 Text to image 生成质量很大程度上取决于你的提示词 Prompt 好不好 前面文章写了一篇文章 一份保姆级的 Stable Diffusion
  • 机器学习算法实战案例:时间序列数据最全的预处理方法总结

    文章目录 1 缺失值处理 1 1 统计缺失值 1 2 删除缺失值 1 3 指定值填充 1 4 均值 中位数 众数填充
  • 主流进销存系统有哪些?企业该如何选择进销存系统?

    主流进销存系统有哪些 企业该如何选择进销存系统 永久免费 的软件 这个可能还真不太可能有 而且就算有 也只能说是相对免费 因为要么就是数据存量有限 要么就是功能有限 数据 信息都不保障 并且功能不完全 免费 免费软件 免费进销存 诸如此类
  • 基于节点电价的电网对电动汽车接纳能力评估模型研究(Matlab代码实现)

    欢迎来到本博客 博主优势 博客内容尽量做到思维缜密 逻辑清晰 为了方便读者 座右铭 行百里者 半于九十 本文目录如下 目录 1 概述 2 运行结果 3 参考文献 4 Matlab代码 数据
  • 基于节点电价的电网对电动汽车接纳能力评估模型研究(Matlab代码实现)

    欢迎来到本博客 博主优势 博客内容尽量做到思维缜密 逻辑清晰 为了方便读者 座右铭 行百里者 半于九十 本文目录如下 目录 1 概述 2 运行结果 3 参考文献 4 Matlab代码 数据
  • 国产化率100%,北斗导航单日定位4500亿次,外媒:GPS将被淘汰

    追赶30年的技术差距 国产卫星导航系统 北斗 开始扬眉吐气 数据显示 北斗导航目前单日定位量达4500亿次 已经获得100多个国家的合作意向 甚至国际民航也摒弃以往 独宠 GPS的惯例 将北斗纳入参考标准 对此 有媒体直言 GPS多年来的技
  • 深度学习(5)--Keras实战

    一 Keras基础概念 Keras是深度学习中的一个神经网络框架 是一个高级神经网络API 用Python编写 可以在TensorFlow CNTK或Theano之上运行 Keras优点 1 允许简单快速的原型设计 用户友好性 模块化和可扩
  • 自动驾驶离不开的仿真!Carla-Autoware联合仿真全栈教程

    随着自动驾驶技术的不断发展 研发技术人员开始面对一系列复杂挑战 特别是在确保系统安全性 处理复杂交通场景以及优化算法性能等方面 这些挑战中 尤其突出的是所谓的 长尾问题 即那些在实际道路测试中难以遇到的罕见或异常驾驶情况 这些问题暴露了实车
  • 对中国手机作恶的谷歌,印度CEO先后向三星和苹果低头求饶

    日前苹果与谷歌宣布合作 发布了 Find My Device Network 的草案 旨在规范蓝牙追踪器的使用 在以往苹果和谷歌的生态形成鲜明的壁垒 各走各路 如今双方竟然达成合作 发生了什么事 首先是谷歌安卓系统的市场份额显著下滑 数年来

随机推荐

  • zi2zi Learning Chinese Character style with conditional GAN 相关解读

    1 zi2zi地址 github 2 网络结构 3 可能需要注意的问题 1 Constant Loss 主要根据参考文献 DTN 主要计算方法 const loss tf reduce mean tf square encoded real
  • STM32F103 入门篇 13-GPIO输入-按键检测

    PA0 PC13同时还具有唤醒功能 上升沿 电容作用 按键按下后会有20ms的抖动 待稳定后通过地线导出 驱动函数 初始化GPIO 使用浮空输入 STM32的四种输入方式 1 上拉输入 GPIO Mode IPU 上拉输入就是信号进入芯片后
  • react中redux的介绍与使用

    1 redux是什么 1 redux是一个专门用于做状态管理的js库 不是react插件库 2 它可以用在react angular vue等项目中 但基本与react配合使用 3 作用 集中式管理react应用中多个组件共享的状态 4 r
  • Laravel 远程代码执行漏洞(CVE-2021-3129)复现

    Laravel 远程代码执行漏洞 CVE 2021 3129 复现 一 漏洞概述 Laravel是一套简洁 优雅的PHP Web开发框架 PHP Web Framework 它可以让你从面条一样杂乱的代码中解脱出来 它可以帮你构建一个完美的
  • 全球及中国工程机械行业营销模式及销售格局分析报告2021-2027年版

    全球及中国工程机械行业营销模式及销售格局分析报告2021 2027年版 HS HS HS HS HS HS HS HS HS HS HS HS HS HS 修订日期 2021年10月 搜索鸿晟信合研究院查看官网更多内容 第一章 工程机械行业
  • 1305. 两棵二叉搜索树中的所有元素(中序遍历)

    给你 root1 和 root2 这两棵二叉搜索树 请你返回一个列表 其中包含 两棵树 中的所有整数并按 升序 排序 示例 1 输入 root1 2 1 4 root2 1 0 3 输出 0 1 1 2 3 4 示例 2 输入 root1
  • 刷脸支付更好地实现了社会资源的协同

    从移动支付到刷脸支付 争体验更是争用户入口 2年前 谈到移动支付 包括许多巨型公司 都只认为它不过是一个收款的入口 方便快捷 不必找零 消费者爱用 企业好清算对账 而如今 移动支付正在演变为各类本地生活服务的重要入口 各路巨头也不断加持 微
  • ISE约束文件UCF与Vivado约束文件XDC(FPGA不积跬步101)

    ISE约束文件UCF与Vivado约束文件XDC FPGA不积跬步101 随着FPGA技术的日益成熟 越来越多的工程师选择使用FPGA进行嵌入式系统的设计和开发 在FPGA的设计中 约束文件的编写是非常重要的一环 而在约束文件的编写中 IS
  • 蓝桥杯2023年第十四届省赛真题------第十四届蓝桥杯本科A组/研究生组2023年省赛题解--全部采用Java语言实现

    引言 今天现在这里 挖个坑 太忙了 这个专题不一样有时间补完 但我会尽力而为的 记录一下今天的日子 2023 04 21 看看这个坑要什么时候自己才能补完 题目pdf下载 第十四届蓝桥杯研究生组pdf下载 在此特别感谢博主int 我的文章参
  • xmind更改分支方向

    XMind如何调整分支主题位置 xmind怎么在左边创建 XMind分支主题任意移动方法 问题 问题解决 问题 在刚开始使用xmind的时候 觉得分支这么这么死板 不能随意拖动到想要的位置 分支朝向不能放到左边 右边 问题解决 这里使用的是
  • 你应该关注的十个智能硬件中文网站

    摘要 不论你是智能硬件从业者 还是智能硬件爱好者 不妨统一称之为 智能硬件er 在信息泛滥的时代 专注于智能硬件 能提供好的资讯 观点 资源的平台屈指可数 这是为您收集的值得关注的十大智能硬件中文网站 不论你是智能硬件从业者 还是智能硬件爱
  • 排序算法之冒泡排序(Bubble sort)

    冒泡排序 Bubble sort 是一个排序算法 可以将一组数列按从小到大或从大到小的顺序排列 操作步骤 从数列的开头开始比较相邻的元素 若前者比后者大 小 则调换二者的位置 依次重复执行1步骤 最终最大 小 的元素排列到了最后 除了已经排
  • C语言求s=a+aa+aaa+aaaa+....

    问题描述 求s a aa aaa aaaa 其中a是一个数字 n表示a的位数 a和n由键盘输入 代码描述 方法1 include
  • 网络安全(黑客)自学路线

    谈起黑客 可能各位都会想到 盗号 其实不尽然 黑客是一群喜爱研究技术的群体 在黑客圈中 一般分为三大圈 娱乐圈 技术圈 职业圈 娱乐圈 主要是初中生和高中生较多 玩网恋 人气 空间 建站收徒玩赚钱 技术高的也是有的 只是很少见 技术圈 这个
  • Python第九、十课

    枭 Python第九 十课 今天讲解了Python的 函数参数类型 变量作用域 函数参数类型 默认值参数 规则 在定义带有默认值参数的函数时 默认值参数必须全部出现在位置参数右侧 且任何一个默认参数右边都不能出现位置参数 默认值参数只在第一
  • 【云原生之kubernetes实战】kubernetes集群的HPA弹性伸缩

    云原生之kubernetes实战 kubernetes集群的HPA弹性伸缩 一 HAP介绍 1 HPA简介 2 HPA的实现原理 3 HPA自动伸缩示意图 4 HPA中影响 Pod 数量的因素 5 HPA改善服务的方式 二 检查本地k8s环
  • python学习主要应该学些什么

    近几年 python 正在成为最受欢迎的编程语言之一 无论是软件开发还是机器学习 python 都能够处理得游刃有余 人们喜欢使用 python 语言是因为它非常容易 相比于 c 语言和 java 等语言 它开发效率更高 它有着丰富的第三方
  • centos7安装或卸载RabbitMQ 3.8.4+Erlang 23.0详细步骤

    参考博客 Linux CentOS 7 下RabbitMQ的安装与配置 风萧萧1999的博客 CSDN博客 安装前准备 1 检查RabbitMQ Erlang版本 https www rabbitmq com which erlang ht
  • 【解决】使用mybatis解决sql报错: Parameter index out of range (1 > number of parameters, which is 0)

    当我们运用ssm框架做项目 并且要通过mybatis来做模糊查询时 sql语句中where条件语句 like后面 要用 like 条件名 而不是平常做where 时的like 条件名
  • InceptionNext实战:使用InceptionNext实现图像分类任务(一)

    文章目录 摘要 安装包 安装timm 安装 grad cam 数据增强Cutout和Mixup EMA 项目结构 计算mean和std 生成数据集 摘要 论文翻译 https blog csdn net m0 47867638 articl