pytorch实现简易回归问题

2023-10-26

代码部分

import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as Data
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import numpy as np

#建立数据
np.random.seed(666)
X = np.linspace(-2, 2, 1000)
y = np.sin(X) + 0.1 * np.random.normal(0, 1, X.size)

# 创建训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=1024)

X_train = torch.from_numpy(X_train).type(torch.FloatTensor)
X_train = torch.unsqueeze(X_train, dim=1)  #转换成二维
y_train = torch.from_numpy(y_train).type(torch.FloatTensor)
y_train = torch.unsqueeze(y_train, dim=1)

X_test = torch.from_numpy(X_test).type(torch.FloatTensor)
X_test = torch.unsqueeze(X_test, dim=1)  #转换成二维

#设置参数
batchsz=50
LR = 1e-3
epochs = 200 

#装载数据
torch_data  = Data.TensorDataset(X_train, y_train)
data=Data.DataLoader(dataset=torch_data,batch_size=batchsz,shuffle=True)

#建立自己的线性nn
class Net(nn.Module):
    #重载初始化函数
    def __init__(self,n_feature,n_hidden,n_output):
        super(Net,self).__init__()
        self.hidden=nn.Linear(n_feature,n_hidden)
        self.predict=nn.Linear(n_hidden,n_output)

    #构建前向传播过程
    def forward(self,x):
        hidden_layer=F.relu(self.hidden(x))
        output_layer=self.predict(hidden_layer)
        return output_layer

# 建立模型
device = torch.device('cuda')
net=Net(n_feature=1,n_hidden=10,n_output=1).to(device)

# 选择优化器
optimizer = torch.optim.Adam(net.parameters(), lr=LR)

#使用均方误差作为损失函数
loss_func = nn.MSELoss().to(device)

#训练
net.train()
for epoch in range(epochs):
    for batchidx,(x_,y_) in enumerate(data):
        x_,y_=x_.to(device),y_.to(device)
        prediction = net(x_)
        loss = loss_func(prediction, y_)
        # 反向传递步骤
        # 1、初始化梯度
        optimizer.zero_grad()
        # 2、计算梯度
        loss.backward()
        # 3、进行optimizer优化
        optimizer.step()
    if epoch % 10 == 0:
        print('epoch {}: loss = {}'
        .format(epoch, loss.item()))

#测试
net.eval()
X_test=X_test.to(device)
predict = net(X_test)
predict = predict.data.cpu().numpy()#cuda tensor 需要先转换为cpu
plt.scatter(X_test.cpu().numpy(), y_test, label='origin')
plt.scatter(X_test.cpu().numpy(), predict, color='red', label='predict')
plt.legend()
plt.show()

运行结果

 

 

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

pytorch实现简易回归问题 的相关文章

随机推荐

  • Javascript显示隐藏DIV

    1 创建一个showhidediv的方法 直接跟ID属性调用 2 HTML页面结构 a a
  • HTML-24:input表单元素

  • 基于python的入侵检测系统毕设_基于时空特征融合的入侵检测系统模型

    期刊 COMPUTERS SECURITY 期刊信息 JCR分区Q1 中科院分区2区 引用因子4 85 摘要 入侵检测系统可以通过分析网络流量的特征来区分正常流量和攻击流量 近年来 神经网络在自然语言处理 计算机视觉 入侵检测等领域得到了发
  • easyui datagrid columnMoving 列移动

    demo 页面 url columnMoving https github com wwwpalmercom jQuery duplicate tree master easyui datagrid columnMoving
  • c++整型与二进制的相互转化

    include
  • 打印100-200之间的所有素数。 (C语言)

    分析 判断是否是素数 1 能被1和自身整除的数是素数 2 试除法 代码 include
  • 服务器选择多大的带宽比较合适,如果遇到攻击了该怎么办

    首先要了解带宽是什么 对于服务器来说带宽分两种 上行带宽和下行带宽 上行带宽 是上传数据的速度 用户要获取到服务器程序搭建里面的内容就需要用到上传带宽 就好比搭建了网站别人要加载内容 服务器首先就要先把内容上传到用户的本地上 而这个步骤就需
  • 杰卡德系数(Jaccard Index)

    杰卡德系数 Jaccard Index 杰卡德系数 又称为杰卡德相似系数 用于比较两个样本之间的差异性和相似性 杰卡德系数越高 则两个样本相似度越高 定义 有两个集合A和B 那么这两个集合的杰卡德系数为A和B的交集除以A和B的并集 当集合A
  • VMware虚拟机启动错误(正在被占用、内部错误)等问题

    参考 VMware虚拟机启动错误 正在被占用 内部错误 等问题 作者 扫地僧 发布时间 2019 09 20 15 02 55 网址 https blog csdn net weixin 42119153 article details 1
  • Spring cache 注解详解

    spring cache注解的使用 CacheConfig 类级别的缓存注解 允许共享缓存名称 Caching 将多种缓存操作分组 Cacheable 触发缓存入口 CacahePut 更新缓存 CacheEvict 触发移除缓存 最常用的
  • 浅谈vscode以及解决官网下载速度慢的问题

    浅谈vscode VSCode 全称 Visual Studio Code 是一款由微软开发且跨平台的免费源代码编辑器 该软件支持语法高亮 代码自动补全 又称 IntelliSense 代码重构 查看定义功能 并且内置了命令行工具和 Git
  • DB-Engines 2017年8月数据库排名发布 总体走势复归平稳

    近日 DB Engines发布了2017年8月数据库排名 数据库排行 经历过此前一系列暴跌暴涨 8月数据库得分走势渐趋平缓 前二十名涨跌幅皆控制在十分以内 前十席位本月无变动 为首的三巨头自不必提 Oracle MySQL Microsof
  • Impala常见错误

    1 尽量少使用 invalidate metadata 尽量用REFRESH TABLE NAME 2 set APPX COUNT DISTINCT true 与 ndv 函数是一样的 都只是估值 Impala SQL 不支持的一个查询中
  • 计算机组成原理——存储系统の选择题整理

    存储器概述 1 存储器存取周期是指 A 存储器的读出时间 B 存储器的写入时间 C 存储器进行连续读或写操作所允许的最短时间间隔 D 存储器进行一次读或写操作所需的平均时间 解析 选C 存取周期是存储器进行连续读或写操作所允许的最短时间间隔
  • LeetCode 最热 100 题, 搜索旋转排序数组,search in rotated sorted array

    作者 Linux猿 简介 CSDN博客专家 华为云享专家 Linux C C 面试 刷题 算法尽管咨询我 关注我 有问题私聊 关注专栏 LeetCode面试必备100题 优质好文持续更新中 欢迎小伙伴们点赞 收藏 留言 目录 一 题目描述
  • 自然语言编程的尝试

    班上有30个学生 甲叫肖鹤云 乙叫李诗情 显示甲 乙的名字 这段代码明显不能运行 需要做一些修改 分配30个学生类至班 甲为班 0 乙为班 1 甲的名字为肖鹤云 乙的名字为李诗情 显示甲的名字 乙的名字 进一步转化 学生类 班 学生类 ma
  • 杨辉三角c语言实现

    在屏幕上打印杨辉三角 include
  • Spring Boot实战.Spring Boot核心原理剖析

    在上节中我们通过了一个小的入门案例已经看到了Spring Boot的强大和简单之处 本章将详细介绍Spring Boot的核心注解 基本配置和运行机制 笔者一直认为 精通一个技术一定要深入了解这个技术帮助我们做了哪些动作 深入理解它底层的运
  • VMware Workstation(虚拟机)安装英文版XP系统

    因需要写英文文档 里面的截图也要求全英文 所以打算在在原有的XP系统上安装一虚拟机 再在虚拟机里安装英文版XP系统 在此记录一下自己的安装过程 虚拟机的安装过程在此略过 首先要下载英文版XP操作系统iso镜像 本人下载网址 http www
  • pytorch实现简易回归问题

    代码部分 import torch from torch autograd import Variable import torch nn as nn import torch nn functional as F import torch