Pytorch模型保存与加载模型继续训练

2023-11-06

1. 网络模型定义与模型参数保存

定义网络模型与基本参数,以及模型训练和模型保存

使用torch.save()方法保存模型

在save_dict={}中可以保存epoch,model,optimizer,scheduler,loss等参数。

my_net = VisionTransformer()
n_epoch = 200
lr = 0.001
optimizer = optim.SGD(my_net.parameters(), lr=lr, momentum=0.9, weight_decay=1e-6)
scheduler = lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=n_epoch, eta_min=lr / 100)
loss_classification = torch.nn.CrossEntropyLoss()

if cuda:
    my_net = my_net.cuda()
    loss_classification = loss_classification.cuda()

for p in my_net.parameters():
    p.requires_grad = True
bestacc = 0.0
savepth = 'mySavepthPath'
for epoch in range(n_epoch):
    my_net.train()
    ....
    if acc > bestacc:
        save_dict = {
            'epoch': epoch,
            'model': my_net.state_dict(),
            'optimizer': optimizer.state_dict()
        }
        torch.save(save_dict, savepth + '.pth')

2. 加载模型继续训练

使用torch.load加载模型,完整代码如下。

要注意的是,要先定义模型和优化器optimizer,把模型放到gpu上,然后再加载模型。
否则执行optimizer.step()时会出现下面这个错误。
Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! 
my_net = VisionTransformer()
n_epoch = 200
lr = 0.001
optimizer = optim.SGD(my_net.parameters(), lr=lr, momentum=0.9, weight_decay=1e-6)
scheduler = lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=n_epoch, eta_min=lr / 100)
loss_classification = torch.nn.CrossEntropyLoss()

if cuda:
    my_net = my_net.cuda()
    loss_classification = loss_classification.cuda()

Resume = True
start_epoch = -1
if Resume:
    path_checkpoint = 'mySavepthPath.pth'
    checkpoint = torch.load(path_checkpoint, map_location=torch.device('cuda'))
    my_net.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    start_epoch = checkpoint['epoch']
    print("start_epoch:", start_epoch)
    print('-----------------------------')


for p in my_net.parameters():
    p.requires_grad = True

bestacc = 0.0
savepth = 'mySavepthPath'

new_start = 0 if start_epoch == -1 else start_epoch
for epoch in range(start_epoch + 1, new_start+n_epoch):
    my_net.train()
    ....
    if acc > bestacc:
        save_dict = {
            'epoch': epoch,
            'model': my_net.state_dict(),
            'optimizer': optimizer.state_dict()
        }
        torch.save(save_dict, savepth + '.pth')

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Pytorch模型保存与加载模型继续训练 的相关文章

随机推荐

  • 数据规约

    主成分的计算步骤 主成分的代码实现 设置工作空间 把 数据及程序 文件夹拷贝到F盘下 再用setwd设置工作空间 setwd F 数据及程序 chapter4 示例程序 数据读取 inputfile lt read csv data pri
  • z系列主板能装服务器系统吗,Intel Z390主板搭配8代酷睿现身:还能安装WIN7系统吗?...

    Intel今年为发烧友带来了最多18核心的Core X系列 搭配X299顶级主板 主流领域则有最多6核心的八代酷睿Coffee Lake S 搭配Z370主板 但坑爹的是 尽管八代和六七代酷睿都是LGA1151接口 但却被故意整成不兼容 因
  • 基于power bi上手业务数据可视化

    分析背景 偶然得到一份关于某连锁火锅品牌在2020年1月 8月的线上平台业务数据 如下图 心想正好利用这份数据 模拟实际业务中基于数据库与bi工具 实践开发可视化图表 一开始考虑用tableau 因为在大学跟刚工作的时候曾系统学习使用过 但
  • 栈的链性表的c语言实现方式 linkstack.h 和 linkstack.c

    linkstack h 文件 ifndef LINK STACK H define LINK STACK H include
  • 数据库 关系代数 投影概念理解

    关系R上的投影是从R 中选择出若干属性列组成新的关系 记作 A R t A t R 其中A 为R 中的属性列 投影操作是从列的角度进行的运算 例3 查询学生的姓名和所在系 即求Student关系在学生姓名和所在系两个属性上的投影 Sname
  • k8s集群新增节点

    如何动态的为k8s集群增加worknode节点 本文将详细介绍 kubeadm搭建k8集群详见 https blog csdn net wangqiubo2010 article details 101203625 一 VMWare xSp
  • 每日算法题(Day5)----取石子

    题目描述 有一种有趣的游戏 玩法如下 玩家 2 人 道具 N 颗石子 规则 游戏双方轮流取石子 每人每次取走若干颗石子 最少取 1 颗 最多取 K 颗 石子取光 则游戏结束 最后取石子的一方为胜 假如参与游戏的玩家都非常聪明 问最后谁会获胜
  • Linux Kafka 2.11-1.1.1 安装搭建

    Kafka是最初由Linkedin公司开发 是一个分布式 支持分区的 partition 多副本的 replica 基于zookeeper协调的分布式消息系统 它的最大的特性就是可以实时的处理大量数据以满足各种需求场景 比如基于hadoop
  • iframe无边框实现

  • Android 11 绕过反射限制

    1 问题出现的背景 腾讯视频在集成我们 replay sdk 的时候发现这么个错误 导致整个 db mock 功能完全失效 Accessing hidden field Landroid database sqlite SQLiteCurs
  • LeetCode1477-找两个和为目标值且不重叠的子数组

    给你一个整数数组 arr 和一个整数值 target 请你在 arr 中找 两个互不重叠的子数组 且它们的和都等于 target 可能会有多种方案 请你返回满足要求的两个子数组长度和的 最小值 请返回满足要求的最小长度和 如果无法找到这样的
  • 餐馆点餐系统(Java GUI + mysql)

    餐馆点餐系统 Java GUI mysql 开发环境 eclipse mysql 开发语言 Java SQL 本系统采用MVC模式开发的 果冻点餐系统 适合Java初级选手学习 本系统实现了用户注册登录 点餐 商家管理订单等一系列功能 首先
  • crc32碰撞_hash碰撞的概率和可能性比你直觉中大得多

    注 这篇文章源自我10年前写的博客 今天看到有人谈密码安全的 再发一遍和大家讨论下 我发现哪怕10年后 这文章也没过时 很多人还是没拎清 冲突概率和样本空间的关系 前段时间跟某大牛叽歪的时候 被提到我写的一篇文章 用CRC32实现短网址的一
  • 基于Spring Boot的酒店客房管理系统

    文章目录 项目介绍 主要功能截图 后台 前台 部分代码展示 设计总结 项目获取方式 作者主页 超级无敌暴龙战士塔塔开 简介 Java领域优质创作者 简历模板 学习资料 面试题库 关注我 都给你 文末获取源码联系 项目介绍 基于Spring
  • 奇偶校验c语言ascii,奇偶校验(parity check)

    parity check 奇偶校验 N a check made of computer data to ensure that the total number of bits of value 1 or 0 in each unit o
  • 查看Linux的用户权限(转载)

    转 Linux查看用户及其权限管理 查看用户 请打开终端 输入命令 who am i 或者 who mom likes 输出的第一列表示打开当前伪终端的用户的用户名 要查看当前登录用户的用户名 去掉空格直接使用 whoami 即可 第二列的
  • ASP.NET MVC - Model Binding

    Http Request 到Input Model的绑定按照model的类型可分为四种情况 Primitive type Collection of primitive type Complex type Collection of com
  • ROC曲线-阈值评价标准

    ROC曲线指受试者工作特征曲线 接收器操作特性曲线 receiver operating characteristic curve 是反映敏感性和特异性连续变量的综合指标 是用构图法揭示敏感性和特异性的相互关系 它通过将连续变量设定出多个不
  • UE4导入3dmax模型并在场景中添加第三人称角色

    1 3dmax安装Datasmith插件 插件下载位置 https www unrealengine com zh CN datasmith plugins 2 3dmax导出模型 3 UE4导入模型 从3dmax导出datasmith的格
  • Pytorch模型保存与加载模型继续训练

    1 网络模型定义与模型参数保存 定义网络模型与基本参数 以及模型训练和模型保存 使用torch save 方法保存模型 在save dict 中可以保存epoch model optimizer scheduler loss等参数 my n