《动手学深度学习 Pytorch版》 3.7 softmax回归的简单实现

2023-10-28

import torch
from torch import nn
from d2l import torch as d2l

batch_size = 256  # 保持批量大小为 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)  # 仍使用Fashion-MNIST数据集

3.7.1 初始化模型参数

net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))  # flatten负责调整网络输入形状,添加一个有 10 个输出的全连接层

def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)  # 使用正态分布中的随机值初始化权重

net.apply(init_weights)
Sequential(
  (0): Flatten(start_dim=1, end_dim=-1)
  (1): Linear(in_features=784, out_features=10, bias=True)
)

3.7.2 重新审视 softmax 的实现

从计算的角度来讲,指数可能会造成数值稳定性问题,即可能会发生溢出。解决上溢问题可在 softmax 运算前先从所有 o k o_k ok 中减去 max ⁡ ( o k ) \max(o_k) max(ok)。但如若有些 o j − max ⁡ ( o k ) o_j-\max(o_k) ojmax(ok) 为过小的负值时则又可能发生下溢,此时可以尽量避免计算 exp ⁡ ( o j − max ⁡ ( o k ) ) \exp(o_j-\max(o_k)) exp(ojmax(ok))
log ⁡ ( y ^ i ) = log ⁡ ( exp ⁡ ( o j − max ⁡ ( o k ) ) ∑ k exp ⁡ ( o k − max ⁡ ( o k ) ) ) = log ⁡ ( exp ⁡ ( o j − max ⁡ ( o k ) ) ) − log ⁡ ( ∑ k exp ⁡ ( o k − max ⁡ ( o k ) ) ) = o j − max ⁡ ( o k ) − log ⁡ ( ∑ k exp ⁡ ( o k − max ⁡ ( o k ) ) ) \begin{align} \log(\hat{y}_i)&=\log\left(\frac{\exp(o_j-\max(o_k))}{\sum_k\exp(o_k-\max(o_k))}\right)\\ &=\log(\exp(o_j-\max(o_k)))-\log\left(\sum_k\exp(o_k-\max(o_k))\right)\\ &=o_j-\max(o_k)-\log\left(\sum_k\exp(o_k-\max(o_k))\right) \end{align} log(y^i)=log(kexp(okmax(ok))exp(ojmax(ok)))=log(exp(ojmax(ok)))log(kexp(okmax(ok)))=ojmax(ok)log(kexp(okmax(ok)))

loss = nn.CrossEntropyLoss(reduction='none')

3.7.3 优化算法

trainer = torch.optim.SGD(net.parameters(), lr=0.1)  # 使用学习率为 0.1 的小批量随机梯度下降作为优化算法

3.7.4 训练

num_epochs = 10
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)  # 调用上节定义的训练函数来训练模型


在这里插入图片描述

练习

(1)尝试调增超参数,例如批量大小、轮数和学习率,并查看结果。

batch_size2 = 1024  # 将 batch_size 提高到 1024
train_iter2, test_iter2 = d2l.load_data_fashion_mnist(batch_size2)
net2 = nn.Sequential(nn.Flatten(),nn.Linear(784,10))
net2.apply(init_weights)

num_epochs = 10
trainer = torch.optim.SGD(net2.parameters(), lr=0.1)
d2l.train_ch3(net2, train_iter2, test_iter2, loss, num_epochs, trainer)  # 提高 batch size 会使 train loss 提高
---------------------------------------------------------------------------

AssertionError                            Traceback (most recent call last)

Cell In[6], line 8
      6 num_epochs = 10
      7 trainer = torch.optim.SGD(net2.parameters(), lr=0.1)
----> 8 d2l.train_ch3(net2, train_iter2, test_iter2, loss, num_epochs, trainer)


File c:\Software\Miniconda3\envs\d2l\lib\site-packages\d2l\torch.py:340, in train_ch3(net, train_iter, test_iter, loss, num_epochs, updater)
    338     animator.add(epoch + 1, train_metrics + (test_acc,))
    339 train_loss, train_acc = train_metrics
--> 340 assert train_loss < 0.5, train_loss
    341 assert train_acc <= 1 and train_acc > 0.7, train_acc
    342 assert test_acc <= 1 and test_acc > 0.7, test_acc


AssertionError: 0.5225925379435221

在这里插入图片描述

batch_size3 = 256
train_iter3, test_iter3 = d2l.load_data_fashion_mnist(batch_size3)
net3 = nn.Sequential(nn.Flatten(),nn.Linear(784,10))
net3.apply(init_weights)

num_epochs = 40  # 将轮数提高到 40
trainer = torch.optim.SGD(net3.parameters(), lr=0.1)
d2l.train_ch3(net3, train_iter3, test_iter3, loss, num_epochs, trainer)  # 提高轮数会使 test acc 突然下降,应该是过拟合了。

在这里插入图片描述

batch_size4 = 256
train_iter4, test_iter4 = d2l.load_data_fashion_mnist(batch_size4)
net4 = nn.Sequential(nn.Flatten(),nn.Linear(784,10))
net4.apply(init_weights)

num_epochs = 10
trainer = torch.optim.SGD(net4.parameters(), lr=0.4)  # 将学习率提高到 0.4
d2l.train_ch3(net4, train_iter4, test_iter4, loss, num_epochs, trainer)  # 提高学习率会使 test acc 极不稳定,train loss 无法收敛。
---------------------------------------------------------------------------

AssertionError                            Traceback (most recent call last)

Cell In[9], line 8
      6 num_epochs = 10
      7 trainer = torch.optim.SGD(net4.parameters(), lr=0.4)  # 将学习率提高到 0.4
----> 8 d2l.train_ch3(net4, train_iter4, test_iter4, loss, num_epochs, trainer)


File c:\Software\Miniconda3\envs\d2l\lib\site-packages\d2l\torch.py:340, in train_ch3(net, train_iter, test_iter, loss, num_epochs, updater)
    338     animator.add(epoch + 1, train_metrics + (test_acc,))
    339 train_loss, train_acc = train_metrics
--> 340 assert train_loss < 0.5, train_loss
    341 assert train_acc <= 1 and train_acc > 0.7, train_acc
    342 assert test_acc <= 1 and test_acc > 0.7, test_acc


AssertionError: 0.6684705724080404

在这里插入图片描述


(2)增加轮数,为什么测试精度会在一段时间后降低?我们如何解决这个问题?

因为发生了过拟合。可以增加样本数。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

《动手学深度学习 Pytorch版》 3.7 softmax回归的简单实现 的相关文章

随机推荐

  • Unity使用c#脚本代码编写基于AudioSource的音频淡入淡出渐变效果

    需求分析与类设计 编写能够通过一个函数调节音乐淡入淡出效果的类 使用时将脚本挂载在AudioSource的物体上 其他脚本的物体能够访问物体上的这个类进行音量调节 需要导入的外部成员变量 1 本物体的AudioSource 2 目标音量 实
  • Lua模拟C#的类继承

    写Lua的都知道Lua语言本身是不提供类继承这个概念的 但是我们可以根据Lua提供的设置元方法的特性来模拟类的建立 以下是我写的一个模拟C 类继承的Lua方法 即只能继承一个父类 但可以继承多个接口 我这个模拟构造类时 父类一定要放在第一个
  • 【22】CSS核心样式(3)——盒模型5种属性

    盒模型又叫框模型 包含了五个用来描述盒子位置 尺寸的属性 分别是宽度 width 高度 height 内边距 padding 边框 border 外边距 margin 为了更好理解 如下生活中的举例 常见盒模型区域 盒模型的属性中 根据不同
  • 数据分析流程

    数据分析流程 1 明确分析目的与框架 2 数据收集 3 数据处理 4 数据分析 5 数据展现 6 撰写报告 数据分析流程概括起来主要包括明确分析目的与框架 数据收集 数据处理 数据分析 数据展现和撰写报告6个阶段 1 明确分析目的与框架 明
  • AI厂工什么时候开始赛博搬砖?

    最近两个月 二次元们找到了AI的 正确用法 玩梗 以造梗最多的NovelAI为例 无论你投喂什么图片 AI都能二次元化 输出精美中不失离谱的图片 你猜它们的原图是什么 这只是大量AI作画正面案例里的一个少数 最近两个月 AI作画带着大量梗图
  • 二十四. Kubernetes 安全

    目录 一 一 官方文档 k8s中不管是外部通过ui管理端操作 还是通过命令行 再或者集群内部执行的操作指令 所有指令都会发送给ApiServer 即使是pod也会被集群认为是一个用户 会给这个用户颁发一个ServiceAccount服务账号
  • STM32 BootLoader跳转之前关闭全部中断

    关闭全局中断 DISABLE INT 关闭滴答定时器 复位到默认值 SysTick gt CTRL 0 SysTick gt LOAD 0 SysTick gt VAL 0 设置所有时钟到默认状态 使用HSI时钟 HAL RCC DeIni
  • CSS自己实现一个步骤条

    前言 步骤条是一种用于引导用户按照特定流程完成任务的导航条 在各种分步表单交互场景中广泛应用 例如 在HIS系统 门诊医生站中的接诊场景中 我们就可以使用步骤条来实现 她的执行步骤分别是 门诊病历 gt 遗嘱录入 gt 完成接诊 我们发现
  • 华为OD机试真题 Java 实现【货币单位换算】【2023Q1 100分】

    一 题目描述 记账本上记录了若干条多国货币金额 需要转换成人民币分 fen 汇总后输出每行记录一条金额 金额带有货币单位 格式为数字 单位 可能是单独元 或者单独分 或者元与分的组合要求将这些货币全部换算成人民币分 fen 后进行汇总 汇总
  • 使用docker进行部署hadoop

    使用docker进行部署hadoop 安装docker wget qO https get docker com sh 安装完成后 要启动docker服务 sudo service docker start 查看是否运行成功 ps aux
  • C++——函数指针

    在C 中 函数指针是指向函数的指针变量 它允许将函数作为参数传递给其他函数 动态选择调用的函数以及在运行时改变函数的行为 函数指针的声明和使用如下所示 1 声明函数指针类型 returnType pointerName parameterT
  • 我的一路走来@电子信息工程和嵌入式该怎么入门

    嵌入式该怎么学 嵌入式从何学起 嵌入式入门需不需要报培训机构 哪个培训机构好点 还有一些是咨询电子信息工程专业的情况等等 这些问题几乎每天都在我的 嵌入式的世界 百度知道团队会遇到和看到的一些问题 归根结底是咨询嵌入式该如何入门 电子信息工
  • js中通过window.location.href和document.location.href、document.URL获取当前浏览器的地址的值,它们的的区别

    1 document表示的是一个文档对象 window表示的是一个窗口对象 一个窗口下可以有多个文档对象 所以一个窗口下只有一个window location href 但是可能有多个document URL document locati
  • HTML+CSS字体文本

    声明 本人的所有博客皆为个人笔记 作为个人知识索引使用 因此在叙述上存在逻辑不通顺 跨度大等问题 希望理解 分享出来仅供大家学习翻阅 若有错误希望指出 感谢 HTML文本标签 文本级语义标签包括 a 超连接 em 侧重点的强调 可嵌套 表现
  • Laya实现控制杆控制3D模型旋转

    export default class JoyStick constructor mod this model mod 模型 this scale Laya Browser width 1920 this rockerBtnOrigin
  • uniapp App权限判断和提示

    1 下载组件App权限判断和提示 DCloud 插件市场 2 导出到需要判断的项目里面 import permision from js sdk wa permission permission js 3 判断是否开启权限 async re
  • 图书馆管理系统 Java

    目录 要求 代码 Operate接口 Book类 Reader类 BookList类 ReadList 类 Infor类 InforList类 main 功能实现 改进 错误 总结 要求 为图书管理人员编写一个图书管理系统 图书管理系统的设
  • 分布式接口幂等性设计实现

    面对分布式架构和微服务复杂的系统架构和网络超时服务器异常等带来的系统稳定性问题 分布式接口的幂等性设计显得尤为重要 本文简要介绍了几种分布式接口幂等性设计实现 包括Token去重机制 乐观锁机制 数据库主键和状态机实现等 以加深理解 1 分
  • WARNING: REMOTE HOST IDENTIFICATION HAS CHANGED!

    使用SCP命令时出现这个错误 解决办法 rm ssh known hosts
  • 《动手学深度学习 Pytorch版》 3.7 softmax回归的简单实现

    import torch from torch import nn from d2l import torch as d2l batch size 256 保持批量大小为 256 train iter test iter d2l load