PyTorch学习日志_20201030_神经网络

2023-11-19

日期:2020.10.30
主题:PyTorch入门
内容:

  • 根据PyTorch官方教程文档,学习PyTorch中神经网络
    包括:定义网络、损失函数、反向传播、更新权重。

  • 根据自己的理解和试验,为代码添加少量注解。

具体代码如下 ↓

from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F

"""
{神经网络}
 通过torch.nn包来构建神经网络
    它依赖于autograd包来定义模型并对它们求导。
 一个nn.Module包含各个层和一个forward(input)方法,该方法返回output。
"""

"""
 一个神经网络的典型训练过程如下:
    1.定义包含一些可学习参数(或者叫权重)的神经网络;
    2.在输入数据集上迭代;
    3.通过网络处理输入;
    4.计算loss(输出和正确答案的距离);
    5.将梯度反向传播给网络的参数;
    6.更新网络的权重,一般使用一个简单的规则:
        weight = weight - learning_rate * gradient
"""

"""
【定义网络】
"""
class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        
        # 输入图像channel:1;输出channel:6;5x5卷积核
        self.conv1 = nn.Conv2d(1, 6, 5)     # 定义二维卷积 1 -> 6
        self.conv2 = nn.Conv2d(6, 16, 5)    # 定义二维卷积 6 -> 16
        
        # an affine operation: y = Wx + b 仿射操作
        self.fc1 = nn.Linear(16 * 5 * 5, 120)   # 定义线性变换
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        # 2x2 Max pooling 池化操作
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        # 如果是方阵,则可以只使用一个数字进行定义
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x)) 	# 非线性激励函数
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # 除去批处理维度的其他所有维度
        num_features = 1
        for s in size:
            num_features *= s
        return num_features
# 只需要定义forward函数,backward函数会在使用autograd时自动定义,backward函数用来计算导数。
# 可以在 forward 函数中使用任何针对张量的操作和计算

net = Net()
print(net)
print('-'*40, '\n')

"""
 二维卷积
 torch.nn.Conv2d(in_channels: int, out_channels: int, 
                 kernel_size: Union[T, Tuple[T, T]])
    
    *channel
        即通道,最初是指电子图片中RGB通道这样的配色方案,
        例如一张RGB图片可以用一个64x64x3的张量来表示,其中,channel=3,
            分别为红色(Red)、绿色(Green)、蓝色(Blue)三个通道。
            进一步,对RGB图片进行卷积操作后,根据过滤器的数量可产生更多的通道——特征图。
        故,一个通道是对某个特征的检测,通道中某一处数值的强弱就是对当前特征强弱的反应。
        
    in_channels
        输入的四维张量[N, C, H, W]中的C,即输入张量的channels数。
        这个形参是确定权重等可学习参数的shape所必需的。
        
    out_channels
        即期望的四维输出张量的channels数。
        这里卷积层的权重和偏置初始化都是采用He初始化的,适合于ReLU函数。
    
    *kernel
        即核函数K(kernel function),指K(x, y) = <f(x), f(y)>,
            其中x和y是n维的输入值,f() 是从n维到m维的映射(通常而言,m>>n)。
            <x, y>是x和y的内积(inner product),亦称点积(dot product)
        它有助于省去在高维空间里进行繁琐计算的“简便运算法”。
            甚至,它能解决无限维空间无法计算的问题!
            (因为有时f()会把n维空间映射到无限维空间)
    
    kernel_size
        即卷积核的大小,一层卷积核的中心pixel可以“看到”输入图 a*b 的区域(连通性)
        一般使用5x5、3x3这种左右两个数相同的卷积核,
            因此这种情况只需要写kernel_size = 5即可。
        若左右两个数不同,比如3x5的卷积核,
            则写作kernel_size = (3, 5),
            注意需要写一个tuple,而不能写一个列表(list)。
        
"""

"""
 线性变换(连同偏置),即 y = x * (W)^T + b
 torch.nn.Linear(in_features: int, out_features: int)
    
    in_features 
        输入张量的大小
    
    out_features
        输出张量的大小
 
"""

# 通过net.parameters()返回模型的可学习参数
params = list(net.parameters())
print(len(params))
print(params[0].size())  # conv1's .weight
print('-'*40, '\n')


# 尝试一个随机的32x32的输入(该网络(LeNet)的期待输入是32x32的张量)。
input = torch.randn(1, 1, 32, 32)
out = net(input)
print(out)
print('-'*40, '\n')

#清零所有参数的梯度缓存,然后进行随机梯度的反向传播
net.zero_grad()
out.backward(torch.randn(1, 10))


"""
 torch.nn只支持小批量处理(mini-batches)。整个torch.nn包只支持小批量样本的输入,不支持单个样本的输入。
    比如,nn.Conv2d 接受一个4维的张量,即nSamples x nChannels x Height x Width
    如果是一个单独的样本,只需要使用input.unsqueeze(0)来添加一个“假的”批大小维度。
"""


"""
 <回顾>
 
 torch.Tensor
    一个多维数组,支持诸如backward()等的自动求导操作,同时也保存了张量的梯度。
    
 nn.Module
    神经网络模块。是一种方便封装参数的方式,具有将参数移动到GPU、导出、加载等功能。
    
 nn.Parameter
    张量的一种,当它作为一个属性分配给一个Module时,它会被自动注册为一个参数。
    
 autograd.Function
    实现了自动求导前向和反向传播的定义,
    每个Tensor至少创建一个Function节点,该节点连接到创建Tensor的函数并对其历史进行编码。
        
"""


"""
【损失函数】
 函数接受一对(output, target)作为输入,计算一个值来估计网络的输出和目标值相差多少。
"""
# nn.MSELoss是比较简单的一种损失函数,计算输出和目标的均方误差(mean-squared error)。
output = net(input)
target = torch.randn(10)  # 本例子中使用模拟数据
target = target.view(1, -1)  # 使目标值与数据值尺寸一致
criterion = nn.MSELoss()

loss = criterion(output, target)
print(loss)
print('-'*40, '\n')


"""
 如果使用loss的.grad_fn属性跟踪反向传播过程,会看到计算图如下
 input -> conv2d -> relu -> maxpool2d -> conv2d -> relu -> maxpool2d
       -> view -> linear -> relu -> linear -> relu -> linear
       -> MSELoss
       -> loss

 所以,当调用loss.backward(),整张图开始关于loss微分,
    图中所有设置了requires_grad=True的张量的.grad属性累积着梯度张量。    
"""
# 为了说明这一点,让我们向后跟踪几步
print(loss.grad_fn)  # MSELoss
print(loss.grad_fn.next_functions[0][0])  # Linear
print(loss.grad_fn.next_functions[0][0].next_functions[0][0])  # ReLU
print('-'*40, '\n')


"""
【反向传播】
 我们只需要调用loss.backward()来反向传播误差。
 我们需要清零现有的梯度,否则梯度将会与已有的梯度累加。
"""

# 调用loss.backward(),并查看conv1层的偏置(bias)在反向传播前后的梯度。
net.zero_grad()     # 清零所有参数(parameter)的梯度缓存

print('conv1.bias.grad before backward')
print(net.conv1.bias.grad)

loss.backward()

print('conv1.bias.grad after backward')
print(net.conv1.bias.grad)
print('-'*40, '\n')


"""
【更新权重】
 最简单的更新规则是随机梯度下降法(SGD): 
    weight = weight - learning_rate * gradient
"""
learning_rate = 0.01
for f in net.parameters():
    f.data.sub_(f.grad.data * learning_rate)
    
    
# 然而,在使用神经网络时,可能希望使用各种不同的更新规则,如SGD、Nesterov-SGD、Adam、RMSProp等。
# 为此,torch.optim包实现了所有的这些方法。
import torch.optim as optim

# 创建优化器(optimizer)
optimizer = optim.SGD(net.parameters(), lr=0.01)

# 在训练的迭代中:
optimizer.zero_grad()   # 清零梯度缓存
output = net(input)
loss = criterion(output, target)
loss.backward()
optimizer.step()    # 更新参数



本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

PyTorch学习日志_20201030_神经网络 的相关文章

  • Erlang:到 Python 实例的端口没有响应

    我正在尝试通过 Erlang 端口与外部 python 进程进行通信 首先 打开一个端口 然后通过 stdin 将消息发送到外部进程 我期待在进程的标准输出上得到相应的答复 我的尝试如下所示 open a port Port open po
  • Python Pandas 滚动聚合一列列表

    我有一个简单的数据框 df 和一列列表lists 我想根据以下内容生成一个附加列lists The df好像 import pandas as pd lists 1 1 2 1 2 3 3 2 9 7 9 4 2 7 3 5 create
  • 使用管理员权限打开cmd(Windows 10)

    我有自己的 python 脚本来管理我的计算机上的 IP 地址 它主要在命令行 Windows 10 中执行netsh命令 您必须具有管理员权限 这是我自己的计算机 我是管理员 运行脚本时我已经使用管理员类型的用户 Adrian 登录 我无
  • 使用 Python 和 lmfit 拟合复杂模型?

    我想适合椭偏仪 http en wikipedia org wiki Ellipsometry使用 LMFit 将数据转换为复杂模型 两个测量参数 psi and delta 是复杂函数中的变量rho 我可以尝试将问题分离为实部和虚部共享参
  • Scrapy 文件管道不下载文件

    我的任务是构建一个可以下载所有内容的网络爬虫 pdfs 在给定站点中 Spider 在本地计算机和抓取集线器上运行 由于某种原因 当我运行它时 它只下载一些但不是全部的 pdf 通过查看输出中的项目可以看出这一点JSON 我已经设定MEDI
  • 将整数系列转换为交替(双元)二进制系列

    我不知道如何最好地表达这个问题 因为在这里谷歌搜索和搜索总是让我找到更复杂的东西 我很确定这是基本的东西 但对于我的生活来说 我找不到一个好的方法来做到这一点下列 给定一个整数序列 比如说 for x in range 0 36 我想将这些
  • Python 内置对象的 __enter__() 和 __exit__() 在哪里定义?

    我读到每次使用 with 时都会调用该对象的 enter 和 exit 方法 我知道对于用户定义的对象 您可以自己定义这些方法 但我不明白这对于 打开 等内置对象 函数甚至测试用例是如何工作的 这段代码按预期工作 我假设它使用 exit 关
  • 如何用函数记录一个文件?

    我有一个带有函数 lib py 但没有类的python 文件 每个函数都有以下样式 def fnc1 a b c This fonction does something param a lalala type a str param b
  • 会话数据库表清理

    该表是否需要清除或者由 Django 自动处理 Django 不提供自动清除功能 然而 有一个方便的命令可以帮助您手动完成此操作 Django 文档 清除会话存储 https docs djangoproject com en dev to
  • 了解 Python 2.7 中的缩进错误

    在编写 python 代码时 我往往会遇到很多缩进错误 有时 当我删除并重写该行时 错误就会消失 有人可以为菜鸟提供 python 中 IndentationErrors 的高级解释吗 以下是我在玩 CheckIO 时收到的最近 inden
  • Eclipse/PyDev 中未使用导入警告,尽管已使用

    我正在我的文件中导入一个绘图包 如下所示 import matplotlib pyplot as plt 稍后我会在我的代码中成功使用此导入 fig plt figure figsize 16 10 然而 Eclipse 告诉我 未使用的导
  • 如何通过selenium中弹出的身份验证?

    我正在尝试使用带有 Selenium 的 Python 脚本加载需要身份验证的网页 options webdriver ChromeOptions prefs download default directory r download de
  • Python脚本从字母和两个字母组合生成单词

    我正在编写一个简短的脚本 它允许我使用我设置的参数生成所有可能的字母组合 例如 b a 参数 单词 5 个字母 第三 第五个字母 b a 第一个字母 ph sd nn mm 或 gh 第二 第四个字母 任意元音 aeiouy 和 rc 换句
  • AttributeError: 'super' 对象没有属性 '__getattr__' 在 Kivy 中使用带有多个 kv 文件的 BoxLayout 时出错

    我很清楚 这个问题已经被问过好几次了 但尝试以下解决方案后 Python Kivy AttributeError 尝试获取 self ids 时 super 对象没有属性 getattr https stackoverflow com qu
  • 如何将 URL 添加到 Telegram Bot 的 InlineKeyboardButton

    我想制作一个按钮 可以从 Telegram 聊天中在浏览器中打开 URL 外部超链接 目前 我只开发了可点击的操作按钮 update message reply text Subscribe to us on Facebook and Te
  • 如何检测一个二维数组是否在另一个二维数组内?

    因此 在堆栈溢出成员的帮助下 我得到了以下代码 data needle s which is a png image base64 code goes here decoded data decode base64 f cStringIO
  • PyQt5按钮lambda变量变成布尔值[重复]

    这个问题在这里已经有答案了 当我运行下面的代码时 它显示如下 为什么 x 不是 x 而是变成布尔值 这种情况仅发生在传递到用 lambda 调用的函数中的第一个参数上 错误的 y home me model some file from P
  • Chrome 驱动程序和 Chromium 二进制文件无法在 aws lambda 上运行

    我陷入了一个问题 我需要在 AWS lambda 上做一些抓取工作 所以我按照下面提到的博客及其代码库作为起点 这非常有帮助 并且在运行时环境 Python 3 6 的 AWS lambda 上对我来说工作得很好 https manivan
  • 超过两个点的Python相对导入

    是否可以使用路径中包含两个以上点的模块引用 就像这个例子一样 Project structure sound init py codecs init py echo init py nix init py way1 py way2 py w
  • 使用 python 将 CSV 文件上传到 Microsoft Azure 存储帐户

    我正在尝试上传一个 csv使用 python 将文件写入 Microsoft Azure 存储帐户 我已经发现C sharp https blogs msdn microsoft com jmstall 2012 08 03 convert

随机推荐

  • 牛客网左神算法中级班学习笔记(第三章)

    本文是牛客网左神算法中级班学习笔记 分析 宏观考虑 搞两个点A B 起始都在左上角 B往右走 走到最右边就往下走 A往下走 走到最下边就往右走 A B每次一起走一步 打印A B两点连线即可 用一个Boolean控制下 交替打印顺序 publ
  • java简易聊天程序

    目录 项目结构 TCP 窗体组成 server client properties 项目结构 TCP 窗体组成 server package cn itcast chat import javax swing import java awt
  • ChatGPTBox 沉浸式的感受ChatGPT带来的快感

    ChatGPT基础功能 1 自然流畅的对话 ChatGPT通过对海量对话数据的学习 具有自然流畅的对话能力 能够与用户进行逼真的自然语言交互 2 能够理解语境 ChatGPT能够理解语境 不仅能根据上下文生成回答 还能识别当前对话的主题 更
  • LabVIEW 读写和缩放音频文件

    LabVIEW 提供了多种方式来读取和写入 WAV 格式的音频文件 完成本模块后 您将能够使用位于 Programming Graphics Sound Sound Files 中的 Simple Read 和 Simple Write 用
  • 感性是什么意思

    感性是什么意思 2005 09 25 15 55 xinghuali 分类 恋爱 有人说自己很感性 不知到底是什么意思 人在这方面分两种 一种是理性 一种就是感性 理性是很理智的那种 就是做事都依据道理 不会冲动 而感性的就是凭着感觉来的那
  • 如何让学习变得有效率

    最近一直在反思这样一个问题 为什么我的学习如此的没有效率 来提高班近三年的时间 我几乎都在全日制学习中度过 可是我的速度并不快 原因在哪 在这里学习 米老师一遍遍强调 如何学习 如何打包 全局观才是我们在这里真正应该学的 可这些在我这些年的
  • redis HyperLogLog原理

    假设现在有一个这样的需求 我们想要实时统计有多少用户访问我们的网站 一个简单的解决方案是用一个set集合来存储用户ID 然后计算任意时刻集合中不同ID的个数即为网站实时访问量 这是一种简单可行的做法 但是假如这个网页访问量很大加上随着时间推
  • C++琐碎知识整理

    C 琐碎知识整理 二 1 C 与C一样 用终止符 terminator 将两条语句分开 终止符是一个分号 它是语句的结束标记 是语句的组成部分 而不是语句之间的标记 所以C 语句一定不能省略分号 2 通常 main 被启动代码调用 而启动代
  • HTML innerHTML属性用法及分析

    innerHTML 设置或返回表格行的开始和结束标签之间的 HTML 看它的英文单词也可以明白就是里面的字符按html标记的语言格式取出来或重新设置 innerHTML属性w3c标准不支持的 但是各大浏览器支持它的实现 innerHTML的
  • ModelAndView,Model和httpServletRequest

    一 参数绑定 1 默认支持类型 springmvc中 有支持默认类型的绑定 也就是说 直接在controller方法形参上定义默认类型的对象 就可以使用这些对象 HttpServletRequest对象 HttpServletRespons
  • WWW 2022 弯道超车:基于纯MLP架构的序列推荐模型

    作者 于辉 机构 中国科学院大学地质与地球物理研究所 研究方向 人工智能与固体地球物理学 作者 周昆 机构 中国人民大学信息学院 研究方向 序列表示学习 本文主要提出了一个基于纯MLP架构的序列化推荐模型 其通过可学习滤波器对用户序列进行编
  • 《Linux基础》02. 目录结构 · vi、vim · 关机 · 重启

    目录结构 1 目录结构 2 vi vim快速入门 2 1 vi 和 vim 的三种模式 2 1 1 一般模式 2 1 2 编辑模式 2 1 3 命令模式 2 2 常用快捷键 2 2 1 一般模式 2 2 2 命令模式 2 2 3 键盘图 3
  • 有关AngularJS请求Web API资源的思路

    页面部分大致如下 div div productManagement是页面module的名称 页面内容通过ng include加载productListView html这个页面 注意 ng include属性值是字符串 app produ
  • NUC980开源项目32-显示内核调试信息

    上面是我的微信和QQ群 欢迎新朋友的加入 编写一个简单的驱动代码 hello c include
  • C++ 面向对象三大特征总结(详解)

    1 面向对象的三大特征 1 封装 封装 将一个对象的全部的属性变量和行为方法进行包装 集中到一个类中 并用权限对其成员属性和成员方法加以限制 使得外部对其访问时 不能随意改变该包装 include
  • 5/26 博客 第四章 交换机基本原理与配置

    交换机 数据链路层的设备 数据帧数据链路层的作用 1 物理地址 网络拓扑的建立 维护 拆除 2 把数据封装在帧中 按顺序传送 3 差错恢复 重传 重新再发一次 4 流量控制 确保中间传输设备的稳定以及双发传输速率的匹配 数据链路层主要的工作
  • FPGM(Filter Pruning via Geometric Median)笔记

    原文地址 文章目录 1 创新点 2 解决了哪些问题 3 原理和算法流程 1 创新点 提出了一种新的过滤器剪枝方法 即通过几何中值的过滤器剪枝 FPGM 来压缩模型 与以前的方法不同 FPGM 通过修剪带有冗余信息的过滤器而不是那些重要性 相
  • docker 指定不同容器使用同一个网段

    问题描述 因为我使用了dockers compose 不同的服务在不同的docker compose文件中 所以当远程调用的时候出现了根据容器名访问失败的错误 基于此我准备在docker中创建一个network 然后让其他容器全都指定使用一
  • Faster Transformer

    背景 Transformer自2017年的 Attention is All you Need 提出以来 成为通用高效的特征提取器 虽然其在NLP TTS ASR CV等多个领域表现优异 但在推理部署阶段 其计算性能却存在巨大挑战 以BER
  • PyTorch学习日志_20201030_神经网络

    日期 2020 10 30 主题 PyTorch入门 内容 根据PyTorch官方教程文档 学习PyTorch中神经网络 包括 定义网络 损失函数 反向传播 更新权重 根据自己的理解和试验 为代码添加少量注解 具体代码如下 from fut