2. 刘二大人《PyTorch深度学习实践》作业--梯度下降

2023-11-08

这里,我在刘老师的基础上做了改进,将线性函数改为了 y = w x + b y = wx+b y=wx+b,以下实现都是基于此线性函数做的。

1. 梯度下降

import numpy as np
import matplotlib.pyplot as plt
x_data = [1.0, 2.0, 3.0]
y_data = [3.0, 5.0, 7.0]

w = 1.0
b = 1.0

def forward(x):
    return x * w + b

# 损失函数
def cost(xs, ys):
    cost = 0
    for x, y in zip(xs, ys):
        y_pred = forward(x)
        cost += (y_pred - y) ** 2
    return cost / len(xs)

# 迭代,计算损失值
cost_list = []
print('Predict (before training)', 4, forward(4))
for epoch in range(100):
    cost_val = cost(x_data, y_data)
    grad_w, grad_b = gradient(x_data, y_data)
    w -= 0.01 * grad_w
    b -= 0.01 * grad_b
    cost_list.append(cost_val)
    print('Epoch:', epoch, 'w=', w,'b=', b, 'loss=', cost_val)
print('Predict (after training)', 4, forward(4))

# 绘制图像
epoches = np.arange(0, 100, 1)
plt.xlabel('epoch')
plt.ylabel('cost')
plt.plot(epoches, cost_list)
plt.grid()
plt.show()

在这里插入图片描述

2. 随机梯度下降

x_data = [1.0, 2.0, 3.0]
y_data = [3.0, 5.0, 7.0]

w = 1.0
b = 1.0

def forward(x):
    return x * w + b

# 随机梯度下降算法
def sgd(x, y):
    y_pred = forward(x)
    grad_w = 0
    grad_b = 0
    grad_w += 2 * x * (y_pred - y)
    grad_b += 2 * (y_pred - y)
    return grad_w, grad_b

# 损失函数
def cost_sgd(x, y):
    y_pred = forward(x)
    return (y_pred - y) ** 2
    
# 迭代,计算损失值
cost_list = []
print('Predict (before training)', 4, forward(4))
for epoch in range(100):
    for x, y in zip(x_data, y_data):
        cost_val = cost_sgd(x, y)
        grad_w, grad_b = sgd(x, y)
        w -= 0.01 * grad_w
        b -= 0.01 * grad_b
        cost_list.append(cost_val)
        print('Epoch:', epoch, 'w=', w, 'b=', b, 'loss=', cost_val)
print('Predict (after training)', 4, forward(4))

# 绘制图像
epoches = np.arange(0, 300, 1)
plt.xlabel('epoch')
plt.ylabel('cost')
plt.plot(epoches, cost_list)
plt.grid()
plt.show()

在这里插入图片描述

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

2. 刘二大人《PyTorch深度学习实践》作业--梯度下降 的相关文章

随机推荐

  • elasticsearch集群文件及路径设置

    es集群文件路径 1 数据目录 日志目录以及插件目录 默认情况下es会将plugin log data config file都放在es的安装目录中 这有一个问题 就是在进行es升级的时候 可能会导致这些目录被覆盖掉使我们集群中的文件或数据
  • Postman应用——下载注册和登录

    文章目录 下载安装 注册登录 注册账号 登录账号 下载安装 Postman下载 https www postman com 访问链接后 进入首页 根据自己的操作系统下载对应的版本 找到下载到的目录直接双击 exe文件 会默认安装在C盘 安装
  • LeetCode(力扣)题目中二叉树的如何生成?根据给定顺序列表生成二叉树(python)

    在刷 leetcode 二叉树相关的题目时 经常有这样给定的例子 例如 检查平衡性 实现一个函数 检查二叉树是否平衡 在这个问题中 平衡树的定义如下 任意一个节点 其两棵子树的高度差不超过 1 示例 1 给定二叉树 3 9 20 null
  • Mybatis-plus 分页排序 错乱-丢失

    今天生产环境出行了一个分页排序错乱的问题 当时有点懵 用的mybatis plus的分页插件实现的 往常也用但是没有出现这个 分页排序 错乱 丢失问题 说实话当时有点懵 经过排查分析 得出了结果 Mybatis plus 分页排序 错乱 丢
  • C语言变参数函数详解

    文章目录 一 前言 二 printf函数源码 三 C语言函数调用堆栈过程 调用约定 压栈过程 那么再来看看其他情况 四 C语言实现可变参数详解 五 需要关注的一些问题 一 前言 在C语言中 我们不管是使用标准库函数还是使用自定义的函数 我们
  • 帧同步(LockStep)该如何反外挂

    在中国的游戏环境下 反挂已经成为了游戏开发的重中之重 甚至能决定一款游戏的生死 吃鸡就是一个典型的案例 目前参与了了一款动作射击的MOBA类游戏的开发 同步方案上选择了帧同步技术 LockStep而非snapshots以下同 那么就有很多人
  • LUA实现麻将胡牌判定

    用LUA实现麻将胡牌的一个思路 hand table 41 42 43 22 21 43 22 11 11 11 42 33 33 33 手牌 card count table 1 1 0 2 0 3 0 4 0 5 0 6 0 7 0 8
  • 【若依】线程池,分页工具,定时任务,aop日志,全局异常处理功能实现

    若依 线程池 分页工具 定时任务 aop日志 全局异常处理功能实现 1 分页工具 使用方法 在调用sql语句前 调用 PageHelper startPage 方法就行了 若依包装过了 调用startPage 方法 1 pagehelper
  • 用c++编写的植物大战僵尸

    源码如下 include
  • 显卡3080设备CentOS 7.9 环境安装最新anconda、tensorflow-gpu 、cudatoolkit、cudnn、 python

    目标 使用3080显卡搭建环境 系统安装 显卡驱动安装 安装anconda 安装 python 安装 cuda 安装 cudnn 安装 tensorflow 一 系统安装 详见历史文档 二 显卡驱动安装 详见历史 三 整理自己需要安装的环境
  • 期货不变的本质是什么意思(期货不变的本质是什么意思呀)

    期货的本质是什么 本质是一个风险转移工具 通过把风险转移给愿意投机获利 亏损的人 产业方得以获得确定的盈利预期 这是期货的核心价值 狭义理论认为期货市场是 零和 负和 游戏 但如果把眼光放宽 把实体产业加进来 可以发现期货是市场环境里一项必
  • web python识花_TensorFlow迁移学习识花实战案例

    TensorFlow 迁移学习识花实战案例 本文主要介绍如何使用迁移学习训练图片识别花朵的模型 即识别出图片上是何种花朵 本文档中涉及的演示代码和数据集来源于网络 你可以在这里下载到 TRANSFER LEARNING zip 本模块将通过
  • cocos2d-x客户端与Java服务器的通信(一)

    o 貌似自己已经有一段时间没有写博客了 其实主要原因还是觉得自己水平有限 加上上班实在是太忙 实在抽不出时间来写博客 言归正传 大家都知道 在网络游戏开发中 网络通信一直是个比较大的难题 一个服务器可能要同时处理几千上万甚至上百万的用户数据
  • 14款开源或免费的GIS软件

    1 QGIS 原称Quantum GIS QGIS 原称Quantum GIS 是一个跨平台的桌面GIS软件 它提供数据的显示 编辑和分析功能 可以自动生成地图 并且能够处理地理空间数据 最后形成你期待的地图数据 它于2004年成为地理空间
  • Idea使用工具Statistic进行代码数量,注释的统计

    Statistic是idea上面的一个用于统计代码数量 注释数量的工具 1 安装 重启idea后 工具就能直接使用 如过没有结果的话 刷新下 扩展设置
  • 中缀表达式转后缀表达式

    一 什么是中缀表达式 后缀表达式 二 中缀表达式转后缀表达式 例题 中缀表达式 10 20 2 3 2 8 转换为对应的后缀为 后缀表达式 10 20 2 3 2 8 解题思路 1 观察两表达式 后缀明显没有 这两个符号 其次与中缀相比数字
  • git push origin master报错的解决方法 & 常见git命令(待更新)

    git push origin master报错的解决方法 常见git命令 待更新 参考Git常用命令 文章目录 git push origin master报错的解决方法 常见git命令 待更新 1 git push origin mas
  • c语言:折半查找法(二分查找法)

    折半查找法 half interval search 优点 比较次数少 查找速度快 平均性能好 缺点 是要求待查表为有序表 且插入删除困难 因此 折半查找方法适用于不经常变动而查找频繁的有序列表 注意 折半查找法仅适用于对已有顺序的数组 数
  • JDBC连接mysql的url的写法和常见属性

    URL jdbc mysql host port database 其后可以添加性能参数 propertyName1 propertyValue1 propertyName2 propertyValue2 MySQL 8 0 以上版本的数据
  • 2. 刘二大人《PyTorch深度学习实践》作业--梯度下降

    这里 我在刘老师的基础上做了改进 将线性函数改为了 y w x b y wx b y wx b 以下实现都是基于此线性函数做的 1 梯度下降 import n