pytorch中网络loss传播和参数更新理解

2023-11-10

相比于2018年,在ICLR2019提交论文中,提及不同框架的论文数量发生了极大变化,网友发现,提及tensorflow的论文数量从2018年的228篇略微提升到了266篇,keras从42提升到56,但是pytorch的数量从87篇提升到了252篇。

TensorFlow: 228--->266

Keras: 42--->56

Pytorch: 87--->252


在使用pytorch中,自己有一些思考,如下:

1. loss计算和反向传播

import torch.nn as nn

criterion = nn.MSELoss().cuda()

output = model(input)

loss = criterion(output, target)
loss.backward()

通过定义损失函数:criterion,然后通过计算网络真实输出和真实标签之间的误差,得到网络的损失值:loss;

最后通过loss.backward()完成误差的反向传播,通过pytorch的内在机制完成自动求导得到每个参数的梯度。

需要注意,在机器学习或者深度学习中,我们需要通过修改参数使得损失函数最小化或最大化,一般是通过梯度进行网络模型的参数更新,通过loss的计算和误差反向传播,我们得到网络中,每个参数的梯度值,后面我们再通过优化算法进行网络参数优化更新。

2. 网络参数更新

在更新网络参数时,我们需要选择一种调整模型参数更新的策略,即优化算法。

优化算法中,简单的有一阶优化算法:

                                                         \theta =\theta -\eta \times \frac{\partial \jmath \left ( \theta \right )}{\partial \theta }

其中\eta就是通常说的学习率,\frac{\partial \jmath \left ( \theta \right )}{\partial \theta }是函数的梯度;

自己的理解是,对于复杂的优化算法,基本原理也是这样的,不过计算更加复杂。

在pytorch中,torch.optim是一个实现各种优化算法的包,可以直接通过这个包进行调用。

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

注意:1)在前面部分1中,已经通过loss的反向传播得到了每个参数的梯度,然后再本部分通过定义优化器(优化算法),确定了网络更新的方式,在上述代码中,我们将模型的需要更新的参数传入优化器。

           2)注意优化器,即optimizer中,传入的模型更新的参数,对于网络中有多个模型的网络,我们可以选择需要更新的网络参数进行输入即可,上述代码,只会更新model中的模型参数。对于需要更新多个模型的参数的情况,可以参考以下代码:

optimizer = torch.optim.Adam([{'params': model.parameters()}, {'params': gru.parameters()}], lr=0.01)

           3) 在优化前需要先将梯度归零,即optimizer.zeros()。

3. loss计算和参数更新

import torch.nn as nn
import torch

criterion = nn.MSELoss().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

output = model(input)

loss = criterion(output, target)

​optimizer.zero_grad()  # 将所有参数的梯度都置零
loss.backward()        # 误差反向传播计算参数梯度
optimizer.step()       # 通过梯度做一步参数更新


​
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

pytorch中网络loss传播和参数更新理解 的相关文章

随机推荐

  • QKL123区块链排行榜(2019年04月)

    QKL123区块链排行榜包括区块链项目 区块链交易平台 区块链媒体 区块链公众号 区块链矿机 区块链矿池 EOS Dapp ETH Dapp 区块链钱包九大榜单 目前 区块链项目榜单选取的客观指标包括流通市值 GitHub提交数 区块链交易
  • 嵌入式stm32基础项目开发:心率检测仪的设计与实现

    嵌入式stm32基础项目开发 心率检测仪的设计与实现 本教程主要给大家谅解了嵌入式stm32开发 心率检测仪的设计与实现 需要的朋友们可以下载来看看 作为参考 项目描述 通过心律传感器采集我们的心律数据 然后通过串口传送到上位机中 上位机用
  • 这篇文章教大家怎么生成ai图片

    在数字化时代 人工智能技术的发展正在改变我们的生活方式 其中之一就是在艺术领域的应用 ai绘画是人工智能技术在艺术领域的一种应用 它可以自动创作出各种各样的图片 为艺术家和设计师提供了更加便捷和高效的绘画工具 ai绘画的出现 不仅可以缩短绘
  • 素数环(回溯算法)

    回溯算法 在包含问题的所有可能解的解空间树中 从根节点出发 按照深度优先遍历的策略进行搜索 对于解空间树种的某个节点 如果该节点满足问题的约束条件 则进入该子树继续进行搜索 否则将以该节点为根节点的子树进行剪枝 回溯法常常可以避免所有的可能
  • layui table.js表格一直返回数据异常

    1 排查数据是否已经正常返回 2 layui table 返回格式默认不能自定义的 返回的分页json格式需要和table js中规定的返回键一致 如下 3 经过测试 其实最重要的是code需要和上图中statusName后的resultC
  • Cisco 路由器VOIP 配置解析

    在企业网络中推广 IP 语音技术有很多优点 例如可以控制数据流量 保证语音质量 充分利用企业租用的数据线路资源 节省传统的长途话费等等 企业使用 IP 语音技术 可以将语音 数据和多媒体通信融合在一个集成的网络中 并在一个企业解决方案中 把
  • 简易版的飞机大战(C语言)

    一 只会发射激光 画质不清晰的飞机大战 游戏的总体结构根据C语言的循环制作的 本来还想说点什么但是注释里面都有 代码 include
  • ansys18安装以后打不开_ansys18.0安装过程及常见问题解决方案【图文】

    1 首先打开ansys18 0安装文件夹 一般情况下通过网络渠道下载的ansys18 0安装包会有四个文件夹 crack文件夹为授权配置文件夹 disk1 disk2 disk3文件夹为安装程序包 我们首先打开disk1文件夹 双击setu
  • 物联网LoRa系列-31:通过LoRa终端实现远程抄表的原理与系统框架(水、电、气、热等通用)

    LoRa终端远程抄表的系统架构图 抄表系统由 无线电表 线集中器 业务数据中心组成 1 无线电表 又称为LoRa终端 内嵌LoRa模块 进行数据的采集 并LoRa WAN协议实现远程数据的传输 LoRa智能终端能将传统水表 电表等读数通过电
  • 可迭代(iterable)和类数组(array-like)

    可迭代 iterable 和类数组 array like 可迭代 iterable 是实现了 Symbol iterator 方法的对象 可以应用 for of 的对象被称为 可迭代的 类数组 array like 是有索引和 length
  • Redis主从复制的原理

    更多内容 欢迎关注微信公众号 全菜工程师小辉 公众号回复关键词 领取免费学习资料 在Redis集群中 让若干个Redis服务器去复制另一个Redis服务器 我们定义被复制的服务器为主服务器 master 而对主服务器进行复制的服务器则被称为
  • pyautogui库的使用教程(超详细)

    一 前言 PyAutoGUI 让您的 Python 脚本控制鼠标和键盘以自动与其他应用程序交互 官方文档 PyAutoGUI documentation 常用函数列表 函数名 功能 基本 pyautogui size 返回包含分辨率的元组
  • 在编辑操作时,el-select多选下拉组件,选中label标签后,框中无法回显选中的label,,,

    1 问题描述 在编辑操作时 页面的el select多选下拉组件 在选择新的label标签时 change事件和监听数组对象都能确定数据已发生改变 ngmodel绑定就是最新的id集合 但就是框中不显示最新选中的label 而change事
  • 论文导读

    图的最大独立集问题 MIS problem 是图论研究中的一个重要问题 具有广泛的应用 本文介绍了最大独立集求解相关的三篇工作 包括一篇启发式方法和两篇基于学习的方法 希望能让大家对这个问题有所了解 问题定义 一个图G V E 的顶点集子集
  • 放弃手写代码吧!用低代码你能生成各种源码

    很多同学不知道为什么要用Low code做开发 传统IT开发不行么 当然可以 传统IT自研软件开发 通过编程去写代码 还有数据库 API 第三方基础架构等 这个方式很好 但不可避免的会带来开发周期长 难度大 技术人员不易开发维护 因此价格及
  • EDUCODER---WEB__JavaScript学习手册十:正则表达式

    第一关 字符串字面量 请在此处编写代码 Begin var pattern js n End 第二关 字符类 请在此处编写代码 Begin var pattern1 a zA Z 0 9 var pattern2 A 0 9 End 第三关
  • Linux下Python环境安装与部署

    因为我是Python零基础 所以如何部署全靠百度 这边我把我查到的资料和安装使用过程中遇到写下来 如果有写的不对的或者有更好的方式 欢迎评论指出 一 Python环境安装 网上有很多安装教程 可以自行百度安装 我参考的是这个 仅第一步安装p
  • The Lost House【树形DP+期望+构造路径】

    题目链接 POJ 2057 题意 有一棵N的点的树 开始的时候蜗牛在1号结点 它不知道它的家在哪个叶子结点 树上的有些结点有虫虫 虫虫会告诉你 你的家是否在以它所在结点为根的子树上 现在需要你规划走的方案 使得找到哪个叶子结点才是家的所走路
  • python将word表格转写入excel

    Notes 想将一份 word 文件中的几个表格转写入 excel 文件中 后续用 excel 处理 用到 python docx 和 pandas 分别处理 word 和 excel 安装 python docx pip install
  • pytorch中网络loss传播和参数更新理解

    相比于2018年 在ICLR2019提交论文中 提及不同框架的论文数量发生了极大变化 网友发现 提及tensorflow的论文数量从2018年的228篇略微提升到了266篇 keras从42提升到56 但是pytorch的数量从87篇提升到