matlab2019a中LSTM网络使用方法及源码示例(Deep Learning Toolbox系列篇6)

2023-11-18

此示例说明如何使用长短期记忆 (LSTM) 网络对序列数据进行分类。

要训练深度神经网络以对序列数据进行分类,可以使用 LSTM 网络。LSTM 网络允许您将序列数据输入网络,并根据序列数据的各个时间步进行预测。

此示例使用 [1] 和 [2] 中所述的日语元音数据集。此示例训练一个 LSTM 网络,旨在根据表示连续说出的两个日语元音的时序数据来识别说话者。训练数据包含九个说话者的时序数据。每个序列有 12 个特征,且长度不同。该数据集包含 270 个训练观测值和 370 个测试观测值。

源码:


%% 通用matlab脚本三连
clear
clc
close all

%% 加载序列数据
% 加载日语元音训练数据。XTrain 是包含 270 个不同长度的 12 维序列的元胞数组。
% Y 是对应于九个说话者的标签 "1"、"2"、...、"9" 的分类向量。
% XTrain 中的条目是具有 12 行(每个特征一行)和不同列数(每个时间步一列)的矩阵。
[XTrain,YTrain] = japaneseVowelsTrainData;
XTrain(1:5)

%% 在绘图中可视化第一个时序。每行对应一个特征。
figure
plot(XTrain{1}')
xlabel("Time Step")
title("Training Observation 1")
legend("Feature " + string(1:12),'Location','northeastoutside')

%% 准备要填充的数据
numObservations = numel(XTrain);
for i=1:numObservations
    sequence = XTrain{i};
    sequenceLengths(i) = size(sequence,2);
end

%% 按序列长度对数据进行排序。
[sequenceLengths,idx] = sort(sequenceLengths);
XTrain = XTrain(idx);
YTrain = YTrain(idx);

%% 在条形图中查看排序的序列长度。
figure
bar(sequenceLengths)
ylim([0 30])
xlabel("Sequence")
ylabel("Length")
title("Sorted Data")

%% 选择小批量大小 27 以均匀划分训练数据,并减少小批量中的填充量。下图说明了添加到序列中的填充。
miniBatchSize = 27;

%% 定义 LSTM 网络架构
inputSize = 12;
numHiddenUnits = 100;
numClasses = 9;

layers = [ ...
    sequenceInputLayer(inputSize)
    bilstmLayer(numHiddenUnits,'OutputMode','last')
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer]

maxEpochs = 100;
miniBatchSize = 27;

%% 指定训练选项
options = trainingOptions('adam', ...
    'ExecutionEnvironment','cpu', ...
    'GradientThreshold',1, ...
    'MaxEpochs',maxEpochs, ...
    'MiniBatchSize',miniBatchSize, ...
    'SequenceLength','longest', ...
    'Shuffle','never', ...
    'Verbose',0, ...
    'Plots','training-progress');

%% 训练 LSTM 网络
net = trainNetwork(XTrain,YTrain,layers,options);

%% 测试 LSTM 网络
[XTest,YTest] = japaneseVowelsTestData;
XTest(1:3)

%% LSTM 网络 net 已使用相似长度的小批量序列进行训练
numObservationsTest = numel(XTest);
for i=1:numObservationsTest
    sequence = XTest{i};
    sequenceLengthsTest(i) = size(sequence,2);
end
[sequenceLengthsTest,idx] = sort(sequenceLengthsTest);
XTest = XTest(idx);
YTest = YTest(idx);

%% 对测试数据进行分类
miniBatchSize = 27;
YPred = classify(net,XTest, ...
    'MiniBatchSize',miniBatchSize, ...
    'SequenceLength','longest');

%% 计算预测值的分类准确度。
acc = sum(YPred == YTest)./numel(YTest)

运行结果:

本文着重讲解一下matlab的深度学习训练方式

在matlab中,训练一个网络有一个常用的函数trainNetwork;此函数通常有四个参数

%% 训练 LSTM 网络
net = trainNetwork(XTrain,YTrain,layers,options);

 其中,XTrain为训练的输入数据集,YTrain为训练数据对应的标签;

layers为网络的架构,也就是指网络每层的处理模式。

options为超参数的设置,包括学习率,优化方法,迭代次数以及批量大小问题等。

1. 构建网络模型;

inputSize = 12;
numHiddenUnits = 100;
numClasses = 9;

layers = [ ...
    sequenceInputLayer(inputSize) %输入层
    bilstmLayer(numHiddenUnits,'OutputMode','last') %第一层隐层,LSTM架构
    fullyConnectedLayer(numClasses) % 第二层隐层,全连接层
    softmaxLayer %softmax处理
    classificationLayer] %分类层

 如上代码片段所示,逐行对每一层的网络处理模式进行说明,并将此构成的数组形式赋予一个变量(此中为layers)。

2. 指定训练的超参数

%% 指定训练选项
options = trainingOptions('adam', ...
    'ExecutionEnvironment','cpu', ...
    'GradientThreshold',1, ...
    'MaxEpochs',maxEpochs, ...
    'MiniBatchSize',miniBatchSize, ...
    'SequenceLength','longest', ...
    'Shuffle','never', ...
    'Verbose',0, ...
    'Plots','training-progress');

其中参数设定的格式必须由trainingOptions进行格式打包,即options的数据类型必须为TrainingOptions派,此示例中options的数据类型为TrainingOptionsADAM。另外本示例数据集方面的XTrain与XTest都是matlab中的cell数据类型。

 更详细的matlab深度学习训练方法(trainNetwork用法)请参考matlab官方文档或本博客Deep Learning ToolBox系列7。

 

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

matlab2019a中LSTM网络使用方法及源码示例(Deep Learning Toolbox系列篇6) 的相关文章

  • 如何在 MATLAB 的 for 循环中读取多个图像?

    我已将结果分段放在一个文件夹中 这些需要在 for 循环中读取并在循环中进一步处理 我尝试阅读如下 for i 1 10 file name dir strcat C Users adminp Desktop dinosaurs im im
  • 图像分析-光纤识别

    我是图像分析新手 您知道如何以仅获取纤维的方式对该图像进行二值化吗 我尝试过不同的阈值技术等 但没有成功 我不介意应该使用什么工具 但我更喜欢 NET or Matlab PS 我不知道该把答案放在哪里 所以我把它放在StackOverfl
  • 如何在Matlab中自定义轮廓线?

    我正在准备一个等高线图 我应该在其中突出显示特定级别的等高线 例如 我的轮廓线值位于 1 和 1 之间 我想突出显示与值 0 相对应的线 我尝试使用以下过程来执行此操作 M c contourf longitude latitude del
  • 在 Excel 中打印 MATLAB 图窗并调整其大小

    我在 MATLAB 中有两个带有手柄的图形hFig1 and hFig2 我想将它们打印到 Excel 中的特定单元格 单元格 E3 和 I3 并将它们重新调整为 2 英寸 x 3 英寸 我尝试过使用 AddPictures对象处理程序和使
  • 如何让MCR启动时间快

    我将 matlab 程序转换为 net 程序集 即 dll 文件 我制作了一个控制台 C 应用程序 添加了 dll 文件并从 php 调用它 每次调用 exe 时都会调用 MCR 如何使 MCR 在服务器启动时初始化 并且即使在一段时间后调
  • 将自动生成的 Matlab 文档导出为 html

    我想为我开发的 Matlab 工具箱生成完整的帮助 我已经看到如何显示自定义文档 http www mathworks fr fr help matlab matlab prog display custom documentation h
  • Matlab Mex文件编译

    我正在尝试编译一个 mex 文件以在 matlab 中使用套接字连接 问题是它总是说我没有安装sdk或编译器 但我已经安装了 Visual Studio 2010 Express Visual Studio 2012 Express Vis
  • 如何使用matlab生成不同频率的正弦波?

    对于我的项目 我需要使用 matlab 生成一个正弦波 它有 100 000 个样本 并且频率在每 10 000 个样本后随机变化 采样率和频率可以根据方便而定 matlab中有没有函数可以生成这个 好的另一个例子 生成 5 个随机频率 r
  • 在 matlab 中求 3d 峰的体积

    现在我有一个带有峰值的 3D 散点图 我需要找到其体积 我的数据来自图像 因此 x 和 y 值表示 xy 平面上的像素位置 z 值是每个像素的像素值 这是我的散点图 scatter3 x y z 20 z filled 我试图找到数据峰值的
  • 什么是 ANN 中的纪元以及它如何转换为 MATLAB 中的代码?

    我试图理解 并可视化 训练人工神经网络的时代到底是什么 我们有一个包含约 7000 个产品的训练集 其中有 10 个特征 输入 这些产品必须根据这 10 个输入分为 7 个类别 我们的 ANN 有 10 个输入 这些输入进入由 10 个神经
  • 使用 MATLAB 进行线路跟踪

    我有一个图像 我想将其转换为逻辑图像 包括线条为黑色 背景为白色 当然 可以使用阈值方法来实现这一点 但我不想使用这种方式来做到这一点 我想通过使用线路跟踪方法或类似的方法来检测它 这是关于视网膜血管检测的 我找到了一个article ht
  • 如何在 Matlab 中将数组打印到 .txt 文件?

    我才刚刚开始学习Matlab 所以这个问题可能非常基本 我有一个变量 a 2 3 3 422 6 121 9 4 55 我希望将值输出到 txt 文件 如下所示 2 3 3 422 6 121 9 4 55 我怎样才能做到这一点 fid f
  • 有没有办法在matlab中进行隐式微分

    我经常使用 matlab 来帮助我解决数学问题 现在我正在寻找一种在 matlab 中进行隐式微分的方法 例如 我想区分y 3 sin x cos y exp x 0关于dy dx 我知道如何使用数学方法通常做到这一点 但我一直在努力寻找使
  • Matlab:如何更改矩阵的存储方式?从 1x1x3 到 1x3?

    我目前有 val 1 0 7216 val 2 0 7216 val 3 0 7216 但我想要 0 7216 0 716 0 721 我可以做什么样的操作来做到这一点 The reshape函数将在这里解决问题 Arrange the e
  • 我需要转义该 MATLAB 字符串中的字符吗?

    我想在 MATLAB 中调用以下 bash 命令 grep Up to test linux vision1 1 log awk print 7 I use system 在MATLAB中 但结果有错误 gt gt status strin
  • matlab中的排列函数是如何工作的

    这是一个有点愚蠢的问题 但我似乎无法弄清楚排列在 matlab 中是如何工作的 以文档为例 A 1 2 3 4 permute A 2 1 ans 1 3 2 4 到底是怎么回事 这如何告诉 matlab 3 和 2 需要交换 哇 这是我迄
  • 理解高斯混合模型的概念

    我试图通过阅读在线资源来理解 GMM 我已经使用 K 均值实现了聚类 并且正在了解 GMM 与 K 均值的比较 以下是我的理解 如有错误请指出 GMM 类似于 KNN 在这两种情况下都实现了聚类 但在 GMM 中 每个簇都有自己独立的均值和
  • 通过 Matlab 访问 Physionet 的 ptbdb 中的数据库

    我首先设置系统 old path which rdsamp if isempty old path rmpath old path 1 end 8 end wfdb url http physionet org physiotools ma
  • MATLAB - 冲浪图数据结构

    我用两种不同的方法进行了计算 对于这些计算 我改变了 2 个参数 x 和 y 最后 我计算了每种变体的两种方法之间的 误差 现在我想根据结果创建 3D 曲面图 x gt on x axis y gt on y axis Error gt o
  • 如何在文本集中创建所有字符组合?

    例如 我有这样的文本集 第 1 栏 a b 第 2 栏 l m n 第 3 栏 v w x y 我想将它们组合起来以获得如下输出 alv alw alx aly amv amw amx amy 这将输出 24 种文本组合 如果我只使用前两列

随机推荐