python机器学习之支持向量机——线性SVM决策过程的可视化案例

2023-11-17

线性SVM决策过程的可视化

1、导入需要的模块

from sklearn.datasets import make_blobs
from sklearn.svm import SVC
import matplotlib.pyplot as plt
import numpy as np

2、实例化数据集,可视化数据集

X,y = make_blobs(n_samples=50, centers=2, random_state=0,cluster_std=0.6)
plt.scatter(X[:,0],X[:,1],c=y,s=50,cmap="rainbow")#rainbow彩虹色
plt.xticks([])
plt.yticks([])
plt.show()

在这里插入图片描述

3、画决策边界:理解函数contour

matplotlib.axes.Axes.contour([X, Y,] Z, [levels], **kwargs)

在这里插入图片描述

#首先要有散点图
plt.scatter(X[:,0],X[:,1],c=y,s=50,cmap="rainbow")
ax= plt.gca() #获取当前的子图,如果不存在,则创建新的子图

在这里插入图片描述

4、 画决策边界:制作网格

#获取平面上两条坐标轴的最大值和最小值
xlim = ax.get_xlim()
ylim = ax.get_ylim()
 
#在最大值和最小值之间形成30个规律的数据
axisx = np.linspace(xlim[0],xlim[1],30)
axisy = np.linspace(ylim[0],ylim[1],30)
 
axisy,axisx = np.meshgrid(axisy,axisx)
#我们将使用这里形成的二维数组作为我们contour函数中的X和Y
#使用meshgrid函数将两个一维向量转换为特征矩阵
#核心是将两个特征向量广播,以便获取y.shape * x.shape这么多个坐标点的横坐标和纵坐标
 
xy = np.vstack([axisx.ravel(), axisy.ravel()]).T
#其中ravel()是降维函数,vstack能够将多个结构一致的一维数组按行堆叠起来
#xy就是已经形成的网格,它是遍布在整个画布上的密集的点
 
plt.scatter(xy[:,0],xy[:,1],s=1,cmap="rainbow")
 
#理解函数meshgrid和vstack的作用
a = np.array([1,2,3])
b = np.array([7,8])
#两两组合,会得到多少个坐标?
#答案是6个,分别是 (1,7),(2,7),(3,7),(1,8),(2,8),(3,8)
 
v1,v2 = np.meshgrid(a,b)
 
v1
 
v2
 
v = np.vstack([v1.ravel(), v2.ravel()]).T

在这里插入图片描述
在这里插入图片描述

5、建模,计算决策边界并找出网格上每个点到决策边界的距离

#建模,通过fit计算出对应的决策边界
clf = SVC(kernel = "linear").fit(X,y)#计算出对应的决策边界
Z = clf.decision_function(xy).reshape(axisx.shape)
#重要接口decision_function,返回每个输入的样本所对应的到决策边界的距离
#然后再将这个距离转换为axisx的结构,这是由于画图的函数contour要求Z的结构必须与X和Y保持一致

#首先要有散点图
plt.scatter(X[:,0],X[:,1],c=y,s=50,cmap="rainbow")
ax = plt.gca() #获取当前的子图,如果不存在,则创建新的子图
#画决策边界和平行于决策边界的超平面
ax.contour(axisx,axisy,Z
           ,colors="k"
           ,levels=[-1,0,1] #画三条等高线,分别是Z为-1,Z为0和Z为1的三条线
           ,alpha=0.5#透明度
           ,linestyles=["--","-","--"])
 
ax.set_xlim(xlim)#设置x轴取值
ax.set_ylim(ylim)

在这里插入图片描述

#Z的本质是输入的样本到决策边界的距离,而contour函数中的level其实是输入了这个距离
#让我们用一个点来试试看
plt.scatter(X[:,0],X[:,1],c=y,s=50,cmap="rainbow")
plt.scatter(X[10,0],X[10,1],c="black",s=50,cmap="rainbow")

在这里插入图片描述

clf.decision_function(X[10].reshape(1,2))
plt.scatter(X[:,0],X[:,1],c=y,s=50,cmap="rainbow")
ax = plt.gca()
ax.contour(axisx,axisy,Z
            ,colors="k"
            ,levels=[-3.33917354]
            ,alpha=0.5
            ,linestyles=["--"])

在这里插入图片描述

6、将绘图过程包装成函数

#将上述过程包装成函数:
def plot_svc_decision_function(model,ax=None):
    if ax is None:
        ax = plt.gca()
    xlim = ax.get_xlim()
    ylim = ax.get_ylim()
    
    x = np.linspace(xlim[0],xlim[1],30)
    y = np.linspace(ylim[0],ylim[1],30)
    Y,X = np.meshgrid(y,x) 
    xy = np.vstack([X.ravel(), Y.ravel()]).T
    P = model.decision_function(xy).reshape(X.shape)
    
    ax.contour(X, Y, P,colors="k",levels=[-1,0,1],alpha=0.5,linestyles=["--","-","--"]) 
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
 
#则整个绘图过程可以写作:
clf = SVC(kernel = "linear").fit(X,y)
plt.scatter(X[:,0],X[:,1],c=y,s=50,cmap="rainbow")
plot_svc_decision_function(clf)

在这里插入图片描述

7、探索建好的模型

clf.predict(X)
#根据决策边界,对X中的样本进行分类,返回的结构为n_samples
 
clf.score(X,y)
#返回给定测试数据和标签的平均准确度
 
clf.support_vectors_
#返回支持向量坐标
 
clf.n_support_#array([2, 1])
#返回每个类中支持向量的个数

在这里插入图片描述

8、推广到非线性情况

from sklearn.datasets import make_circles
X,y = make_circles(100, factor=0.1, noise=.1)
 
X.shape
 
y.shape
 
plt.scatter(X[:,0],X[:,1],c=y,s=50,cmap="rainbow")
plt.show()

在这里插入图片描述
用我们已经定义的函数来划分这个数据的决策边界:

clf = SVC(kernel = "linear").fit(X,y)
plt.scatter(X[:,0],X[:,1],c=y,s=50,cmap="rainbow")
plot_svc_decision_function(clf)
clf.score(X,y)

在这里插入图片描述
明显,现在线性SVM已经不适合于我们的状况了是两种类别。这个时候,如果我们能够在原本的据,来看看添加维度让我们的数据如何变化。

9、为非线性数据增加维度并绘制3D图像

#定义一个由x计算出来的新维度r
r = np.exp(-(X**2).sum(1))
 
rlim = np.linspace(min(r),max(r),100)
 
from mpl_toolkits import mplot3d
 
#定义一个绘制三维图像的函数
#elev表示上下旋转的角度
#azim表示平行旋转的角度
def plot_3D(elev=30,azim=30,X=X,y=y):
    ax = plt.subplot(projection="3d")
    ax.scatter3D(X[:,0],X[:,1],r,c=y,s=50,cmap='rainbow')
    ax.view_init(elev=elev,azim=azim)
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_zlabel("r")
    plt.show()
    
plot_3D()

在这里插入图片描述
可以看见,此时此刻我们的数据明显是线性可分的了:我们可以使用一个平面来将数据完全分开,并使平面的上方的所有数据点为一类,平面下方的所有数据点为另一类。

10、 将上述过程放到Jupyter Notebook中运行

#如果放到jupyter notebook中运行
from sklearn.svm import SVC
import matplotlib.pyplot as plt
import numpy as np
 
from sklearn.datasets import make_circles
X,y = make_circles(100, factor=0.1, noise=.1)
plt.scatter(X[:,0],X[:,1],c=y,s=50,cmap="rainbow")
 
def plot_svc_decision_function(model,ax=None):
    if ax is None:
        ax = plt.gca()
    xlim = ax.get_xlim()
    ylim = ax.get_ylim()
    
    x = np.linspace(xlim[0],xlim[1],30)
    y = np.linspace(ylim[0],ylim[1],30)
    Y,X = np.meshgrid(y,x) 
    xy = np.vstack([X.ravel(), Y.ravel()]).T
    P = model.decision_function(xy).reshape(X.shape)
    
    ax.contour(X, Y, P,colors="k",levels=[-1,0,1],alpha=0.5,linestyles=["--","-","--"])
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
 
clf = SVC(kernel = "linear").fit(X,y)
plt.scatter(X[:,0],X[:,1],c=y,s=50,cmap="rainbow")
plot_svc_decision_function(clf)
 
r = np.exp(-(X**2).sum(1))
 
rlim = np.linspace(min(r),max(r),100)
 
from mpl_toolkits import mplot3d
 
def plot_3D(elev=30,azim=30,X=X,y=y):
    ax = plt.subplot(projection="3d")
    ax.scatter3D(X[:,0],X[:,1],r,c=y,s=50,cmap='rainbow')
    ax.view_init(elev=elev,azim=azim)
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_zlabel("r")
    plt.show()
 
from ipywidgets import interact,fixed
interact(plot_3D,elev=[0,30,60,90],azip=(-180,180),X=fixed(X),y=fixed(y))
plt.show()

elev和azim都是可调节的:
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

明显我们可以用一个平面将两类数据隔开,这个平面就是我们的决策边界了。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

python机器学习之支持向量机——线性SVM决策过程的可视化案例 的相关文章

  • 将二进制 Numpy 数组转换为无符号整数

    我有一个 Numpy 数组对象的长二维矩阵 其维度为 n x 12 这是该矩阵的前 10 行 b 0 0 0 0 1 1 1 1 1 0 1 0 0 0 0 0 1 1 1 1 1 0 0 0 0 0 0 0 1 1 1 1 1 0 0 0
  • Numpy:具有特定条件的线性系统。无负解

    我正在使用 numpy 编写 Python 代码 在我的代码中 我使用 linalg solve 来求解 n 个变量中的 n 个方程的线性系统 当然 解决方案可以是积极的 也可以是消极的 我需要做的是始终有正解或至少等于 0 为此 我首先希
  • Mac OS X:“ModuleNotFoundError:没有名为“numpy”的模块”

    我重新安装后Anaconda https en wikipedia org wiki Anaconda Python distribution 我无法再在 Python 3 上导入 NumPy import numpy as np Outp
  • 如何保存 numpy 数组图像并将它们放入单个文件夹中?

    我有一个 numpy 数组 其中包含 5000 个 28 x 28 图像 5000 28 28 我想将所有这些图像保存为 jpg 文件并将它们全部保存在一个文件夹中 实现这一目标最快 最有效的方法是什么 我尝试使用以下命令将 50 000
  • 当存在多个条件时替换 numpy 数组中的元素

    这个问题与以下帖子相关 如果满足条件则替换 Numpy 元素 https stackoverflow com questions 19766757 replacing numpy elements if condition is met 假
  • 从 MySQL 将数字数据加载到 python/pandas/numpy 数组的最快方法

    我想从 MySQL 表中读取一些数字 双精度 即 float64 数据 数据大小约为 200k 行 MATLAB 参考 tic feature accel off conn database c fetch exec conn select
  • ValueError:当数组不是序列时设置带有序列的数组元素

    您好 此代码旨在存储使用 open cv 绘制的矩形的坐标 并将结果编译为单个图像 import numpy as np import cv2 im cv2 imread 1 jpg im3 im copy gray cv2 cvtColo
  • 使用 NumPy 函数计算 Pandas 的加权平均值

    假设我们有一个像这样的 pandas 数据框 a b id 36 25 2 40 25 3 46 23 2 40 22 5 42 20 5 56 39 3 我想执行一个操作 a div b 然后按 id 分组 最后使用 a 作为权重计算加权
  • Numpy:查找两个 3-D 数组之间的欧几里德距离

    给定两个维度为 2 2 2 的 3 D 数组 A 0 0 92 92 0 92 0 92 B 0 0 92 0 0 92 92 92 如何有效地找到 A 和 B 中每个向量的欧几里得距离 我尝试过 for 循环 但速度很慢 而且我正在按 g
  • 使用 numpy.distutils.core.setup 之前安装 numpy

    我在用numpy distutils设置具有 fortran 模块的包 mypackage 问题是如果我这样做pip install mypackage在没有 numpy 的环境中 出现以下错误 ModuleNotFoundError 没有
  • 内存高效的随机数迭代器,无需替换

    我觉得这应该很容易 但经过多次搜索和尝试后我找不到答案 基本上 我有大量的物品 我想以随机顺序进行采样 而不需要更换 在本例中 它们是二维数组中的单元 我用于较小数组的解决方案不会转换 因为它需要对内存数组进行改组 如果我必须采样的数量很小
  • 使用 python 中的硬件 rng

    是否有任何现成的库 以便 numpy 程序可以使用 intel 硬件 prng rdrand 来填充随机数缓冲区 如果做不到这一点 有人可以为我指明一些我可以改编或使用的 C 代码的正确方向 我将 CPython 和 Cython 与 nu
  • python中稀疏矩阵的相关系数?

    有谁知道如何从Python中的一个非常大的稀疏矩阵计算相关矩阵 基本上 我正在寻找类似的东西numpy corrcoef这将适用于 scipy 稀疏矩阵 您可以从协方差矩阵相当直接地计算相关系数 如下所示 import numpy as n
  • 使用 PIL 用附近的颜色填充空白图像空间(也称为修复)

    我用 PIL 创建一个图像 我需要填充空白区域 显示为黑色 我可以轻松地用静态颜色填充它 但我想做的是用附近的颜色填充像素 例如 边框之后的第一个像素可能是填充像素的高斯模糊 或者可能是中描述的推拉型算法Lumigraph Gortler
  • 添加和访问 numpy 结构化数组的对象类型字段

    我正在使用 numpy 1 16 2 简单来说 我想知道如何将对象类型字段添加到结构化数组中 标准方式通过recfunctions模块抛出错误 我想这是有原因的 因此 我想知道我的解决方法是否有问题 此外 我想了解为什么这种解决方法是必要的
  • 如何将 numpy.matrix 提高到非整数幂?

    The 运算符为numpy matrix不支持非整数幂 gt gt gt m matrix 1 0 0 5 0 5 gt gt gt m 2 5 TypeError exponent must be an integer 我想要的是 oct
  • Numpy 优化

    我有一个根据条件分配值的函数 我的数据集大小通常在 30 50k 范围内 我不确定这是否是使用 numpy 的正确方法 但是当数字超过 5k 时 它会变得非常慢 有没有更好的方法让它更快 import numpy as np N 5000
  • 协方差矩阵的对角元素不是 1 pandas/numpy

    我有以下数据框 A B 0 1 5 1 2 6 2 3 7 3 4 8 我想计算协方差 a df iloc 0 values b df iloc 1 values 使用 numpy 作为 cov numpy cov a b I get ar
  • 如何在Python中对类别进行加权随机抽样

    给定一个元组列表 其中每个元组都包含一个概率和一个项目 我想根据其概率对项目进行采样 例如 给出列表 3 a 4 b 3 c 我想在 40 的时间内对 b 进行采样 在 python 中执行此操作的规范方法是什么 我查看了 random 模
  • Python:有类似matlab的反斜杠运算符吗?

    Matlab 和 Julia 有反斜杠运算符来求解线性系统 我真的不知道 Matlab 是做什么的 但是 Julia 不计算逆函数 但它计算逆函数对给定向量的影响 这在计算上更容易 我有一个 numpy 稀疏矩阵 我想将其伪逆应用于向量 P

随机推荐

  • Android SELinux

    Google参考链接 https source android com docs core architecture aidl aidl hals sepolicy A 通信框架SE文件修改 public attributes vendor
  • 【canal系】canal集群异常Could not find first log file name in binary log index file

    这里先说明下这边使用的canal版本号为1 1 5 在描述这个问题之前 首先需要简单对于canal架构有个基本的了解 canal工作原理 canal 模拟 MySQL slave 的交互协议 伪装自己为 MySQL slave 向 MySQ
  • 详解@Override注解

    目录 1 是什么 2 为什么用 3 举例说明 1 示例一 2 示例二 3 示例三 1 是什么 Override注解是伪代码 用于表示被标注的方法是一个重写方法 Override注解 只能用于标记方法 并且它只在编译期生效 不会保留在clas
  • QT中添加Q_OBJECT出现的问题

    Multiple Inheritance Requires QObject to Be First 多重继承QObject一定要放在前面 我在用class My Node public QGraphicsItem public QObjec
  • 产业互联网-构建智能+时代数字生态新图景

    在2019腾讯全球数字生态大会新闻发布会上 腾讯云联合腾讯研究院 共同发布了行业重磅报告 产业互联网 构建智能 时代数字生态新图景 报告首次阐述了产业互联网的战略框架和实践方法论 报告指出 产业互联网的实现 需要跨界共建数字生态共同体 形成
  • linux安装telnet工具下载,Linux下安装telnet的方法

    一 安装telnet 1 检测telnet server的rpm包是否安装 root localhost rpm qa telnet server 若无输入内容 则表示没有安装 出于安全考虑telnet server rpm是默认没有安装的
  • NestedScrolling机制(一)——概述

    http blog csdn net al4fun article details 53888990 如今 NestedScrolling机制 可以称为嵌套滚动或嵌套滑动 在各种app中的应用已经十分广泛了 下图是 饿了么 中的一个例子 当
  • 虹膜识别 Iris_Osiris_v4.1源码,mfc测试用例

    01 资源 win10 vs2015 git opencv3 3 0 cmake 参考虹膜识别文档 开源虹膜识别软件OSIRIS4 1的使用入门 将开源虹膜识别算法OSIRIS4 1移植到Windows opencv3 3 0的配置参考 也
  • Leetcode 202. 快乐数(找规律注意回环)

    快乐数 编写一个算法来判断一个数 n 是不是快乐数 快乐数 定义为 对于一个正整数 每一次将该数替换为它每个位置上的数字的平方和 然后重复这个过程直到这个数变为 1 也可能是 无限循环 但始终变不到 1 如果 可以变为 1 那么这个数就是快
  • 记录几个CentOS安装包(rpm)的下载地址-离线安装必备

    1 http rpmfind net linux RPM index html 2 https centos pkgs org 3 http mirror centos org centos 7 extras x86 64 Packages
  • Java处理SSH

    JSch 登录 密码方式 session setPassword password 公私秘钥方式 jsch addIdentity ssh id rsaxxx SFTP简介 SFTP是Secure File Transfer Protoco
  • 【YOLOv7/YOLOv5系列算法改进NO.49】模型剪枝、蒸馏、压缩

    文章目录 前言 一 解决问题 二 基本原理 三 剪枝操作 四 知识蒸馏操作 前言 作为当前先进的深度学习目标检测算法YOLOv7 已经集合了大量的trick 但是还是有提高和改进的空间 针对具体应用场景下的检测难点 可以不同的改进方法 此后
  • go 设置 GOROOT 和 GOPATH

    点击在我的博客 xuxusheng com 中查看 有更好的排版哦 发表失败全部丢失 写完了又重写一遍 csdn 都没个自动保存功能 强烈吐槽 go 里面有两个非常重要的环境变量 GOROOT 和 GOPATH 其中 GOROOT 是安装
  • linux CPU性能监控(进阶)和杂谈

    线程与进程的区别 进程 是执行一段程序 即一旦程序被载入到内存中准备执行 它就是一个进程 线程 单个进程中执行每一个任务就是一个线程 一个线程只属于一个进程 一个进程里可以有多个线程 上下文切换 在处理器执行期间 运行进程的信息被存储在处理
  • javax.net.ssl.SSLException: Received fatal alert: protocol_version

    最近需要第三方回传数据到自己的地址 发现调不通 如下 1 第三方错误提示 根据提示是请求时所用的tls协议版本与目标地址所能使用的不一致 2 第三方查看代码中所有的tls版本 查看目标地址所能支持的tls版本 nmap script ssl
  • Python的十二道编程题,码住战胜一切

    一 计算文件大小 import os def get size path size 0 l path while l path l pop lst os listdir path for name in lst son path os pa
  • Visuial Studio 打开 Unity 新建脚本时,新脚本继承MonoBehaviour暂时失效为白色的解决方法

    点击 文件 gt 最近使用的项目和解决方案 gt 点击当前项目 即可瞬间重载当前项目 这个时候 白色的MonoBehaviour会变成绿色 就可以了 当然最传统的方法就是关掉VS再打开 不过挺浪费时间的
  • umijs框架加载cesium

    创建umi项目 yarn create umi 选择app 选择是否使用typescript N 选择依赖 yarn yarn start 项目创建完成后 添加cesium yarn add cesium 下载版本是1 67 不同版本配置方
  • 【Android】替换系统默认字体

    android系统默认字体分类 DroidSans ttf 系统默认英文字体 DroidSans Bold ttf 系统默认英文粗字体 DroidSansFallback ttf 系统默认中文字体 为系统新增字体 1 复制字体到framew
  • python机器学习之支持向量机——线性SVM决策过程的可视化案例

    线性SVM决策过程的可视化 1 导入需要的模块 from sklearn datasets import make blobs from sklearn svm import SVC import matplotlib pyplot as