pytorch 深度学习入门代码 (三)Logistic 回归代码实现

2023-05-16

"""Logistic 回归的代码实现"""

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np


class LogisticRegression(nn.Module):
    def __init__(self):
        super(LogisticRegression, self).__init__()
        self.lr = nn.Linear(2, 1)
        self.sm = nn.Sigmoid()

    def forward(self, x):
        x = self.lr(x)
        x = self.sm(x)
        return x


if __name__ == '__main__':
    with open('data.txt', 'r', encoding='utf8') as f:
        data_list = f.readlines()
        data_list = [i.split('\n')[0] for i in data_list]
        data_list = [i.split(',') for i in data_list]
        data = [(float(i[0]), float(i[1]), float(i[2])) for i in data_list]
        data = torch.Tensor(data)

    logistic_model = LogisticRegression()
    if torch.cuda.is_available():
        logistic_model.cuda()

    criterion = nn.BCELoss()
    optimizer = torch.optim.SGD(logistic_model.parameters(), lr=1e-3, momentum=0.9)

    for epoch in range(10000):
        if torch.cuda.is_available():
            x = Variable(data[:, 0:2]).cuda()
            y = Variable(data[:, 2]).cuda().unsqueeze(1)
        else:
            x = Variable(data[:, 0:2])
            y = Variable(data[:, 2]).unsqueeze(1)
        # forward
        out = logistic_model(x)
        loss = criterion(out, y)
        print_loss = loss.data.item()
        mask = out.ge(0.5).float()
        correct = (mask == y).sum()
        acc = correct.item() / x.size(0)
        # backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if (epoch + 1) % 1000 == 0:
            print('*' * 10)
            print('epoch {}'.format(epoch + 1))
            print('loss is {:.4f}'.format(print_loss))
            print('acc is {:.4f}'.format(acc))


    w0, w1 = logistic_model.lr.weight[0]
    w0 = w0.item()
    w1 = w1.item()
    b = logistic_model.lr.bias.item()
    plot_x = np.arange(30, 100, 0.1)
    plot_y = (-w0 * plot_x - b) / w1
    plt.plot(plot_x, plot_y)

    x0 = list(filter(lambda x: x[-1] == 0.0, data))
    x1 = list(filter(lambda x: x[-1] == 1.0, data))
    plot_x0_0 = [i[0] for i in x0]
    plot_x0_1 = [i[1] for i in x0]
    plot_x1_0 = [i[0] for i in x1]
    plot_x1_1 = [i[1] for i in x1]

    plt.plot(plot_x0_0, plot_x0_1, 'ro', label='x_0')
    plt.plot(plot_x1_0, plot_x1_1, 'bo', label='x_1')
    plt.legend()
    plt.show()

少量测试数据集

34.62365962451697,78.0246928153624,0
30.2867107622687,43.89499752400101,0
35.84740876993872,72.90219802708364,0
60.18259938620976,86.3855209546826,1
79.0327360507101,75.3443764369103,1
45.08327747668339,56.3163717815305,0
61.10666453684766,96.51142588489624,1
75.02474556738889,46.55401354116538,1
76.09878670226257,87.42056971926803,1
84.43281996120035,43.53339331072109,1
95.86155507093572,38.22527805795094,0
75.01365838958247,30.60326323428011,0
82.30705337399482,76.48196330235604,1
69.36458875970939,97.71869196188608,1
39.53833914367223,76.03681085115882,0
53.9710521485623,89.20735013750265,1
69.07014406283025,52.74046973016765,1
67.9468554771161746,67.857410673128,0

这里写图片描述

附100个训练数据集
https://download.csdn.net/download/jgt_insect/10575505

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

pytorch 深度学习入门代码 (三)Logistic 回归代码实现 的相关文章

  • (原创)解决APP进程被杀死出现的一些坑

    这几天在开发中遇到了这样一个问题 当打开的App数量达到一定数目时 原先的App进程会被系统杀死 然后再次进入这个被杀死的App时 发现了一些异常 经过排查 xff0c 是viewpager的getitem方法没有被调用 通过网上的一些信息
  • 8位字节对齐算法

    参考文章 8bit对齐算法 void testAlign for int i 61 0 i lt 61 10 i 43 43 int align 61 align8Bit i printf 34 the align is d n 34 al
  • Ubuntu22.04安装libudev-dev时的Bug

    新安装了Ubuntu22 04 xff0c 然后安装libudev dev xff1a sudo apt install libudev dev 发现了非常奇怪的事情 xff1a 正在读取软件包列表 完成 正在分析软件包的依赖关系树 完成
  • ubuntu服务器安装python3

    输入python查看python2是否安装 输入python3查看python3是否安装 一般都会安装python2和3 xff0c 系统服务也会调用python xff0c 所以没事不要乱卸载 sudo apt autoremove py
  • 个人面试经历经验谈

    到昨天接到金蝶得Offer xff0c 我想我为期三个星期的找工作面试之旅应该是告一段落了 原以为接到Offer会有点高兴 xff0c 但是一回味这三个星期的起起落落 xff0c 便实在是高兴不起来 xff0c 虽然手上有好几个Offer可
  • mysql 运行sql报错1118 - Row size too large (> 8126). Changing some columns to TEXT or BLOB

    innodb file per table 61 1 innodb file format 61 Barracuda innodb file format check 61 ON innodb log file size 61 512M i
  • Git 操作源地址(查看 添加 修改 删除)

    查看源地址 git remote v 修改源地址 git remote set url origin git地址 添加源地址 git remote add NAME GIT URL NAME 为新的Git库源地址名 xff0c GIT UR
  • [CentOS] 四、安装 ranger

    四 安装 ranger 作者 xff1a 解琛 时间 xff1a 2020 年 9 月 15 日 ranger ranger 安装Nerd Fonts字体 span class token function git span clone h
  • java实习两个月总结

    实习两个月总结 刚开始实习的时候激情满满 慢慢的激情也退却了 在杭州月薪3000干了两个月我自己都觉得不可思议 杭州的物价大家有目共睹 先谈谈收获 认识了java8的新特性 了解了开发中常用的工具和工具包 持续集成部署的jenkins sw
  • chrome各种版本下载地址:

    Download older versions of Google Chrome for Windows Linux and Mac Download older versions of Google Chrome for Windows
  • idea 注入mapper报错报红的几种解决方案

    方法1 xff1a 为 64 Autowired 注解设置required 61 false 使用 64 Autowired 注解时 xff0c 若希望允许null值 xff0c 可设置required 61 false 像这样 xff1a
  • 当node遇上Egg遇上TypeScript

    快速入门 通过骨架快速初始化 xff1a npx egg init type 61 ts showcase cd showcase amp amp npm i npm run dev 上述骨架会生成一个极简版的示例 xff0c 更完整的示例
  • fastjson 导致 swagger 页面无法显示

    问题 xff1a 增加swagger后 xff0c 无法访问 http localhost 8080 swagger ui index html xff0c 去除fastjson配置后确可以访问 相关配置信息 xff1a lt fastjs
  • springboot发送HTTP请求

    1 添加依赖 使用RestTemplate进行发送请求 xff0c 添加相关依赖 lt 发送请求的依赖 gt lt dependency gt lt groupId gt org apache httpcomponents lt group
  • 启动docker容器一致提示端口被占用,即使是已经删除相关端口的进程

    1 重启docker服务 systemctl restart docker 2 启动对应的docker容器 docker start tomcat 3 如果提示端口已被占用 xff0c 则查看占用进程并杀死 netstat ntulp gr
  • windows上的IDEA连接Docker

    docker中勾选 查看连接
  • MySQL将字段的值进行拼接

    应用场景 xff1a 1 同张表分组时将某个字段的值进行拼接 将学生按班级分组 xff0c 对同一班级的学生姓名进行拼接 SELECT classid group concat stu name as stu names FROM stud
  • Maven异常:Could not find artifact

    Maven异常 xff1a Could not find artifact 执行maven install的时候出现了以下异常 xff1a INFO Scanning for projects ERROR ERROR Some proble
  • git获取最新的tag

    获取git最新的tag标签 git tag n sort 61 taggerdate head n 1 这样获取到的会有合并信息 如果只需要获取到tag名称 xff0c 也可以这样 git tag sort 61 taggerdate he
  • 使用阿里云的函数计算来实现OSS资源的打包下载

    文档地址 xff1a 如何使用函数计算将多个文件打包下载到本地 对象存储 OSS 阿里云 计算函数可以通过对外公网域名进行访问 xff1a 计算函数的参数有几个 xff1a bucket xff1a 使用的OSS的bucket xff0c

随机推荐

  • linux系统上nodejs 报错:node: /lib/x86_64-linux-gnu/libm.so.6: version `GLIBC_2.27‘ not found

    原因 xff1a 因为当前系统不支持GLIBC 2 27 xff0c 而且node的版本过高 xff0c 但是后来降低了版本还是报这个错误 xff0c 后来发现低版本的软链接在 usr bin xff0c 而高版本的软链接在 usr loc
  • 使用nvm控制nodejs版本

    原因 xff1a 由于项目需要用到两个版本的nodejs xff0c 如果只是一个版本的nodejs的话 xff0c 其中一个项目就会报错 xff0c 所以需要用到nvm进行nodejs版本控制 xff0c 使用不同版本的nodejs来进行
  • opencv 实战案例 (一)

    目录 xff1a 1 用 Canny 算子检测图像轮廓提取车道线任务 xff08 Canny xff09 2 用 findContours 发现硬币轮廓任务 Canny 43 findContours 3 用概率霍夫变换检测车道线任务 Ca
  • 企业微信-构造网页授权链接实现登录

    文档地址 xff1a 构造网页授权链接 接口文档 企业微信开发者中心 注意 xff1a 1 redirect uri xff1a 回调链接地址 xff0c 需要使用urlencode对链接进行处理 2 scope xff1a 如果需要获取成
  • dpkg: 处理软件包 xxx (--configure)时出错解决方法

    问题 xff1a dpkg 处理软件包 libicu dev configure 时出错 xff1a 依赖关系问题 仍未被配置 dpkg 依赖关系问题使得 libxml2 dev amd64 的配置工作不能继续 xff1a libxml2
  • oracle 删除表以及回复数据

    找回删除的表 select object name original name partition name type ts name createtime droptime from recyclebin WHERE original n
  • 银行卡信息查询

    银行卡bin 银行卡信息 请移步到github xff1a https github com burningmyself bank
  • ProcessDefinition是干这个用的

    流程定义ProcessDefinition是对业务过程的完整描述 xff0c 例如请假流程定义 报销流程定义等 流程定义的管理包括部署流程定义 查询流程定义 查看流程定义图和删除流程定义 1 部署流程定义 使用RepositoryServi
  • 截取字符串的三种方法

    众所周知 xff0c java提供了很多字符串截取的方式 下面就来看看大致有几种 span class hljs number 1 span span class hljs built in split span 43 正则表达式来进行截取
  • Iterator主要有三个方法:hasNext()、next()、remove()详解

    一 Iterator的API 关于Iterator主要有三个方法 xff1a hasNext next remove hasNext 没有指针下移操作 xff0c 只是判断是否存在下一个元素 next xff1a 指针下移 xff0c 返回
  • @ModelAttribute用法详解

    转载于 xff1a https blog csdn net harry zh wang article details 57329613 之前项目中并自己并没有怎么使用到过 64 ModelAttribute这个注解 xff0c 接手一个老
  • mysql除法运算保留小数的用法

    参照 xff1a https www cnblogs com owenma p 7097602 html 在工作中会遇到计算小数而且需要显现出小数末尾的0 xff0c 我们会用到DECIMAL这个函数 xff0c 这是一个函数非常强悍 xf
  • IDEA—点击文件代码与目录自动同步对应

    关注微信公众号 xff1a CodingTechWork xff0c 一起学习进步 引言 在使用IDEA的时候 xff0c 我们Ctrl 43 Shift 43 F搜索文件后 xff0c 总是要慢慢找文件在哪个包路径下 如查看路径顶端 xf
  • springboot打包完成之后无法读取到resources下的资源文件

    File privateKeyFile 61 ResourceUtils getFile classpath wx pfx PrivateKey privateKey 61 getPrivateKey privateKeyFile priv
  • 接口签名实现拦截的两种方式

    1 采用spring的aop思想进行拦截 需要自定义注解 xff0c 然后定义切面 xff08 五大类 xff09 然后在定义 xff0c 可以获取所有的参数 2 拦截器的实现方式 自定义拦截器 xff0c 然后对拦截器进行配置即可 配置
  • Java程序员利器,lombok神搭档:delombok插件

    Lombok是一款非常实用Java工具 xff0c 它可以帮助开发人员减少样板代码 xff0c 使开发人员专注业务逻辑 xff0c 在Java界几乎无人不知 但也有一些明显的缺点 xff0c 例如 xff1a 对插件强依赖 xff0c 在团
  • C++bind函数

    1 基本概念 bind函数定义在头文件 functional 中 可以将 bind 函数看作一个通用的函数适配器 xff0c 它接受一个可调用对象 xff0c 生成一个新的可调用对象来 适应 原对象的参数列表 C 43 43 Primer
  • C++值的分类 —— 摘自维基百科

    在C 43 43 11 xff0c 对于值的分类 xff0c 要考虑标识 xff08 identity xff09 与可移动性 xff08 movability xff09 xff0c 二者的组合产生了五种分类 xff1a 基础值类型 左值
  • pytorch 深度学习入门代码 (一)线性回归代码实现

    34 34 34 一维线性回归代码实现 34 34 34 import torch from torch autograd import Variable import matplotlib pyplot as plt import tor
  • pytorch 深度学习入门代码 (三)Logistic 回归代码实现

    span class hljs string 34 34 34 Logistic 回归的代码实现 34 34 34 span span class hljs keyword import span matplotlib pyplot spa