ANN神经网络入门——分类问题(MATLAB)

2023-11-13

写在前面

 本篇博客的鸢尾花分类程序来源于博客http://www.cnblogs.com/heaad/archive/2011/03/07/1976443.html

上述博客中,作者主要介绍了以下三部分内容:

(1)神经网络基本原理

(2)AForge.NET实现前向神经网络的方法

(3)Matlab实现前向神经网络的方法

该博客对神经网络算法的入门和理解有很大的帮助,本博客主要给出神经网络算法的两个简单应用,分别是Fishr集上鸢尾花(Iris)的分类以及常见的分类问题-蠓虫分类,同时通过对两个分类程序(Code by MATLAB)进一步分析,加强对神经网络算法的理解。

--------------------------------------------------------------------------------------------

*本篇博客用到的所有程序和源文件下载请见以下链接:https://pan.baidu.com/s/1ggYS3gn  密码:tepb

*链接失效请发邮件至 2867555086@qq.com

*转载或引用请注明来源

--------------------------------------------------------------------------------------------


一、正文

1、Fishr集上鸢尾花Iris数据集的分类

①iris数据集简介

iris数据集的中文名是安德森鸢尾花卉数据集,英文全称是Anderson’s Iris data set。iris包含150个样本,对应数据集的每行数据。每行数据包含每个样本的四个特征和样本的类别信息,所以iris数据集是一个150行5列的二维表。
通俗地说,iris数据集是用来给花做分类的数据集,每个样本包含了花萼长度、花萼宽度、花瓣长度、花瓣宽度四个特征(前4列),我们需要建立一个分类器,分类器可以通过样本的四个特征来判断样本属于山鸢尾、变色鸢尾还是维吉尼亚鸢尾(这三个名词都是花的品种)。
iris的每个样本都包含了品种信息,即目标属性(第5列,也叫target或label)。

样本局部截图:


MATLAB中可运行命令

load iris.dat

 将数据集载入到工作区,部分数据集如图所示。数据集的前四列分别为与鸢尾花种类相关的4个特征值,对应上图中的花萼长度、花萼宽度、花瓣长度及花瓣宽度;第五列为鸢尾花所属种类 ,分为1-Setosa、2-Versicolour、3-Virginica三类。

②数据预处理

这里的神经网络属于监督学习的模式,因此需要从上述数据集中分离出训练集和测试集,我们分别记为trainData和testData。我们从iris数据集中选取2/3数据作为训练集trainData,选取1/3数据作为测试集testData,并分别将其保存至trainData.txt和testData.txt文件,用于程序的数据导入源(两个文件均已上传至源代码)。

③分类源程序

%读取训练数据
clear
clc
%------本代码采用ANN对鸢尾花进行分类,程序运行前,请准备好#鸢尾花样本#的训练集和测试集(可在MATLAB中载入iris.dat查看数据)-----
%f1 f2 f3 f4是四个特征值
[f1,f2,f3,f4,class] = textread('trainData.txt' , '%f%f%f%f%f',150);
%特征值归一化
[input,minI,maxI] = premnmx( [f1 , f2 , f3 , f4 ]')  ;
%构造输出矩阵
s = length( class) ;
output = zeros( s , 3  ) ;
for i = 1 : s 
   output( i , class( i )  ) = 1 ;
end

%创建神经网络
net = newff( minmax(input) , [10 3] , { 'logsig' 'purelin' } , 'traingdx' ) ; 
%{
    minmax(input):获取4个输入信号(存储在f1 f2 f3 f4中)的最大值和最小值;
    [10,3]:表示使用2层网络,第一层网络节点数为10,第二层网络节点数为3;
    { 'logsig' 'purelin' }:
        表示每一层相应神经元的激活函数;
        即:第一层神经元的激活函数为logsig(线性函数),第二层为purelin(对数S形转移函数)
    'traingdx':表示学习规则采用的学习方法为traingdx(梯度下降自适应学习率训练函数)
%}
%设置训练參数
net.trainparam.show = 50 ;% 显示中间结果的周期
net.trainparam.epochs = 500 ;%最大迭代次数(学习次数)
net.trainparam.goal = 0.01 ;%神经网络训练的目标误差
net.trainParam.lr = 0.01 ;%学习速率(Learning rate)

%开始训练
%其中input为训练集的输入信号,对应output为训练集的输出结果
net = train( net, input , output' ) ;
%================================训练完成====================================%
%=============================接下来进行测试=================================%

%读取测试数据
[t1 t2 t3 t4 c] = textread('testData.txt' , '%f%f%f%f%f',150);

%测试数据归一化
testInput = tramnmx ( [t1,t2,t3,t4]' , minI, maxI ) ;
%[testInput,minI,maxI] = premnmx( [t1 , t2 , t3 , t4 ]')  ;
%仿真
%其中net为训练后得到的网络,返回的Y为
Y = sim( net , testInput ) 

%统计识别正确率
[s1 , s2] = size( Y ) ;
hitNum = 0 ;
for i = 1 : s2
    [m , Index] = max( Y( : ,  i ) ) ;
    if( Index  == c(i)   ) 
        hitNum = hitNum + 1 ; 
    end
end
sprintf('识别率是 %3.3f%%',100 * hitNum / s2 )

④代码的相关说明

A. 语句net = newff( minmax(input) , [10 3] , { 'logsig' 'purelin' } , 'traingdx' ) ;用于创建神经网络,其参数含义和用法如下:

    (1)minmax(input):获取4个输入信号(存储在f1 f2 f3 f4中)的最大值和最小值;
    (2) [10,3]:表示使用2层网络,第一层网络节点数为10,第二层网络节点数为3。其中最后一层的网络包含的节点数一定要与网络的理论输出个数保持一致,例如本例中鸢尾花的种类数为3,因此最后一层的网络节点数为3;
    (3){ 'logsig' 'purelin' }:表示每一层相应神经元的激活函数,即:第一层神经元的激活函数为logsig(线性函数),第二层为purelin(对数S形转移函数),其他激活函数和用法请参见神经网络与深度学习之激活函数
    (4) 'traingdx':表示学习规则采用的学习方法为traingdx(梯度下降自适应学习率训练函数)。常见的训练函数(学习方法)有:

  traingd :梯度下降BP训练函数(Gradient descent backpropagation)
  traingdx :梯度下降自适应学习率训练函数

    (5)创建的神经网络用MATLAB神经网络工具箱显示如图,图中更形象的展示了构造的神经网络模型。


B. 关于正确率的统计算法的说明

第一次看到这里的正确率统计算法时,我自己是不大明白的,之后又从网上搜了一些资料并查阅了MATLAB的帮助文档,才明白代码的含义

语句net = train( net, input , output' ) ;是对网络进行训练,该语句明确了网络的输出为output,通过对output矩阵的构造方式分析,我们可知网络的输出可以看成3个,我们不妨即为C1、C2、C3,分别代表鸢尾花的三个种类,例如:

(1)当output的某一行为1 0 0,则说明该花属于C1类

(2)当output的某一行为0 1 0,则说明该花属于C2类

(3)当output的某一行为0 1 0,则说明该花属于C3类

语句Y = sim( net , testInput ) 是对训练后的网络net进行仿真测试,测试用的数据为testInput;这里,Y返回的是网络训练后对测试输入的预测值,例如:

(1)当Y的某一行为1.0220  -0.0020  -0.0091,代表输出结果C1=1.0220, 输出结果C2=-0.0020,C3=-0.0091

(2)当Y的某一行为-0.0108  0.9884  -0.0216,代表输出结果C1=-0.0108,输出结果C2=0.9884,C3=-0.0216

输出结果中只包含一个1和两个0是理想情况下的结果,在进行仿真时,分类输出往往达不到这样的结果,但我们可以根据哪个结果对应的值与1的接近程度来进行判断,例如仿真结果(1)说明该花极有可能属于C1类,仿真结果(2)说明该花极有可能属于C2类。

*从以上描述中我们可以明白神经网络算法也可以应用于具体数值的预测,且应用广泛。

⑤仿真和运行结果


说明:以上程序的识别率稳定在96%左右,训练150次左右达到收敛,训练曲线如上图所示:


2、蠓虫分类

①背景介绍

生物学家WLGrogan和WWWirth发现,蠓虫的种类与它们的触角长度和翼长有关,且蠓虫大致分为两类,记为Apf和Af。

现有如下样本,下面需要通过神经网络模型找到一种有效的对蠓虫分类的方法。


②数据预处理

与鸢尾花分类类似,我们随机选取2/3作为训练集trainData,1/3作为测试集testData,考虑到这是一个分类问题,因此我们将目标值0.9替换为2,代表蠓虫属于Apf类;将目标值0.1替换为1,代表蠓虫属于Af类,处理后导入至文本文件trainData.txt和testData.txt,作为程序的数据导入源文件(两个文件均已上传至源代码)

分类源程序

%读取训练数据
clear
clc
%f1 f2 f3 f4是四个特征值
[f1,f2,class] = textread('trainData.txt' , '%f%f%f');
%特征值归一化
[input,minI,maxI] = premnmx( [f1 , f2 ]')  ;
%构造输出矩阵
s = length( class) ;
output = zeros( s , 2  ) ;
for i = 1 : s 
   output( i , class( i )  ) = 1 ;
end
output
%创建神经网络
net = newff( minmax(input) , [10 2] , { 'logsig' 'purelin' } , 'traingdx' ) ; 
%设置训练參数
net.trainparam.show = 50 ;% 显示中间结果的周期
net.trainparam.epochs = 500 ;%最大迭代次数(学习次数)
net.trainparam.goal = 0.01 ;%神经网络训练的目标误差
net.trainParam.lr = 0.01 ;%学习速率(Learning rate)

%开始训练
%其中input为训练集的输入信号,对应output为训练集的输出结果
net = train( net, input , output' ) ;
%================================训练完成====================================%
%=============================接下来进行测试=================================%

%读取测试数据
[t1 t2 c] = textread('testData.txt' , '%f%f%f');

%测试数据归一化
testInput = tramnmx ( [t1,t2]' , minI, maxI ) ;
%[testInput,minI,maxI] = premnmx( [t1 , t2]')  ;
%仿真
%其中net为训练后得到的网络,返回的Y为
Y = sim( net , testInput ) 
%{
Y返回预测值,输出有两个记为A、B,理想情况下输出为上述的output,输出结果只有1和0两种
即:output =
     0     1
     0     1
     0     1
     0     1
     1     0
     1     0
     1     0
     1     0
     1     0
     1     0
---------------------------------------------------------------------------------
例:Y =
    0.1432    0.4841    1.0754    1.2807    0.8405
    0.6034    0.5329    0.1427   -0.4194    0.0947
则说明:
    对于第一个测试数据,输出结果A=0.1432,输出结果B=0.6034
    对于第一个测试数据,输出结果A=0.4841,输出结果B=0.5329
因此,对于本例:若结果A更接近于1(左接近或右接近),那么说明该测试数据属于第一个分类;
               若结果B更接近于1(左接近或右接近),那么说明该测试数据属于第二个分类。
因此,ANN除了应用于分类问题,也可应用于对具体数值的预测问题。
---------------------------------------------------------------------------------
_______________________________________________________________2018.02.06 by_LeoHao
%}
%统计识别正确率
[s1 , s2] = size( Y ) ;
hitNum = 0 ;
for i = 1 : s2
    [m , Index] = max( Y( : ,  i ) ) ;
    if( Index  == c(i)   ) 
        hitNum = hitNum + 1 ; 
    end
end
sprintf('识别率是 %3.3f%%',100 * hitNum / s2 )

  神经网络可视化

  创建的神经网络用MATLAB神经网络工具箱显示如图,图中更形象的展示了构造的神经网络模型。


  ⑤仿真和运行结果



*说明:从上图来看,识别率近似于100%,训练在120次左右仅稍微达到收敛,这可能是由于训练集样本规模太小导致的。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

ANN神经网络入门——分类问题(MATLAB) 的相关文章

随机推荐

  • Unity的2D图集处理,并切割出一张张小图片

    转载请注明出处 http blog csdn net hongyouwei article details 45011315 在我们找资源的时候 有时候会遇到那种把一堆图片放进一张png图片里面的情况 在做2D游戏的时候 我们往往需要将里面
  • jackson自定义全局序列化、反序列化

    需要自定义Jackson序列化和反序列化有两种方式 一种是全局定义 一种是非全局定义 先来看看全局定义 全局定义的步骤如下 以定义一个localDateTime的序列化和反序列化为例 一 创建序列化类 创建一个序列化类然后继承JsonSer
  • 基于JavaWeb(JSP+Servlet+MySQL)编程实现员工信息的添加、修改、删除、列表显示。

    1 项目结构 2 页面主要代码 2 1员工添加页面代码以及效果图 add jsp 员工添加
  • linux 修改用户密码的几种方法

    1 passwd 命令 手动修改 root localhost testuser passwd testuser Changing password for user testuser New password Retype new pas
  • Spring Factories

    该文章转载自 https blog csdn net lvoyee article details 82017057 Spring Boot中有一种非常解耦的扩展机制 Spring Factories 这种扩展机制实际上是仿照Java中的S
  • 简说数据库事务的ACID

    事务是应用程序中一系列严密的操作 所有操作必须成功完成 否则在每个操作中所作的所有更改都会被撤消 也就是事务具有原子性 一个事务中的一系列的操作要么全部成功 要么一个都不做 原子性 Atomicity 一致性 Consistency 隔离性
  • postman 模拟请求中添加 header,post请求中传json参数

    1 GET 请求 2 Post 请求 请求参数为Json header中带有参数 问题延伸 GET请求不能够 添加 Body 吗 答案 转载于 https www cnblogs com zacky31 p 8808110 html
  • 【网络安全】——区块链安全和共识机制

    区块链安全和共识机制 摘要 区块链技术作为一种分布式去中心化的技术 在无需第三方的情况下 使得未建立信任的交易双方可以达成交易 因此 区块链技术近年来也在金融 医疗 能源等多个行业得到了快速发展 然而 区块链为无信任的网络提供保障的同时 也
  • 《算法导论》总结与分析

    算法导论总结与分析 分治 strassen算法 介绍 步骤 正确性证明 复杂度分析 排序 堆排序 介绍 步骤 构建 排序 优先队列 复杂度分析 快速排序 介绍 步骤 复杂度分析 最坏情况 最好情况 线性时间排序 介绍 步骤 复杂度分析 数据
  • 课堂小作业之3位水仙花数计算

    3位水仙花数计算 描述 3位水仙花数 是指一个三位整数 其各位数字的3次方和等于该数本身 例如 ABC是一个 3位水仙花数 则 A的3次方 B的3次方 C的3次方 ABC
  • 组合优化技术

    组合优化是指在离散领域内 寻找最优解的问题 在通信工程中 组合优化的应用非常广泛 例如在无线通信系统中 可以使用组合优化算法来优化信道资源分配 功率控制 调制方式等问题 组合优化问题通常包含以下要素 决策变量 表示问题的解 通常是一个离散的
  • spring cloud版本由1.5.x升级到2.x所遇到的坑

    众所知周 spring cloud 1 5版本与2 x版本差异很大 官方没有做向下兼容 导致大家对于升级spring cloud版本都非常慎重 此处 首先推荐阅读官方给出的迁移手册 Spring Boot 2 0 Migration Gui
  • ChatGPT学习相关资料整理

    ChatGPT学习相关资料整理 关于ChatGPT的相关咨询和新闻 ChatGPT能力起源 https mp weixin qq com s 4l0ADjdsCxSVvBeVKxSqWA ChatGPT的发展历程 https zhuanla
  • 生产数据采集MDC的总体思路

    一 数控机床通过网口连到局域网 MDC服务器与数控机床通讯 定时取得所需数据 将数据写入数据库 二 MES对数据库中的数据进行分析 展示到大屏上 我这里是机械制造型企业 以上步骤已经完成 有相同需求的朋友 欢迎一起交流细节
  • SQL INSERT INTO 语句

    INSERT INTO 语句用于向表中插入新记录 语法 指定列插入数据 INSERT INTO table name colnum1 colnum2 column3 VLAUES value1 value2 value3 不指定列插入数据
  • java - JVM CPU100%,问题排查

    前段时间我们新上了一个新的应用 因为流量一直不大 集群QPS大概只有5左右 写接口的rt在30ms左右 因为最近接入了新的业务 业务方给出的数据是日常QPS可以达到2000 大促峰值QPS可能会达到1万 所以 为了评估水位 我们进行了一次压
  • FormData实现文件上传

    应用场景 FormData Ajax技术实现文件上传 1 FormData使用 FormData是一个构造函数 首先new FormData 得到一个FormData对象 可以直接使用 直接console会是一个空白的对象 有append方
  • 未解决-联想拯救者r7000 CTRL+C复制键无法使用

    情况描述 突然不能使用 不知道是什么操作导致不能使用 也不知道什么操作解决了问题 发生次数 四次以上 判断过程 鼠标右键复制可以使用 外接键盘复制可以使用 无QQ 微信等热键冲突 网页 记事本都无法使用 ctrlA ctrlX ctrlV可
  • 手把手教你画活动图,再无难搞的流程分析

    上次介绍了 用例图这样画 3步让你做需求分析有理有据 这次聊聊活动图 也许你对活动图并不了解 不过 说起流程图 想必你不会陌生 你可以暂且把活动图 看成 UML 中的流程图 都知道 做产品要分析流程 可怎么把流程理清楚呢 当然不能凭空想象
  • ANN神经网络入门——分类问题(MATLAB)

    写在前面 本篇博客的鸢尾花分类程序来源于博客http www cnblogs com heaad archive 2011 03 07 1976443 html 在上述博客中 作者主要介绍了以下三部分内容 1 神经网络基本原理 2 AFor