sklearn.model_selection.train_test_split

2023-10-27

数据集划分:sklearn.model_selection.train_test_split(*arrays, **options)

主要参数说明:

*arrays:可以是列表、numpy数组、scipy稀疏矩阵或pandas的数据框

test_size:可以为浮点、整数或None,默认为None

①若为浮点时,表示测试集占总样本的百分比

②若为整数时,表示测试样本样本数

③若为None时,test size自动设置成0.25

train_size:可以为浮点、整数或None,默认为None

①若为浮点时,表示训练集占总样本的百分比

②若为整数时,表示训练样本的样本数

③若为None时,train_size自动被设置成0.75

random_state:可以为整数、RandomState实例或None,默认为None

①若为None时,每次生成的数据都是随机,可能不一样

②若为整数时,每次生成的数据都相同

stratify:可以为类似数组或None

①若为None时,划分出来的测试集或训练集中,其类标签的比例也是随机的

②若不为None时,划分出来的测试集或训练集中,其类标签的比例同输入的数组中类标签的比例相同,可以用于处理不均衡的数据集

通过简单栗子看看各个参数的作用:

①test_size决定划分测试、训练集比例

  1. In [ 1]: import numpy as np
  2. ...: from sklearn.model_selection import train_test_split
  3. ...: X = np.arange( 20)
  4. ...: y = [ 'A', 'B', 'A', 'A', 'A', 'B', 'A', 'B', 'B', 'A', 'A', 'B', 'B', 'A', 'A', 'B', 'A
  5. ...: ', 'B', 'A', 'A']
  6. ...: X_train , X_test , y_train,y_test = train_test_split(X,y,test_size= 0.25
  7. ...: ,random_state= 0)
  8. ...:
  9. In [ 2]: X_test.shape
  10. Out[ 2]: ( 5,)
  11. In [ 3]: X_train.shape
  12. Out[ 3]: ( 15,)
  13. In [ 4]: X_test ,y_test
  14. Out[ 4]: (array([ 18, 1, 19, 8, 10]), [ 'A', 'B', 'A', 'B', 'A'])
②random_state不同值获取到不同的数据集

设置random_state=0再运行一次,结果同上述相同

  1. In [ 5]: import numpy as np
  2. ...: from sklearn.model_selection import train_test_split
  3. ...: X = np.arange( 20)
  4. ...: y = [ 'A', 'B', 'A', 'A', 'A', 'B', 'A', 'B', 'B', 'A', 'A', 'B', 'B', 'A', 'A', 'B', 'A
  5. ...: ', 'B', 'A', 'A']
  6. ...: X_train , X_test , y_train,y_test = train_test_split(X,y,test_size= 0.25
  7. ...: ,random_state= 0)
  8. ...: X_test ,y_test
  9. ...:
  10. Out[ 5]: (array([ 18, 1, 19, 8, 10]), [ 'A', 'B', 'A', 'B', 'A'])
设置random_state=None运行两次,发现两次的结果不同

  1. In [ 6]: import numpy as np
  2. ...: from sklearn.model_selection import train_test_split
  3. ...: X = np.arange( 20)
  4. ...: y = [ 'A', 'B', 'A', 'A', 'A', 'B', 'A', 'B', 'B', 'A', 'A', 'B', 'B', 'A', 'A', 'B', 'A
  5. ...: ', 'B', 'A', 'A']
  6. ...: X_train , X_test , y_train,y_test = train_test_split(X,y,test_size= 0.25
  7. ...: )
  8. ...: X_test ,y_test
  9. ...:
  10. Out[ 6]: (array([ 3, 18, 14, 7, 4]), [ 'A', 'A', 'A', 'B', 'A'])
  11. In [ 7]: import numpy as np
  12. ...: from sklearn.model_selection import train_test_split
  13. ...: X = np.arange( 20)
  14. ...: y = [ 'A', 'B', 'A', 'A', 'A', 'B', 'A', 'B', 'B', 'A', 'A', 'B', 'B', 'A', 'A', 'B', 'A
  15. ...: ', 'B', 'A', 'A']
  16. ...: X_train , X_test , y_train,y_test = train_test_split(X,y,test_size= 0.25
  17. ...: )
  18. ...: X_test ,y_test
  19. ...:
  20. Out[ 7]: (array([ 18, 6, 3, 14, 8]), [ 'A', 'A', 'A', 'A', 'B'])
③设置stratify参数,可以处理数据不平衡问题

  1. In [ 8]: import numpy as np
  2. ...: from sklearn.model_selection import train_test_split
  3. ...: X = np.arange( 20)
  4. ...: y = [ 'A', 'B', 'A', 'A', 'A', 'B', 'A', 'B', 'B', 'A', 'A', 'B', 'B', 'A', 'A', 'B', 'A
  5. ...: ', 'B', 'A', 'A']
  6. ...: X_train , X_test , y_train,y_test = train_test_split(X,y,test_size= 0.25
  7. ...: ,stratify=y)
  8. ...: X_test ,y_test
  9. ...:
  10. Out[ 8]: (array([ 18, 8, 3, 10, 11]), [ 'A', 'B', 'A', 'A', 'B'])
  11. In [ 9]: import numpy as np
  12. ...: from sklearn.model_selection import train_test_split
  13. ...: X = np.arange( 20)
  14. ...: y = [ 'A', 'B', 'A', 'A', 'A', 'B', 'A', 'B', 'B', 'A', 'A', 'B', 'B', 'A', 'A', 'B', 'A
  15. ...: ', 'B', 'A', 'A']
  16. ...: X_train , X_test , y_train,y_test = train_test_split(X,y,test_size= 0.25
  17. ...: ,stratify=y)
  18. ...: X_test ,y_test
  19. ...:
  20. Out[ 9]: (array([ 6, 19, 8, 17, 0]), [ 'A', 'A', 'B', 'B', 'A'])
  21. In [ 10]: X_train,y_train
  22. Out[ 10]:
  23. (array([ 7, 1, 11, 10, 15, 2, 3, 5, 4, 13, 12, 16, 18, 14, 9]),
  24. [ 'B', 'B', 'B', 'A', 'B', 'A', 'A', 'B', 'A', 'A', 'B', 'A', 'A', 'A', 'A'])
设置stratify=y时,我们发现每次划分后,测试集和训练集中的类标签比例同原始的样本中类标签的比例相同,都为2:3
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

sklearn.model_selection.train_test_split 的相关文章

  • DynamodB:如何更新排序键?

    该表有两个键 filename 分区键 和eventTime 排序键 我要更新eventTime对于某些filename Tried put item and update item 发送相同的filename与新的eventTime但这些
  • 使用管理员权限打开cmd(Windows 10)

    我有自己的 python 脚本来管理我的计算机上的 IP 地址 它主要在命令行 Windows 10 中执行netsh命令 您必须具有管理员权限 这是我自己的计算机 我是管理员 运行脚本时我已经使用管理员类型的用户 Adrian 登录 我无
  • 使用 Python 和 lmfit 拟合复杂模型?

    我想适合椭偏仪 http en wikipedia org wiki Ellipsometry使用 LMFit 将数据转换为复杂模型 两个测量参数 psi and delta 是复杂函数中的变量rho 我可以尝试将问题分离为实部和虚部共享参
  • Scrapy 文件管道不下载文件

    我的任务是构建一个可以下载所有内容的网络爬虫 pdfs 在给定站点中 Spider 在本地计算机和抓取集线器上运行 由于某种原因 当我运行它时 它只下载一些但不是全部的 pdf 通过查看输出中的项目可以看出这一点JSON 我已经设定MEDI
  • Python3将模块从文件夹导入到另一个文件夹

    我的结构字典是 mainFolder folder1 init py file1 py file2 py folder2 init py file3 py file4 py setup py init py 我需要将 file4 py 从f
  • 使用 Tkinter 打开网页

    因此 我的应用程序需要能够打开其中的单个网页 并且它必须来自互联网并且未保存 特别是我想使用 Tkinter GUI 工具包 因为它是我最熟悉的工具包 最重要的是 我希望能够在窗口中生成事件 例如单击鼠标 但无需实际使用鼠标 有什么好的方法
  • 会话数据库表清理

    该表是否需要清除或者由 Django 自动处理 Django 不提供自动清除功能 然而 有一个方便的命令可以帮助您手动完成此操作 Django 文档 清除会话存储 https docs djangoproject com en dev to
  • 如何知道python运行脚本的路径?

    sys arg 0 给我 python 脚本 例如 python hello py 返回 sys arg 0 的 hello py 但我需要知道 hello py 位于完整路径中的位置 我怎样才能用Python做到这一点 os path a
  • 我可以用关闭的文件对象做什么?

    当您打开文件时 它存储在一个打开的文件对象中 该对象使您可以访问该文件的各种方法 例如读取或写入 gt gt gt f open file0 gt gt gt f
  • 如何从 python 脚本执行 7zip 命令

    我试图了解如何使用 os system 模块来执行 7zip 命令 现在我不想用 Popen 或 subprocess 让事情变得复杂 我已经安装了 7zip 并将 7zip exe 复制到我的用户文件夹中 我只想提取我的测试文件 inst
  • Snakemake:将多个输入用于具有多个子组的一个输出的规则

    我有一个工作管道 用于下载 比对和对公共测序数据执行变体调用 问题是它目前只能在每个样本的基础上工作 i e作为每个单独测序实验的样本 如果我想对一组实验 例如样本的生物和 或技术复制 执行变体调用 则它不起作用 我试图解决它 但我无法让它
  • 将图与热图(可能是对数)配对?

    How to create a pair plot in Python like the following but with heat maps instead of points or instead of a hex bin plot
  • 如何将 URL 添加到 Telegram Bot 的 InlineKeyboardButton

    我想制作一个按钮 可以从 Telegram 聊天中在浏览器中打开 URL 外部超链接 目前 我只开发了可点击的操作按钮 update message reply text Subscribe to us on Facebook and Te
  • 数据损坏 C++ 和 Python 之间的管道

    我正在编写一些代码 从 Python 获取二进制数据 将其通过管道传输到 C 对数据进行一些处理 在本例中计算互信息度量 然后将结果通过管道传输回 Python 在测试时 我发现如果我发送的数据是一组尺寸小于 1500 X 1500 的 2
  • 根据标点符号列表替换数据框中的标点符号[重复]

    这个问题在这里已经有答案了 使用 Canopy 和 Pandas 我有数据框 a 其定义如下 a pd read csv text txt df pd DataFrame a df columns test test txt 是一个单列文件
  • 如何检测一个二维数组是否在另一个二维数组内?

    因此 在堆栈溢出成员的帮助下 我得到了以下代码 data needle s which is a png image base64 code goes here decoded data decode base64 f cStringIO
  • 为什么从 openAI 导入 Universe 模块时出现“无效语法”错误

    当我导入时universe来自 openAI 的模块 我收到以下错误 Traceback most recent call last File
  • PyQt5按钮lambda变量变成布尔值[重复]

    这个问题在这里已经有答案了 当我运行下面的代码时 它显示如下 为什么 x 不是 x 而是变成布尔值 这种情况仅发生在传递到用 lambda 调用的函数中的第一个参数上 错误的 y home me model some file from P
  • ProcessPoolExecutor 传递多个参数

    ESPN播放器免费 class ESPNPlayerFree def init self player id match id match id team 团队名单1 277906 cA2i150s81HI3qbq1fzi za1Oq5CG
  • Tkinter 将鼠标点击绑定到框架

    我一定错过了一些明显的东西 我的 Tkinter 程序中有两个框架 每个框架在网格布局中都有一堆标签 我想将鼠标点击绑定到其中一个而不是另一个 我目前使用 root bind

随机推荐