MindSpore图像分类训练resnet50实现

2023-05-16

目录

一、mindspore简介

二、训练环境

三、数据集与数据加载

四、模型训练和验证

五、迁移学习

六、模型测试和导出


一、mindspore简介

         MindSpore是华为开源的全场景深度学习框架,旨在实现易开发、高效执行、全场景覆盖三大目标,其中易开发表现为API友好、调试难度低,高效执行包括计算效率、数据预处理效率和分布式训练效率,全场景则指框架同时支持云、边缘以及端侧场景。

       MindSpore支持的Windowslinux系统,其中Windows版本仅支持CPU运行,linux版本则支持GPUNPU(华为昇腾系列处理器)。

       MindSpore官网地址:https://www.mindspore.cn/。官网包含了安装说明、教程、文档、官方开源模型等资源,方便初学者快速入门。

二、训练环境

硬件环境:cpu(i7-1165G7)、内存16G;

软件环境:windows10、python3.7、pycharm、mindspore1.5

三、数据集与数据加载

         MindSpore提供API接口直接加载Cirfar10、ImageNet、coco等开源数据集,对图像分类自定义数据集加载也十分方便。这里准备训练一个识别哈士奇和拉布拉多犬的二分类模型,首先需要准备图像并存入对应文件夹。如下:

 

         数据准备:数据集分为训练集和测试集,两种类别的图片数量尽量一致,训练集husky(399)、labrador(400),验证集:husky(51)、labrador(49)。数据集文件结构:

dataset:

                  train:

                          husky:1.jpg...

                          labrador:1.jpg...

                  val:

                          husky:1.jpg...

                          labrador:1.jpg...

         加载数据集:将不同类别图像放在不同文件夹下,mindspore.dataset .ImageFolderDataset()接口可以直接对数据集进行加载和标注。        

train_data_path = 'dataset/train'

data_set = ds.ImageFolderDataset(data_path, num_parallel_workers=8, shuffle=True)

         图像预处理:图像解码、调整大小、标准化、矩阵转置。       

  image_size = [224, 224]

  mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]

  std = [0.229 * 255, 0.224 * 255, 0.225 * 255]

  trans = [

        CV.Decode(),

        CV.Resize(image_size),

        CV.Normalize(mean=mean, std=std),

        CV.HWC2CHW()

    ]

         数据增强:如果数据集比较小,为了增强模型泛化能力,可以通过修改tran配置进行数据增强。        

trans = [

            CV.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),

            CV.RandomHorizontalFlip(prob=0.5),

            CV.Normalize(mean=mean, std=std),

            CV.HWC2CHW()

        ]

数据的map映射、批量处理和数据重复的操作:

data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=8)
    data_set=data_set.map(operations=type_cast_op,input_columns="label",num_parallel_workers=8)

data_set = data_set.batch(batch_size, drop_remainder=True)

data_set = data_set.repeat(repeat_num)

四、模型训练和验证

         使用MindSpore官方resnet.py脚本构建一个resnet50网络。        

net = resnet50(2)

num_epochs=5

         定义优化器和损失函数:        

opt = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)

loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')

实例化模型:

model = Model(net, loss, opt, metrics={"Accuracy": nn.Accuracy()})

模型训练:

eval_param_dict = {"model":model,"dataset":val_ds,"metrics_name":"Accuracy"}

eval_cb = EvalCallBack(apply_eval, eval_param_dict,)

model.train(num_epochs,train_ds, callbacks=[eval_cb,TimeMonitor()],dataset_sink_mode=False)

         训练过程中,对每一个epoch进行验证,保留验证精度最好的模型参数。

         训练结束后,使用训练过程保存的精度最好的参数对验证集进行验证,并对验证结果可视化。

五、迁移学习

         MindSpore实现迁移学习流程:定义网络并加载预训练模型;删除预训练模型最后一层参数;给网络加载加载预训练参数;冻结除最后一层外所有参数。

# 加载预训练模型

param_dict = load_checkpoint('resnet50.ckpt')

# 获取最后一层参数的名字

filter_list = [x.name for x in net.end_point.get_parameters()]

# 删除预训练模型最后一层的参数

filter_checkpoint_parameter_by_list(param_dict, filter_list)

# 给网络加载参数

load_param_into_net(net, param_dict)

# 冻结除最后一层外的所有参数

for param in net.get_parameters():

    if param.name not in ["end_point.weight","end_point.bias"]:

         param.requires_grad = False

六、模型测试和导出

         测试模型:模型训练完成后,通过推理代码和测试集对模型进行评估。

推理实现代码:

import os

import numpy as np

import cv2

import mindspore.nn as nn

from mindspore import dtype as mstype

import mindspore.dataset.vision.c_transforms as CV

from mindspore import Model, Tensor, context, load_checkpoint, load_param_into_net

from resnet import resnet50

#设置使用设备,CPU/GPU/Ascend

context.set_context(mode=context.GRAPH_MODE, device_target="CPU")

def normalize(image):

    mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]

    std = [0.229 * 255, 0.224 * 255, 0.225 * 255]

    image = cv2.resize(image, [224, 224], cv2.INTER_LINEAR)

    image = image / 1.0

    image = (image[:, :] - mean) / std

    image = image[:, :, ::-1].transpose((2, 0, 1))  # HWC-->CHW

    return image

def pre_deal(data_path):

    image = cv2.imread(data_path)

    norm_img = normalize(image)

    #norm_img = ms_normalize(image)

    images = [norm_img]

    images = Tensor(images, mstype.float32)

    return images

def infer(ckpt_path, data_path, num_class):

    image = pre_deal(data_path)

    net = resnet50(num_class)

    param_dict = load_checkpoint(ckpt_path)

    load_param_into_net(net, param_dict)

    loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')

    model = Model(net, loss, metrics={"Accuracy": nn.Accuracy()})

    output = model.predict(image)

    print(output)

    pred = np.argmax(output.asnumpy(), axis=1)

    return pred

if __name__ == '__main__':

    ckpt_path = 'transfer_best.ckpt'

    data_path = 'test'

    class_name = {0: 'husky', 1: 'labrador'}

    for path in os.listdir(os.path.join(data_path)):

        path = os.path.join(data_path) + '/' + path

        print(path)

        result = infer(ckpt_path, path, 2)

        print(class_name[result[0]])

  为了方便推理部署,MindSpore支持导出MINDIR、AIR、ONNX三种格式。

from mindspore import export, load_checkpoint, load_param_into_net

from mindspore import Tensor

import numpy as np

from resnet import resnet50

net = resnet50(2)

# 将模型参数存入parameter的字典中

param_dict = load_checkpoint("best.ckpt")

# 将参数加载到网络中

load_param_into_net(net, param_dict)

input = np.random.uniform(0.0, 1.0, size=[1, 3, 224, 224]).astype(np.float32)

#导出模型,可导出ONNX、MINDIR、AIR格式

export(net, Tensor(input), file_name='resnet50_best', file_format='ONNX')

全部实现代码:

https://gitee.com/chen-jian51/mindspore_resnet50_husky_labrador/tree/master

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

MindSpore图像分类训练resnet50实现 的相关文章

  • 用C#打造自己的实体转换器

    说明 尽管随着NoSQL的普及 xff0c 数据库访问的性能已经非常关注的重点了 xff08 可以通过架构来解决这个瓶颈 xff09 xff0c 所以有越来越多的项目使用了ORM来访问和操作数据库 xff0c 在周公的博客上有一个系列的文章
  • Java工程师考试题

    Java工程师考试题 一 填空题 xff08 本大题10小题 xff0c 每小题2分 xff0c 共20分 xff09 1 当Java对象不再被引用变量引用时 时 将被垃圾回收器回收 2 用POS方法的HTTP包 xff0c HTTP头与P
  • 云原生|kubernetes|ingress-nginx插件部署(kubernetes-1.23和最新版controller-1.6.4)

    前言 xff1a ingress是kubernetes内的一个重要功能插件 xff0c 这个使得服务治理成为一个可能 xff0c 当然 xff0c 结合微服务更为妥当了 不管是什么插件 xff0c 还是服务 xff0c 第一步当然是要能顺利
  • 企业私有云

    企业私有云 企业私有云 xff08 Private Cloud xff09 的定义 xff1a 针对特定的企业 组织和团体提供云服务 xff0c 不对外开放的云计算数据中心 企业私有云的特点 xff1a 1 用户拥有完整的云计算IT系统 x
  • 关于linux下VNC服务的一些介绍(本文章是基于tigervnc)

    一 为什么要写这篇文章 近期在项目上遇到一个很尴尬的现象 xff0c 项目上唯一的一台跳板机不能通过堡垒机进行VNC登录了 xff0c 该跳板机平时用于访问内网web界面做测试 xff1b 但是跳板机内部的VNC服务和端口都正常 xff08
  • Java对象类型转换:向上转型和向下转型

    将一个类型强制转换成另一个类型的过程被称为类型转换 对象类型转换 xff0c 是指存在继承关系的对象 xff0c 不是任意类型的对象 当对不存在继承关系的对象进行强制类型转换时 xff0c 会抛出 Java 强制类型转换 xff08 jav
  • 华为云服务器(linux系统)完整配置流程(包含jdk、Tomcat配置、网页配置等)

    去年华为云服务器做活动 xff0c 白嫖了一个弹性云服务器 xff0c 一直没有用 xff0c 今天着手来配置一下 xff0c 不然要过期了 一边配置一边记录流程 xff0c 亲测有效哦 xff01 首先 xff0c 需要安装一个远程登陆软
  • sql获取两个时间戳之间的时间差以及报错 [Err] 1292 - Truncated incorrect time value: '932:13:47'

    前段时间再项目开发过程中写到一个update语句 xff0c 需求两个时间戳之差作为where条件但是用了 HOUR TIMEDIFF expr1 expr2 方法成功了 UPDATE work order complaint SET 96
  • HTML5 基础知识总结(全)

    文章目录 1 文档类型2 字符集3 标签 lt h1 gt 到 lt h6 gt 4 文本格式化标签 xff08 熟记 xff09 5 标签属性6 图像标签img7 链接标签8 锚点定位9 base标签10 特殊字符11 注释标签12 相对
  • IntelliJ IDEA集成maven

    一 idea中maven的配置 1 maven配置 首先需要在idea中对maven进行集成 xff0c 目录为File Setting Build Execution Deployment Build Tools maven xff0c
  • centos7防火墙配置详细(转载)

    一 条件防火墙是开启的 systemctl start firewalld 1 查看防火墙的配置 firewall cmd state firewall cmd list all 2 开放80端口 firewall cmd permanen
  • JAVA简单快速排序讲解

    首先 xff0c 我们来了解一下什么是快速排序 xff1a 所谓快速排序 xff0c 就是在冒泡排序的基础上进行改进 xff0c 延伸出来的一种跳跃性的排序方法 xff0c 我们都知道 xff0c 冒泡排序 xff0c 就是相邻两个数之间进
  • 基于 CentOS7 的 KVM 部署 + 虚拟机创建

    目录 一 实验环境二 部署 KVM三 创建虚拟机四 远程管理 KVM 虚拟机FAQ 一 实验环境 实验环境 xff1a VMware Workstation 16 Pro 打开虚拟机之前 xff0c 首先开启 VMware Workstat
  • 云原生|kubernetes|网络插件flannel二进制部署和calico的yaml清单部署总结版

    前言 xff1a 前面写了一些关于calico的文章 xff0c 但感觉好像是浅尝辄止 xff0c 分散在了几篇文章内 xff0c 并且很多地方还是没有说的太清楚云原生 kubernetes kubernetes的网络插件calico和fl
  • 在PROC程序中出现 "error: break statement not within loop or switch" 的原因。

    今天碰到一个问题 xff0c 如果proc预编译后生成的 c文件中有下面代码 xff1a if sqlca sqlcode 61 61 1403 break 如果在gcc编译时出现 error break statement not wit
  • Ubuntu18.04下ROS安装

    提示 xff1a 文章写完后 xff0c 目录可以自动生成 xff0c 如何生成可参考右边的帮助文档 文章目录 前言一 Ubuntu操作系统版本对应二 安装ROS1 换源2 添加国内源3 设置密钥4 安装ROS5 初始化 rosdepc6
  • jdbc连接MySQL数据库的简单应用

    jdbc连接MySQL数据库和连接Oracle数据库大体步骤一样 xff0c 首先加载数据库驱动包 xff0c 然后创建数据库连接 xff0c 接着执行sql语句 xff0c 最后返回结果集 但连接MySQL我们需要导入的驱动包是mysql
  • Matlab学习笔记4——readtable

    Matlab学习笔记4 readtable 基于文件创建表 xff0c 第一行就作为表头 xff0c 如果取的表头符合matlab的命名规则 xff0c 那么该列直接如此命名 语法 T 61 readtable filename T 61
  • 【Mysql细节】插入日期数据报格式错误:Data truncation: Incorrect datetime value

    看图吧 xff1a 为什么会在插入的第四条数据报格式错误呢 xff1f 首先这些插入数据是直接复制的 xff08 不是纯手写 xff09 看到报错第一反应是不是觉得自己的数据格式有问题啊 xff0c 细看又没有发现有啥问题 那我是如何解决的
  • 联合索引的最左前缀匹配原则

    目录 联合索引 最左前缀匹配原则 最左匹配原则的成因 联合索引 所谓的联合索引就是指 xff0c 由两个或以上的字段共同构成一个索引 本文测试用例的数据表结构如下 xff0c 一张简简单单的学生信息表 tb student xff0c 仅包

随机推荐

  • KEIL下载程序失败系列问题

    发现问题 xff1a 例如 xff1a 当你使用keil下载程序时 xff0c 往往会出现以下类似问题 xff0c 下面带你解决问题 1 电源 xff1a one 首先 xff0c 当你做嵌入式方面工作 xff0c 出了问题重中之重就是检查
  • js各进制之间的相互转换

    size 61 medium 十进制转二进制 parseInt num toString 2 十进制转八进制parseInt num toString 8 十进制转十六进制parseInt num toString 16 二进制转十进制pa
  • linux|奇怪的知识---账号安全加固,ssh安全加固

    前言 xff1a 一般情况下 xff0c 我们对于账号的安全是比较随意的 xff0c 因为在生产环境里 xff0c 基本都是使用堡垒机这样的带有安全审计功能的工具对各个主机进行监控 xff0c 管理 xff0c 并且结合prometheus
  • Python 自学笔记

    前言 此Python3笔记仅为本人自学网络教学视频总结的笔记 xff0c 初衷仅为个人的学习和复习使用 xff0c 本人使用的编译器为Pycharm xff0c 内容仅供参考 xff08 俺是小白 xff0c 有不对的地方希望各位大佬指出
  • libcli工具的使用-命令行修改输入参数

    libcli工具的使用 命令行修改输入参数 libcli工具介绍 Libcli 提供了一个共享的 C 库 xff0c 用于将类似 Cisco 的命令行界面包含到其他软件中 它是一个 telnet 接口 xff0c 支持用户可定义的功能树的命
  • STM32无法连接JLink(Flash读写保护) 解决方法

    By Ailson Jack Date 2020 12 12 个人博客 xff1a 首页 说好一起走 本文在我博客的地址是 xff1a STM32无法连接JLink Flash读写保护 解决方法 说好一起走 xff0c 排版更好 xff0c
  • Softmax到AMSoftmax(附可视化代码和实现代码)

    Softmax nbsp 个人理解 在训练的时候 加上角度margin 把预测出来的值减小 往0那里挤压 离标注距离更大 减少训练得分 加大loss 增加训练收敛难度 不明白的有个问题 减去m后 如果出现负数怎么办 nbsp nbsp 以下
  • linux python保存mp4

    解决 python调用OpenCV 保存视频时使用 avc1 格式出现 Could not find encoder for codec id 27 Encoder not found的错误 此错误不能保存视频文件 以及使用 mpeg 格式
  • 在树莓派4B安装 scipy 笔记,不需要删除numpy,不需要mkl

    在树莓派4B安装 scipy 笔记 xff0c 不需要删除numpy xff0c 不需要mkl 参考官网 xff1a 不要用sudo xff0c 带上 user xff0c 否则有问题 xff0c 官网 最好用pip安装 python sp
  • LINUX基础试题大全(4)

    说明 xff1a 此文章由于题数庞大 xff0c 为方便阅读本人将其分为四篇文章为大家分享 xff01 答案会今后不断进行更新 xff01 LINUX基础试题大全 xff08 1 xff09 填空题题 LINUX基础试题大全 xff08 2
  • Oracle 创建用户

    span class token comment 查看表空间 span span class token keyword select span span class token operator span span class token
  • Win10 (mstsc)局域网远程桌面连接,超全面设置。(附带,外网远程连接mstsc)

    TCP IP xff08 Transmission Control Protocol Internet Protocol xff0c 传输控制协议 网际协议 xff09 是指能够在多个不同网络间实现信息传输的协议簇 TCP IP协议不仅仅指
  • postgresql |数据库 |数据库的常用备份方法总结

    前言 xff1a 数据库的重要性就不需要在这里重复了 xff0c 那么 xff0c 不管是测试环境 xff0c 还是开发环境 xff0c 亦或者是生产环境 xff0c 数据库作为系统内 xff08 项目内 xff09 的一个非常重要的组件
  • USB连接到centos7虚拟机出现错误(VMware USB Arbitration Service无法启动)

    USB连接到centos7虚拟机出现错误 有可能是因为VMware USB Arbitration Service出现无法自启动的问题 windows主机中WIN 43 R打开运行界面 xff0c 输入 接着查看VMware USB Arb
  • Centos7下配置网卡网桥

    简介 xff1a Linux系统下开启虚拟机需要配置网桥 xff0c 从而使得Linux系统里虚拟化软件的虚拟交换机与宿主机的物理网卡绑定一起 虚拟机与宿主机互相独立的IP 物理网卡监听这些IP xff0c 从而达到虚拟机与物理机在同一个局
  • 【玩转cocos2d-x之二十五】数据结构CCArray

    原创作品 xff0c 转载请标明 xff1a http blog csdn net jackystudio article details 16938787 CCArray是从cocos2d中移植过来的 xff0c 类似于Apple的NSM
  • 5、钉钉平台

    文章目录 文章目录 应用管理运行环境应用类型编程模式风神工作台基础应用信息接口权限应用发布 钉钉集成参数设置钉钉接口地址常量 钉钉集成登录钉钉事件回调接口注册通讯录事件回调群会话事件回调签到事件回调审批事件回调 日志管理签到事件回调审批事件
  • CentOS常用防火墙命令

    systemctl启动 停止 查看防火墙状态 systemctl从CentOS7 x开始引入的一个服务管理工具命令 xff0c 集 service和chkconfig的功能于一体 启动防火墙 systemctl start firewall
  • ARDUINO LCD显示简单的汉字、符号(二 已写成的字模和基于Python的检索系统)

    北 61 0x0A 0x0A 0x0A 0x1B 0x0A 0x0A 0x0A 0x1B 京 61 0x04 0x1F 0x0E 0x0A 0x0E 0x15 0x15 0x0C 市 61 0x04 0x1F 0x04 0x1F 0x15
  • MindSpore图像分类训练resnet50实现

    目录 一 mindspore简介 二 训练环境 三 数据集与数据加载 四 模型训练和验证 五 迁移学习 六 模型测试和导出 一 mindspore简介 MindSpore 是华为开源的全场景深度学习框架 xff0c 旨在实现易开发 高效执行