【pytorch】Module.parameters()函数实现与网络参数管理

2023-11-02

我们知道可以通过Module.parameters()获取网络的参数,那这个是如何实现的呢?我先直接看看函数的代码实现:

    def parameters(self):
        r"""Returns an iterator over module parameters.

        This is typically passed to an optimizer.

        Yields:
            Parameter: module parameter

        Example::

            >>> for param in model.parameters():
            >>>     print(type(param.data), param.size())
            <class 'torch.FloatTensor'> (20L,)
            <class 'torch.FloatTensor'> (20L, 1L, 5L, 5L)

        """
        for name, param in self.named_parameters():
            yield param
    def named_parameters(self, memo=None, prefix=''):
        r"""Returns an iterator over module parameters, yielding both the
        name of the parameter as well as the parameter itself

        Yields:
            (string, Parameter): Tuple containing the name and parameter

        Example::

            >>> for name, param in self.named_parameters():
            >>>    if name in ['bias']:
            >>>        print(param.size())

        """
        if memo is None:
            memo = set()
        #本身模块的参数
        for name, p in self._parameters.items():
            if p is not None and p not in memo:
                memo.add(p)
                yield prefix + ('.' if prefix else '') + name, p
        for mname, module in self.named_children():
            submodule_prefix = prefix + ('.' if prefix else '') + mname
            #递归取得子模块的参数
            for name, p in module.named_parameters(memo, submodule_prefix):
                yield name, p

可以看到是通过枚举模块和子模块(成员对象是Module类型)的成员_parameters,那_parameters是什么?我先不着急,我们先看Module的一些实现,首先看下初始化函数:

    def __init__(self):
        self._backend = thnn_backend
        self._parameters = OrderedDict()
        self._buffers = OrderedDict()
        self._backward_hooks = OrderedDict()
        self._forward_hooks = OrderedDict()
        self._forward_pre_hooks = OrderedDict()
        self._modules = OrderedDict()
        self.training = True

可以看到_parameters(也留意_modules ) 其实是有序字典。

接着我们看下函数__setattr__(self, name, value)

   def __setattr__(self, name, value):
        def remove_from(*dicts):
            for d in dicts:
                if name in d:
                    del d[name]

        params = self.__dict__.get('_parameters')
        #如果成员是Parameter类型
        if isinstance(value, Parameter):
            if params is None:
                raise AttributeError(
                    "cannot assign parameters before Module.__init__() call")
            remove_from(self.__dict__, self._buffers, self._modules)
            self.register_parameter(name, value)
        elif params is not None and name in params:
            if value is not None:
                raise TypeError("cannot assign '{}' as parameter '{}' "
                                "(torch.nn.Parameter or None expected)"
                                .format(torch.typename(value), name))
            self.register_parameter(name, value)
        else:
            modules = self.__dict__.get('_modules')
            #如果成员是Module类型
            if isinstance(value, Module):
                if modules is None:
                    raise AttributeError(
                        "cannot assign module before Module.__init__() call")
                remove_from(self.__dict__, self._parameters, self._buffers)
                modules[name] = value
            elif modules is not None and name in modules:
                if value is not None:
                    raise TypeError("cannot assign '{}' as child module '{}' "
                                    "(torch.nn.Module or None expected)"
                                    .format(torch.typename(value), name))
                modules[name] = value
            else:
                buffers = self.__dict__.get('_buffers')
                if buffers is not None and name in buffers:
                    if value is not None and not isinstance(value, torch.Tensor):
                        raise TypeError("cannot assign '{}' as buffer '{}' "
                                        "(torch.Tensor or None expected)"
                                        .format(torch.typename(value), name))
                    buffers[name] = value
                else:
                    object.__setattr__(self, name, value)

我们知道如果类实现了该函数,赋值类成员时,将调用该函数。可以看到如果赋值类成员的对象是Parameter类型,那么将调用函数register_parameter注册参数,看该函数实现,其实是添加参数到有序字典成员_parameters中:

    def register_parameter(self, name, param):
        r"""Adds a parameter to the module.

        The parameter can be accessed as an attribute using given name.

        Args:
            name (string): name of the parameter. The parameter can be accessed
                from this module using the given name
            parameter (Parameter): parameter to be added to the module.
        """
        if '_parameters' not in self.__dict__:
            raise AttributeError(
                "cannot assign parameter before Module.__init__() call")

        elif not isinstance(name, torch._six.string_classes):
            raise TypeError("parameter name should be a string. "
                            "Got {}".format(torch.typename(name)))
        elif '.' in name:
            raise KeyError("parameter name can't contain \".\"")
        elif name == '':
            raise KeyError("parameter name can't be empty string \"\"")
        elif hasattr(self, name) and name not in self._parameters:
            raise KeyError("attribute '{}' already exists".format(name))

        if param is None:
            self._parameters[name] = None
        elif not isinstance(param, Parameter):
            raise TypeError("cannot assign '{}' object to parameter '{}' "
                            "(torch.nn.Parameter or None required)"
                            .format(torch.typename(param), name))
        elif param.grad_fn:
            raise ValueError(
                "Cannot assign non-leaf Tensor to parameter '{0}'. Model "
                "parameters must be created explicitly. To express '{0}' "
                "as a function of another Tensor, compute the value in "
                "the forward() method.".format(name))
        else:
            self._parameters[name] = param

所以通过调用Module.parameters()获取网络的参数,有一部分是类成员中的Parameter对象,是不是全部呢?我们后面看

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【pytorch】Module.parameters()函数实现与网络参数管理 的相关文章

  • Python、Tkinter、更改标签颜色

    有没有一种简单的方法来更改按钮中文本的颜色 I use button text input text here 更改按下后按钮文本的内容 是否存在类似的颜色变化 button color red Use the foreground设置按钮
  • Python PAM 模块的安全问题?

    我有兴趣编写一个 PAM 模块 该模块将利用流行的 Unix 登录身份验证机制 我过去的大部分编程经验都是使用 Python 进行的 并且我正在交互的系统已经有一个 Python API 我用谷歌搜索发现pam python http pa
  • 使用 openCV 对图像中的子图像进行通用检测

    免责声明 我是计算机视觉菜鸟 我看过很多关于如何在较大图像中查找特定子图像的堆栈溢出帖子 我的用例有点不同 因为我不希望它是具体的 而且我不确定如何做到这一点 如果可能的话 但我感觉应该如此 我有大量图像数据集 有时 其中一些图像是数据集的
  • 如何使用固定的 pandas 数据框进行动态 matplotlib 绘图?

    我有一个名为的数据框benchmark returns and strategy returns 两者具有相同的时间跨度 我想找到一种方法以漂亮的动画风格绘制数据点 以便它显示逐渐加载的所有点 我知道有一个matplotlib animat
  • DreamPie 不适用于 Python 3.2

    我最喜欢的 Python shell 是DreamPie http dreampie sourceforge net 我想将它与 Python 3 2 一起使用 我使用了 添加解释器 DreamPie 应用程序并添加了 Python 3 2
  • pandas 替换多个值

    以下是示例数据框 gt gt gt df pd DataFrame a 1 1 1 2 2 b 11 22 33 44 55 gt gt gt df a b 0 1 11 1 1 22 2 1 33 3 2 44 4 3 55 现在我想根据
  • 如何使用 Scrapy 从网站获取所有纯文本?

    我希望在 HTML 呈现后 可以从网站上看到所有文本 我正在使用 Scrapy 框架使用 Python 工作 和xpath body text 我能够获取它 但是带有 HTML 标签 而且我只想要文本 有什么解决办法吗 最简单的选择是ext
  • Spark的distinct()函数是否仅对每个分区中的不同元组进行洗牌

    据我了解 distinct 哈希分区 RDD 来识别唯一键 但它是否针对仅移动每个分区的不同元组进行了优化 想象一个具有以下分区的 RDD 1 2 2 1 4 2 2 1 3 3 5 4 5 5 5 在此 RDD 上的不同键上 所有重复键
  • 为 pandas 数据透视表中的每个值列定义 aggfunc

    试图生成具有多个 值 列的数据透视表 我知道我可以使用 aggfunc 按照我想要的方式聚合值 但是如果我不想对两列求和或求平均值 而是想要一列的总和 同时求另一列的平均值 该怎么办 那么使用 pandas 可以做到这一点吗 df pd D
  • Python tcl 未正确安装

    我刚刚为 python 安装了graphics py 但是当我尝试运行以下代码时 from graphics import def main win GraphWin My Circle 100 100 c Circle Point 50
  • 安装后 Anaconda 提示损坏

    我刚刚安装张量流GPU创建单独的后环境按照以下指示here https github com antoniosehk keras tensorflow windows installation 但是 安装后当我关闭提示窗口并打开新航站楼弹出
  • IRichBolt 在storm-1.0.0 和 pyleus-0.3.0 上运行拓扑时出错

    我正在运行风暴拓扑 pyleus verbose local xyz topology jar using storm 1 0 0 pyleus 0 3 0 centos 6 6并得到错误 线程 main java lang NoClass
  • feedparser 在脚本运行期间失败,但无法在交互式 python 控制台中重现

    当我运行 eclipse 或在 iPython 中运行脚本时 它失败了 ascii codec can t decode byte 0xe2 in position 32 ordinal not in range 128 我不知道为什么 但
  • 当玩家触摸屏幕一侧时,如何让 pygame 发出警告?

    我使用 pygame 创建了一个游戏 当玩家触摸屏幕一侧时 我想让 pygame 给出类似 你不能触摸屏幕两侧 的错误 我尝试在互联网上搜索 但没有找到任何好的结果 我想过在屏幕外添加一个方块 当玩家触摸该方块时 它会发出警告 但这花了很长
  • 使用 OpenPyXL 迭代工作表和单元格,并使用包含的字符串更新单元格[重复]

    这个问题在这里已经有答案了 我想使用 OpenPyXL 来搜索工作簿 但我遇到了一些问题 希望有人可以帮助解决 以下是一些障碍 待办事项 我的工作表和单元格数量未知 我想搜索工作簿并将工作表名称放入数组中 我想循环遍历每个数组项并搜索包含特
  • Python - 按月对日期进行分组

    这是一个简单的问题 起初我认为很简单而忽略了它 一个小时过去了 我不太确定 所以 我有一个Python列表datetime对象 我想用图表来表示它们 x 值是年份和月份 y 值是此列表中本月发生的日期对象的数量 也许一个例子可以更好地证明这
  • Numpy 优化

    我有一个根据条件分配值的函数 我的数据集大小通常在 30 50k 范围内 我不确定这是否是使用 numpy 的正确方法 但是当数字超过 5k 时 它会变得非常慢 有没有更好的方法让它更快 import numpy as np N 5000
  • 如何使用google colab在jupyter笔记本中显示GIF?

    我正在使用 google colab 想嵌入一个 gif 有谁知道如何做到这一点 我正在使用下面的代码 它并没有在笔记本中为 gif 制作动画 我希望笔记本是交互式的 这样人们就可以看到代码的动画效果 而无需运行它 我发现很多方法在 Goo
  • 循环标记时出现“ValueError:无法识别的标记样式 -d”

    我正在尝试编码pyplot允许不同标记样式的绘图 这些图是循环生成的 标记是从列表中选取的 为了演示目的 我还提供了一个颜色列表 版本是Python 2 7 9 IPython 3 0 0 matplotlib 1 4 3 这是一个简单的代
  • Python:元类属性有时会覆盖类属性?

    下面代码的结果让我感到困惑 class MyClass type property def a self return 1 class MyObject object metaclass MyClass a 2 print MyObject

随机推荐

  • 【大模型】更强的 LLaMA2 来了,开源可商用、与 ChatGPT 齐平

    大模型 可商用且更强的 LLaMA2 来了 LLaMA2 简介 论文 GitHub huggingface 模型列表 训练数据 训练信息 模型信息 许可证 参考 LLaMA2 简介 2023年7月19日 Meta 发布开源可商用模型 Lla
  • 合并有序数组

    合并两个有序数组 描述 给你两个有序整数数组 nums1 和 nums2 请你将 nums2 合并到 nums1 中 使 num1 成为一个有序数组 说明 初始化 nums1 和 nums2 的元素数量分别为 m 和 n 你可以假设 num
  • Pytest+selenium+allure+Jenkins自动化测试框架搭建及使用

    一 环境搭建 1 Python下载及安装 Python可应用于多平台包括windows Linux 和 Mac OS X 本文主要介绍windows环境下 你可以通过终端窗口输入 python 命令来查看本地是否已经安装Python以及Py
  • 软件测试22种测试方法与详解

    黑盒测试 不基于内部设计和代码的任何知识 而是基于需求和功能性 白盒测试 基于一个应用代码的内部逻辑知识 测试是基于覆盖全部代码 分支 路径 条件 单元测试 最微小规模的测试 以测试某个功能或代码块 典型地由程序员而非测试员来做 因为它需要
  • 用js制作一个视觉差背景

    我在网上冲浪的时候看到了一个文字和背景下滑速度不一致的情况 这看起来背景会有一种3d的感觉 于是研究了一下 首先先写出大概的html和css div class box div class bg div h2 我是一个文字 h2 p 我是一
  • 算法实验题1

    第一题 由1 3 4 5 7 8这6个数字组成六位数中 能被11整除的最大的数是多少 解答 可以使用暴力枚举法 将1 3 4 5 7 8的所有排列组合情况求出来 判断它们是否能被11整除 然后取其中能被11整除的最大值 但是这个方法的时间复
  • 蓝桥杯 第6天 动态规划(4)

    目录 1 121 买卖股票的最佳时机 力扣 LeetCode leetcode cn com 1 暴力解法 2 动态规划 2 122 买卖股票的最佳时机 II 力扣 LeetCode leetcode cn com 3 123 买卖股票的最
  • uni-app 页面样式

    页面样式与布局 尺寸单位 uni app 支持的通用 css 单位包括 px rpx px 即屏幕像素 rpx 即响应式px 一种根据屏幕宽度自适应的动态单位 以750宽的屏幕为基准 750rpx恰好为屏幕宽度 屏幕变宽 rpx 实际显示效
  • C++整数转成二进制方法总结

    经常遇到要用到二进制的情况 这里我就记录下 1 逐次经典位操作 返回一个含有二进制数的vector include
  • 【深度学习之图像理解】图像分类、物体检测、物体分割、实例分割、语义分割的区别

    Directions in the CV 物体分割 Object segment 属于图像理解范畴 那什么是图像理解 Image Understanding IU 领域包含众多sub domains 如图像分类 物体检测 物体分割 实例分割
  • 前端zip.js实现加密打包上传文件

    背景 一方面 部分系统对文件的私密性和安全性要求较高 实现前端加密打包 服务端不存储密码 下载时手动输入密钥并解压文件 另一方面 传输压缩包到客户端 节约了带宽 节约了传输时间 使用的库 zip js Support of the Zip6
  • List写入Excel,poi操作

    前言 公司最近需要将所有的报表导出集中到报表中心系统中 需要做一个通用的Excel工具类 让各个业务系统简单高效的生成Excel报表 由于原先各个业务系统生成报表方式都不一样 有的地方还直接使用了CSV 因此需要统一生成Excel 本来想用
  • Android Studio 可以正常编译运行 但是代码爆红

    这段时间毕设选题 选了一个自己曾经做过的题目 因为之前是用Android Studio2 3 3写的 现在导入Android Studio 3 2 1 代码报错 但是能正常编译运行 很是奇怪 主要报错原因是 找不到有些类 之前用Androi
  • zookeeper(二)——2PC理论、zookeeper集群、ZAB 协议

    一 关于 2PC 提交 Two Phase Commitment Protocol 当一个事务操作需要跨越多个分布式节点的时候 为了保持事务处理的 ACID特性 就需要引入一个 协调者 TM 来统一调度所有分布式节点的执行逻辑 这些被调度的
  • CTFshow php特性 web111

    目录 源码 思路 题解 总结 源码
  • adworld-crypto-banana_princess

    拿到了一个打不开的pdf文件 用hex editor打开一下看看 再看一下正常的pdf文件 猜测是用了rot13映射了一下字母字符 解密脚本 def load data filename content with open filename
  • SpringBoot线程上下文传递数据

    1 底层实现 使用ThreadLocal 使用方法 public T get public void set T value public void remove 2 自定义上下文 package com ybw context confi
  • golang 闭包函数的应用技巧

    一 有名函数和匿名函数 函数变量类型初始值为nil 函数字面量类型的语法表达格式是 func InputTypeList OutputTypeList 无参函数 func fun var f func 无入参无返回值的函数对象声明 初始值为
  • 以太坊教程:入门学习开发以太坊dapp

    一 区块链 1 分布式去中心化 比特币设计的初衷就是要避免依赖中心化的机构 没有发行机构 也不可能操纵发行数量 既然没有中心化的信用机构 在电子货币运行的过程中 也势必需要一种机制来认可运行在区块链上的行为 包括比特币的运营 亦或是运行在区
  • 【pytorch】Module.parameters()函数实现与网络参数管理

    我们知道可以通过Module parameters 获取网络的参数 那这个是如何实现的呢 我先直接看看函数的代码实现 def parameters self r Returns an iterator over module paramet