PyTorch 官方教程:撸一个神经网络

2023-11-01

本文为 PyTorch 官方教程中:如何构建神经网络。基于 PyTorch 专门构建神经网络的子模块 torch.nn 构建一个简单的神经网络。

完整教程运行 codelab

torch.nn 文档

神经网络由对数据执行操作的层/模块组成。torch.nn 提供了构建神经网络所需的所有模块。

PyTorch 中的每个模块都是 nn.module 的子类。
在下面的部分中,我们将构建一个神经网络来进行10种类别的分类。

建立神经网络

神经网络由对数据执行操作的层/模块组成。torch.nn 提供了构建神经网络所需的所有模块。PyTorch 中的每个模块都是 nn.module 的子类。
在下面的部分中,我们将构建一个神经网络来进行10种类别的分类。

import os
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

加载训练设备

我们希望能够在硬件加速器,比如 GPU 上训练我们的模型。可以通过 torch.cuda 来检测 GPU 是否可用。

device = 'cuda' if torch.cuda.is_available() else 'cpu' #检测gpu是否可用,不可用使用cpu
print('Using {} device'.format(device)) #输出使用设备类型

定义类

我们通过 nn.Module 来定义神经网络,并在__init__ 中初始化神经网络。每个 nn.Module 子类在 forward 方法中实现对输入数据的操作。

class NeuralNetwork(nn.Module):
    def __init__(self): #定义网络结构
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
            nn.ReLU()
        )

    def forward(self, x): #前向传播
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

在使用模型前需要先实例化模型,并将其移动到 GPU 上

model = NeuralNetwork().to(device) #实例化模型
print(model)

为了在模型的输入和输出之间创建复杂的非线性映射,需要使用非线性的激活函数。

它们在线性变换后引入非线性,帮助神经网络学习各种各样的复杂映射。在这个模型中,我们在线性层之间使用 nn.ReLU,也可以使用其他激活函数来引入非线性。

X = torch.rand(1, 28, 28, device=device)  #生成(1,28,28)的数据
logits = model(X) #向模型输入数据
pred_probab = nn.Softmax(dim=1)(logits) #调用softmax 将预测值映射为(0,1)间的概率
y_pred = pred_probab.argmax(1) #最大概率对应分类
print(f"Predicted class: {y_pred}")

神经网络各层说明

接下来,我们分解网络来具体讲述每一层的功能。

为了说明这一点,我们将取小批量的3个尺寸为28x28的图像样本输入网络

input_image = torch.rand(3,28,28) #生成(3,28,28)的数据
print(input_image.size())

nn.Flatten 层

Flatten 层用来把多维的输入一维化,常用在从卷积层到全连接层的过渡。

nn.Flatten 层,可以将每个 28x28 图像转换 784 ( 28 × 28 = 784 28\times 28=784 28×28=784)个像素值的连续数组(批量维度保持为3)。

flatten = nn.Flatten() 
flat_image = flatten(input_image) #(3,28,28)转换为(3,784)
print(flat_image.size())

nn.Linear 层

nn.Linear 层,即线性层,是一个使用权重和偏差对输入数据作线性变换的模块。

layer1 = nn.Linear(in_features=28*28, out_features=20) #输入(3,28*28) 输出(3,20)
hidden1 = layer1(flat_image)
print(hidden1.size())

nn.ReLU 层

为了在模型的输入和输出之间创建复杂的非线性映射,需要使用非线性的激活函数。它们在线性变换后引入非线性,帮助神经网络学习各种各样的复杂映射。

在这个模型中,我们在线性层之间使用 nn.ReLU,也可以使用其他激活函数来引入非线性。

print(f"Before ReLU: {hidden1}\n\n")
hidden1 = nn.ReLU()(hidden1)
print(f"After ReLU: {hidden1}")

nn.Sequential 层

神经网络的最后一个线性层返回 logits,即值域区间在 [ − ∞ , ∞ ] [-\infty,\infty] []中的原始值。这些值传递给nn.Softmax模块后,logit被缩放为 [ 0 , 1 ] [0,1] [01]区间中,表示模型对每个类的预测概率。

dim参数表示每一维度进行运算的位置,运算结果相加为1。

softmax = nn.Softmax(dim=1)
pred_probab = softmax(logits)

输出模型结构

神经网络中的许多层都是参数化的,即具有相关联的权重和偏差,这些参数在训练中被迭代优化。

子类 nn.Module 自动跟踪模型对象内部定义的所有字段,并使用模型的 parameters() 或 named_parameters() 方法访问所有参数。

我们可以通过模型迭代每个参数,并输出其尺寸和值。

print("Model structure: ", model, "\n\n")

for name, param in model.named_parameters():
    print(f"Layer: {name} | Size: {param.size()} | Values : {param[:2]} \n")

最终输出结果可访问完整教程

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

PyTorch 官方教程:撸一个神经网络 的相关文章

  • 通过 Scrapy 抓取 Google Analytics

    我一直在尝试使用 Scrapy 从 Google Analytics 获取一些数据 尽管我是一个完全的 Python 新手 但我已经取得了一些进展 我现在可以通过 Scrapy 登录 Google Analytics 但我需要发出 AJAX
  • Python 中的 Lanczos 插值与 2D 图像

    我尝试重新缩放 2D 图像 灰度 图像大小为 256x256 所需输出为 224x224 像素值范围从 0 到 1300 我尝试了两种使用 Lanczos 插值来重新调整它们的方法 首先使用PIL图像 import numpy as np
  • 使 django 服务器可以在 LAN 中访问

    我已经安装了Django服务器 可以如下访问 http localhost 8000 get sms http 127 0 0 1 8000 get sms 假设我的IP是x x x x 当我这样做时 从同一网络下的另一台电脑 my ip
  • OpenCV Python cv2.mixChannels()

    我试图将其从 C 转换为 Python 但它给出了不同的色调结果 In C Transform it to HSV cvtColor src hsv CV BGR2HSV Use only the Hue value hue create
  • PyUSB 1.0:NotImplementedError:此平台不支持或未实现操作

    我刚刚开始使用 pyusb 基本上我正在玩示例代码here https github com walac pyusb blob master docs tutorial rst 我使用的是 Windows 7 64 位 并从以下地址下载 z
  • 使用 on_bad_lines 将 pandas.read_csv 中的无效行写入文件

    我有一个 CSV 文件 我正在使用 Python 来解析该文件 我发现文件中的某些行具有不同的列数 001 Snow Jon 19801201 002 Crom Jake 19920103 003 Wise Frank 19880303 l
  • python 相当于 R 中的 get() (= 使用字符串检索符号的值)

    在 R 中 get s 函数检索名称存储在字符变量 向量 中的符号的值s e g X lt 10 r lt XVI s lt substr r 1 1 X get s 10 取罗马数字的第一个符号r并将其转换为其等效整数 尽管花了一些时间翻
  • 根据列值突出显示数据框中的行?

    假设我有这样的数据框 col1 col2 col3 col4 0 A A 1 pass 2 1 A A 2 pass 4 2 A A 1 fail 4 3 A A 1 fail 5 4 A A 1 pass 3 5 A A 2 fail 2
  • 是否可以忽略一行的pyright检查?

    我需要忽略一行的pyright 检查 有什么特别的评论吗 def create slog group SLogGroup data Optional dict None SLog insert one SLog group group da
  • OpenCV 无法从 MacBook Pro iSight 捕获

    几天后 我无法再从 opencv 应用程序内部打开我的 iSight 相机 cap cv2 VideoCapture 0 返回 并且cap isOpened 回报true 然而 cap grab 刚刚返回false 有任何想法吗 示例代码
  • 绘制方程

    我正在尝试创建一个函数 它将绘制我告诉它的任何公式 import numpy as np import matplotlib pyplot as plt def graph formula x range x np array x rang
  • 如何在ipywidget按钮中显示全文?

    我正在创建一个ipywidget带有一些文本的按钮 但按钮中未显示全文 我使用的代码如下 import ipywidgets as widgets from IPython display import display button wid
  • 如何在seaborn displot中使用hist_kws

    我想在同一图中用不同的颜色绘制直方图和 kde 线 我想为直方图设置绿色 为 kde 线设置蓝色 我设法弄清楚使用 line kws 来更改 kde 线条颜色 但 hist kws 不适用于显示 我尝试过使用 histplot 但我无法为
  • 每个 X 具有多个 Y 值的 Python 散点图

    我正在尝试使用 Python 创建一个散点图 其中包含两个 X 类别 cat1 cat2 每个类别都有多个 Y 值 如果每个 X 值的 Y 值的数量相同 我可以使用以下代码使其工作 import numpy as np import mat
  • 解释 Python 中的数字范围

    在 Pylons Web 应用程序中 我需要获取一个字符串 例如 关于如何做到这一点有什么建议吗 我是 Python 新手 我还没有找到任何可以帮助解决此类问题的东西 该列表将是 1 2 3 45 46 48 49 50 51 77 使用
  • 有人用过 Dabo 做过中型项目吗? [关闭]

    Closed 这个问题是基于意见的 help closed questions 目前不接受答案 我们正处于一个新的 ERP 风格的客户端 服务器应用程序的开始阶段 该应用程序是作为 Python 富客户端开发的 我们目前正在评估 Dabo
  • 发送用户注册密码,django-allauth

    我在 django 应用程序上使用 django alluth 进行身份验证 注册 我需要创建一个自定义注册表单 其中只有一个字段 电子邮件 密码将在服务器上生成 这是我创建的表格 from django import forms from
  • 使用 Python 的 matplotlib 选择在屏幕上显示哪些图形以及将哪些图形保存到文件中

    我想用Python创建不同的图形matplotlib pyplot 然后 我想将其中一些保存到文件中 而另一些则应使用show 命令 然而 show 显示all创建的数字 我可以通过调用来避免这种情况close 创建我不想在屏幕上显示的绘图
  • Rocket UniData/UniVerse:ODBC 无法分配足够的内存

    每当我尝试使用pyodbc连接到 Rocket UniData UniVerse 数据时我不断遇到错误 pyodbc Error 00000 00000 Rocket U2 U2ODBC 0302810 Unable to allocate
  • 导入错误:没有名为 site 的模块 - mac

    我已经有这个问题几个月了 每次我想获取一个新的 python 包并使用它时 我都会在终端中收到此错误 ImportError No module named site 我不知道为什么会出现这个错误 实际上 我无法使用任何新软件包 因为每次我

随机推荐

  • 用IDEA创建第一个SpringBoot程序,并开发一个JSON接口

    1 打开idea主界面选择 Create New Project 2 在弹出的页面中我们选择左侧的 Spring Initializr jdk版本选择自己安装的版本 PS jdk版本要1 8以上哦 3 下一个页面 在Group栏输入组织名
  • IDEA代码覆盖率测试

    代码覆盖率测试 1 使用idea自带的代码覆盖率工具 1 创建test文档 右击将 test 目录设置为测试文档 2 选中需要测试的类 按Ctrl shift T 创建测试类 并选中要测试的方法 在测试案例中 编写测试代码 点击Edit C
  • 小程序分包实现

    目录 一 使用场景 二 操作方式 1 建立分包文件夹 2 文件构建 3 文件配置 三 总结 一 使用场景 微小程序分包常用于代码量较大的小程序 发布时会受到大小限制 二 操作方式 1 建立分包文件夹 在项目根目录下创建分包文件夹 此处我创建
  • L1-8 乘法口诀数列

    本题要求你从任意给定的两个 1 位数字 a1 和 a2 开始 用乘法口诀生成一个数列 an 规则为从 a1 开始顺次进行 每次将当前数字与后面一个数字相乘 将结果贴在数列末尾 如果结果不是 1 位数 则其每一位都应成为数列的一项 输入格式
  • ad电阻原理图_光敏电阻的基础知识介绍

    39G电子技术 电路 电子元件等 全套资料免费领 干货下载 十天学会单片机完整版 100个实例 PPT 点击上方红字 即可获取 一 光敏电阻 光敏电阻是用硫化隔或硒化隔等半导体材料制成的特殊电阻器 表面还涂有防潮树脂 具有光电导效应 二 特
  • TCP 拥塞窗口原理

    学过网络相关课程的 都知道TCP中 有两个窗口 滑动窗口 在我们的上一篇文章中有讲 接收方通过通告发送方自己的可以接受缓冲区大小 这个字段越大说明网络吞吐量越高 从而控制发送方的发送速度 拥塞窗口 也就是本文要讲的 概念 一个连接的TCP双
  • element-plus elplus el-tree三种图标自定义 并且点击图标展开收起 点击文字获取数据

    前言 公司需求 需要实现如下样式的树形列表 基于vue3 element plus 当节点展开时 显示展开的文件夹图标 当节点收起时显示收起的文件夹 最后一级显示文件样式 废话没有了 代码如下
  • C规范编辑笔记(九)

    往期文章 C规范编辑笔记 一 C规范编辑笔记 二 C规范编辑笔记 三 C规范编辑笔记 四 C规范编辑笔记 五 C规范编辑笔记 六 C规范编辑笔记 七 C规范编辑笔记 八 正文 今天我们来分享一下C规范编辑笔记第九篇 话不多说 我们直接来看
  • 树莓派数据远程传输学习记录——TCP/IP协议连接OneNet云平台传输数据的方法

    目录 项目场景 问题描述 解决方案 OneNet云平台前期项目搭建准备 以网络调试助手模拟树莓派建立连接并发送数据 树莓派与OneNet云平台进行对接 最后总结 项目场景 本人在进行树莓派项目开发时进行数据远程传输 4G WiFi通信 过程
  • Spark 3.0.3 源码阅读及 idea 调试环境搭建

    目录 1 源码下载 2 源码解压并编译 3 使用 Idea 打开或导入 4 idea 调试环境设置 Master 设置 Worker 设置 1 源码下载 Downloads Apache Spark 2 源码解压并编译 编译前建议在环境变量
  • ingress 400 Bad Request The plain HTTP request was sent to HTTPS port

    问题现象 访问时返回400 Bad Request 并提示The plain HTTP request was sent to HTTPS port 问题原因 Ingress Controller到后端Pod请求使用了默认的HTTP请求 但
  • 效果:网页页面随机改变颜色+自定义样式背景颜色随机改变+20秒倒计时+时间一到马上跳转新页面

  • linux 安装flash

    2 將下载好的包拷到某个目录下并解压得到文件 得到如下libflashplayer so文件与usr文件夹 3 将libflashplayer so拷到firefox的插件目录 usr lib mozilla plugin 下 sudo c
  • 个人总结-基础算法

    文章目录 基础算法 各个算法的复杂度及稳定性等 冒泡排序 蛮力法 理解 函数代码 测试用例 选择排序 蛮力法 理解 函数代码 测试用例 归并排序 分治法 理解 函数代码 测试用例 快速排序 分治法 理解 函数代码 测试用例 插入排序 减治法
  • GitHub、GIT、Intellij集成github初探

    一 什么是Git 刚接触Git或github的童鞋可能会把它们的概念搞混淆 所以在这里稍微解释一下 Git和github是两个完全不同的概念 Git是一个版本管理系统 Version Control System 简称 VCS 早期版本管理
  • kafka-offset手动提交和自动提交

    目录 首先回顾之前的知识点 自动提交offset 手动提交 消费者poll消息的细节 完整代码 按照新方法进行消费消息 1 指定时间进行消息的消费 2 指定分区开始从头消费 指定分区的偏移量开始消费 新消费组的消费offset规则 首先回顾
  • 国内如何申请到Twitter API

    Tip Twitter Developer Platform 申请只能申请一次 被拒后该账户就不能再申请了 一点要做好详细的准备再提交申请 网上的申请教程有的很坑 几句话就提交申请通过了 几率很小 Twitter开发者平台的申请 记录三次申
  • vue+element文本域设置自适应和默认高度

  • ⛳ TCP 协议面试题

    目录 TCP 协议面试题 一 为什么关闭连接的需要四次挥 建 连接却只要三次握 呢 二 为什么连接建 的时候是三次握 可以改成两次握 吗 三 为什么主动断开 在TIME WAIT状态必须等待2MSL的时间 四 如果已经建 了连接 但是Cli
  • PyTorch 官方教程:撸一个神经网络

    本文为 PyTorch 官方教程中 如何构建神经网络 基于 PyTorch 专门构建神经网络的子模块 torch nn 构建一个简单的神经网络 完整教程运行 codelab torch nn 文档 神经网络由对数据执行操作的层 模块组成 t