梯度下降法(BGD,SGD,MSGD)python+numpy具体实现

2023-11-13

梯度下降是一阶迭代优化算法。为了使用梯度下降找到函数的局部最小值,一个步骤与当前位置的函数的梯度(或近似梯度)的负值成正比。如果相反,一个步骤与梯度的正数成比例,则接近该函数的局部最大值;该程序随后被称为梯度上升。梯度下降也被称为最陡峭的下降,或最快下降的方法。(from wikipad)

首先,大家要明白它的本质:这是一个优化算法!!!它是可以用来解决很多问题的,一般学习机器学习的朋友都会在线性回归的遇到这个名词,但是要声明的是,它和最小二乘法类似,是用于求解线性回归问题的一种方法。同时它的功能又不仅于此,它在线性回归中的意义在于通过寻找梯度最大的方向下降(或上升)来找到损失函数最小时候对应的参数值。

好了,绕来绕去的就拿线性回归的例子来和大家讲讲吧。

梯度下降方法

本质是每次迭代的时候都沿着梯度最大的地方更新参数。现在假设有函数(Rosenbrock函数:是一个用来测试最优化算法性能的非凸函数,由Howard Harry Rosenbrock在1960年提出[1]。也称为Rosenbrock山谷或Rosenbrock香蕉函数,也简称为香蕉函数)如下定义:

f(x,y)=(1x)2+100(yx2)2

很明显,其最小最为 f(1,1)=0 ,其三维图片如下:
这里写图片描述
函数 f 分别对 x y 求导得到
f(x,y)x=2(1x)2100(yx2)2x

f(x,y)y=2100(yx2)

在实现的过程中可以给出x, y初始值(例如设置为 0, 0) 然后计算函数在这个点的梯度,并按照梯度方向更新x, y的值。

这里给出通过梯度下降法计算上述函数的最小值对应的x 和 y

import numpy as np


def cal_rosenbrock(x1, x2):
    """
    计算rosenbrock函数的值
    :param x1:
    :param x2:
    :return:
    """
    return (1 - x1) ** 2 + 100 * (x2 - x1 ** 2) ** 2


def cal_rosenbrock_prax(x1, x2):
    """
    对x1求偏导
    """
    return -2 + 2 * x1 - 400 * (x2 - x1 ** 2) * x1

def cal_rosenbrock_pray(x1, x2):
    """
    对x2求偏导
    """
    return 200 * (x2 - x1 ** 2)

def for_rosenbrock_func(max_iter_count=100000, step_size=0.001):
    pre_x = np.zeros((2,), dtype=np.float32)
    loss = 10
    iter_count = 0
    while loss > 0.001 and iter_count < max_iter_count:
        error = np.zeros((2,), dtype=np.float32)
        error[0] = cal_rosenbrock_prax(pre_x[0], pre_x[1])
        error[1] = cal_rosenbrock_pray(pre_x[0], pre_x[1])

        for j in range(2):
            pre_x[j] -= step_size * error[j]

        loss = cal_rosenbrock(pre_x[0], pre_x[1])  # 最小值为0

        print("iter_count: ", iter_count, "the loss:", loss)
        iter_count += 1
    return pre_x

if __name__ == '__main__':
    w = for_rosenbrock_func()  
    print(w)

如果大家想运行这个算法,建议使用默认的参数,效果还不错。不要把step_size设置过大,会出问题的(可能是实现过程有问题,请指正)。

线性回归问题

这里关于回归的前导介绍我建议大家取看周志华老师的西瓜书,介绍得通透明亮,但是周老师对线性回归问题给出的解决方法是通过最小二乘法来做的,而我们在这里要用梯度下降。

这里给出一般的定义吧~

一般的线性回归方程如下:

y=θ1x1+θ2x2++θnxn+b

转换为:
y=θ1x1+θ2x2++θnxn+θ0b

这里 θ0=1 转换为向量的形式 y=θTx θ x ,均为为行向量。

现在需要定义损函数,用于判断最后得到的预测参数的预测效果。常用的损失函数是均方误差:

J(θ)=12mj=1m(h(θ)iyi)2

i 是维度索引 j 是样本索引,接下来对 θ 求导得到
J(θ)θj=1mj=1m(h(θ)iyi)xij

更新公式为:
θi=θiα1mj=1m(h(θ)iyi)xij

α 就是学习的步长。

BGM(批量梯度下降法)

import numpy as np

def gen_line_data(sample_num=100):
    """
    y = 3*x1 + 4*x2
    :return:
    """
    x1 = np.linspace(0, 9, sample_num)
    x2 = np.linspace(4, 13, sample_num)
    x = np.concatenate(([x1], [x2]), axis=0).T
    y = np.dot(x, np.array([3, 4]).T)  # y 列向量
    return x, y

def bgd(samples, y, step_size=0.01, max_iter_count=10000):
    sample_num, dim = samples.shape
    y = y.flatten()
    w = np.ones((dim,), dtype=np.float32)
    loss = 10
    iter_count = 0
    while loss > 0.001 and iter_count < max_iter_count:
        loss = 0
        error = np.zeros((dim,), dtype=np.float32)
        for i in range(sample_num):
            predict_y = np.dot(w.T, samples[i])
            for j in range(dim):
                error[j] += (y[i] - predict_y) * samples[i][j]

        for j in range(dim):
            w[j] += step_size * error[j] / sample_num

        for i in range(sample_num):
            predict_y = np.dot(w.T, samples[i])
            error = (1 / (sample_num * dim)) * np.power((predict_y - y[i]), 2)
            loss += error

        print("iter_count: ", iter_count, "the loss:", loss)
        iter_count += 1
    return w

if __name__ == '__main__':
    samples, y = gen_line_data()
    w = bgd(samples, y)
    print(w)  # 会很接近[3, 4]

SGB(随机梯度下降法)

import numpy as np

def gen_line_data(sample_num=100):
    """
    y = 3*x1 + 4*x2
    :return:
    """
    x1 = np.linspace(0, 9, sample_num)
    x2 = np.linspace(4, 13, sample_num)
    x = np.concatenate(([x1], [x2]), axis=0).T
    y = np.dot(x, np.array([3, 4]).T)  # y 列向量
    return x, y

def sgd(samples, y, step_size=0.01, max_iter_count=10000):
    """
    随机梯度下降法
    :param samples: 样本
    :param y: 结果value
    :param step_size: 每一接迭代的步长
    :param max_iter_count: 最大的迭代次数
    :param batch_size: 随机选取的相对于总样本的大小
    :return:
    """
    sample_num, dim = samples.shape
    y = y.flatten()
    w = np.ones((dim,), dtype=np.float32)
    loss = 10
    iter_count = 0
    while loss > 0.001 and iter_count < max_iter_count:
        loss = 0
        error = np.zeros((dim,), dtype=np.float32)
        for i in range(sample_num):
            predict_y = np.dot(w.T, samples[i])
            for j in range(dim):
                error[j] += (y[i] - predict_y) * samples[i][j]
                w[j] += step_size * error[j] / sample_num

        # for j in range(dim):
        #     w[j] += step_size * error[j] / sample_num

        for i in range(sample_num):
            predict_y = np.dot(w.T, samples[i])
            error = (1 / (sample_num * dim)) * np.power((predict_y - y[i]), 2)
            loss += error

        print("iter_count: ", iter_count, "the loss:", loss)
        iter_count += 1
    return w

if __name__ == '__main__':
    samples, y = gen_line_data()
    w = sgd(samples, y)
    print(w)  # 会很接近[3, 4]

MBGB(小批量梯度下降法)

import numpy as np
import random

def gen_line_data(sample_num=100):
    """
    y = 3*x1 + 4*x2
    :return:
    """
    x1 = np.linspace(0, 9, sample_num)
    x2 = np.linspace(4, 13, sample_num)
    x = np.concatenate(([x1], [x2]), axis=0).T
    y = np.dot(x, np.array([3, 4]).T)  # y 列向量
    return x, y

def mbgd(samples, y, step_size=0.01, max_iter_count=10000, batch_size=0.2):
    """
    MBGD(Mini-batch gradient descent)小批量梯度下降:每次迭代使用b组样本
    :param samples:
    :param y:
    :param step_size:
    :param max_iter_count:
    :param batch_size:
    :return:
    """
    sample_num, dim = samples.shape
    y = y.flatten()
    w = np.ones((dim,), dtype=np.float32)
    # batch_size = np.ceil(sample_num * batch_size)
    loss = 10
    iter_count = 0
    while loss > 0.001 and iter_count < max_iter_count:
        loss = 0
        error = np.zeros((dim,), dtype=np.float32)

        # batch_samples, batch_y = select_random_samples(samples, y,
        # batch_size)

        index = random.sample(range(sample_num),
                              int(np.ceil(sample_num * batch_size)))
        batch_samples = samples[index]
        batch_y = y[index]

        for i in range(len(batch_samples)):
            predict_y = np.dot(w.T, batch_samples[i])
            for j in range(dim):
                error[j] += (batch_y[i] - predict_y) * batch_samples[i][j]

        for j in range(dim):
            w[j] += step_size * error[j] / sample_num

        for i in range(sample_num):
            predict_y = np.dot(w.T, samples[i])
            error = (1 / (sample_num * dim)) * np.power((predict_y - y[i]), 2)
            loss += error

        print("iter_count: ", iter_count, "the loss:", loss)
        iter_count += 1
    return w

if __name__ == '__main__':
    samples, y = gen_line_data()
    w = mbgd(samples, y)
    print(w)  # 会很接近[3, 4]
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

梯度下降法(BGD,SGD,MSGD)python+numpy具体实现 的相关文章

  • 在 Python 中解析 TCL 列表

    我需要在双括号上拆分以空格分隔的 TCL 列表 例如 OUTPUT 172 25 50 10 01 01 Ethernet 172 25 50 10 01 02 Ethernet Traffic Item 1 172 25 50 10 01
  • Pandas set_levels,如何避免标签排序?

    我使用时遇到问题set levels多索引 from io import StringIO txt Name Height Age Metres A 1 25 B 95 1 df pd read csv StringIO txt heade
  • 计算另一个字符串中多个字符串的出现次数

    在 Python 2 7 中 给定以下字符串 Spot是一只棕色的狗 斑点有棕色的头发 斑点的头发是棕色的 查找字符串中 Spot brown 和 hair 总数的最佳方法是什么 在示例中 它将返回 8 我正在寻找类似的东西string c
  • 在 Python 中将列表元素作为单独的项目返回

    Stackoverflow 的朋友们大家好 我有一个计算列表的函数 我想单独返回列表的每个元素 如下所示 接收此返回的函数旨在处理未定义数量的参数 def foo my list 1 2 3 4 return 1 2 3 4 列表中的元素数
  • VSCode Settings.json 丢失

    我正在遵循教程 并尝试将 vscode 指向我为 Scrapy 设置的虚拟工作区 但是当我在 VSCode 中打开设置时 工作区设置 选项卡不在 用户设置 选项卡旁边 我还尝试通过以下方式手动转到文件 APPDATA Code User s
  • Django Rest Framework 是否有第三方应用程序来自动生成 swagger.yaml 文件?

    我有大量的 API 端点编写在django rest framework并且不断增加和更新 如何创建和维护最新的 API 文档 我当前的版本是 Create swagger yaml文件并以某种方式在每次端点更改时自动生成 然后使用此文件作
  • 如何从Python中的函数返回多个值? [复制]

    这个问题在这里已经有答案了 如何从Python中的函数返回多个变量 您可以用逗号分隔要返回的值 def get name you code return first name last name 逗号表示它是一个元组 因此您可以用括号将值括
  • 使用主题交换运行多个 Celery 任务

    我正在用 Celery 替换一些自制代码 但很难复制当前的行为 我期望的行为如下 创建新用户时 应向tasks与交换user created路由键 该消息应该触发两个 Celery 任务 即send user activate email
  • PyQt 使用 ctrl+Enter 触发按钮

    我正在尝试在我的应用程序中触发 确定 按钮 我当前尝试的代码是这样的 self okPushButton setShortcut ctrl Enter 然而 它不起作用 这是有道理的 我尝试查找一些按键序列here http ftp ics
  • 矩形函数的数值傅里叶变换

    本文的目的是通过一个众所周知的分析傅里叶变换示例来正确理解 Python 或 Matlab 上的数值傅里叶变换 为此 我选择矩形函数 这里报告了它的解析表达式及其傅立叶变换https en wikipedia org wiki Rectan
  • 如何将特定范围内的标量添加到 numpy 数组?

    有没有一种更简单 更节省内存的方法可以单独在 numpy 中执行以下操作 import numpy as np ar np array a l r ar c a a 0 l ar tolist a r 它可能看起来很原始 但它涉及获取给定数
  • Django 视图中的“请求”是什么

    在 Django 第一个应用程序的 Django 教程中 我们有 from django http import HttpResponse def index request return HttpResponse Hello world
  • Pandas 组合不同索引的数据帧

    我有两个数据框df 1 and df 2具有不同的索引和列 但是 有一些索引和列重叠 我创建了一个数据框df索引和列的并集 因此不存在重复的索引或列 我想填写数据框df通过以下方式 for x in df index for y in df
  • 如何使用 Python 3 检查目录是否包含文件

    我到处寻找这个答案但找不到 我正在尝试编写一个脚本来搜索特定的子文件夹 然后检查它是否包含任何文件 如果包含 则写出该文件夹的路径 我已经弄清楚了子文件夹搜索部分 但检查文件却难倒了我 我发现了有关如何检查文件夹是否为空的多个建议 并且我尝
  • Spider 必须返回 Request、BaseItem、dict 或 None,已“设置”

    我正在尝试从以下位置下载所有产品的图像 我的蜘蛛看起来像 from shopclues items import ImgData import scrapy class multipleImages scrapy Spider name m
  • python 中的“槽包装器”是什么?

    object dict 和其他地方的隐藏方法设置为这样的
  • 如何以正确的方式为独立的Python应用程序制作setup.py?

    我读过几个类似的主题 但还没有成功 我觉得我错过或误解了一些基本的事情 这就是我失败的原因 我有一个用 python 编写的 应用程序 我想在标准 setup py 的帮助下进行部署 由于功能复杂 它由不同的 python 模块组成 但单独
  • 如果 PyPy 快 6.3 倍,为什么我不应该使用 PyPy 而不是 CPython?

    我已经听到很多关于PyPy http en wikipedia org wiki PyPy项目 他们声称它比现有技术快 6 3 倍CPython http en wikipedia org wiki CPython口译员开启他们的网站 ht
  • python 对浮点数进行不正确的舍入

    gt gt gt a 0 3135 gt gt gt print 3f a 0 314 gt gt gt a 0 3125 gt gt gt print 3f a 0 312 gt gt gt 我期待 0 313 而不是 0 312 有没有
  • 如何将Python3设置为Mac上的默认Python版本?

    有没有办法将 Python 3 8 3 设置为 macOS Catalina 版本 10 15 2 上的默认 Python 版本 我已经完成的步骤 看看它安装在哪里 ls l usr local bin python 我得到的输出是这样的

随机推荐

  • 1.还不会部署高可用的kubernetes集群?看我手把手教你使用二进制部署v1.23.6的K8S集群实践(上)

    关注 WeiyiGeek 设为 特别关注 每天带你玩转网络安全运维 应用开发 物联网IOT学习 本章目录 0x00 前言简述 0x01 环境准备 主机规划 软件版本 网络规划 0x02 安装部署 1 基础主机环境准备配置 2 负载均衡管理工
  • 家用 NAS 服务器搭建

    1 前言 使用NAS 一般除了在家里通过局域网访问 还会有外网访问的需求 即在外面通过移动网络或者其他网络访问家中的NAS 正常情况下在外面是没有办法访问家庭网络的 甚至是nas 因为nas获取的是局域网IP 而不是广域网IP 全球唯一地址
  • Unable to Create Process

    Error Unable to create process OK Details gt gt 如果你的操作系统是Win7而你又直接点击运行按钮的话会提示此错误 错误的原因是你程序中有对注册表的处理或一些底层操作 所以会提示这个错误 解决办
  • 循环队列(Java实现)

    Java数据结构学习笔记2 循环队列 核心逻辑代码如下 class CircleQueue private int maxSize 0 private int front 指向队列的第一个元素 private int rear 指向队列的最
  • 多模态融合 2022

    论文题目 DeepFusion Lidar Camera Deep Fusion for Multi Modal 3D Object Detection 前融合 单位 google 注 4D Net和3D CVF也研究了lidar和相机两个
  • H2数据库--转载

    一 H2数据库介绍 常用的开源数据库有 H2 Derby HSQLDB MySQL PostgreSQL 其中H2和HSQLDB类似 十分适合作为嵌入式数据库使用 而其它的数据库大部分都需要安装独立的客户端和服务器端 H2的优势 1 h2采
  • 使用 ffmpeg 转换视频格式 mp4 webm

    ffmpeg 是 nix 系统下最流行的音视频处理库 功能强大 并且提供了丰富的终端命令 实是日常视频处理的一大利器 实例 flac 格式转 mp3 音频格式转换非常简单 Python span class wp keywordlink a
  • 【多方安全计算】差分隐私(Differential Privacy)解读

    多方安全计算 差分隐私 Differential Privacy 解读 文章目录 多方安全计算 差分隐私 Differential Privacy 解读 1 介绍 2 形式化 3 差分隐私的方法 3 1 最简单的方法 加噪音 3 2 加高斯
  • 汉诺塔问题【C语言实现】

    目录 一 前言 二 动图演示 三 打印步骤 四 打印步数 一 前言 汉诺塔 Tower of Hanoi 又称河内塔 是一个源于印度古老传说的益智玩具 大梵天创造世界的时候做了三根金刚石柱子 在一根柱子上从下往上按照大小顺序摞着64片黄金圆
  • ORA-01427问题的分析和解决

    前几天开发的同事反馈一个问题 说前台系统报出了ORA错误 希望我们能看看是什么原因 java sql SQLException ORA 01427 single row subquery returns more than one row
  • /usr/bin/ld cannot find -lGL

    ubuntu 16 04虚拟机 装的Qt 5 10 随便写了个带UI的Demo 然后报错如下 解决如下 很多Linux发行版本 Qt安装完成后如果直接编译或者运行项目 会出现 cannot find lGL 错误 这是因为Qt找不到Open
  • 【matlab 斩波电路仿真】

    斩波电路仿真 要求 斩波电路原理 基本斩波电路 降压斩波电路搭建 结果分析 要求 斩波电路仿真 斩波电路原理 斩波电路的功能是将直流电变为另一固定电压或者可调电压的直流电 包括直接直流变流电路和简介直流变流电路 其中 直流变流电路也称为斩波
  • 某在线学习平台《数据挖掘》第八章课后习题

    此文章是本人结合课程内容和网上资料整理 难免有误差 仅供参考 1 下面哪种距离度量方法为欧几里得距离 2 以下哪个算法将两个簇的邻近度定义为不同簇的所有点对的平均逐对邻近度 它是一种凝聚层次聚类技术 AMIN 单链 BMAX 全链 C 组平
  • HK32F030MF4P6 SWD管脚功能复用GPIO

    由于电暖控制器项目上管脚不够 需要将SWD管脚复用 使用网上购买的JLINK 下载和串口调试特别方便 应用场景 往往GPIO管脚不够使用 需要将SWD下载管脚复用GPIO功能 需要用到以下设置 下载器需要接上RESET管脚 TSSOP20
  • (四)索引与数据完整性

    一 索引 1 索引的作用 快速存取数据 既可以改善数据库性能 又可以保证列值的唯一性 实现表与表之间的参照完整性 在使用ORDER BY GROUP BY子句进行数据检索时 利用索引可以减少排序和分组的时间 2 索引的分类 索引按照存储方法
  • PyQt5探索-0 用Pycharm配置PyQt5环境

    感觉是时候学习一下PyQt了 决定直接从PyQt5开始 用Pycharm做开发环境 因为之前用的Eric实在感觉不爽 今天先从配置环境开始 先安装好Pycharm Qt 安装Pycharm插件sip PyQt5 pyqt5 tools si
  • 【计算机视觉】论文单词理解—bells and whistles

    一 问题来源 最后在阅读论文的时候 遇到了一个单词 不是很理解 这个单词是bells and whistles 二 单词的理解 bells and whistles 它的含义并不是指 铃铛和口哨 其真正的含义是指 bells and whi
  • Mac 中的sublime text3 如何安装插件

    Mac 中的sublime text3 如何安装插件 相信大家在Windows系统中试用sublime text 的体验非常不错 我也是在Windows系统中使用了两年的时间 才转战Mac系统的 但是说实话 Mac系统好多东西都是十分不习惯
  • 时间序列学习(6)——LSTM中Layer的使用

    文章目录 1 复习一下 nn RNN 的参数 2 LSTM的 init 函数 3 LSTM forward 4 动手写一个简单的lstm层 1 复习一下 nn RNN 的参数 参数介绍 1 input size The number of
  • 梯度下降法(BGD,SGD,MSGD)python+numpy具体实现

    梯度下降是一阶迭代优化算法 为了使用梯度下降找到函数的局部最小值 一个步骤与当前位置的函数的梯度 或近似梯度 的负值成正比 如果相反 一个步骤与梯度的正数成比例 则接近该函数的局部最大值 该程序随后被称为梯度上升 梯度下降也被称为最陡峭的下