Pytorch中的torch.nn.Linear()方法的详解

2023-11-08

torch.nn.Linear()作为深度学习中最简单的线性变换方法,其主要作用是对输入数据应用线性转换,先来看一下官方的解释及介绍:

class Linear(Module):
    r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`

    This module supports :ref:`TensorFloat32<tf32_on_ampere>`.

    Args:
        in_features: size of each input sample
        out_features: size of each output sample
        bias: If set to ``False``, the layer will not learn an additive bias.
            Default: ``True``

    Shape:
        - Input: :math:`(N, *, H_{in})` where :math:`*` means any number of
          additional dimensions and :math:`H_{in} = \text{in\_features}`
        - Output: :math:`(N, *, H_{out})` where all but the last dimension
          are the same shape as the input and :math:`H_{out} = \text{out\_features}`.

    Attributes:
        weight: the learnable weights of the module of shape
            :math:`(\text{out\_features}, \text{in\_features})`. The values are
            initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where
            :math:`k = \frac{1}{\text{in\_features}}`
        bias:   the learnable bias of the module of shape :math:`(\text{out\_features})`.
                If :attr:`bias` is ``True``, the values are initialized from
                :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
                :math:`k = \frac{1}{\text{in\_features}}`

    Examples::

        >>> m = nn.Linear(20, 30)
        >>> input = torch.randn(128, 20)
        >>> output = m(input)
        >>> print(output.size())
        torch.Size([128, 30])
    """
    __constants__ = ['in_features', 'out_features']
    in_features: int
    out_features: int
    weight: Tensor

    def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
        super(Linear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(out_features, in_features))
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self) -> None:
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

    def forward(self, input: Tensor) -> Tensor:
        return F.linear(input, self.weight, self.bias)

    def extra_repr(self) -> str:
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None
        )


# This class exists solely for Transformer; it has an annotation stating
# that bias is never None, which appeases TorchScript

这里我们主要看__init__()方法,很容易知道,当我们使用这个方法时一般需要传入2~3个参数,分别是in_features: int, out_features: int, bias: bool = True,第三个参数是说是否加偏置(bias),简单来讲,这个函数其实就是一个'一次函数':y = xA^T + b,(T表示张量A的转置),首先super(Linear, self).__init__()就是老生常谈的方法,之后初始化in_features和out_features,接下来就是比较重要的weight的设置,我们可以很清晰的看到weight的shape是(out_features,in_features)的,而我们在做xA^T时,并不是x和A^T相乘的,而是x和A.weight^T相乘的,这里需要大大留意,也就是说先对A做转置得到A.weight,然后在丢入y = xA^T + b中,得出结果。

接下来奉上一个小例子来实践一下:

import torch

# 随机初始化一个shape为(128,20)的Tensor
x = torch.randn(128,20)
# 构造线性变换函数y = xA^T + b,且参数(20,30)指的是A的shape,则A.weight的shape就是(30,20)了
y= torch.nn.Linear(20,30)
output = y(x)
# 按照以上逻辑使用torch中的简单乘法函数进行检验,结果很显然与上述符合
# 下面的y.weight可以理解为一个shape为(30,20)的一个可学习的矩阵,.t()表示转置
# y.bias若为TRUE,则bias是一个Tensor,且其shape为out_features,在该程序中应为30
# 更加细致的表达一下y = (128 * 20) * (30 * 20)^T + (if bias (1,30) ,else: 0)
ans = torch.mm(x,y.weight.t())+y.bias
print('ans.shape:\n',ans.shape)
print(torch.equal(ans,output))

由于本人水平有限,难免出现错误,欢迎大佬批评指正~

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Pytorch中的torch.nn.Linear()方法的详解 的相关文章

随机推荐

  • vscode html文件自动补充html骨架失效

    vscode html文件自动补充html骨架失效 输入 Tab键补全html骨架失效 解决办法 1 让html文件处于编辑状态 按下快捷键ctrl shift p 2 在跳转的对话框里面输入change language mode 在下拉
  • 什么是动态库?什么又是静态库?(如何生成/如何使用 ! ! !)

    动态库 静态库 目录 一 gcc g 的链接方式 1 动态链接 2 静态链接 二 库的优缺点 1 动态库的优点 2 静态库的优点 三 库的生成 四 库的使用 目录 一 gcc g 的链接方式 对于我们编译一段程序经常会需要调用一个函数库 就
  • 浅谈对于servlet的见解

    众所周知 我们创建一个javaweb项目后 在客户端想要访问服务器 得发起http请求 服务器对请求会进行响应 看似简单的请求和响应有很大的门道 虽然我们都会用servlet但是不乏有人不懂其中的原理 接下来我就浅谈一下servlet的一些
  • 获取当前IP地址,跳转到对应城市网站。

    博客迁移 时空蚂蚁http cui cuihongbo com 1 代码 index php
  • COCO-stuff用法

    COCO stuff API 1 是 COCO API 2 的扩展 安装见 3 这里研究一下 COCO stuff 的用法 Files 下载链见 4 image 训练集图片 train2017 zip 验证集图片 val2017 zip 分
  • Vue+ElementUI实现将数据库中的数字展示成对应汉字

    需求 数据库中存的是数字类型 需要展示成对应的汉字 其中 1 gt 部级 2 gt 省级 3 gt 市级 4 gt 其他 dvIdxIndexList里面是从后台查到的结果集 我们首直接用map遍历后台返回的结果集 利用里面的回调对数据进行
  • 在windows10的系统下安装MySQL

    简单介绍一下 Mysql workbench的安装教程 官方网址 https www mysql com downloads 下拉到最下面 点击进去 下载这两个软件 分别是 mysql的具体网址 https dev mysql com do
  • 记一次在forEach循环中使用异步代码无效

    背景 代码如下 const res1 await getOrderPackage XM LX 95 入院检查套餐 const res2 await getOrderPackage XM LX 98 入院检验套餐 const res res1
  • 将windows10 的编码修改为UTF-8

    临时修改 只作用于当前窗口 先进入cmd命令窗口 快捷键win键 R 直接输入 chcp 65001 然后回车键 Enter键 执行 这时候本次打开的窗口编码就已经是UTF 8了 永久修改 win键 R 然后在输入框输入regedit 确定
  • mediapipe教程5:在安卓上运行mediapipe的handTracking

    一 前言和准备见mediapipe教程4 这篇博客开门见山 直接来步骤 二 在安卓上运行mediapipe的handTracking 参考网址 步骤 https google github io mediapipe getting star
  • 手游服务器微信互通,9月14日部分服务器数据互通公告

    尊敬的轩辕勇士们 轩辕传奇手游 开放测试以来人气沸腾 各种战斗的激烈程度也随之升级 为了让勇士们尽享更刺激 更热血的战斗 我们计划于9月14日6 00 9 00期间进行数据互通操作 数据互通期间 相关服务器将暂时无法进入 造成您的不便 恳请
  • 卷积神经网络的深入理解-归一化篇(Batch Normalization具体实例)

    卷积神经网络的深入理解 归一化篇 标准化 归一化 神经网络中主要用在激活之前 卷积之后 持续补充 归一化在网络中的作用 1 线性归一化 进行线性拉伸 可以增加对比度 2 零均值归一化 像素值 均值 方差 3 Batch Normalizat
  • 多路复用select、poll、epoll总结

    多路复用select poll epoll总结 一 多路复用 IO多路复用是指使用单个线程同时处理多个IO请求 在IO多路复用模型中一个线程可以监视多个文件描述符 fd 一旦某个fd就绪 读 写就绪 或者超时 就能够通知应用程序进行相应的读
  • 获取宝塔Linux面板登陆地址账号和密码

    在ssh终端输入 etc init d bt default
  • Realtime Multi-Person 2D Pose Estimation Using Part Affinity Fields

    Realtime Multi Person 2D Pose Estimation Using Part Affinity Fields 1 文章概要 文章实现了图片中的多人姿态检测 与已有的方法相比 最大的优势在于检测的速度对人物的数量不敏
  • python三方库是什么_python第三方库有哪几种

    在对于python的使用上 除了要掌握基本的操作方法外 如果有一些好用的工具辅助 效果也是非常明显的 为了能够给大家提供最大的帮助 python中的第三方库的种类也是非常多的 本篇挑选了使用功能强大 且比较好用的第三方类 整理出了它们的一些
  • R数据处理——按符号分割数据&统计两列数据组合的频数

    初始数据格式 数据格式如图所示 每个Keywords里面都含有多个关键词 使用分号 间隔开 一共有ABCDEF六个group 并且关键词有重复 最终想要的数据格式 统计所有不重复的关键词在六个group中出现的频次 使用R来处理 rm li
  • SAP CO TCODE

    CO 主数据 利润中心主数据维护 标准层次 KCH5N KCH6N 利润中心 组 非标准层次 KE51 KE52 KE53 利润中心 KCH1 KCH2 KCH3 利润中心组 成本中心主数据维护 标准层次 OKEON OKENN 成本中心
  • Spring Boot:让你轻松掌握自动装配的奥秘

    Spring Boot是基于Spring框架开发的一种应用框架 它通过自动装配机制 大大简化了Spring应用的开发和部署 使开发者可以更加专注于业务逻辑的实现 而无需过多关注Bean的实例化和装配过程 本文将从以下几个方面介绍Spring
  • Pytorch中的torch.nn.Linear()方法的详解

    torch nn Linear 作为深度学习中最简单的线性变换方法 其主要作用是对输入数据应用线性转换 先来看一下官方的解释及介绍 class Linear Module r Applies a linear transformation