用法比较简单,不过容易忘,记录一下。假设已定义好模型,名为model。
查看模型结构:
>>> print(model)
查看网络参数:
for name, parameters in model.named_parameters():
print(name, ':', parameters.size())
来个具体的简单的例子:
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 3)
def forward(self, x):
x = self.conv1(x)
return x
model = Net()
print(model)
print()
for name, parameters in model.named_parameters():
print(name, ':', parameters.size())
输出:
Net(
(conv1): Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))
)
conv1.weight : torch.Size([6, 1, 3, 3])
conv1.bias : torch.Size([6])
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)