神经网络的模型参数
model.parameters(), model.named_parameters(), model.state_dict() 这三个方法都可以查看神经网络的参数信息,用于更新参数,或者用于模型的保存。作用都类似,写法略有出入
就以Pytorch之经典神经网络(一) —— 全连接网络(MNIST) 来举例 Pytorch之经典神经网络CNN(一) —— 全连接网络 / MLP (MNIST) (trainset和Dataloader & batch training & learning_rate)_hxxjxw的博客-CSDN博客
print(*[name for name, _ in self.model.named_parameters()], sep='\n')
print(*set([name.split('.')[0] for name, _ in self.named_parameters()]), sep='\n')
查看网络模型参数是否可训练
print(*[_.requires_grad for name, _ in model.named_parameters()], sep='\n')
model.named_parameters()
net.named_parameters()中param是len为2的tuple
param[0]是name,fc1.weight、fc1.bias等
param[1]是fc1.weight、fc1.bias等对应的值
一直是0,1,2,......, 这种序号
for _,param in enumerate(net.named_parameters()):
print(param[0])
print(param[1])
print('----------------')
![](https://img-blog.csdnimg.cn/20210206214225505.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L2h4eGp4dw==,size_16,color_FFFFFF,t_70)
model.parameters()
net.parameters()中param就是fc1.weight、fc1.bias等对应的值,没带名字
for _,param in enumerate(net.parameters()):
print(param)
print('----------------')
![](https://img-blog.csdnimg.cn/20210206214729426.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L2h4eGp4dw==,size_16,color_FFFFFF,t_70)
model.state_dict()
net.state_dict() 中的param就只是str字符串 fc1.weight, fc1.bias等等
但它们可以作为参数来输出对应的值
for _,param in enumerate(net.state_dict()):
print(param)
print(net.state_dict()[param])
print('----------------')
![](https://img-blog.csdnimg.cn/20210206215009734.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L2h4eGp4dw==,size_16,color_FFFFFF,t_70)
神经网络的各个层
当神经网络是这么定义的时候,即没有用nn.Sequential()
![](https://img-blog.csdnimg.cn/20210206220007400.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L2h4eGp4dw==,size_16,color_FFFFFF,t_70)
此时 print(net)
net = Net()
print(net)
![](https://img-blog.csdnimg.cn/20210206220115918.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L2h4eGp4dw==,size_16,color_FFFFFF,t_70)
输出单个的网络层
net = Net()
print(net.fc1)
print(net.fc2)
print(net.fc3)
![](https://img-blog.csdnimg.cn/20210206220229951.png)
输出各个网络层的weight,bias参数
net = Net()
print(net.fc1.weight)
print(net.fc1.bias)
print(net.fc2.weight)
print(net.fc2.bias)
print(net.fc3.weight)
print(net.fc3.bias)
![](https://img-blog.csdnimg.cn/20210206220353810.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L2h4eGp4dw==,size_16,color_FFFFFF,t_70)
当使用nn.Sequential定义的时候
import torch
import torchvision
from torchvision import transforms
from matplotlib import pyplot as plt
from torch import nn
from torch.nn import functional as F
from utils import plot_image,plot_curve,one_hot
# class Net(nn.Module):
# def __init__(self):
# super(Net, self).__init__()
#
# #三层全连接层
# #wx+b
# self.fc1 = nn.Linear(28*28, 256)
# self.fc2 = nn.Linear(256,64)
# self.fc3 = nn.Linear(64,10)
#
# def forward(self, x):
# x = F.rule(self.fc1(x)) #F.relu和torch.relu,用哪个都行
# x = F.relu(self.fc2(x))
# x = F.relu(self.fc(3))
#
# return x
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc = nn.Sequential(
nn.Linear(28 * 28, 256),
nn.ReLU(),
nn.Linear(256, 64),
nn.ReLU(),
nn.Linear(64, 10)
)
def forward(self, x):
# x: [b, 1, 28, 28]
# h1 = relu(xw1+b1)
x = self.fc(x)
return x
batch_size = 512
#一次处理的图片的数量
#gpu一次可以处理并行多张图片
transform = transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,))
])
trainset = torchvision.datasets.MNIST(
root='dataset/',
train=True, #如果为True,从 training.pt 创建数据,否则从 test.pt 创建数据。
download=True, #如果为true,则从 Internet 下载数据集并将其放在根目录中。 如果已下载数据集,则不会再次下载。
transform=transform
)
#train=True表示是训练数据,train=False是测试数据
train_loader = torch.utils.data.DataLoader(
dataset=trainset,
batch_size=batch_size,
shuffle=True #在加载的时候将图片随机打散
)
testset = torchvision.datasets.MNIST(
root='dataset/',
train=False,
download=True,
transform=transform
)
train_loader = torch.utils.data.DataLoader(
dataset=testset,
batch_size=batch_size,
shuffle=True
)
net = Net()
print(net.fc)
print(net.fc[0])
print(net.fc[1])
print(net.fc[2])
print(net.fc[3])
print(net.fc[4])
print(net.fc[0].weight)
print(net.fc[0].bias)
![](https://img-blog.csdnimg.cn/20210206222332921.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L2h4eGp4dw==,size_16,color_FFFFFF,t_70)