使用MATLAB的trainNetwork设计一个简单的LSTM神经网络

2023-11-14


前言

借助MATLAB的deepNetworkDesigner搭一个简单的LSTM,数据集使用mnist手写数字识别数据集。


一、数据集

mnist数据集包括60000组训练数据和对应的标签,10000组测试数据和对应标签。每个数据都是一个28x28的矩阵,可以将其看做28x28像素的灰度图像(黑底白字)。而LSTM的输入应当是一个序列,我们可以把矩阵的每一行当做一帧,把图像分为28帧输入到LSTM。
数据集可以在我上传的资源里找到。

数据的格式是这样的:

在这里插入图片描述
XTrain,即训练图像,是一个60000x1的cell,cell的每一个元素是一个28x28的矩阵。矩阵的每一列为一帧。直接将矩阵以图片显示是这样的:

 imshow(cell2mat(XTrain(8)))

在这里插入图片描述
这不是某希腊字母,而是手写数字3。我们希望按行输入,而MATLAB按列读取,因此我做了个转置。再转置一下就能看到正常的图像:

 imshow(cell2mat(XTrain(8))')

在这里插入图片描述
标签的格式为:

在这里插入图片描述
可以直接通过categorical函数实现数值到categorical的转换,比如:

在这里插入图片描述

输入训练数据的方式不唯一,我用的只是其中一种,详情见MathWorks官网:trainNetwork

二、网络结构

使用一层128个隐藏节点的LSTM,一层全连接,输出使用softmax。网络的输入是一个序列,输出是标签,在MATLAB中,此网络可以这样描述:

layers = [ ...
    sequenceInputLayer(inputSize)                   %sequence输入
    lstmLayer(numHiddenUnits,'OutputMode','last')   %lstm
    fullyConnectedLayer(numClasses)                 %全连接
    softmaxLayer                                    %softmax
    classificationLayer];                           %label输出

三、测试程序

完整的测试程序如下:

clear
clc
%加载数据
load('.\mnist_data_mat\XTrain.mat')
load('.\mnist_data_mat\YTrain.mat')
load('.\mnist_data_mat\XTest.mat')
load('.\mnist_data_mat\YTest.mat')

%设置参数
inputSize = 28;         %28个输入节点
numHiddenUnits = 128;   %128个隐藏节点
numClasses = 10;        %10种分类结果

layers = [ ...
    sequenceInputLayer(inputSize)                   %sequence输入
    lstmLayer(numHiddenUnits,'OutputMode','last')   %lstm
    fullyConnectedLayer(numClasses)                 %全连接
    softmaxLayer                                    %softmax
    classificationLayer];                           %label输出

options = trainingOptions('adam', ...
    'ExecutionEnvironment','cpu', ...
    'MaxEpochs',5, ...
    'MiniBatchSize',60, ...
    'GradientThreshold',1, ...
    'Verbose',false, ...
    'Plots','training-progress');

net=trainNetwork(XTrain,YTrain,layers, options);    %训练

Y_pred = classify(net, XTest);                      %测试
accy = sum(Y_pred == YTest) / length(YTest);        %计算准确度

准确度为97.73%
options里的参数可以修改一下,我用同样结构的网络不同的参数做出了98.74%的准确度,仍有提升空间。这里为了节省训练时间牺牲了一些精度。
训练好的网络也上传到了资源里。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

使用MATLAB的trainNetwork设计一个简单的LSTM神经网络 的相关文章

  • 按元素出现的频率对数组元素进行排序

    是否可以在 matlab octave 中使用sort函数根据元素的相对频率对数组进行排序 例如数组 m 4 4 4 10 10 10 4 4 5 应该产生这个数组 5 10 10 10 4 4 4 4 4 5是出现频率较低的元素 位于顶部
  • 检查Matlab中脚本需要使用的函数

    我有一个别人写的代码包 我正在运行一个脚本 它调用一些函数 这些函数又调用更多函数 等等 我想获取不是 MATLAB 内置函数但属于包的一部分的函数列表 我尝试使用matlab codetools requiredFilesAndProdu
  • 使用mat2cell将MxN的矩阵划分为1xN大小的M矩阵

    我有一个大小为 MxN 的矩阵 比方说 1867x3 1867 行和 3 列 我想将其分成 1867 个大小为 1x3 的单元格 我使用了mat2cell X 1 1866 这里X是矩阵 1867x3 结果给出了两个单元格 一个单元格的大小
  • Simulink 仿真引擎如何工作?

    我想了解 Simulink 仿真引擎的工作原理 它是否使用离散事件模拟机制 那么如何处理连续时间 它是否依赖于基于静态循环的代码生成 或者 在第一个周期之前 它会计算出块的执行顺序 从不需要任何其他块输入的块开始 每个周期 它都会根据输入和
  • Matlab没有优化以下内容吗?

    我有一个很长的向量 1xrv 和一个很长的向量w1xs 和一个矩阵Arxs 它是稀疏的 但维度非常大 我期望 Matlab 对以下内容进行优化 这样我就不会遇到内存问题 A v w 但看起来 Matlab 实际上是在尝试生成完整的v w矩阵
  • 有没有办法在matlab中进行隐式微分

    我经常使用 matlab 来帮助我解决数学问题 现在我正在寻找一种在 matlab 中进行隐式微分的方法 例如 我想区分y 3 sin x cos y exp x 0关于dy dx 我知道如何使用数学方法通常做到这一点 但我一直在努力寻找使
  • 我需要转义该 MATLAB 字符串中的字符吗?

    我想在 MATLAB 中调用以下 bash 命令 grep Up to test linux vision1 1 log awk print 7 I use system 在MATLAB中 但结果有错误 gt gt status strin
  • matlab中类库的全局变量

    我有一些matlab声明的类 我如何声明所有类中都可见的常量 例如 这些常量可以是在所有类的方法中使用的物理常量 首先想到的是使用全局变量 还有更好的办法吗 最好在单独的文件中声明这些常量 包含常量的类是执行此操作的一种很好的干净方法 请参
  • 检测植物图片中的所有分支

    我想知道有什么可以检测下图中的所有绿色树枝 目前我开始应用 Frangi 过滤器 options struct FrangiScaleRange 5 5 FrangiScaleRatio 1 FrangiBetaOne 1 FrangiBe
  • 理解高斯混合模型的概念

    我试图通过阅读在线资源来理解 GMM 我已经使用 K 均值实现了聚类 并且正在了解 GMM 与 K 均值的比较 以下是我的理解 如有错误请指出 GMM 类似于 KNN 在这两种情况下都实现了聚类 但在 GMM 中 每个簇都有自己独立的均值和
  • Pytorch LSTM:计算交叉熵损失的目标维度

    我一直在尝试在 Pytorch 中使用 LSTM LSTM 后跟自定义模型中的线性层 但在计算损失时出现以下错误 Assertion cur target gt 0 cur target lt n classes failed 我用以下函数
  • 了解 fminunc 参数和匿名函数、函数处理程序

    请多多包涵 问题在最后 我试图找出 fminunc 调用方式的差异 这个问题源于 Andrew Ng 在他的 Coursera 机器学习课程中的第 3 周材料 我正在回答这个问题 Matlab Andrew Ng 机器学习课程中 t cos
  • MATLAB 中的霍夫变换

    有谁知道如何使用霍夫变换来检测二值图像中最强的线 A zeros 7 7 A 6 10 18 24 36 38 41 1 使用 rho theta 格式 其中 theta 以 45 为步长 从 45 到 90 以及如何在 MATLAB 中显
  • Matlab 的 imresize 函数中用于插值的算法是什么?

    我正在使用 Matlab Octaveimresize 对给定的二维数组重新采样的函数 我想了解如何使用特定的插值算法imresize works 我在Windows上使用八度 e g A 1 2 3 4 是一个二维数组 然后我使用命令 b
  • 检测数据集中线性行为的算法

    我已经发布了一个关于对数据集的一部分进行多项式拟合的算法 https stackoverflow com q 17595932 2320757前一段时间收到一些建议去做我想做的事 但我现在面临另一个问题 我尝试应用答案中建议的想法 我的目标
  • 如何将复杂的 csv 文件导入到 Matlab 中的数值向量

    我想知道我们应该如何读取由字符串 双精度数和字符等组成的复杂 csv 文件 例如 您能否提供一个可以在此 csv 文件中提取数值的成功命令 Click here http www ecb europa eu stats money yc d
  • Matlab Builder JA - 将 Matlab 编译成 Java jar - 免费版本?

    请记住 我对 Matlab 一无所知 Matlab Builder JA 允许开发人员构建 Matlab 应用程序并将其导出到 Java jar 中 太棒了 我只需要生成一个 jar 然后就可以从其他 java 代码中使用它 有谁知道单罐包
  • matlab中求和函数句柄

    Hi我试图对两个函数句柄求和 但它不起作用 例如 y1 x x x y2 x x x 3 x y3 y1 y2 我收到的错误是 对于 function handle 类型的输入参数 未定义函数或方法 plus 这只是一个小例子 实际上我实际
  • 矩形函数的数值傅里叶变换

    本文的目的是通过一个众所周知的分析傅里叶变换示例来正确理解 Python 或 Matlab 上的数值傅里叶变换 为此 我选择矩形函数 这里报告了它的解析表达式及其傅立叶变换https en wikipedia org wiki Rectan
  • 读出 Matlab / Octave fft2() 函数输出的特定点

    我正在熟悉 Octave 及其功能fft2 在此玩具示例中 我的目标是生成以下 256 x 256 png 图像的 2D DFT 为了能够轻松理解输出 我尝试将此图像转换为 256 x 256 图像 消除颜色信息 Im imread cir

随机推荐

  • python消消乐游戏界面的实现:

    一 环境介绍 1 Python 版本 Python 消消乐游戏可以在 Python 2 7 和 Python 3 x 版本中运行 2 Pygame 模块 Python 消消乐游戏需要使用 Pygame 模块来实现游戏界面和图形绘制等功能 如
  • DP线和HDMI线区别,优缺点,传输显示器图像速率

    参考DP接口与HDMI接口各有什么优势 哪个更好 资料来源于网络 仅供参考 最近在x宝上买显示器的线 看到各种hdmi dp版本的线 2 0 dp1 4 4k 8k typec转三口hdmi 可把我看昏了 在网上收集了一些资料用于总结 以及
  • QMUI 学习一: 入门,如何添加QMUI框架到 android项目 ,并引入QMUI的主题Theme:

    用是最新的Android Studio 3 6 x的 下了新的QMUI Demo参考学习UI 先上效果图 如何添加QMUI框架并引用它的主题 1 添加框架 在app gradle里面添加依赖 implementation com qmuit
  • 感知器算法实现多类样本的线性分类(Matlab)

    原理 略 步骤 二分类问题 1 将第一类样本作为正样本 第二类样本作为负样本 首先 对样本的向量空间进行增广 即对n维向量x的首部或者尾部增加一个参数1 增广为 n 1 维向量 并对其进行规范化 即正样本不做处理 负样本的 n 1 维向量取
  • web buuctf [极客大挑战 2019]Knife1

    题目给出得信息量是一句话木马 文本里面有一个 菜刀 字眼 可以尝试一下用中国菜刀 现在大部分都是用蚁剑 测试连接 提示成功 将该数据添加到界面上 点进去 点到根目录 在最下面有一个flag文件 点开即可 这道题考点 1 一句话木马 2 中国
  • 液晶显示器汉字字模存储及显示

    一 3 种汉字字模存储和提取的方法 1 字模存放在程序存储器中 这种方法较为常用 针对程序不大或单片机无外部扩展数据存储区功能的情况 2 通过外扩的EEPROM 存储汉字字模数据 将其作为外部数据存储器进行寻址 采用哈佛结构的单片机 如80
  • 罗小黑用flash做的_Flash动画制作小黑人经典动画效果技巧介绍(图文)

    本教程是向大家介绍Flash动画制作小黑人经典动画效果技巧 教程很经典 介绍的非常详细 相信对学习Flash朋友有一定的帮助 转发过来 希望对大家有所帮助 解决思路 小黑人动作是典型的人物动作 我们利用小黑人可以练习我们对人物动作的掌握 因
  • 带优美外观的UserControl控件GroupBox

    http www myfirm cn news DotNetUserInterface 20080208095730391 html 写在前面 如果大家觉得 Net自带的GroupBox控件太差了 样子很不美观 而想用 Net强大的自定义功
  • 【MyBatis】 动态SQL——模糊查询 LIKE

    一 LIKE SELECT FROM t usr WHERE name like name SQL解析为 SELECT FROM t usr WHERE name like 海 可以看到 传参必须用 不能用 所以这样写的弊端就是不安全 不能
  • 腾讯家低调开发的良心工具?目前无任何付费机制还挺好用~

    去年 Tik Tok Clipping 和哔哩哔哩 Clipping 相继推出了自己的桌面编辑软件 相比专业首演 无论是操作逻辑还是内置素材库 它们都能让非专业人士更容易上手 大大降低创作门槛 但还是有朋友反馈电脑不能动 实在没办法动 只好
  • 在macOS上安装NodeJS多版本管理工具

    需求 现在Node js也有很多的版本啦 简单地使用某个版本 只需要去下载安装对应版本就可以了 如果需要多个版本在机器上共存 并在需要时切换到相应的版本环境 这时候就需要多版本的管理工具了 而 n nvm就是这个有效的工具 简介 NVM 即
  • 【SQL注入-无回显】时间盲注:原理、函数、利用过程

    目录 一 时间盲注 延时 1 1 简介 1 2 原理 二 常用函数 2 1 延迟函数 编辑 2 2 相关函数 2 3 示例语句 三 利用过程 3 1 第一步 判断注入点 3 2 第二步 判断可使用注入方法 3 3 第三步 猜数据库名称长度
  • springBoot 部署Docker环境中

    目录 一 准备 二 docker运行环境 三 DockerFile 四 制作镜像 五 启动容器 六 访问 一 准备 需要打好的 jar 包 这里不再赘述 docker 环境 二 docker运行环境 安装JDK docker pull pr
  • monkeyrunner之夜神模拟器的安装与使用(二)

    在上一篇文章 安卓开发环境搭建中 我们创建并启动了eclipse自带的安卓模拟器 该模拟器不仅启动慢 而且在使用过程中的反应速度也是出奇的差 经常出现卡机现象 为了解决这种现象 因此 我们又寻找到了更加合适的模拟器 夜神模拟器 该模拟器除了
  • JavaScript离线手册 w3c(w3school) 百度网盘

    听尚硅谷李超老师课 感觉离线文档特别实在 JavaScript w3school 离线分享 baidu wangpan 链接 https pan baidu com s 1AwMZy2MpvxzBtePtDp39nQ pwd imle 提取
  • 【Zabbix实战之部署篇】Zabbix监控windows系统配置方法

    Zabbix实战之部署篇 Zabbix监控windows系统配置方法 一 检查Zabbix监控平台状态 1 检查Zabbix各组件状态 2 检查Zabbix的首页 二 下载windows代理 1 访问Zabbix官网下载界面 2 查看下载安
  • Redis 分布式集群搭建

    转 https blog csdn net daybreak1209 article details 51493265 在Redis的安装和部署 Linux 一文中详细介绍了在Linux环境中搭建Redis服务 本文将介绍关于Redis分布
  • python如何不生成pyc文件(三种方式)

    python如何不生成pyc文件 三种方式 当 import导入另一个模块的时候会生成pyc文件 python3会生成 pycache 如何不生成编译文件呢 1 使用 B参数 即 python3 B test py 里面的包含的就不会生成p
  • 最强自动化测试框架Playwright(4)-控件操作

    文本输入 适用于input textarea 其他可编辑内容的元素 Text input page get by role textbox fill Peter Date input page get by label Birth date
  • 使用MATLAB的trainNetwork设计一个简单的LSTM神经网络

    文章目录 前言 一 数据集 二 网络结构 三 测试程序 前言 借助MATLAB的deepNetworkDesigner搭一个简单的LSTM 数据集使用mnist手写数字识别数据集 一 数据集 mnist数据集包括60000组训练数据和对应的