基于梯度下降算法求解线性回归

2023-11-16

点击上方“小白学视觉”,选择加"星标"或“置顶

重磅干货,第一时间送达

01. 线性回归(Linear Regression)

梯度下降算法在机器学习方法分类中属于监督学习。利用它可以求解线性回归问题,计算一组二维数据之间的线性关系,假设有一组数据如下下图所示

其中X轴方向表示房屋面积、Y轴表示房屋价格。我们希望根据上述的数据点,拟合出一条直线,能跟对任意给定的房屋面积实现价格预言,这样求解得到直线方程过程就叫线性回归,得到的直线为回归直线,数学公式表示如下:

02. 梯度下降

 

03. 代码实现各步

训练数据读入

List<DataItem> items = new ArrayList<DataItem>();
File f = new File(fileName);
try {
    if (f.exists()) {
        BufferedReader br = new BufferedReader(new FileReader(f));
        String line = null;
        while((line = br.readLine()) != null) {
            String[] data = line.split(",");
            if(data != null && data.length == 2) {
                DataItem item = new DataItem();
                item.x = Integer.parseInt(data[0]);
                item.y = Integer.parseInt(data[1]);
                items.add(item);
            }
        }
        br.close();
    }
} catch (IOException ioe) {
    System.err.println(ioe);
}
return items;

归一化处理

float min = 100000;
float max = 0;
for(DataItem item : items) {
    min = Math.min(min, item.x);
    max = Math.max(max, item.x);
}
float delta = max - min;
for(DataItem item : items) {
    item.x = (item.x - min) / delta;
}

梯度下降

int repetion = 1500;
float learningRate = 0.1f;
float[] theta = new float[2];
Arrays.fill(theta, 0);
float[] hmatrix = new float[items.size()];
Arrays.fill(hmatrix, 0);
int k=0;
float s1 = 1.0f / items.size();
float sum1=0, sum2=0;
for(int i=0; i<repetion; i++) {
    for(k=0; k<items.size(); k++ ) {
        hmatrix[k] = ((theta[0] + theta[1]*items.get(k).x) - items.get(k).y);
    }
    for(k=0; k<items.size(); k++ ) {
        sum1 += hmatrix[k];
        sum2 += hmatrix[k]*items.get(k).x;
    }
    sum1 = learningRate*s1*sum1;
    sum2 = learningRate*s1*sum2;
    // 更新 参数theta
    theta[0] = theta[0] - sum1;
    theta[1] = theta[1] - sum2;
}
return theta;

价格预言 - theta表示参数矩阵

float result = theta[0] + theta[1]*input;
return result;

线性回归Plot绘制

int w = 500;
int h = 500;
BufferedImage plot = new BufferedImage(w, h, BufferedImage.TYPE_INT_ARGB);
Graphics2D g2d = plot.createGraphics();
g2d.setRenderingHint(RenderingHints.KEY_ANTIALIASING, RenderingHints.VALUE_ANTIALIAS_ON);
g2d.setPaint(Color.WHITE);
g2d.fillRect(0, 0, w, h);
g2d.setPaint(Color.BLACK);
int margin = 50;
g2d.drawLine(margin, 0, margin, h);
g2d.drawLine(0, h-margin, w, h-margin);
float minx=Float.MAX_VALUE, maxx=Float.MIN_VALUE;
float miny=Float.MAX_VALUE, maxy=Float.MIN_VALUE;
for(DataItem item : series1) {
    minx = Math.min(item.x, minx);
    maxx = Math.max(maxx, item.x);
    miny = Math.min(item.y, miny);
    maxy = Math.max(item.y, maxy);
}
for(DataItem item : series2) {
    minx = Math.min(item.x, minx);
    maxx = Math.max(maxx, item.x);
    miny = Math.min(item.y, miny);
    maxy = Math.max(item.y, maxy);
}
// draw X, Y Title and Aixes
g2d.setPaint(Color.BLACK);
g2d.drawString("价格(万)", 0, h/2);
g2d.drawString("面积(平方米)", w/2, h-20);
// draw labels and legend
g2d.setPaint(Color.BLUE);
float xdelta = maxx - minx;
float ydelta = maxy - miny;
float xstep = xdelta / 10.0f;
float ystep = ydelta / 10.0f;
int dx = (w - 2*margin) / 11;
int dy = (h - 2*margin) / 11;
// draw labels
for(int i=1; i<11; i++) {
    g2d.drawLine(margin+i*dx, h-margin, margin+i*dx, h-margin-10);
    g2d.drawLine(margin, h-margin-dy*i, margin+10, h-margin-dy*i);
    int xv = (int)(minx + (i-1)*xstep);
    float yv = (int)((miny + (i-1)*ystep)/10000.0f);
    g2d.drawString(""+xv, margin+i*dx, h-margin+15);
    g2d.drawString(""+yv, margin-25, h-margin-dy*i);
}
// draw point
g2d.setPaint(Color.BLUE);
for(DataItem item : series1) {
    float xs = (item.x - minx) / xstep + 1;
    float ys = (item.y - miny) / ystep + 1;
    g2d.fillOval((int)(xs*dx+margin-3), (int)(h-margin-ys*dy-3), 7,7);
}
g2d.fillRect(100, 20, 20, 10);
g2d.drawString("训练数据", 130, 30);
// draw regression line
g2d.setPaint(Color.RED);
for(int i=0; i<series2.size()-1; i++) {
    float x1 = (series2.get(i).x - minx) / xstep + 1;
    float y1 = (series2.get(i).y - miny) / ystep + 1;
    float x2 = (series2.get(i+1).x - minx) / xstep + 1;
    float y2 = (series2.get(i+1).y - miny) / ystep + 1;
    g2d.drawLine((int)(x1*dx+margin-3), (int)(h-margin-y1*dy-3), (int)(x2*dx+margin-3), (int)(h-margin-y2*dy-3));
}
g2d.fillRect(100, 50, 20, 10);
g2d.drawString("线性回归", 130, 60);
g2d.dispose();
saveImage(plot);

04. 总结

本文通过最简单的示例,演示了利用梯度下降算法实现线性回归分析,使用更新收敛的算法常被称为LMS(Least Mean Square)又叫Widrow-Hoff学习规则,此外梯度下降算法还可以进一步区分为增量梯度下降算法与批量梯度下降算法,这两种梯度下降方法在基于神经网络的机器学习中经常会被提及,对此感兴趣的可以自己进一步探索与研究。

下载1:OpenCV-Contrib扩展模块中文版教程

在「小白学视觉」公众号后台回复:扩展模块中文教程即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。

下载2:Python视觉实战项目52讲

在「小白学视觉」公众号后台回复:Python视觉实战项目即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。

下载3:OpenCV实战项目20讲

在「小白学视觉」公众号后台回复:OpenCV实战项目20讲即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。

交流群

欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器、自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN、算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

基于梯度下降算法求解线性回归 的相关文章

随机推荐

  • 移动端与服务端交互安全方案

    系统流程图 验签 解决问题 1 身份验证 是否是我规定的那个人 2 防篡改 是否被第三方劫持并篡改参数 3 防重放 是否重复请求 具体算法 1 约定appKey 保证该调用请求是平台授权过的调用方发出的 保证请求方唯一性 2 将appKey
  • 常用巡检命令

    思科设备 show version 查看系统软 硬件版本信息 show running config 查看设备运行的配置信息 show ip interfaces brief 查看所有接口摘要信息 show interfaces 查看全部接
  • Java中弹出对话框中的几种方式

    1 显示一个错误对话框 该对话框显示的 message 为 alert JOptionPane showMessageDialog null alert alert JOptionPane ERROR MESSAGE 2 显示一个内部信息对
  • JavaSE知识体系目录

    文章目录 Java基础语法知识 关键字 运算符 数据类型 流程控制语句 面向对象 异常和常用类 集合 Collection Map IO 字节流 字符流 线程 网络 Java基础语法知识 关键字 运算符 算数运算符 比较运算符 赋值运算符
  • CSS盒模型自适应布局——calc与box-sizing

    CSS盒模型 1 CSS中盒模型分为两种 第一种是W3C的标准模型 即盒子的宽高等于内容的宽高 盒子的padding和border不计算在内 第二种是IE的传统模型 IE6以下 不含IE6 称为怪异模式或者QuirksMode 即盒子的宽高
  • sklearn中的LASSO

    LASSO import numpy as np import matplotlib pyplot as plt np random seed 42 x np random uniform 3 0 3 0 size 100 X x resh
  • pytorch 笔记: Swin-Transformer 代码

    理论部分 论文笔记 Swin Transformer Hierarchical Vision Transformer using Shifted Windows UQI LIUWJ的博客 CSDN博客 源码部分 Swin Transform
  • Java占位符总结

    文章目录 实现方式 方式一 jdk1 8 java text MessageFormat 方式二 Log4j javaorg slf4j helpers MessageFormatter 方式三 commons text org apach
  • linux下搭建goprotobuf

    linux下搭建goprotobuf 1 搭建go语言环境 参考官网 http golang org doc install 主要是设置好GO PATH这个变量 这个就是你的工作环境目录 可以使用go env来查询设置好了没 2 搭建pro
  • python中列表概念,Python基本数据类型——List(列表)

    1 序列 1 1 序列的基本概念 序列是Python中最基本的一种数据结构 序列用于保存一组有序的数据 所有的数据在序列当中都有一个唯一的位置 索引 并且序列中的数据会按照添加的顺序来分配索引 数据结构是指计算机中数据存储的方式 1 2 序
  • Pinpoint--基础--04--请求追踪和字节码插装

    Pinpoint 基础 04 请求追踪和字节码插装 备注 背景 英文原文 https naver github io pinpoint 1 8 4 techdetail html Dapper原文 https ai google resea
  • 00后卷王自述,我真的很卷吗?

    前段时间我去面试了一个软件测试公司 成功拿到了offer 薪资也从10k涨到了18k 对于工作都还没两年的我来说 还是比较满意的 毕竟有些工作了3到4年的可能还没有我的高 在公司一段时间后大家都说我是卷王 其实我也没办法 自己家里条件不是很
  • Pytorch ----注意力机制与自注意力机制的代码详解与使用

    注意力机制的核心重点就是让网络关注到它更需要关注的地方 当我们使用卷积神经网络去处理图片的时候 我们会更希望卷积神经网络去注意应该注意的地方 而不是什么都关注 我们不可能手动去调节需要注意的地方 这个时候 如何让卷积神经网络去自适应的注意重
  • Java基础6--对象和类

    Java基础6 对象和类 文章目录 Java基础6 对象和类 概念 Java中的对象 Java 中的类 构造方法 创建对象 访问实例变量和方法 Java 内部类 非静态内部类 静态内部类 从内部类访问外部类成员 import 语句 概念 对
  • 异步编程CompletableFuture系列3 接口合并

    直接上代码 import java util concurrent CompletableFuture import java util concurrent TimeUnit public class Test3 public stati
  • 没有找到MSVCR90D.DLL的两种解决方法

    1 没有找到MSVCR90D DLL的简单解决方法之一 在VS2005 2008下写C C 程序时 偶然会出现这样的错误 这样的错误一般会出现在第一次运行项目时 或重装VS后 这里提供一种简单的解决办法 希望对初学者有用 打开项目的属性页
  • 【CCPC-2019】【江西省赛】【霖行】J-Worker

    CCPC 2019 江西省赛 霖行 J Worker 题目 Avin meets a rich customer today He will earn 1 million dollars if he can solve a hard pro
  • python中用pickle打开文件报错:EOFError: Ran out of input

    用pickle dump 保存文件之后如果不关闭文件就会出现此错误 f open test pkl wb pickle dump dict f f close 后面添加关闭就不会报错
  • JAVA - 判断两个浮点数相等

    背景知识 float型和double型是JAVA的基本类型 用于浮点数表示 在JAVA中float型占4个字节32位 double型占8个字节64位 一般比较适合用于工程测量计算中 其在内存里的存储结构如下 float型 符号位 1 bit
  • 基于梯度下降算法求解线性回归

    点击上方 小白学视觉 选择加 星标 或 置顶 重磅干货 第一时间送达 01 线性回归 Linear Regression 梯度下降算法在机器学习方法分类中属于监督学习 利用它可以求解线性回归问题 计算一组二维数据之间的线性关系 假设有一组数