代码部分
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as Data
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import numpy as np
#建立数据
np.random.seed(666)
X = np.linspace(-2, 2, 1000)
y = np.sin(X) + 0.1 * np.random.normal(0, 1, X.size)
# 创建训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=1024)
X_train = torch.from_numpy(X_train).type(torch.FloatTensor)
X_train = torch.unsqueeze(X_train, dim=1) #转换成二维
y_train = torch.from_numpy(y_train).type(torch.FloatTensor)
y_train = torch.unsqueeze(y_train, dim=1)
X_test = torch.from_numpy(X_test).type(torch.FloatTensor)
X_test = torch.unsqueeze(X_test, dim=1) #转换成二维
#设置参数
batchsz=50
LR = 1e-3
epochs = 200
#装载数据
torch_data = Data.TensorDataset(X_train, y_train)
data=Data.DataLoader(dataset=torch_data,batch_size=batchsz,shuffle=True)
#建立自己的线性nn
class Net(nn.Module):
#重载初始化函数
def __init__(self,n_feature,n_hidden,n_output):
super(Net,self).__init__()
self.hidden=nn.Linear(n_feature,n_hidden)
self.predict=nn.Linear(n_hidden,n_output)
#构建前向传播过程
def forward(self,x):
hidden_layer=F.relu(self.hidden(x))
output_layer=self.predict(hidden_layer)
return output_layer
# 建立模型
device = torch.device('cuda')
net=Net(n_feature=1,n_hidden=10,n_output=1).to(device)
# 选择优化器
optimizer = torch.optim.Adam(net.parameters(), lr=LR)
#使用均方误差作为损失函数
loss_func = nn.MSELoss().to(device)
#训练
net.train()
for epoch in range(epochs):
for batchidx,(x_,y_) in enumerate(data):
x_,y_=x_.to(device),y_.to(device)
prediction = net(x_)
loss = loss_func(prediction, y_)
# 反向传递步骤
# 1、初始化梯度
optimizer.zero_grad()
# 2、计算梯度
loss.backward()
# 3、进行optimizer优化
optimizer.step()
if epoch % 10 == 0:
print('epoch {}: loss = {}'
.format(epoch, loss.item()))
#测试
net.eval()
X_test=X_test.to(device)
predict = net(X_test)
predict = predict.data.cpu().numpy()#cuda tensor 需要先转换为cpu
plt.scatter(X_test.cpu().numpy(), y_test, label='origin')
plt.scatter(X_test.cpu().numpy(), predict, color='red', label='predict')
plt.legend()
plt.show()
运行结果