class LeNet5(nn.Module):
def __init__(self):
super(LeNet5, self).__init__()
# 定义卷积层
self.conv = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=6, padding=2, kernel_size=5, stride=1),
# 24*24*6,图片大小变为 28+2*2 = 32 (两边各加2列0),保证输入输出尺寸相同
nn.Sigmoid(),
nn.MaxPool2d(kernel_size=2, stride=2), # 14*14*6
nn.Conv2d(6, 16, 5), # in_channels, out_channels, kernel_size 8*8*6
nn.Sigmoid(),
nn.MaxPool2d(2, 2) # kernel_size, stride, # 5x5x16
)
self.fc = nn.Sequential(
nn.Linear(in_features=5 * 5 * 16, out_features=120),
nn.Sigmoid(),
nn.Linear(120, 84),
nn.Linear(84, 10)
)
# 定义前项传播
def forward(self, img):
feature = self.conv(img)
# 全连接层均使用的nn.Linear()线性结构,输入输出维度均为一维,故需要把数据拉为一维
output = self.fc(feature.view(img.shape[0], -1))
return output
...
outputs=net(inputs)
...
语句指向outputs=net(inputs) 很有可能就是网络的问题,当我回去看我写的LeNet5这个网络模型时发现forward()函数名是灰色的,最后发现是排列的位置错了,前项传播留在了前头所写的继承类方法中,将forward()函数级别提前,就可以了
级别提前后的forward()
一般出现 raise NotImplementedError 的错误的时候,都是子类没有重写父类中的成员成员函数,然后子类对象调用该函数时,会提示这个错误!