Pytorch(六)(模型参数的遍历) —— model.parameters() & model.named_parameters() & model.state_dict()

2023-11-08

神经网络的模型参数

model.parameters(), model.named_parameters(), model.state_dict() 这三个方法都可以查看神经网络的参数信息,用于更新参数,或者用于模型的保存。作用都类似,写法略有出入

就以Pytorch之经典神经网络(一) —— 全连接网络(MNIST) 来举例 Pytorch之经典神经网络CNN(一) —— 全连接网络 / MLP (MNIST) (trainset和Dataloader & batch training & learning_rate)_hxxjxw的博客-CSDN博客   

print(*[name for name, _ in self.model.named_parameters()], sep='\n')
print(*set([name.split('.')[0] for name, _ in self.named_parameters()]), sep='\n')
查看网络模型参数是否可训练
print(*[_.requires_grad for name, _ in model.named_parameters()], sep='\n')

model.named_parameters()

net.named_parameters()中param是len为2的tuple
param[0]是name,fc1.weight、fc1.bias等
param[1]是fc1.weight、fc1.bias等对应的值

一直是0,1,2,......, 这种序号

for _,param in enumerate(net.named_parameters()):
    print(param[0])
    print(param[1])
    print('----------------')

model.parameters()

net.parameters()中param就是fc1.weight、fc1.bias等对应的值,没带名字

for _,param in enumerate(net.parameters()):
    print(param)
    print('----------------')

model.state_dict()

net.state_dict() 中的param就只是str字符串 fc1.weight, fc1.bias等等

但它们可以作为参数来输出对应的值

for _,param in enumerate(net.state_dict()):
    print(param)
    print(net.state_dict()[param])
    print('----------------')

神经网络的各个层

当神经网络是这么定义的时候,即没有用nn.Sequential()

此时 print(net)

net = Net()
print(net)

输出单个的网络层

net = Net()
print(net.fc1)
print(net.fc2)
print(net.fc3)

输出各个网络层的weight,bias参数

net = Net()
print(net.fc1.weight)
print(net.fc1.bias)
print(net.fc2.weight)
print(net.fc2.bias)
print(net.fc3.weight)
print(net.fc3.bias)

当使用nn.Sequential定义的时候

import torch
import torchvision
from torchvision import transforms
from matplotlib import pyplot as plt
from torch import nn
from torch.nn import functional as F

from utils import plot_image,plot_curve,one_hot

# class Net(nn.Module):
#     def __init__(self):
#         super(Net, self).__init__()
#
#         #三层全连接层
#         #wx+b
#         self.fc1 = nn.Linear(28*28, 256)
#         self.fc2 = nn.Linear(256,64)
#         self.fc3 = nn.Linear(64,10)
#
#     def forward(self, x):
#         x = F.rule(self.fc1(x)) #F.relu和torch.relu,用哪个都行
#         x = F.relu(self.fc2(x))
#         x = F.relu(self.fc(3))
#
#         return x


class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()

        self.fc = nn.Sequential(
            nn.Linear(28 * 28, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, 10)
        )

        def forward(self, x):
            # x: [b, 1, 28, 28]
            # h1 = relu(xw1+b1)
            x = self.fc(x)

            return x

batch_size = 512
#一次处理的图片的数量
#gpu一次可以处理并行多张图片

transform = transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.1307,), (0.3081,))
])


trainset = torchvision.datasets.MNIST(
    root='dataset/',
    train=True,  #如果为True,从 training.pt 创建数据,否则从 test.pt 创建数据。
    download=True, #如果为true,则从 Internet 下载数据集并将其放在根目录中。 如果已下载数据集,则不会再次下载。
    transform=transform
)
#train=True表示是训练数据,train=False是测试数据

train_loader = torch.utils.data.DataLoader(
    dataset=trainset,
    batch_size=batch_size,
    shuffle=True  #在加载的时候将图片随机打散
)

testset = torchvision.datasets.MNIST(
    root='dataset/',
    train=False,
    download=True,
    transform=transform
)

train_loader = torch.utils.data.DataLoader(
    dataset=testset,
    batch_size=batch_size,
    shuffle=True
)

net = Net()
print(net.fc)
print(net.fc[0])
print(net.fc[1])
print(net.fc[2])
print(net.fc[3])
print(net.fc[4])
print(net.fc[0].weight)
print(net.fc[0].bias)


本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Pytorch(六)(模型参数的遍历) —— model.parameters() & model.named_parameters() & model.state_dict() 的相关文章

  • 复旦计算机学硕408,又一所院校专业课改投408——复旦大学

    原标题 又一所院校专业课改投408 复旦大学 复旦大学简称 复旦 位于中国上海 位列211工程 985工程 入选双一流 是一所综合性研究型的全国重点大学 学校现有一级学科国家重点学科11个 二级学科国家重点学科19个 国家重点 培育 学科3
  • 类的加载详解

    到目前为止 我们已经写了无数个类了 但是具体它在Java虚拟机中到底是怎么实现的 我们还从未探索过 今天就带着大家一起初探一下jvm对类加载的过程 目前博主技术水平有限 以后随着技术的更加成熟 会更新博客内容的 也欢迎更多小伙伴持续关注 和
  • Ubuntu10下SSH2协议安装

    Ubuntu10下SSH2协议安装 SSH2是一套安全通讯协议框架 早期的SSH1由于存在安全漏洞 现在已经不用了 基于SSH2协议的产品目前主要有openssh putty SSH Secure Shell Client等 安装了SSH2
  • 基础配置Tomcat及使用

    配置Tomcat 背景简介 目前很多网站由java编写 所以解析Java程序需要有相关的软件来编写完成 Tomcat是其中之一 Tomcat技术先进 性能稳定且免费 是目前比较流行的web应用服务器 Tomcat是一个轻量化级应用服务器 实
  • jaspersoft studio动态图片传输

    业务需求简述 在实际业务开发中需要动态生成PDF 其中包含客户签字图片 技术栈 JasperReport Jaspersoft Studio软件 动态图片传输流程 jaspersoft studio 拖入image到工作区 选择最后一项点击
  • java爬取人人网数据

    通过httpclient何httpparser两个类爬人人网中得数据 其中的详细步骤以及文档下面详细介绍 爬人人网相关代码 SuppressWarnings deprecation public class RenRen 输入用户名及密码
  • 1.6 起步 - 初次运行 Git 前的配置

    1 6 起步 初次运行 Git 前的配置 版本说明 版本 作者 日期 备注 0 1 loon 2019 3 19 初稿 目录 文章目录 1 6 起步 初次运行 Git 前的配置 版本说明 目录 初次运行 Git 前的配置 1 用户信息 2
  • 怀旧服服务器荣誉系统是啥,魔兽世界怀旧服:荣誉系统要开了?大元帅吸引人,军衔要不要冲?...

    魔兽世界怀旧服马上开荣誉系统了 相信很多pvp玩家都会有冲军衔的目标 因为军衔是实力和荣誉的象征 有玩家清晰记得 到了R13更新那天 跟另外两个元帅法师 一共3个元帅套 站在铁炉堡银行门口 围观的人超多 那种自豪和成就感难以言喻 大家知道大

随机推荐

  • echarts 饼图的指示线(labelline) 问题

    数据过多 且几个比较小的数据在一块扎堆 series name type pie center 25 50 radius 45 60 minAngle 10 设置每块扇形的最小占比 avoidLabelOverlap false hover
  • linux一次性创建多个文件/文件夹

    1 创建多个文件 touch file 1 10 注 创建10个文件 文件名file0 file1 file10 2 创建多个目录 mkdir folder 1 10 注 一次性创建10个文件夹 目录名为folder1 folder2 fo
  • 【Hive报错】Hive报错Expression Not In Group By Key解决方法

    SQL例如以下会报错 select sum time as time roadCoding upstreamOrDownstream from historicalroaddata where 报以下roadcoding upstreamO
  • 10个实用的Python数据可视化图表总结

    可视化是一种方便的观察数据的方式 可以一目了然地了解数据块 我们经常使用柱状图 直方图 饼图 箱图 热图 散点图 线状图等 这些典型的图对于数据可视化是必不可少的 除了这些被广泛使用的图表外 还有许多很好的却很少被使用的可视化方法 这些图有
  • 各向异性(anisotropic)浅提

    文章目录 各向异性 anisotropic 定义 哪种物体具有各向异性反射 什么导致各向异性反射 总结 各向异性 anisotropic 定义 它指一种存在方向依赖性 这意味着在不同的方向不同的特性 相对于该属性各向同性 当沿不同轴测量时
  • [Anaconda]——Linux下conda虚拟环境缺“msvcrt”

    问题 这里是在使用不同节点的系统时 一个是普通的节点 一个是GPU节点 在普通节点下准备好了所有的环境 使用Linux的NIS功能 利用网络把硬盘挂载到不同的节点 这个时候普通节点和GPU节点就做到了数据同步 但是发现在使用conda虚拟环
  • ThinkPHP6 框架 对接 ChatGPT应用

    ThinkPHP6是一款优秀的PHP开发框架 它提供了丰富的功能和易于使用的API 使得开发人员可以快速构建高质量的Web应用程序 本文将介绍如何使用ThinkPHP6框架对接ChatGPT应用 实现智能聊天机器人的功能 首先 我们需要在T
  • 部署mac os渗透测试环境

    一 序言 每次重装系统后配置环境都是需要耗费大量时间 特此写一篇mac os部署渗透测试环境 二 过程 一 系统设置 1 常用设置 SSD 开启 TRIM 支持 sudo trimforce enable APP安装开启任何来源 sudo
  • Lua coroutine.create

    Lua coroutine creat 相当于在C 中使用lua newthread Equivalent of Lua coroutine create in C using lua newthread 问 题 I have a call
  • 基于Qt的OpenGL编程(3.x以上GLSL可编程管线版)---(二十)面剔除

    Vries的教程是我看过的最好的可编程管线OpenGL教程 没有之一 其原地址如下 https learnopengl cn github io 04 20Advanced 20OpenGL 04 20Face 20culling 关于面剔
  • # HTB-Tier2- Vaccine

    HTB Tier2 Vaccine Web Network Vulnerability Assessment Databases Injection Custom Applications Protocols Source Code Ana
  • 毕业设计 基于Arduino的计算器

    0 前言 这两年开始毕业设计和毕业答辩的要求和难度不断提升 传统的毕设题目缺少创新和亮点 往往达不到毕业答辩的要求 这两年不断有学弟学妹告诉学长自己做的项目系统达不到老师的要求 为了大家能够顺利以及最少的精力通过毕设 学长分享优质毕业设计项
  • 暑期实训日志11——webstorm+chrome实时浏览插件

    在网上看到一个webstorm chrome里JetBrains IDE Support能够实现实时浏览的小工具 感觉非常实用 一 JetBrains IDE Support下载 下载地址 直接从谷歌商店下载也可 前提是进得去 下载好后打开
  • sqli-labs(28-28a)

    Less 28 1 测试http 127 0 0 1 sqli labs Less 28 id 1 27 页面回显不正常 但又没有错误提示 报错注入没戏 尝试闭合语句 加单引号回显不正常 说明sql语句闭合至少有 可能有 判断有无 在Les
  • log4cplus基础知识

    一 简介 log4cplus是C 编写的开源的日志系统 具有线程安全 灵活 以及多粒度控制的特点 通过将信息划分优先级使其可以面向程序调试 运行 测试 和维护等全生命周期 你可以选择将信息输出到屏幕 文件 NT event log 甚至是远
  • SSM controller要能跳转页面又要能返回字符串

    SpringMVC因为添加了下面这个bean 视图解析器 当你方法返回的是 json 字符串等其它值时 会404 跳转 jsp jsp页面
  • 回文数的判断

    文章目录 题目 一 方案一 二 方案二 三 方案三 四 方案四 题目 判断一个整数是否是回文数 回文数是指正序 从左向右 和倒序 从右向左 读都是一样的整数 提示 下面案例可供参考 一 方案一 public boolean palindro
  • 二叉树 深度优先搜索(DFS)、广度优先搜索(BFS)

    深度优先搜索算法 Depth First Search DFS是搜索算法的一种 它沿着树的深度遍历树的节点 尽可能深的搜索树的分支 当节点v的所有边都己被探寻过 搜索将回溯到发现节点v的那条边的起始节点 这一过程一直进行到已发现从源节点可达
  • pytorch Embedding模块,自动为文本加载预训练的embedding

    pytorch 提供了一个简便方法torch nn Embedding from pretrained 可以将文本与预训练的embedding对应起来 词 embedding word1 0 2 3 4 word2 1 2 3 4 word
  • Pytorch(六)(模型参数的遍历) —— model.parameters() & model.named_parameters() & model.state_dict()

    神经网络的模型参数 model parameters model named parameters model state dict 这三个方法都可以查看神经网络的参数信息 用于更新参数 或者用于模型的保存 作用都类似 写法略有出入 就以P