PyTorch中nn.Module类简介

2023-11-20

      torch.nn.Module类是所有神经网络模块(modules)的基类,它的实现在torch/nn/modules/module.py中。你的模型也应该继承这个类,主要重载__init__、forward和extra_repr函数。Modules还可以包含其它Modules,从而可以将它们嵌套在树结构中。

      只要在自己的类中定义了forward函数,backward函数就会利用Autograd被自动实现。只要实例化一个对象并传入对应的参数就可以自动调用forward函数。因为此时会调用对象的__call__方法,而nn.Module类中的__call__方法会调用forward函数。

      nn.Module类中函数介绍:

      __init__:初始化内部module状态。

      register_buffer:向module添加buffer,不作为模型参数,可作为module状态的一部分。默认情况下,buffer是持久(persistent)的,将与参数一起保存。buffer是否persistent的区别在于这个buffer是否被放入self.state_dict()中被保存下来。

      register_parameter:向module添加参数。

      add_module:添加一个submodule(children)到当前module中。

      apply:将fn递归应用于每个submodule(children),典型用途为初始化模型参数。

      cuda:将所有模型参数和buffers转移到GPU上。

      xpu:将所有模型参数和buffers转移到XPU上。

      cpu:将所有模型参数和buffers转移到CPU上。

      type:将所有参数和buffers转换为所需的类型。

      float:将所有浮点参数和buffers转换为float32数据类型。

      double:将所有浮点参数和buffers转换为double数据类型。

      half:将所有浮点参数和buffers转换为float16数据类型。

      bfloat16:将所有浮点参数和buffers转换为bfloat16数据类型。

      to:将参数和buffers转换为指定的数据类型或转换到指定的设备上。

      register_backward_hook:在module中注册一个反向钩子。不推荐使用。

      register_full_backward_hook:在module中注册一个反向钩子。每次计算梯度时都会调用此钩子。使用此钩子时不允许就地(in place)修改输入或输出,否则会触发error。

      register_forward_pre_hook:在module中注册前向pre-hook。每次调用forward之前都会调用此钩子。

      register_forward_hook:在module中注册一个前向钩子。每次forward计算输出后都会调用此钩子。

      state_dict:返回包含了module的整个状态的字典。其中keys是对应的参数和buffer名称。

      load_state_dict:将参数和buffers从state_dict复制到module及其后代(descendants)中。

      parameters:返回module的参数的迭代器。

      named_parameters:返回module的参数的迭代器,产生(yield)参数的名称以及参数本身。不会返回重复的parameter。

      buffers:返回module的buffers的迭代器。

      named_buffers:返回module的buffers的迭代器,产生(yield)buffer的名称以及buffer本身。不会返回重复的buffer。

      children:返回直接子module的迭代器。

      named_children:返回直接子module的迭代器,产生(yield)子module的名称以及子module本身。不会返回重复的children。

      modules:返回网络中所有modules的迭代器。

      named_modules:返回网络中所有modules的迭代器,产生(yield)module的名称以及module本身。不会返回重复的module。

      train:将module设置为训练模式。这仅对某些module起作用。module.py实现中会修改self.training并通过self.children()来调整所有submodule的状态。

      eval:将module设置为评估模式。这仅对某些module起作用。module.py实现中直接调用train(False)。

      requires_grad_:更改autograd是否应记录对此module中参数的操作。此方法就地(in place)设置参数的requires_grad属性。

      zero_grad:将所有模型参数的梯度设置为零。

      share_memory:

      extra_repr:设置module的额外表示。你应该在自己的modules中重新实现此方法。

     测试代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as F # nn.functional.py中存放激活函数等的实现

@torch.no_grad()
def init_weights(m):
    print("xxxx:", m)
    if type(m) == nn.Linear:
         m.weight.fill_(1.0)
         print("yyyy:", m.weight)

class Model(nn.Module):
    def __init__(self):
        # 在实现自己的__init__函数时,为了正确初始化自定义的神经网络模块,一定要先调用super().__init__
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5) # submodule(child module)
        self.conv2 = nn.Conv2d(20, 20, 5)
        self.add_module("conv3", nn.Conv2d(10, 40, 5)) # 添加一个submodule到当前module,等价于self.conv3 = nn.Conv2d(10, 40, 5)
        self.register_buffer("buffer", torch.randn([2,3])) # 给module添加一个presistent(持久的) buffer
        self.param1 = nn.Parameter(torch.rand([1])) # module参数的tensor
        self.register_parameter("param2", nn.Parameter(torch.rand([1]))) # 向module添加参数

        # nn.Sequential: 顺序容器,module将按照它们在构造函数中传递的顺序添加,它允许将整个容器视为单个module
        self.feature = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
        self.feature.apply(init_weights) # 将fn递归应用于每个submodule,典型用途为初始化模型参数
        self.feature.to(torch.double) # 将参数数据类型转换为double
        cpu = torch.device("cpu")
        self.feature.to(cpu) # 将参数数据转换到cpu设备上

    def forward(self, x):
       x = F.relu(self.conv1(x))
       return F.relu(self.conv2(x))

model = Model()
print("## Model:", model)

model.cpu() # 将所有模型参数和buffers移动到CPU上
model.float() # 将所有浮点参数和buffers转换为float数据类型
model.zero_grad() # 将所有模型参数的梯度设置为零

# state_dict:返回一个字典,保存着module的所有状态,参数和persistent buffers都会包含在字典中,字典的key就是参数和buffer的names
print("## state_dict:", model.state_dict().keys())

for name, parameters in model.named_parameters(): # 返回module的参数(weight and bias)的迭代器,产生(yield)参数的名称以及参数本身
    print(f"## named_parameters: name: {name}; parameters size: {parameters.size()}")

for name, buffers in model.named_buffers(): # 返回module的buffers的迭代器,产生(yield)buffer的名称以及buffer本身
    print(f"## named_buffers: name: {name}; buffers size: {buffers.size()}")

# 注:children和modules中重复的module只被返回一次
for children in model.children(): # 返回当前module的child module(submodule)的迭代器
    print("## children:", children)

for name, children in model.named_children(): # 返回直接submodule的迭代器,产生(yield) submodule的名称以及submodule本身
    print(f"## named_children: name: {name}; children: {children}")

for modules in model.modules(): # 返回当前模型所有module的迭代器,注意与children的区别
    print("## modules:", modules)

for name, modules in model.named_modules(): # 返回网络中所有modules的迭代器,产生(yield)module的名称以及module本身,注意与named_children的区别
    print(f"## named_modules: name: {name}; module: {modules}")

model.train() # 将module设置为训练模式
model.eval() # 将module设置为评估模式

print("test finish")

     GitHubhttps://github.com/fengbingchun/PyTorch_Test

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

PyTorch中nn.Module类简介 的相关文章

随机推荐

  • API接口开发简述简单示例

    作为最流行的服务端语言PHP PHP Hypertext Preprocessor 在开发API方面 是很简单且极具优势的 API Application Programming Interface 应用程序接口 架构 已经成为目前互联网产
  • vue3.0删除node_modules 无用的依赖

    安装插件 npm i depcheck 查看无用的插件 npx depcheck 对应删除 npm uninstall kd layout
  • C++/C++11中变长参数的使用

    C C 11中的变长参数可以应用在宏 函数 模板中 1 宏 在C99标准中 程序员可以使用变长参数的宏定义 变长参数的宏定义是指在宏定义中参数列表的最后一个参数为省略号 而预定义宏 VA ARGS 则可以在宏定义的实现部分替换省略号所代表的
  • docker 安装tomcat遇到问题

    docker 安装 tomcat 启动 tomcat docker pull tomcat 8 默认启动 docker run d p 7788 8080 tomcat 8 进入容器 docker exec it 541d6c30c295
  • Spring源码分析refresh()第二篇

    invokeBeanFactoryPostProcessors方法 这个方法比较重要 实例化和调用所有已注册的BeanFactoryPostProcessor bean 如果有已经注入的BeanFactoryPostProcessor 则优
  • JavaScript 高级应用第一弹

    文章目录 Gorit 带你学全栈 JavaScript 高级应用 第一弹 一 数组篇 1 1 展开表达式 1 2 返回一个新数组 1 2 1 map 1 2 2 filter 1 2 3 concat 1 3 索引相关问题 1 4 返回 b
  • Qt 中遇到QLabel::setPixmap() 设置图片不起作用(图片被替换后还是显示替换前的图片)解决方法

    1 问题 当使用下面的命令设置图片后 第一次会成功显示图片 当我删除当前图片后并且用另一张图片重命名成先前删除的图片时 再次刷新显示还是先前删除的图片资源 重启软件又正常显示修改后的图片 ui gt label gt setPixmap Q
  • 3.[mybatis]的查询源码分析(执行流程、缓存、整合spring要点)

    目录 1 装饰器模式 2 sqlSession的创建 open 2 1 newExecutor 3 selectOne分析 3 1 二级缓存 3 2 一级缓存 4 数据库查询核心分析 queryFromDatabase 4 1 Simple
  • Wave x Incredibuild

    Wave 公司简介 Wave 是一家虚拟娱乐公司 致力于帮助艺术家和粉丝通过协作创造出世界上最具互动性的现场表演体验 Wave 整合了最顶尖的现场音乐 游戏和广播技术 将现场音乐表演转化为沉浸式虚拟体验 便于观众通过 YouTube Twi
  • java 模拟库存管理系统

    本案例要求编写一个程序 模拟库存管理系统 该系统内容主要包括 商品入库 商品显示 和删除商品功能 此程序用手机举例 此管理系统分别为两个类Phone 和Test类 Phone类 确定四个变量 类 1 生成空参数构造方法 2 全部参数的构造方
  • 经典的期货量化交易策略大全(含源代码)

    1 双均线策略 期货 双均线策略是简单移动平均线策略的加强版 移动平均线目的是过滤掉时间序列中的高频扰动 保留有用的低频趋势 它以滞后性的代价获得了平滑性 比如 在一轮牛市行情后 只有当价格出现大幅度的回撤之后才会在移动平均线上有所体现 而
  • 引介

    转载自 https ethfans org posts rlp encode and decode RLP编码和解码 RLP Recursive Length Prefix 递归的长度前缀 是一种编码规则 可用于编码任意嵌套的二进制数组数据
  • sqli-labs第26~28关

    第26关 查看源码 黑名单 过滤了 or and 空格 s 代表正则表达式中的一个空白字符 可能是空格 制表符 其他空白 即 s 用于匹配空白字符 我们常见的绕过空格的就是多行注释 但这里过滤了 不太行啊 将空格 or and 等各种符号过
  • [设计模式]模板方法模式(Template Method)

    1 意图 定义一个操作中的算法的骨架 而将一些步骤延迟到子类中 TemplateMethod使得子类可以不改变一个算法的结构即可重定义该算法的某些特定步骤 2 动机 其实就是如意图所描述的 算法的骨架是一样的 就是有些特殊步骤不一样 就可以
  • java一行代码实现RESTFul接口

    一 介绍spring data rest Spring Data REST是基于Spring Data的repository之上 可以把 repository 自动输出为REST资源 目前支持 Spring Data JPA Spring
  • vue3 vue-router 钩子函数

    全局路由守卫 vue router4 0中将next取消了 可写可不写 return false取消导航 undefined或者是return true验证导航通过 router beforeEach to from gt next是可选参
  • 大数据案例--电信日志分析系统

    目录 一 项目概述 1 概述 二 字段解释分析 1 数据字段 2 应用大类 3 应用小类 三 项目架构 四 数据收集清洗 1 数据收集 2 数据清洗 五 Sqoop使用 1 简介 2 Sqoop安装步骤 3 Sqoop的基本命令 六 数据导
  • 静态时序分析的三种分析模式(简述)

    经过跟行业前辈的探讨和参考一些书籍 本文中的 个人理解 部分有误 即 个人理解 在一个库中 尽管电路器件单元已经被综合映射 但是工具可以通过改变周围的环境来得到不同的单元延时 所以即使是同一个库 调用工艺参数不一样的情况下 其单元延时是不同
  • 黑客零基础入门方法有哪些?如何自学黑客技术?

    大家经常问我一个问题 黑客零基础入门方法有哪些 以及如何自学黑客技术 首先要说的是世界上大部分的网络黑客都是自学成才的 这与黑客这门技术有很大的原因 黑客是一个靠兴趣驱动的技术 大部分成为黑客的人一开始都是被黑客的酷炫身份所吸引从而成为黑客
  • PyTorch中nn.Module类简介

    torch nn Module类是所有神经网络模块 modules 的基类 它的实现在torch nn modules module py中 你的模型也应该继承这个类 主要重载 init forward和extra repr函数 Modul