matlab练习程序(神经网络识别mnist手写数据集)

2023-11-02

记得上次练习了神经网络分类,不过当时应该有些地方写的还是不对。

这次用神经网络识别mnist手写数据集,主要参考了深度学习工具包的一些代码。

mnist数据集训练数据一共有28*28*60000个像素,标签有60000个。

测试数据一共有28*28*10000个,标签10000个。

这里神经网络输入层是784个像素,用了100个隐含层,最终10个输出结果。

arc代表的是神经网络结构,可以增加隐含层,不过我试了没太大效果,毕竟梯度消失。

因为是最普通的神经网络,最终识别错误率大概在5%左右。

迭代曲线:

 

代码如下:

clear all;
close all;
clc;

load mnist_uint8;

train_x = double(train_x) / 255;
test_x  = double(test_x)  / 255;
train_y = double(train_y);
test_y  = double(test_y);

mu=mean(train_x);    
sigma=max(std(train_x),eps);
train_x=bsxfun(@minus,train_x,mu);          %每个样本分别减去平均值
train_x=bsxfun(@rdivide,train_x,sigma);     %分别除以标准差

test_x=bsxfun(@minus,test_x,mu);
test_x=bsxfun(@rdivide,test_x,sigma);

arc = [784 100 10]; %输入784,隐含层100,输出10
n=numel(arc);

W = cell(1,n-1);    %权重矩阵
for i=2:n
    W{i-1} = (rand(arc(i),arc(i-1)+1)-0.5) * 8 *sqrt(6 / (arc(i)+arc(i-1)));
end

learningRate = 2;   %训练速度
numepochs = 5;      %训练5遍
batchsize = 100;    %一次训练100个数据

m = size(train_x, 1);       %数据总量
numbatches = m / batchsize;    %一共有numbatches这么多组

%% 训练
L = zeros(numepochs*numbatches,1);
ll=1;
for i = 1 : numepochs
    kk = randperm(m);
    for l = 1 : numbatches
        batch_x = train_x(kk((l - 1) * batchsize + 1 : l * batchsize), :);
        batch_y = train_y(kk((l - 1) * batchsize + 1 : l * batchsize), :);

       %% 正向传播
        mm = size(batch_x,1);
        x = [ones(mm,1) batch_x];
        a{1} = x;
        for ii = 2 : n-1
            a{ii} = 1.7159*tanh(2/3.*(a{ii - 1} * W{ii - 1}'));   
            a{ii} = [ones(mm,1) a{ii}];
        end
        
        a{n} = 1./(1+exp(-(a{n - 1} * W{n - 1}')));
        e = batch_y - a{n};
        L(ll) = 1/2 * sum(sum(e.^2)) / mm; 
        ll=ll+1;
       %% 反向传播
        d{n} = -e.*(a{n}.*(1 - a{n}));
        for ii = (n - 1) : -1 : 2
            d_act = 1.7159 * 2/3 * (1 - 1/(1.7159)^2 * a{ii}.^2);
            
            if ii+1==n    
                d{ii} = (d{ii + 1} * W{ii}) .* d_act; 
            else 
                d{ii} = (d{ii + 1}(:,2:end) * W{ii}).* d_act;
            end          
        end
         
        for ii = 1 : n-1
            if ii + 1 == n
                dW{ii} = (d{ii + 1}' * a{ii}) / size(d{ii + 1}, 1);
            else
                dW{ii} = (d{ii + 1}(:,2:end)' * a{ii}) / size(d{ii + 1}, 1);      
            end
        end
         
       %% 更新参数
        for ii = 1 : n - 1       
            W{ii} = W{ii} - learningRate*dW{ii};
        end
              
    end
end

%% 测试,相当于把正向传播再走一遍
mm = size(test_x,1);
x = [ones(mm,1) test_x];
a{1} = x;
for ii = 2 : n-1    
    a{ii} = 1.7159 * tanh( 2/3 .* (a{ii - 1} * W{ii - 1}'));  
    a{ii} = [ones(mm,1) a{ii}];
end
a{n} = 1./(1+exp(-(a{n - 1} * W{n - 1}')));

[~, i] = max(a{end},[],2);
labels = i;                         %识别后打的标签
[~, expected] = max(test_y,[],2);
bad = find(labels ~= expected);     %有哪些识别错了
er = numel(bad) / size(x, 1)       %错误率

plot(L);

测试数据可以在这里下载到:https://pan.baidu.com/s/19YPUe9S9xnztg9JGnoXxqw

转载于:https://www.cnblogs.com/tiandsp/p/9042908.html

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

matlab练习程序(神经网络识别mnist手写数据集) 的相关文章

  • 在 3d 空间中的两个平面之间进行插值

    我正在开发一种工具 可以让您在 3D 体积 上圈出 包围事物 我想通过标记 切片 1 和 3 并从该信息 填充 切片 2 来节省时间 两个简单的解决方案是 1 slice2 slice1 AND slice3 gets the overla
  • FFT 的功率谱密度

    我有一段代码可以获取部分信号的 FFT 现在我正在尝试获取 PSD Fs 44100 cj sqrt 1 T 6 dt 1 Fs left test 1 right test 2 time 45 interval 636 w range t
  • 在 Excel 中打印 MATLAB 图窗并调整其大小

    我在 MATLAB 中有两个带有手柄的图形hFig1 and hFig2 我想将它们打印到 Excel 中的特定单元格 单元格 E3 和 I3 并将它们重新调整为 2 英寸 x 3 英寸 我尝试过使用 AddPictures对象处理程序和使
  • 在 MATLAB 中使用 FFT 的频率响应

    这是场景 使用频谱分析仪 我有输入值和输出值 样本数是32000采样率为2000样本 秒 输入是正弦波50 hz 输入为电流 输出为压力 单位 psi 我如何使用 MATLAB 根据这些数据计算频率响应 使用 MATLAB 中的 FFT 函
  • MATLAB 中最有效的矩阵求逆

    在 MATLAB 中计算某个方阵 A 的逆矩阵时 使用 Ai inv A should be the same as Ai A 1 MATLAB 通常会通知我这不是最有效的求逆方法 那么什么是更有效率的呢 如果我有一个方程系统 可能会使用
  • 在另一列中添加具有特定条件的一列,如 excel 的 sumif

    我有一个像这样的矩阵 A 1 2 2 3 3 4 4 5 5 6 6 8 7 9 8 5 9 4 现在我想添加第二列 条件是如果 limit 0 interval 3 且 limit limit interval 或者换句话说 当第 1 列
  • MATLAB - 如何将子图一起缩放?

    我在一张图中有多个子图 每个图的 X 轴是相同的变量 时间 每个图上的 Y 轴都不同 无论是它所代表的内容还是数据的大小 我想要一种同时放大所有图的时间尺度的方法 理想情况下 可以在其中一张图上使用矩形缩放工具 并让其他图相应地更改其 X
  • 为什么旋转 3D 点云后顶点法线会翻转?

    我有两个人脸 3D 点云样本 蓝色点云表示目标面 红色点云表示模板 下图显示目标面和模板面在不同方向上对齐 目标面大致沿 x 轴 模板面大致沿 y 轴 Figure 1 The region around the nose is displ
  • 在 MATLAB 中重命名文件

    我正在尝试以编程方式重命名工作目录中的文件a temp txt to b hello txt 您建议如何这样做 MATLAB中有一个简单的文件重命名函数吗 我认为您正在寻找 MOVEFILE
  • 如何找到在matlab中重复的矩阵的每一行的索引?

    我想找到矩阵中所有有重复项的行的索引 例如 A 1 2 3 4 1 2 3 4 2 3 4 5 1 2 3 4 6 5 4 3 要返回的向量将是 1 2 4 很多类似的问题建议使用unique函数 我已经尝试过 但我能得到的最接近我想要的功
  • 如何在 Matlab 中对数组应用低通或高通滤波器?

    有没有一种简单的方法可以将低通或高通滤波器应用于 MATLAB 中的数组 我对 MATLAB 的强大功能 或数学的复杂性 有点不知所措 需要一个简单的函数或一些指导 因为我无法从文档或网络搜索中找到答案 看着那 这filter http w
  • 如何在没有安装Visual Studio的另一台机器上使用Visual Studio生成的dll?

    我已经在 Visual Studio 2012 中生成了动态库 我想在另一台机器上使用该库 但我不想在远程机器上安装 Visual Studio 我有 mex 库和 dll 我想运行一个使用这两个库的脚本 当我运行脚本时 出现以下错误 缺少
  • Matlab 和 Python 中的优化算法(dog-leg trust-region)

    我正在尝试使用 Matlab 和 Python 中的狗腿信赖域算法求解一组非线性方程 在Matlab中有fsolve https www mathworks com help optim ug fsolve html其中此算法是默认算法 而
  • 在 Matlab 中将 datenum 转换为 datetime 的最快方法

    我在 Matlab 中将 datenum 转换为 datetime 时遇到问题 Given dnum floor now floor now 1 我尝试了以下方法 datenum dnum 但这没有用 我发现有效的方法是 datetime
  • matlab中类库的全局变量

    我有一些matlab声明的类 我如何声明所有类中都可见的常量 例如 这些常量可以是在所有类的方法中使用的物理常量 首先想到的是使用全局变量 还有更好的办法吗 最好在单独的文件中声明这些常量 包含常量的类是执行此操作的一种很好的干净方法 请参
  • 理解高斯混合模型的概念

    我试图通过阅读在线资源来理解 GMM 我已经使用 K 均值实现了聚类 并且正在了解 GMM 与 K 均值的比较 以下是我的理解 如有错误请指出 GMM 类似于 KNN 在这两种情况下都实现了聚类 但在 GMM 中 每个簇都有自己独立的均值和
  • 如何在 MATLAB 中绘制 3D 曲面图?

    我有一个像这样的数据集 0 1 0 2 0 3 0 4 1 10 11 12 13 2 11 12 13 14 3 12 13 14 15 4 13 14 15 16 我想在 matlab 中绘制 3D 曲面图 使列标题位于 y 轴 行标题
  • glpk.LPX 向后兼容性?

    较新版本的glpk没有LPXapi 旧包需要它 我如何使用旧包 例如COBRA http opencobra sourceforge net openCOBRA Welcome html 与较新版本的glpk 注意COBRA适用于 MATL
  • 有效地绘制大时间序列(matplotlib)

    我正在尝试使用 matplotlib 在同一轴上绘制三个时间序列 每个时间序列有 10 6 个数据点 虽然生成图形没有问题 但 PDF 输出很大 在查看器中打开速度非常慢 除了以栅格化格式工作或仅绘制时间序列的子集之外 还有其他方法可以获得
  • 对数据进行分布拟合 - MATLAB

    我正在尝试对从显微镜图像中收集的一些数据进行分布 我们知道 152 左右的峰值是由于泊松过程造成的 我想将分布拟合到图像中心的大密度 同时忽略高强度数据 我知道如何将正态分布拟合到数据 红色曲线 但它不能很好地捕获右侧的重尾 尽管泊松分布应

随机推荐

  • 一分钟快速利用ChatGPT生成PPT

    目标 让AI给我们生成一篇PPT报告 首先介绍一下什么是ChatGPT ChatGPT是一种基于自然语言处理技术的人工智能应用 它使用OpenAI的GPT模型来自动生成自然语言的回复 可以作为虚拟助手 客服机器人等方面的应用 与其他机器学习
  • SOC芯片中VIP和IP之间的路由关系

    通用PAD是双向端口 inout 这就意味着每个通用PAD可以根据需要被配置成输入或输出 如图1所示 图1 ind是输入端口 do是输出端口 obe是输出使能信号 当obe为低电平时 PAD作为输入端口使用 三态门关闭 do高阻 片外数据通
  • nm 命令显示

    用途 显示关于对象文件 可执行文件以及对象文件库里的符号信息 语法 nm A C X 32 64 32 64 f h l p r T v B P e g u d o x t Format File 描述 nm 命令显示关于指定 File 中
  • FISCO BCOS工程师常用的性能分析工具推荐

    FISCO BCOS是完全开源的联盟区块链底层技术平台 由金融区块链合作联盟 深圳 简称金链盟 成立开源工作组通力打造 开源工作组成员包括博彦科技 华为 深证通 神州数码 四方精创 腾讯 微众银行 亦笔科技和越秀金科等金链盟成员机构 代码仓
  • Hibernate学习笔记 查询简介

    创建实体类 在介绍Hibernate查询语言之前 首先我们来建立一下数据库 这里直接使用了MySQL自带的样例数据库world 如果你没有安装MySQL那么需要安装一下 并且在安装的时候选择安装样例数据库 安装完成之后 应该能在MySQL中
  • 《区块链技术与应用》学习笔记13——ETH权益证明

    矿工挖矿是为了取得出块奖励 获取收益 而系统给予出块奖励的目的是激励矿工参与区块链系统维护 进行记账 而挖矿本质上是看矿工投入资金来决定的 投入资金买设备 gt 设备决定算力 gt 算力比例决定收益 那么 为什么不直接拼 钱 呢 现状是用钱
  • getchar与scanf的区别

    getchar getchar先读取一个字符放到ch里面去 如果这个字符不等于EOF 就进入循环 打印这个字符 当getchar读到文件末尾或者结束时 它会返回一个EOF 此时结束循环 输入A 输出A 输入b 输出b 当我们想要结束时 输入
  • 仿牛客社区项目(第五章)(上)

    文章目录 第三章 Kafka 构建TB级异步消息系统 一 阻塞队列 1 阻塞队列测试方法 2 测试结果 二 Kafka入门 1 Kafka下载 2 Kafka安装与配置 3 Kafka的启动 4 Kafka使用 三 Spring整合Kafk
  • SpringBoot:使用 @Lazy 注解懒加载

    为什么需要懒加载 我们知道 在 SpringBoot 应用程序启动的时候 会实例化一些对象加入到 IOC 容器里边 这个过程是非常耗时的 那我们想要减少这个耗时的过程就需要 Lazy 注解 对象加入容器的时机 如下代码 package co
  • [JQuery]分页插件jQuery pager plugin功能扩展

    原文地址 http blog csdn net starfd article details 25292019 http blog csdn net nz360 article details 52326232 牛逼分页 http www
  • python:使用 PythonMagick 生成 icon 图标

    目录 PythonMagick 下载与安装 把图片转化成 icon PythonMagick 下载与安装 使用 pip install PythonMagick是不行的 会提示没有这个模块 因此 需要到第三方去把该模块下载下来 再安装 下载
  • 【数据库】B树、B+树、索引

    B树 B 树 索引 二叉树是二分树 多分树是二叉树的推广 多分树主要适用于静态的索引数据文件 在插入和删除的时候需要把插入位置之后的每个记录都要向后移动 从而导致增加新的索引项和索引页块 需要对外存上的页块进行大量的调整 因此对于经常需要插
  • flutter聊天界面-TextField输入框buildTextSpan实现@功能展示高亮功能

    flutter聊天界面 TextField输入框buildTextSpan实现 功能展示高亮功能 最近有位朋友讨论的时候 提到了输入框的高亮展示 在flutter TextField中需要插入特殊样式的标签 比如 请 张三 回答一下 这一串
  • TypeScript Variable Type: never

    目录 never 的定义 never 的特点 never 的定义 never 是其它类型 包括 null 和 undefined 的子类型 代表从不会出现的值 never 通常有两种表现形式 抛出异常 返回值为 never 的函数可以是抛出
  • mysql查询课程的前三名

    看了下网上的查询课程前三名的 真是五花八门 真是 I服了U还各种错 看来啥事还是得自己动手 表结构 CREATE TABLE student id bigint 20 NOT NULL AUTO INCREMENT s no bigint
  • pygame安装教程(python)

    1 安装好python 2 打开cmd 3 输入 pip version 如果显示未安装pip 那么输入pip 等待安装完毕 检查pip是否安装 pip version 4 输入 pip install pygame 可能会出现这种错误 输
  • 软件产品设计的学习总结

    一个成功的软件产品通常需要包含以下几个方面 可靠性 软件产品需要稳定可靠 能够正确地运行 并且在用户使用中没有频繁崩溃或者其他问题 安全性 软件产品在使用过程中需要保证数据的安全性 包括用户的个人和商业隐私等方面 易用性 软件产品需要具有高
  • JAVA中的高并发,解决高并发的方案

    java高并发 如何解决 什么方式解决 一 什么是高并发 二 高并发的解决方法有两种 三 追加 一 什么是高并发 1 1 高并发 High Concurrency 是互联网分布式系统架构设计中必须考虑的因素之一 它通常是指 通过设计保证系统
  • 实现spawn-fcgi的守护监控功能

    http blog csdn net cleanfield article details 6409830 这几天做公司平台的api部分 决定采用c c 版本的fcgi 于是采用了spawn fcgi作为启动框架 不过遇到个问题就是spaw
  • matlab练习程序(神经网络识别mnist手写数据集)

    记得上次练习了神经网络分类 不过当时应该有些地方写的还是不对 这次用神经网络识别mnist手写数据集 主要参考了深度学习工具包的一些代码 mnist数据集训练数据一共有28 28 60000个像素 标签有60000个 测试数据一共有28 2