Pytorch与Onnx模型的保存、转换与操作

2023-11-01

Open Neural Network Exchange(ONNX,开放神经网络交换)格式,是一个用于表示深度学习模型的标准,可使模型在不同框架之间进行转移。

一、pytorch模型保存/加载

有两种方式可用于保存/加载pytorch模型 1)文件中保存模型结构和权重参数 2)文件只保留模型权重.

1、文件中保存模型结构和权重参数

模型保存与调用方式一(只保存权重):

保存:

torch.save(model.state_dict(), mymodel.pth)#只保存模型权重参数,不保存模型结构

调用:

model = My_model(*args, **kwargs)  #这里需要重新创建模型,My_model
model.load_state_dict(torch.load(mymodel.pth))#这里根据模型结构,导入存储的模型参数
model.eval()

模型保存与调用方式二(保存完整模型):

保存:

torch.save(model, mymodel.pth)#保存整个model的状态

调用:

model=torch.load(mymodel.pth)#这里已经不需要重构模型结构了,直接load就可以
model.eval()

.pt表示pytorch的模型,.onnx表示onnx的模型,后缀名为.pt, .pth, .pkl的pytorch模型文件之间其实没有任何区别

二、pytorch模型转ONNX模型

1、文件中保存模型结构和权重参数

import torch
torch_model = torch.load("save.pt") # pytorch模型加载
batch_size = 1  #批处理大小
input_shape = (3,244,244)   #输入数据


#set the model to inference mode
torch_model.eval()

x = torch.randn(batch_size,*input_shape)		# 生成张量(模型输入格式)
export_onnx_file = "test.onnx"					# 目的ONNX文件名

// 导出export:pt->onnx
torch.onnx.export(torch_model,					# pytorch模型
                    x,							# 生成张量(模型输入格式)
                    export_onnx_file,			# 目的ONNX文件名
                    do_constant_folding=True,	# 是否执行常量折叠优化
                    input_names=["input"],		# 输入名(可略)
                    output_names=["output"],	# 输出名(可略)
                    dynamic_axes={"input":{0:"batch_size"},		# 批处理变量(可略)
                                    "output":{0:"batch_size"}}) 

注:dynamic_axes字段用于批处理.若不想支持批处理或固定批处理大小,移除dynamic_axes字段即可.

2、文件中只保留模型权重

import torch
torch_model = selfmodel()  					# 由研究员提供python.py文件
batch_size = 1 								# 批处理大小
input_shape = (3, 244, 244) 				# 输入数据

#set the model to inference mode
torch_model.eval()

x = torch.randn(batch_size,*input_shape) 	# 生成张量(模型输入格式)
export_onnx_file = "test.onnx" 				# 目的ONNX文件名

// 导出export:pt->onnx
torch.onnx.export(torch_model,					# pytorch模型
                    x,							# 生成张量(模型输入格式)
                    export_onnx_file,			# 目的ONNX文件名
                    do_constant_folding=True,	# 是否执行常量折叠优化
                    input_names=["input"],		# 输入名(可略)
                    output_names=["output"],	# 输出名(可略)
                    dynamic_axes={"input":{0:"batch_size"},	# 批处理变量(可略)
                                    "output":{0:"batch_size"}})

3、onnx文件操作

3.1 加载onnx文件

# "加载load"
model=onnx.load('net.onnx')

检查模型格式是否完整及正确

onnx.checker.check_model(model)

3.2 打印onnx模型文件信息

session=onnxruntime.InferenceSession('net.onnx')
inp=session.get_inputs()[0]


#conv1=session.get_inputs()['conv1']
#out1=session.get_outputs()[1]
out=session.get_provider_options()
#print(inp,conv1,out1)
print(inp)
#print(out)
"打印图信息:字符串信息"
graph=onnx.helper.printable_graph(model.graph)
print(type(graph))

3.3 获取onnx模型输入输出层

input=model.graph.input
output = model.graph.output
"""输入输出层"""
print(input,output)

3.4 推断结果

"""推断"""
session=onnxruntime.InferenceSession('net.onnx')
input_name = session.get_inputs()
print(input_name)
output_name=session.get_outputs()[0].name
res=session.run([output_name],{input_name[0].name:inputs.numpy()})
print(res)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Pytorch与Onnx模型的保存、转换与操作 的相关文章

随机推荐

  • 创建第一个servlet项目(简单版创建)--详细图文教程

    Servlet 是一种实现动态页面的技术 是一组 Tomcat 提供给程序猿的 API 帮助程序猿简单高效的开发一 个 web app 今天讲一下如何建立一个servlet项目 注意 基于meven创建servlet项目 前提meven要下
  • Clion远程调试树莓派并传递视频流

    Clion远程调试树莓派并传递视频流 0 前言 1 远程调试配置 1 1 远端配置 1 2 本地配置 2 视频流传输 环境 windows10 LTSC raspi 0 前言 近期学习opencv 并准备一些比赛项目 听学长介绍Clion可
  • Linux 网卡重新获取IP

    1 所有网卡驱动重新加载 service network restart 2 对单一网卡进行操作 ifconfig a 获取所有网卡信息 可以看到所有网卡的名字 ifconfig 网卡名称 down ifconfig 网卡名称 up 3 D
  • nvme测试工具:nvme_cli

    nvme cli工具是用于对nvme盘测试的一款通用工具 提供了读写块 查看control namespace信息等功能 下载路径 nvme cli工具是用于对nvme盘进行测试的一款通用工具 其它文档类资源 CSDN下载 如果需要交叉编译
  • 运行node出现“ operation not permitted”错误解决办法

    windows系统下使用node js在使用npm安装express时报错的解决方法 安装时出现如下错误 C Users admin gt npm uninstall express gnpm ERR Windows NT 10 0 143
  • 雷军写的代码上热搜了

    雷军写的代码 一词突然上了微博热搜 一瞬间 我想起了这张图 到底发生了什么 好奇的我点进去一看 原来是因为雷军预告年度演讲的微博里配了一张海报 这张海报信息量非常大 一眼就能看到有很多代码元素 放大一点看看局部 这还是16位实模式下的汇编语
  • JSONObject重复引用导致结果中出现$ref的问题

    转自链接 https blog csdn net baceng article details 92836486 解决办法 先把JSONObject转换成String 然后再转换回JSONObject 例 bussinessData JSO
  • MAYA基础知识和技巧总结

    目录 自定义工具架 自定义热盒 打开Maya时隐藏Output Window 快捷键 小技巧 元素选择技巧 隐藏和显示元素的几种方法 多切割工具 加线 切割 补面的几种方法 复制的几种方法 加入参考图并锁定不动 曲线建模技巧 双轨成型工具
  • CUDA 编程入门

    CUDA 编程入门 更好的阅读体验 CUDA 概述 CUDA 是 NVIDIA 推出的用于其发布的 GPU 的并行计算架构 使用 CUDA 可以利用 GPU 的并行计算引擎更加高效的完成复杂的计算难题 在目前主流使用的冯 诺依曼体系结构的计
  • mongoDB使用总结

    windows安装 zip压缩包方式安装 下载 注意 因为现在最新版的mongodb不兼容win7 对windows系统的最低要求是win10 所以win7系统要安装mongodb数据库必须考虑使用旧版安装 Download MongoDB
  • 一个强大的漏洞公告网站

    http seclists org 通过一篇文章 MySQL 严重 Bug 用户登陆漏洞 得知的这个网站 强大
  • Python不是一门伟大的语言

    作为一门简洁易用 生态蓬勃且具有高泛用性的编程语言 Python一直以来都被不少人称作 编程语言中的瑞士军刀 尤其随着近来AI热潮席卷全球 Python在编程语言圈中的地位也随之水涨船高 甚至一度被视作AI专用语言或大数据专用语言 然而从语
  • 数据结构---归并排序

    归并排序 第一步 分组 第二步 归并 归并操作 第一步 第二步 第三步 JAVA实现 总结 第一步 分组 第1层分成2个大组 每组n 2个元素 第2层分成4个小组 每组n 4个元素 第3层分成8个更小的组 每组n 8个元素 一直到每组只有一
  • 各种经纬度坐标系转换-百度坐标系、火星坐标系、国际坐标系

    各种经纬度坐标系转换 百度坐标系 火星坐标系 国际坐标系 文章代码参考网上 测试没什么问题 汇总整理希望对大家有帮助 dou WGS84 国际坐标系 为一种大地坐标系 也是目前广泛使用的GPS全球卫星定位系统使用的坐标系 GCJ02 火星坐
  • UE4中使用数据表(Data Table)

    本文依据官方文档数据驱动游戏性元素整理而来 做过游戏的应该都清楚 如果游戏稍微有点规模 那么使用数据驱动来做游戏一般是必不可少的一步 一般也就是策划通过本表的方式来解决 下面我们来简单说一下UE4中如何使用DataTable来实现数据驱动开
  • MobileViG实战:使用MobileViG实现图像分类任务(一)

    文章目录 摘要 安装包 安装timm 安装 grad cam 数据增强Cutout和Mixup EMA 项目结构 计算mean和std 生成数据集 摘要 论文翻译 https blog csdn net m0 47867638 articl
  • 获取cookie的两种方式EL表达式中判断数据是否为空

    1 使用java代码获取cookie Cookie cs request getCookies 通过请求获取 for Cookie c cs if c getName equals loginAct String loginAct c ge
  • js总结(二)--彻底理解js中this的指向,不必硬背。

    首先必须要说的是 this的指向在函数定义的时候是确定不了的 只有函数执行的时候才能确定this到底指向谁 实际上this的最终指向的是那个调用它的对象 这句话有些问题 后面会解释为什么会有问题 虽然网上大部分的文章都是这样说的 虽然在很多
  • 解决 vscode 登录微软账户同步设置 出现“vscode.dev 关闭了连接“ 问题

    我的电脑最近重装了系统 之前的软件都删除了 在重新安装vscode之后想同步之前的设置 主题时出现了问题 我的解决方法是 在当前页面 输入 https vscode dev 看能不能打开 如果能打开 再次点击vscode登录账号同步设置 我
  • Pytorch与Onnx模型的保存、转换与操作

    Open Neural Network Exchange ONNX 开放神经网络交换 格式 是一个用于表示深度学习模型的标准 可使模型在不同框架之间进行转移 一 pytorch模型保存 加载 有两种方式可用于保存 加载pytorch模型 1