BP神经网络(输出层采用Softmax激活函数、交叉熵损失函数)公式推导

2023-10-28

本篇博客主要介绍经典的三层BP神经网络的基本结构及反向传播算法的公式推导。

我们首先假设有四类样本,每个样本有三类特征,并且我们在输出层与隐藏层加上一个偏置单元。这样的话,我们可以得到以下经典的三层BP网络结构:

这里写图片描述

当我们构建BP神经网络的时候,一般是有两个步骤,第一是正向传播(也叫做前向传播),第二是反向传播(也就是误差的反向传播)。

Step1 正向传播
在正向传播之前,可以先给W,b赋初始值,最好不要全设置为0,不然后面会出现问题。赋完初值后,下面开始正向传播:

neth1=W11i1+W12i2+W13i3+b1 n e t h 1 = W 11 ∗ i 1 + W 12 ∗ i 2 + W 13 ∗ i 3 + b 1
Outh1=11+eneth1 O u t h 1 = 1 1 + e − n e t h 1 ——>激活函数为sigmoid函数: y=11+ex y = 1 1 + e − x

隐含层到输出层:
netO1=W11h1+W12h2+W13h3+W14h4+b1=4j=1W1jOuthj+b1 n e t O 1 = W 11 ′ ∗ h 1 + W 12 ′ ∗ h 2 + W 13 ′ ∗ h 3 + W 14 ′ ∗ h 4 + b 1 ′ = ∑ j = 1 4 W 1 j ′ ∗ O u t h j + b 1 ′
OutO1=enetO1enetO1+enetO2+enetO3+enetO4 O u t O 1 = e n e t O 1 e n e t O 1 + e n e t O 2 + e n e t O 3 + e n e t O 4
——>激活函数为Softmax型、常用于多分类问题

到这里我们已经完成了正向传播,这里我之所以没给出向量的形式,是因为我觉得标量的形式容易理解。下面我们开始反向传播。

Step2 反向传播
1.计算总误差
这里我们采用交叉熵损失函数,我看到网上大部分都是采用均方误差形式的损失函数,使用交叉熵的相对较少,且使用交叉熵损失函数具有较多优点。

Etotal=4i=1targetOilnOutOi E t o t a l = − ∑ i = 1 4 t a r g e t O i ∗ l n O u t O i
这里其实只有一项,因为 targetOi t a r g e t O i 无论在何时都只有一项为1,也就是说,要么第一类,此时 Etotal=targetO1lnOutO1 E t o t a l = − t a r g e t O 1 ∗ l n O u t O 1 . 要么第二类,此时 Etotal=targetO2lnOutO2 E t o t a l = − t a r g e t O 2 ∗ l n O u t O 2 ,要么是第三类,第四类等情况。

2.隐含层到输出层的权值更新
W11 W 11 ′ 为例,我们想知道 W11 W 11 ′ 对整体误差产生了多少影响,可用整体误差对 W11 W 11 ′ 求偏导得出。
EtotalW11=EtotalOutO1OutO1netO1netO1W11 ∂ E t o t a l ∂ W 11 ′ = ∂ E t o t a l ∂ O u t O 1 ∗ ∂ O u t O 1 ∂ n e t O 1 ∗ ∂ n e t O 1 ∂ W 11 ′

下面我们依次来计算每个式子(不要着急,一步一步算):
EtotalOutO1=targetO11OutO1 ∂ E t o t a l ∂ O u t O 1 = − t a r g e t O 1 ∗ 1 O u t O 1
OutO1netO1=OutO1(1OutO1) ∂ O u t O 1 ∂ n e t O 1 = O u t O 1 ∗ ( 1 − O u t O 1 )
netO1netW11=Outh1 ∂ n e t O 1 ∂ n e t W 11 ′ = O u t h 1

然后将三者相乘,就可以了嘛??
答案是否定的,之前我也是这么推导的,结果在迭代时发现,权值一直在增大,后来经过很长时间的分析才发现,原来这里的 EtotalW11 ∂ E t o t a l ∂ W 11 ′ 我求错了。

问题出在哪里呢?

是因为采用了交叉熵的损失函数,在更新 W11 W 11 ′ 时,误差不仅仅来自于 O1 O 1 ,还与其他所有的输出层的节点有关系。咋一看非常不可思议,但是仔细一想,你会发现因为在计算 OutOi O u t O i 是,分母中e的指数涉及到了其他所有的神经元的输出,即 netO2 n e t O 2 netO3 n e t O 3 等。
所以,我们对 W11 W 11 ′ 的偏导就应该是:
EtotalW11=EtotalOutO1OutO1netO1netO1W11+EtotalOutO2OutO2netO1netO1W11+EtotalOutO3OutO3netO1netO1W11+EtotalOutO4OutO4netO1netO1W11 ∂ E t o t a l ∂ W 11 ′ = ∂ E t o t a l ∂ O u t O 1 ∗ ∂ O u t O 1 ∂ n e t O 1 ∗ ∂ n e t O 1 ∂ W 11 ′ + ∂ E t o t a l ∂ O u t O 2 ∗ ∂ O u t O 2 ∂ n e t O 1 ∗ ∂ n e t O 1 ∂ W 11 ′ + ∂ E t o t a l ∂ O u t O 3 ∗ ∂ O u t O 3 ∂ n e t O 1 ∗ ∂ n e t O 1 ∂ W 11 ′ + ∂ E t o t a l ∂ O u t O 4 ∗ ∂ O u t O 4 ∂ n e t O 1 ∗ ∂ n e t O 1 ∂ W 11 ′

因为我这里有四类,所以显得式子很长,但是经过化简,可以得到以下式子:

EtotalW11=(OutO1targetO1)Outh1 ∂ E t o t a l ∂ W 11 ′ = ( O u t O 1 − t a r g e t O 1 ) ∗ O u t h 1

我们可以令 OutO1targetO1 O u t O 1 − t a r g e t O 1 δO1 δ O 1 ,意思是在该神经元输出点的误差值,那么我们就可以很容易的得到权值的更新公式:
W11=W11ηδO1Outh1 W 11 ′ = W 11 ′ − η δ O 1 O u t h 1

同理,我们可以得到偏置的更新公式:
b1=b1ηδO1 b 1 ′ = b 1 ′ − η δ O 1

其中, η η 表示学习率,这是一个可以自己调节的变量,看自己的数据分布情况,可设为0.1、0.05等等。

从输入层到隐含层的权值更新:
EtotalW11=EtotalOuth1Outh1neth1neth1W11 ∂ E t o t a l ∂ W 11 = ∂ E t o t a l ∂ O u t h 1 ∗ ∂ O u t h 1 ∂ n e t h 1 ∗ ∂ n e t h 1 ∂ W 11

其中, EtotalOuth1=EO1Outh1+EO2Outh1+EO3Outh1+EO4Outh1 ∂ E t o t a l ∂ O u t h 1 = ∂ E O 1 ∂ O u t h 1 + ∂ E O 2 ∂ O u t h 1 + ∂ E O 3 ∂ O u t h 1 + ∂ E O 4 ∂ O u t h 1 是因为输出层的每一个神经元都对隐藏层的第一个神经元有误差的传递。

由此我们可以得到:
EtotalW11=(δO1W11+δO2W12+δO3W13+δO4W14)Outh1(1Outh1)i1 ∂ E t o t a l ∂ W 11 = ( δ O 1 ∗ W 11 ′ + δ O 2 ∗ W 12 ′ + δ O 3 ∗ W 13 ′ + δ O 4 ∗ W 14 ′ ) ∗ O u t h 1 ∗ ( 1 − O u t h 1 ) ∗ i 1
我们将
(δO1W11+δO2W12+δO3W13+δO4W14)Outh1(1Outh1) ( δ O 1 ∗ W 11 ′ + δ O 2 ∗ W 12 ′ + δ O 3 ∗ W 13 ′ + δ O 4 ∗ W 14 ′ ) ∗ O u t h 1 ∗ ( 1 − O u t h 1 )
记作 δh1 δ h 1

那么,我们的输入层到隐含层的权值及偏置的更新策略为:
W11=W11ηδh1i1 W 11 = W 11 − η δ h 1 i 1

b1=b1ηδh1 b 1 = b 1 − η δ h 1

至此,我们已经将两层之间的权值和偏置都更新完了,可以根据以上写成向量或者矩阵的形式,方便后面的运算。

最后,我想说,虽然现在很多机器学习的框架用的如火如荼,但是我觉得对于刚入门的同学来说,一些基本的公式推导和证明还是要掌握的。如有错误,欢迎交流和指正。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

BP神经网络(输出层采用Softmax激活函数、交叉熵损失函数)公式推导 的相关文章

随机推荐

  • IDEA使用maven搭建java项目连接redis(图文)

    1 新建项目 2 添加依赖 对应的依赖配置项可以在https mvnrepository com artifact redis clients jedis 中找到 可以根据自己想要的版本进行配置 3 mevan下载jar 4 解决依赖包导入
  • 美国教育数据分析

    现有一份来自kaggle的美国教育相关的数据集 数据中一共有1497个样本 25个属性 我们先将这份数据的缺失值进行补充 并进行标准化 然后将这份数据中的学生数学成绩作为标签 利用其它的24个属性构建机器学习方法 来对学生成绩进行预测 通过
  • 《英语国家社会与文化入门》答案翻译

    Traslate by 东莞理工学院 14级软件2班 赖静朝 本资料仅供学习交流 不保证内容的绝对准确性 严禁使用于任何商业用途 下载地址 链接 http pan baidu com s 1nv6JhAp 密码 b3oc 如果发现错误 可以
  • 锐捷实操系列

    1 锐捷实操 Telnet方式登录 1 配置路由器以太网口的IP地址 Ruijie config interfaceinterface number 进入接口配置模式 Ruijie config if GigabitEthernet0 0
  • Blender基础操作:移动游标位置、快速设置原点

    1 游标 作用 用于物体定位 比如 1 作为出生点 新创建的物体的初始位置 2 在游标处汇合的操作 右键 吸附 选中项 游标 如何移动游标的位置 手工移动 选中游标工具 点到指定位置 不推荐 右键 吸附 游标 选中项 自动吸附 勾选磁铁 顶
  • 华为测开面试记,三面被吊打,所幸最后Offer已到手

    在互联网做了几年之后 去大厂 镀镀金 是大部分人的首选 大厂不仅待遇高 福利好 更重要的是 它是对你专业能力的背书 大厂工作背景多少会给你的简历增加几分竞争力 但说实话 想进大厂还真没那么容易 最近面试华为 结果三面被吊打 不甘心的我整理了
  • 简单使用 MySQL 索引

    MySQL 索引 1 什么是索引 在数据库表中 对字段建立索引可以大大提高查询速度 通过善用这些索引 可以令 MySQL 的查询和 运行更加高效 如果合理的设计且使用索引的 MySQL 是一辆兰博基尼的话 那么没有设计和使用索引的 MySQ
  • linux:vmware下docker容器无法ping通外部

    问题 如题 原因 可能是因为网络原因 之前用的梯子如果没有断开而直接关机 导致网卡出现异常 而vm ware也可能是同样原因 尝试解决 关闭vmware并重启 再一次进入到容器 问题解决
  • ant编译Tomcat8时报错 the archive file.tar.gz doesn't exist

    报错 testexist echo Testing for D project Tomcat 8 0 2 src share commons dbcp2 2 0 SNAPSHOT src build xml downloadgz 2 pro
  • Python tkinter 树形列表控件(Treeview)的使用方法

    1 方法 方法 描述 bbox item column None 返回指定item的框选范围 或者单元格的框选范围 column cid option None kw 设置或者查询某一列的属性 delete items 删除指定行或者节点
  • [专利与论文-11]:南京市职称申请 - 继续教育学时认定表如何填写

    2021年电子信息申报通知 中 高级 南京人力资源和社会保障学会 关于做好2020年度南京市专业技术人员继续教育工作的通知 关于做好2020年度南京市专业技术人员继续教育工作的通知 今年申报职称 关于学时 需要填写 南京市专业技术人员继续教
  • MySQL自带数据库

    文章目录 MySQL自带数据库 自带数据库介绍 1 mysql 2 information schema Server层统计信息字典表 Server层表级别对象字典表 Server层其它信息字典表 InnoDB层系统字典表 InnoDB层锁
  • tcp短连接TIME_WAIT问题解决方法大全(4)——tcp_tw_reuse

    tcp tw reuse选项的含义如下 http www kernel org doc Documentation networking ip sysctl txt tcp tw reuse BOOLEANAllow to reuse TI
  • arduino uno官方原理图_Arduino基础入门篇27—步进电机驱动库的使用

    本篇介绍步进电机驱动库的使用 通过读取电位器输入 控制步进电机转动相应角度 Stepper库是官方提供的驱动库 我们启动Arduino IDE 点击 文件 示例 就能找到Stepper库 官方提供了四个例程 关于Stepper库可参考官方介
  • vscode配置xdebug调试

    参考 vscode配置PHP调试xDebug wx61cd54ea3a202的技术博客 51CTO博客 Xdebug V3 不会停止 VSCode 中的断点 1 打印php信息 2 打开 Xdebug Support Tailored In
  • JavaFX制作餐厅管理系统(附源码)

    相信有很多同学在做毕业设计或者是提升自己的会选择做一个系统 下面从各个方面了解制作餐厅管理系统 以下均为up主个人思路 有错误的地方欢迎各路大佬指点 非常感谢 供各位同学参考学习 前言 制作思路 资料准备 功能实现 最后优化 注意 餐厅管理
  • MySQL:二、Table约束,多表联查,数据库备份、恢复

    目录 一 数据的完整性 约束 1 1 实体完整性 1 1 1 主键约束 primary key 1 1 2 唯一约束 1 1 3 自动增长列 1 2 域完整性 1 2 1 非空约束 not null 1 2 2 默认值约束 1 3 外键约束
  • 【git】LibreSSL SSL_connect: SSL_ERROR_SYSCALL in connection to github.com:443

    1 概述 今天git 拉取一个项目报错 lcc lcc IdeaProjects third git clone https github com xxxx xxxx git Cloning into xxxx
  • SM4算法设计原理

    SM4分组密码算法描述 SM4分组密码算法是一个迭代分组密码算法 由加解密算法和密钥扩展算法组成 SM4分组密码算法采用非平衡Feistel结构 分组长度为128b密钥长度为128b 加密算法与密钥扩展算法均采用非线性迭代结构 加密运算和解
  • BP神经网络(输出层采用Softmax激活函数、交叉熵损失函数)公式推导

    本篇博客主要介绍经典的三层BP神经网络的基本结构及反向传播算法的公式推导 我们首先假设有四类样本 每个样本有三类特征 并且我们在输出层与隐藏层加上一个偏置单元 这样的话 我们可以得到以下经典的三层BP网络结构 当我们构建BP神经网络的时候