随机梯度下降求解svm(MATLAB)

2023-10-26

本文转载自:http://blog.csdn.net/orangehdc/article/details/38682501

随机梯度下降法(Stochastic Gradient Descent)求解以下的线性SVM模型:


w的梯度为:


传统的梯度下降法需要把所有样本都带入计算,对于一个样本数为n的d维样本,每次迭代求一次梯度,计算复杂度为O(nd) ,当处理的数据量很大而且迭代次数比较多的时候,程序运行时间就会非常慢。

随机梯度下降法每次迭代不再是找到一个全局最优的下降方向,而是用梯度的无偏估计 来代替梯度。每次更新过程为:


由于随机梯度每次迭代采用单个样本来近似全局最优的梯度方向,迭代的步长应适当选小一些以使得随机梯度下降过程尽可能接近于真实的梯度下降法。


下面我用matlab写的一个demo,速度不是很快,跑USPS数据库(二进制格式)csdn下载链接(mat格式),要五分钟,准确率88%左右,效果一般:

[cpp]  view plain  copy
  1. clear;  
  2. load E:\dataset\USPS\USPS.mat;  
  3. % data format:  
  4. % Xtr n1*dim  
  5. % Xte n2*dim  
  6. % Ytr n1*1   
  7. % Yte n2*1  
  8. % warning: labels must range from 1 to n, n is the number of labels  
  9. % other label values will make mistakes  
  10. u=unique(Ytr);  
  11. Nclass=length(u);  
  12.   
  13. allw=[];allb=[];  
  14. step=0.01;C=0.1;  
  15. param.iterations=1;  
  16. param.lambda=1e-3;  
  17. param.biaScale=1;  
  18. param.t0=100;  
  19.   
  20. tic;  
  21. for classname=1:1:Nclass    
  22.     temp_Ytr=change_label(Ytr,classname);  
  23.     [w,b] = sgd_svm(Xtr,temp_Ytr, param);  
  24.     allw=[allw;w];  
  25.     allb=[allb;b];  
  26.     fprintf('class %d is done \n', classname);  
  27. end  
  28.   
  29. [accuracy predict_label]=my_svmpredict(Xte, Yte, allw, allb);  
  30. fprintf(' accuracy is  %.2f percent.\n' ,  accuracy*100 );  
  31. toc;  


[cpp]  view plain  copy
  1. function [temp_Ytr] = change_label(Ytr,classname)  
  2. temp_Ytr=Ytr;  
  3. tep2=find(Ytr~=classname);  
  4. tep1=find(Ytr==classname);  
  5. temp_Ytr(tep2)=-1;  
  6. temp_Ytr(tep1)= 1;  

[cpp]  view plain  copy
  1. function [true_W,b]=sgd_svm(X,Y,param)  
  2. % input:   
  3. % X is n*dim  
  4. % Y is n*1 (label is 1 or 0)  
  5. % output:  
  6. % true_W is dim*1 ,so the score is X*W'+b  
  7. % b      is 1*1 number  
  8. iterations=param.iterations;%10  
  9. lambda=param.lambda;%1e-3  
  10. biaScale=param.biaScale;%0  
  11. t0=param.t0;%100  
  12. t=t0;  
  13.   
  14. w=zeros(1,size(X,2));  
  15. bias=0;  
  16.   
  17. for k=1:1:iterations  
  18.     for i=1:1:size(X,1)  
  19.         t=t+1;  
  20.         alpha = (1.0/(lambda*t));  
  21.         if(Y(i)*(X(i,:)*w'+bias)<1)  
  22.             bias=bias+alpha*Y(i)*biaScale;  
  23.             w=w+alpha*Y(i,1).*X(i,:);  
  24.         end  
  25.     end  
  26. end  
  27. b=bias;  
  28. true_W=w;  


[cpp]  view plain  copy
  1. function [accuracy predict_label]=my_svmpredict(Xte, Yte, allw, allb)  
  2. % allw is nclass * dim  
  3. % allb is nclass * 1  
  4. % Yte must range from 1 to nclass, other label values will make mistakes  
  5. score = Xte * allw'+repmat(allb',[size(Bte,1),1]);  
  6. [bb  c]=sort(score,2,'descend');  
  7. predict_label=c(:,1);  
  8. temp = predict_label((predict_label-Yte)==0);  
  9. right=size( temp,1 );  
  10. accuracy=right/size(Yte,1);  
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

随机梯度下降求解svm(MATLAB) 的相关文章

  • 如何让MCR启动时间快

    我将 matlab 程序转换为 net 程序集 即 dll 文件 我制作了一个控制台 C 应用程序 添加了 dll 文件并从 php 调用它 每次调用 exe 时都会调用 MCR 如何使 MCR 在服务器启动时初始化 并且即使在一段时间后调
  • 在 MATLAB 中使用 FFT 的频率响应

    这是场景 使用频谱分析仪 我有输入值和输出值 样本数是32000采样率为2000样本 秒 输入是正弦波50 hz 输入为电流 输出为压力 单位 psi 我如何使用 MATLAB 根据这些数据计算频率响应 使用 MATLAB 中的 FFT 函
  • 检查Matlab中脚本需要使用的函数

    我有一个别人写的代码包 我正在运行一个脚本 它调用一些函数 这些函数又调用更多函数 等等 我想获取不是 MATLAB 内置函数但属于包的一部分的函数列表 我尝试使用matlab codetools requiredFilesAndProdu
  • MATLAB 滚动图

    我有一个脑电图数据库 我想绘制它 数据库是一个19 1000 134的矩阵 其中 19 是通道数 在第一种方法中 我只使用一个渠道 1000 个样本大小 采样率为 500 Hz 时为 1000 个点 即 2 秒数据 134 epochs的数
  • 如何在 Matlab 中使用谷歌翻译?

    我正在编写一个程序 使用 Matlab 列出电影字幕文件中的所有唯一单词 现在我有一个独特的单词列表 我想将其翻译成我的语言并在观看电影之前了解其含义 有谁知道如何在 Matlab 中使用 Google Translate 以便完成我的脚本
  • 在 matlab 中求 3d 峰的体积

    现在我有一个带有峰值的 3D 散点图 我需要找到其体积 我的数据来自图像 因此 x 和 y 值表示 xy 平面上的像素位置 z 值是每个像素的像素值 这是我的散点图 scatter3 x y z 20 z filled 我试图找到数据峰值的
  • Matlab 字段名索引[重复]

    这个问题在这里已经有答案了 所以我有一个包含多个表的元胞数组 我试图访问表的第一个列名称 c table1 table2 table3 以下两行都给了我错误 fieldnames c 1 1 fieldnames c 1 1 Error i
  • 如何在没有安装Visual Studio的另一台机器上使用Visual Studio生成的dll?

    我已经在 Visual Studio 2012 中生成了动态库 我想在另一台机器上使用该库 但我不想在远程机器上安装 Visual Studio 我有 mex 库和 dll 我想运行一个使用这两个库的脚本 当我运行脚本时 出现以下错误 缺少
  • Matlab 和 Python 中的优化算法(dog-leg trust-region)

    我正在尝试使用 Matlab 和 Python 中的狗腿信赖域算法求解一组非线性方程 在Matlab中有fsolve https www mathworks com help optim ug fsolve html其中此算法是默认算法 而
  • 如何使用 MATLAB 的 substruct 函数创建表示使用“end”的引用的结构?

    我想使用substruct http www mathworks com help matlab ref substruct html函数创建一个结构体以供使用subsref 目的是使用索引字符串subsref而不是通常的 符号 因为我正在
  • 通过 Matlab 访问 Physionet 的 ptbdb 中的数据库

    我首先设置系统 old path which rdsamp if isempty old path rmpath old path 1 end 8 end wfdb url http physionet org physiotools ma
  • 如何在文本集中创建所有字符组合?

    例如 我有这样的文本集 第 1 栏 a b 第 2 栏 l m n 第 3 栏 v w x y 我想将它们组合起来以获得如下输出 alv alw alx aly amv amw amx amy 这将输出 24 种文本组合 如果我只使用前两列
  • glpk.LPX 向后兼容性?

    较新版本的glpk没有LPXapi 旧包需要它 我如何使用旧包 例如COBRA http opencobra sourceforge net openCOBRA Welcome html 与较新版本的glpk 注意COBRA适用于 MATL
  • 有效地绘制大时间序列(matplotlib)

    我正在尝试使用 matplotlib 在同一轴上绘制三个时间序列 每个时间序列有 10 6 个数据点 虽然生成图形没有问题 但 PDF 输出很大 在查看器中打开速度非常慢 除了以栅格化格式工作或仅绘制时间序列的子集之外 还有其他方法可以获得
  • 图像处理 - 使用 opencv 进行服装分割

    我正在使用 opencv 进行服装特征识别 第一步 我需要通过从图像中移除脸部和手来分割 T 恤 任何建议表示赞赏 我建议采用以下方法 Use 阿德里安 罗斯布鲁克的用于检测皮肤的皮肤检测算法 谢谢罗莎 格隆奇以获得他的评论 在方差图上使用
  • Matlab 的 imresize 函数中用于插值的算法是什么?

    我正在使用 Matlab Octaveimresize 对给定的二维数组重新采样的函数 我想了解如何使用特定的插值算法imresize works 我在Windows上使用八度 e g A 1 2 3 4 是一个二维数组 然后我使用命令 b
  • 给定协方差矩阵,在Matlab中生成高斯随机变量

    Given a M x M期望的协方差 R 以及所需数量的样本向量 N计算一个N x M高斯随机向量 X在普通 MATLAB 中 即不能使用r mvnrnd MU SIGMA cases 不太确定如何解决这个问题 通常你需要一个协方差并且意
  • 检测数据集中线性行为的算法

    我已经发布了一个关于对数据集的一部分进行多项式拟合的算法 https stackoverflow com q 17595932 2320757前一段时间收到一些建议去做我想做的事 但我现在面临另一个问题 我尝试应用答案中建议的想法 我的目标
  • Matlab dec2bin 给出错误的值

    我正在使用 Matlab 的 dec2bin 将十进制数转换为二进制字符串 但是 我得到了错误的结果 例如 gt gt dec2bin 13339262925365424727 ans 101110010001111010010100111
  • 如何在MATLAB中显示由三个矩阵表示的图像?

    我有 3 个相同大小的 2D 矩阵 假设 200 行和 300 列 每个矩阵代表三种 基本 颜色 红色 绿色和蓝色 之一的值 矩阵的值可以在 0 到 255 之间 现在我想组合这些矩阵以将它们显示为彩色图像 200 x 300 像素 我怎样

随机推荐

  • URL、URI和URN之间的区别

  • 程序员应该掌握的 10 个搜索技巧

    在今天 用户可以通过搜索引擎轻松找出自己想要的信息 但还是难以避免结果不尽如人意的情况 实际上 用户仅需掌握几个常用技巧即可轻松化解这种尴尬 下面介绍 10 个在进行 Google 搜索时可以使用的便捷技巧 其他搜索引擎也支持这 10 种技
  • C++外观模式

    外观模式 1 外观模式简介及应用场景 外观者模式其实就是相当于对一组子系统功能的组合 对外提供统一的简单接口的模式 当我们在实际开发中 一般情况下是一个单独的子系统对应的是一个独立的功能模块 但是随着业务功能的不断增加 对应子系统的迭代必然
  • CentOS8 服务篇4:FTP文件传输服务搭建与配置

    FTP 文件传输服务三种配置模式 匿名模式 本地用户模式 虚拟用户模式 安装ftp服务 安装完后再根据不同模式进行配置 root localhost yum repos d yum install y vsftpd ftp vsftpd是搭
  • Qt中qss样式表

    qss样式表是用于设置QT程序UI界面中控件的背景图片 大小 字体颜色 字体类型 按钮状态变化等属性 美化UI界面 实现界面和程序的分离 可以快速切换皮肤 1 基本语法 selector attribute value 说明 selecto
  • Java生成exe执行文件

    一 准备工作 下载可将jar包转换的工具EXE4J工具 下载地址为 https www ej technologies com download exe4j files 下载完成 直接点击下一步安装 直到安装完成 导出项目jar包 按以下步
  • javaFile类知识点总结

    1 File类 Java io File类是文件和目录路径名的抽象表示 主要用于文件和目录的创建 查找 删除等操作 File中的静态成员变量 pathSeparator与系统有关的路径分隔符 File pathSeparator 代表路径分
  • android系统删除apk的广播,研究androidapk安装卸载等产生的系统广播

    想更加清楚的了解 android 系统在安装 卸载时产生的系统广播 于是写了一个 demo 来做监听 BroadReceiver 配置如下 html 这里有一点要注意 需配置 否则收不到广播 1 当你第一次安装某个应用的时候 java 10
  • 干货

    SpringCloud的从整体架构上看 相对来说是完整的 庞大的 它不仅仅是一个基础性架构工具 它为微服务架构提供了一个 全家桶 的套餐 每一个模块关注各自的职能 并且能够很好地配合与协作 能够帮助入门者快速搭建起一套微服务架构的服务 内容
  • MyBatis之使用JSONObject代替JavaBean优雅返回多表查询结果

    项目中需要返回多个表的查询结果 比如user表中的用户信息和user个人的所在班的班级信息 目前我们有user实体类和class实体类 一般情况下如果是单表查询 比如查询user信息 那么查询的返回值就是一个user对象或一个user对象列
  • Qt_Qt报错multiple target patterns

    去看看pro文件中的路径是否有问题
  • ARM7的三级流水线过程

    看到汇编中很多关于程序返回与中断返回时处理地址都很特别 仔细想想原来是流水线作用的效果 所以 决定总结学习下ARM流水线 ARM7处理器采用3级流水线来增加处理器指令流的速度 能提供0 9MIPS MHz的指令处理速度 PS MIPS Mi
  • Android RxJava第一弹之原理详解、使用详解、常用场景(基于Rxjava2.0)

    Android RxJava第一弹之原理详解 使用详解 常用场景 基于Rxjava2 0 Android RxJava第二弹之RxJava封装库 RxJava Animation RxJava Glide Android RxJava第三弹
  • C语言数据结构复杂度

    文章目录 前言 什么是数据结构 什么是算法 算法效率 算法的复杂度 时间复杂度 时间复杂度的概念 大O的渐进表示法 常见时间复杂度计算举例 空间复杂度 常见复杂度对比 前言 从这篇博客开始为数据结构与算法的相关内容 数据结构比较难 博主建议
  • Leecode初级算法字符串——验证回文串

    给定一个字符串 验证它是否是回文串 只考虑字母和数字字符 可以忽略字母的大小写 说明 本题中 我们将空字符串定义为有效的回文串 示例 1 输入 A man a plan a canal Panama 输出 true 解释 amanaplan
  • tcp三次握手

    在TCP IP协议中 TCP协议提供可靠的连接服务 采用三次握手建立一个连接 第一次握手 建立连接时 客户端发送syn包 syn j 到服务器 并进入SYN SEND状态 等待服务器确认 第二次握手 服务器收到syn包 必须确认客户的SYN
  • Ubuntu18.04安装教程

    Ubuntu18 04安装教程 一 准备工作 1 下载 Ubuntu 镜像 2 制作U盘启动盘 3 给 Ubuntu 分配硬盘空间 二 安装 Ubuntu18 04 1 设置启动项 2 正式安装 1 选择语言 2 键盘布局 3 无线连网 4
  • Python 字典10种意想不到的用途

    Python 字典10种意想不到的用途 1 switch case语句 2 记忆化 3 稀疏矩阵 4 图表 5 状态机 6 计数频率 7 XML HTML 解析 8 配置文件 9 缓存 API 响应 10 编码和解码数据 源码 参考 Pyt
  • GIS gentools jar包使用

    package ghgf import java io File import java io IOException import java io Serializable import java net MalformedURLExce
  • 随机梯度下降求解svm(MATLAB)

    本文转载自 http blog csdn net orangehdc article details 38682501 随机梯度下降法 Stochastic Gradient Descent 求解以下的线性SVM模型 w的梯度为 传统的梯度