【机器学习 - 4】:线性回归算法

2023-10-27

线性回归


线性回归的理解

线性回归:判断数据的特征和目标值之间具有一定的线性关系。
简单线性回归:样本的特征只有一个,用线性回归法进行预测,叫做简单线性回归。
多元线性回归:样本的特征有两个或两个以上,叫做多元线性回归。

如下图所示,为线性回归模型
在这里插入图片描述

损失函数

损失函数:np.sum((y`-y)**2),即预测值和真实值的差值之和。因为有复数的存在所以求平方,不用绝对值的原因:用平方方便后续的求导和求极值。

最小二乘法
在这里插入图片描述
一些推导过程:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
重要结论
在这里插入图片描述

简单线性回归


  1. 先画出数据的散点图
import numpy as np
import matplotlib.pyplot as plt

x = np.array([1,2,3,4,5])
y = np.array([2,1,3,2,5])

plt.scatter(x,y)
plt.axis([0,6,0,6])
plt.show()

在这里插入图片描述

  1. 对数据进行处理,求出a和b
# y = a * x + b
# 先求出平均值
x_mean = np.mean(x)
y_mean = np.mean(y)

num = 0.0 # 分子
d = 0.0 # 分母

for x_i, y_i in zip(x, y):
    num += (x_i-x_mean)*(y_i-y_mean)
    d += (x_i-x_mean)**2
a = num/d
b = y_mean-a*x_mean

在这里插入图片描述

  1. 求出y`,并画出预测直线,求出这条线,使得真实值与预测值的差值达到最小。
y_hat = a * x + b

plt.plot(x, y_hat, color='r')
plt.scatter(x,y)
plt.axis([0, 6, 0, 6])
plt.show()

在这里插入图片描述

封装线性回归算法

import numpy as np


class SimpleLinearRegression:
    def __init__(self):
        self.a_ = None
        self.b_ = None
        self.x_mean = None
        self.y_mean = None

    def fit(self, x_train, y_train):
        self.x_mean = np.mean(x_train)
        self.y_mean = np.mean(y_train)

        num = 0.0   # 分子
        d = 0.0     # 分母
        for x_i, y_i in zip(x_train, y_train):
            num += (x_i-self.x_mean) * (y_i-self.y_mean)
            d += (x_i-self.x_mean)**2
        self.a = num/d
        self.b = self.y_mean - self.a * self.x_mean

        return self

    def predict(self, x_test):
        return self.a * x_test + self.b

    def __repr__(self):
        return 'SimpleLinearRegression()'

在jupter notebook中导入运行
在这里插入图片描述

线性回归算法


使用线性回归算法的前提:数据具有一定的线性关系。

我们希望找到一条最佳拟合的直线方程,y=ax+b,对于每一个样本点,在这个直线方程上都有一个预测值,预测值和真实值有一定的差距,我们希望这些样本到直线方程的差距之和最小。

计算差距:sqrt(|y-y`|**2),使用平方并开根号的方式更适合我们进行求导或求值。

在sklearn中调用线性回归算法

  1. 导入模块
from sklearn.linear_model import LinearRegression
import numpy as np
import matplotlib.pyplot as plt
  1. 准备数据,训练模型
# 准备数据
x = np.array([1,2,3,4,5])
y = np.array([3,1,4,3,6])

lin_reg = LinearRegression()
lin_reg.fit(x.reshape(-1,1), y)	# 拟合,训练模型

在这里插入图片描述

  1. 画出散点图和预测直线
plt.scatter(x, y)
plt.plot(x, lin_reg.predict(x.reshape(-1,1)), color='r')
plt.axis([0,6,0,7])
plt.show()

在这里插入图片描述

向量化运算

如下图所示,向量化运算更加方便,向量点乘是先乘后加,原理一样。
在这里插入图片描述

x = np.array([1,2,3,4,5])
y = np.array([3,1,4,3,6])
def lin_fit(x, y):
    x_mean = np.mean(x)
    y_mean = np.mean(y)
    num = 0.0
    d = 0.0
    num = (x-x_mean).dot(y-y_mean)
    d = (x-x_mean).dot(x-x_mean)
    a = num/d
    b = y_mean-a*x_mean
    return a, b

在这里插入图片描述

线性回归模型中的误差


在分类问题可以将score看成准确率,在回归问题将score看成模型的好坏程度。
在这里插入图片描述

均方误差 MSE

均方误差的公式如下图所示:
在这里插入图片描述
为什么要除以样本数量m?
举个例子,比如第一个团队有2个人,统计其工资的均方误差为800,第二个团队有100个人,工资的均方误差为1000,能说明第一个团队比较好吗?这是不行的,因为统计的个数不同,样本不同,导致量纲不一样,所以需要除以样本数量m,减少量纲的影响。

封装的函数

# 均方误差 MSE
def MSE(y_true, y_predict):
    return np.sum((y_true-y_predict)**2)/len(y_true)

在这里插入图片描述

均方根误差

在这里插入图片描述
在均方误差中进行开根号处理,可以消除量纲的影响。

封装的均方根误差

# 均方根误差
from math import sqrt
def RMSE(y_true, y_predict):
    return sqrt(np.sum((y_true-y_predict)**2)/len(y_true))

在这里插入图片描述

平均绝对误差

在这里插入图片描述
封装的平均绝对误差

# 平均绝对误差
def MAE(y_true, y_predict):
    return np.sum(np.absolute(y_true-y_predict))/len(y_true)

在这里插入图片描述

调用sklearn中的均方根误差和平均绝对误差函数

from sklearn.metrics import mean_squared_error, mean_absolute_error
mean_squared_error(x, y_hat)
mean_absolute_error(x, y_hat)

在这里插入图片描述

R squared error (常用)

R^2(以下用R2表示)分类的准确度在0和1之间,R2为1时,模型最优,即没有出现任何错误。

计算公式如下:
在这里插入图片描述
封装R squared error

import numpy as np
x = np.array([1,2,3,4,5])
y = np.array([3,1,4,3,6])
def r2_score(x_true, y_predict):
    return 1-((np.sum((x_true-y_predict)**2)/len(x_true))/np.var(x_true))

在这里插入图片描述

或调用均值方差 MSE
在这里插入图片描述
调用sklearn中的线性回归算法,计算预测值,最终的误差结果还是一样

from sklearn.linear_model import LinearRegression
lin_reg = LinearRegression()
lin_reg.fit(x.reshape(-1,1), y)
y_predict = lin_reg.predict(x.reshape(-1,1))

在这里插入图片描述

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【机器学习 - 4】:线性回归算法 的相关文章

  • Python、Tkinter、更改标签颜色

    有没有一种简单的方法来更改按钮中文本的颜色 I use button text input text here 更改按下后按钮文本的内容 是否存在类似的颜色变化 button color red Use the foreground设置按钮
  • Python PAM 模块的安全问题?

    我有兴趣编写一个 PAM 模块 该模块将利用流行的 Unix 登录身份验证机制 我过去的大部分编程经验都是使用 Python 进行的 并且我正在交互的系统已经有一个 Python API 我用谷歌搜索发现pam python http pa
  • 如何在android上的python kivy中关闭应用程序后使服务继续工作

    我希望我的服务在关闭应用程序后继续工作 但我做不到 我听说我应该使用startForeground 但如何在Python中做到这一点呢 应用程序代码 from kivy app import App from kivy uix floatl
  • Python 多处理示例不起作用

    我正在尝试学习如何使用multiprocessing但我无法让它发挥作用 这是代码文档 http docs python org 2 library multiprocessing html from multiprocessing imp
  • SQL Alchemy 中的 NULL 安全不等式比较?

    目前 我知道如何表达 NULL 安全的唯一方法 SQL Alchemy 中的比较 其中与 NULL 条目的比较计算结果为 True 而不是 NULL 是 or field None field value 有没有办法在 SQL Alchem
  • 如何使用 Scrapy 从网站获取所有纯文本?

    我希望在 HTML 呈现后 可以从网站上看到所有文本 我正在使用 Scrapy 框架使用 Python 工作 和xpath body text 我能够获取它 但是带有 HTML 标签 而且我只想要文本 有什么解决办法吗 最简单的选择是ext
  • 从 scikit-learn 导入 make_blobs [重复]

    这个问题在这里已经有答案了 我收到下一个警告 D Programming Python ML venv lib site packages sklearn utils deprecation py 77 DeprecationWarning
  • 在 NumPy 中获取 ndarray 的索引和值

    我有一个 ndarrayA任意维数N 我想创建一个数组B元组 数组或列表 其中第一个N每个元组中的元素是索引 最后一个元素是该索引的值A 例如 A array 1 2 3 4 5 6 Then B 0 0 1 0 1 2 0 2 3 1 0
  • Python 中的二进制缓冲区

    在Python中你可以使用StringIO https docs python org library struct html用于字符数据的类似文件的缓冲区 内存映射文件 https docs python org library mmap
  • 当玩家触摸屏幕一侧时,如何让 pygame 发出警告?

    我使用 pygame 创建了一个游戏 当玩家触摸屏幕一侧时 我想让 pygame 给出类似 你不能触摸屏幕两侧 的错误 我尝试在互联网上搜索 但没有找到任何好的结果 我想过在屏幕外添加一个方块 当玩家触摸该方块时 它会发出警告 但这花了很长
  • Python:尝试检查有效的电话号码

    我正在尝试编写一个接受以下格式的电话号码的程序XXX XXX XXXX并将条目中的任何字母翻译为其相应的数字 现在我有了这个 如果启动不正确 它将允许您重新输入正确的数字 然后它会翻译输入的原始数字 我该如何解决 def main phon
  • Python - 按月对日期进行分组

    这是一个简单的问题 起初我认为很简单而忽略了它 一个小时过去了 我不太确定 所以 我有一个Python列表datetime对象 我想用图表来表示它们 x 值是年份和月份 y 值是此列表中本月发生的日期对象的数量 也许一个例子可以更好地证明这
  • Numpy 优化

    我有一个根据条件分配值的函数 我的数据集大小通常在 30 50k 范围内 我不确定这是否是使用 numpy 的正确方法 但是当数字超过 5k 时 它会变得非常慢 有没有更好的方法让它更快 import numpy as np N 5000
  • 如何改变Python中特定打印字母的颜色?

    我正在尝试做一个简短的测验 并且想将错误答案显示为红色 欢迎来到我的测验 您想开始吗 是的 祝你好运 法国的首都是哪里 法国 随机答案不正确的答案 我正在尝试将其显示为红色 我的代码是 print Welcome to my Quiz be
  • Python:计算字典的重复值

    我有一本字典如下 dictA unit1 test1 alpha unit1 test2 beta unit2 test1 alpha unit2 test2 gamma unit3 test1 delta unit3 test2 gamm
  • VSCode:调试配置中的 Python 路径无效

    对 Python 和 VSCode 以及 stackoverflow 非常陌生 直到最近 我已经使用了大约 3 个月 一切都很好 当尝试在调试器中运行任何基本的 Python 程序时 弹出窗口The Python path in your
  • 从 Python 中的类元信息对 __init__ 函数进行类型提示

    我想做的是复制什么SQLAlchemy确实 以其DeclarativeMeta班级 有了这段代码 from sqlalchemy import Column Integer String from sqlalchemy ext declar
  • 循环标记时出现“ValueError:无法识别的标记样式 -d”

    我正在尝试编码pyplot允许不同标记样式的绘图 这些图是循环生成的 标记是从列表中选取的 为了演示目的 我还提供了一个颜色列表 版本是Python 2 7 9 IPython 3 0 0 matplotlib 1 4 3 这是一个简单的代
  • Pandas 与 Numpy 数据帧

    看这几行代码 df2 df copy df2 1 df 1 df 1 values 1 df2 ix 0 0 我们的教练说我们需要使用 values属性来访问底层的 numpy 数组 否则我们的代码将无法工作 我知道 pandas Data
  • PyAudio ErrNo 输入溢出 -9981

    我遇到了与用户相同的错误 Python 使用 Pyaudio 以 16000Hz 录制音频时出错 https stackoverflow com questions 12994981 python error audio recording

随机推荐

  • hibernate注解

    现在EJB3实体Bean是纯粹的POJO 实际上表达了和Hibernate持久化实体对象同样的概念 他们的映射都通过JDK5 0注释来定义 EJB3规范中的XML描述语法至今还没有定下来 注释分为两个部分 分别是逻辑映射注释和物理映射注释
  • 目标检测一阶段和二阶段对比图

    图片来源
  • 『学Vue2+Vue3』认识Vue3

    认识Vue3 1 Vue2 选项式 API vs Vue3 组合式API 特点 代码量变少 分散式维护变成集中
  • Linux下安装jre

    Linux下安装Java运行环境 现需要项目部署到Linux中 需要配置java运行环境 注 以下测试环境系统为centOS 用户为超级管理员 jre8 1 下载最新版的jre 服务器环境下不需要配置jdk 下载地址如下 http www
  • microsoft visual c++ 6.0中文版两种使用方法

    microsoft visual c 6 0 是一款语言编程软件 那么很多人都不知道microsoft visual c 6 0中文版怎么使用 其实使用方法很简单哦 只要打开microsoft visual c 6 0中文版就可以进行语言编
  • 《数据结构》 图的创建与遍历 代码表示

    测试数据 10 15 共10个顶点 15条边 0 1 0 8 0 0 第一 二个数表示连接两个顶点的起始顶点 第三个数1表示单通行 0表示双向通行 4 8 1 5 4 0 5 9 1 0 6 0 7 3 1 8 3 1 2 5 0 2 1
  • c#排列组合算法

    Combinatorics cs代码清单 using System using System Collections using System Data
  • 二叉树所有节点转换成大于该节点的平均值,没有最大值就转换成0

    import java util ArrayList import java util List import java util function ToIntFunction import java util stream Collect
  • CUnit(单元测试框架)

    CUnit是一个用C语言编写 管理和运行单元测试的轻量级系统 它为C程序员提供了具有灵活多样用户界面的基本测试功能 CUnit是作为一个静态库构建的 它与用户的测试代码链接在一起 它使用一个简单的框架来构建测试结构 并为测试公共数据类型提供
  • Buildroot制作根文件系统过程(基于MYD-AM335X开发板)

    buildroot的功能很强大 可以利用它制作交叉编译工具链 根文件系统 甚至可以构建多种嵌入式平台的bootloader linux 下面以米尔科技的MYD AM335X平台为例展示如何利用buildroot制作自己所需的根文件系统 一
  • 柔性OLED拼接屏有哪些场景化应用?

    柔性OLED拼接屏是一种新型的显示技术 它采用了柔性OLED屏幕 可以实现多个屏幕的拼接 形成一个大屏幕显示 这种技术可以应用于各种场合 如商业展示 广告宣传 会议演示等 柔性OLED屏幕是一种新型的显示技术 它采用了柔性材料作为基底 可以
  • java基于ssm+vue的共享充电宝管理系统 elementui

    随着时代的发展 人们的生活越来越离不开手机 但是因为技术水平等原因的限制 手机的电池并没有人们想象中的那么耐用 很多时候人们在外出的时候 很可能会遇到手机没电的情况发生 作为日常通讯的必备工具 如果没电了 很可能会影响一些重要的事情 尤其是
  • 3D相机调研

    最近因为自己实验需要配置一个3D相机 安装在机械臂上实现eye in arm的自动化引导过程 调研结果记录如下 3D相机又称为深度相机 即通过该相机能检测出拍摄空间的景深距离 与普通相机 2D的最大区别 普通彩色相机 2D相机 拍摄到的图片
  • JS -- input输入框只能输入正整数

    摘自文章 input输入框只能输入正整数 半城烟沙的技术博客 51CTO博客 one
  • 最大比例

    X星球的某个大奖赛设了M级奖励 每个级别的奖金是一个正整数 并且 相邻的两个级别间的比例是个固定值 也就是说 所有级别的奖金数构成了一个等比数列 比如 16 24 36 54 其等比值为 3 2 现在 我们随机调查了一些获奖者的奖金数 请你
  • 面试题深入思考01-----Arrays.sort()与Collections.sort()

    面试题深入思考01 Arrays sort 与Collections sort 1 Collections sort Collections本质是关于集合的一种工具类 其中包含对集合的各种api 例如排序 反转 交换和复制等 其中sort方
  • word怎么恢复保存前的文件,word文件恢复

    我们在使用word编辑文档时偶尔会有误删除文档的经历 word要怎么恢复保存前的文件呢 本文为你提供了五种解决思路 你可以通过搜索word文档的备份文档 自动恢复文件 临时文件 回收站 第三方数据恢复软件找到文档 方法一 搜索 Word 备
  • katex

    Katex Accents Accent functions inside text Delimiters Delimiter Sizing Environments Letters and Unicode Other Letters Un
  • Android ----蓝牙架构

    蓝牙 1 fromwork 2 service 3 driver Bluetooth apk bluedroid 芯片厂家 fromwork到service直接调用 service到driver利用service调用 fromwork到dr
  • 【机器学习 - 4】:线性回归算法

    文章目录 线性回归 线性回归的理解 损失函数 简单线性回归 封装线性回归算法 线性回归算法 在sklearn中调用线性回归算法 向量化运算 线性回归模型中的误差 均方误差 MSE 均方根误差 平均绝对误差 调用sklearn中的均方根误差和