机器学习之网格搜索技术,如何在Auto-sklearn中应用网格搜索技术

2023-11-14

一,机器学习中的网格搜索技术是怎么回事

网格搜索(Grid Search)是一种常用的超参数调优方法,它通过遍历给定的超参数组合,从中寻找最优的超参数组合。

可以用于选择最佳模型参数。在机器学习模型的训练过程中,有许多参数需要设置,例如神经网络中的隐藏层数、每层的神经元数量等。这些参数会影响模型的性能和准确度。

网格搜索技术通过枚举一组超参数的可能取值,然后对每一组取值进行训练和评估来确定最优的超参数组合。它基于一个预定义的参数网格,并使用交叉验证来确定最佳参数集合,从而实现了自动调参

网格搜索技术在处理高维数据时非常有用,因为它可以在所有可能的参数组合中搜索最佳的超参数组合,从而提高模型的准确度和稳定性。

二,通俗解释

网格搜索技术其实就是一种智能的“试错”方法,可以帮助机器学习模型找到最佳的参数组合。假设我们有一个需要训练的机器学习模型,里面有很多参数需要设置,但我们不知道哪个参数组合会得到最好的结果。这时候,我们可以通过网格搜索技术来自动化地测试多种可能的参数组合,并找到最优的那个(类似于在一个网格中排列所有可能的参数组合,逐个尝试,直到找到最佳的那个)。这个过程需要计算机自动化地完成,并输出最优参数组合以供我们使用。这样,我们就可以节省大量时间和精力来手动调整参数,并且可以更容易地找到最佳的参数组合,从而让我们的机器学习模型变得更加准确和可靠。

三,在一般情况下使用网格搜索技术

  1. 首先,我们需要导入必要的python库,包括scikit-learn(一种流行的机器学习库)和numpy(用于数据处理和计算):
import numpy as np
from sklearn.model_selection import GridSearchCV
from sklearn.svm import SVC
  1. 然后,我们需要准备一些训练数据,这里使用scikit-learn自带的iris数据集:
from sklearn.datasets import load_iris
iris = load_iris()
X, y = iris.data, iris.target
  1. 接下来,我们定义一个支持向量机(SVM)模型,并设置需要测试的超参数范围。在这个示例中,我们将测试不同的C值和gamma值(两个影响SVM性能的重要参数):
svm = SVC()
parameters = {'C': [0.1, 1, 10], 'gamma': [0.01, 0.1, 1]}
  1. 我们现在可以使用GridSearchCV来执行网格搜索。这里我们设置cv参数为5,表示使用5折交叉验证来评估每个参数组合的性能。
clf = GridSearchCV(svm, parameters, cv=5)
clf.fit(X, y)
  1. 训练完成后,我们可以使用best_params_属性来获取最佳的参数组合,以及使用best_score_属性来获取最佳的性能得分
print("Best parameters: ", clf.best_params_)
print("Best score: ", clf.best_score_)
  1. 最后,我们可以使用最佳的参数组合来训练最终的模型,并用它进行预测:
best_svm = SVC(C=1, gamma=0.01)
best_svm.fit(X, y)
predictions = best_svm.predict(X)

这个示例中,我们使用网格搜索技术寻找最佳的SVM超参数组合。当我们执行GridSearchCV时,它会自动测试每种可能的Cgamma值组合,并返回最佳的参数组合和相应的准确度得分。最后,我们使用最佳的超参数组合来训练最终的模型,并用它来进行新的数据预测。

四,GridSearchCV网格搜索技术的原理

scikit-learn中,GridSearchCV类是用来实现网格搜索技术的。它内部使用了交叉验证技术来评估每一个超参数组合的性能。

原理如下:

  1. 首先,将训练集数据划分为k个等分。

对于每一组超参数组合,使用k-1份数据进行模型训练,并在留出的k份数据上进行预测,得到该组超参数组合的平均准确度得分。

  1. 重复步骤2,直到所有的超参数组合都被测试完毕,得到每组超参数组合的平均准确度得分

  2. 最终选择具有最高平均准确度得分的超参数组合作为最佳超参数组合,并返回该组合对应的模型。

  3. 可以看出GridSearchCV核心思想就是穷举所有可能的超参数组合,通过交叉验证来计算每个组合的性能得分,从而找到最佳的超参数组合。这种技术虽然简单,但非常有效,能够帮助机器学习工程师快速找到最优的模型超参数,从而提高模型的准确度和稳定性。

五,如何在Auto-sklearn中使用网格搜索技术

1. Auto-sklearn实际用应用中一般不会使用网格搜索技术

在autosklearn中,不需要使用传统的网格搜索技术。相反,autosklearn使用一种叫做“贝叶斯优化”的方法或者随机搜索方法来寻找最佳超参数组合。

2. 不使用网格搜索技术的原因

网格搜索是一种传统的超参数优化方法,它通过定义一组超参数的值域范围,对所有可能的超参数组合进行穷举搜索。例如,对于两个超参数ab,如果它们分别在[1,10][0.01, 0.1]范围内,则可以使用网格搜索算法在100个不同的超参数组合中搜索最佳的超参数组合。网格搜索的缺点是当超参数数量较多或者取值范围较大时,搜索空间会变得非常庞大,计算成本也会增加。

3. 随机搜索和贝叶斯优化

  1. 随机搜索是一种比网格搜索更为高效的超参数优化方法,它直接在超参数空间中随机采样一定数量的超参数组合,并评估它们的性能,从而寻找最佳的超参数组合。相比于网格搜索,随机搜索无需对所有超参数进行穷举搜索,因此计算成本更低。

  2. 贝叶斯优化是一种基于贝叶斯统计学理论的超参数优化方法,它通过先验概率和观测数据来不断更新超参数空间的后验概率分布,并选择其中最有可能获得最佳性能的超参数组合进行评估。贝叶斯优化算法具有高效、智能和自适应的特点,因此在大规模或复杂的机器学习问题中通常表现出色。

  3. 与网格搜索相比,随机搜索和贝叶斯优化算法都可以降低计算成本,并且可以避免在非重要的区域内搜索。同时,贝叶斯优化算法还可以更好地处理噪声数据,并具有更好的全局优化能力。

4. 三种技术的应用场景

网格搜索:适用于超参数数量较少且范围已知的情况下,可以使用网格搜索来穷举搜索所有超参数的可能值。

随机搜索:适用于超参数空间比较大或无法确定超参数的取值范围时,随机搜索算法可以在超参数空间内随机采样一定数量的超参数组合进行评估。

贝叶斯优化:适用于需要优化时间和计算成本的场景,例如模型训练周期长、计算资源有限等。贝叶斯优化算法可以根据前几次试验的结果调整超参数搜索的方向,并尝试更精确地找出最佳超参数组合。

5. 在Auto-sklearn使用,网格搜索,随机搜索和贝叶斯优化

在autosklearn中,以分类任务为例:

1. 在Auto-sklearn使用,网格搜索

网格搜索:网格搜索需要指定每个超参数的所有可能值,然后将这些值排列成一个网格结构进行搜索。对于分类问题,通常需要调整的超参数包括分类器类型、正则化系数、C值、kernel等。

例如,可以使用以下代码定义多个超参数的值范围:

import autosklearn.classification as asc
cls = asc.AutoSklearnClassifier(time_left_for_this_task=360,
                                 per_run_time_limit=30,
                                 include_estimators=['random_forest', 'adaboost', 'libsvm_svc'],
                                 resampling_strategy='cv',
                                 resampling_strategy_arguments={'folds': 5},
                                 n_jobs=4)
cls_config_space = cls.get_configuration_space()
params = {'classifier:__choice__': ['random_forest', 'adaboost', 'libsvm_svc'],
          'classifier:random_forest:max_features': Int(1, 20),
          'classifier:adaboost:n_estimators': Categorical([50, 100, 200]),
          'classifier:adaboost:learning_rate': Float(0.01, 0.1, default=0.1, log=True),
          'classifier:libsvm_svc:C': Float(0.01, 10, default=1, log=True),
          'preprocessor:__choice__': ['no_preprocessing', 'select_percentile_classification'],
          'preprocessor:select_percentile_classification:percentile': Int(1, 100),
          'preprocessor:select_percentile_classification:score_func': ['f_classif', 'chi2']}

在这个例子中,我们指定了三种不同的分类器类型(随机森林、AdaBoost和支持向量机)和两种数据预处理方法(无预处理和百分位数特征选择),对于每种分类器,我们还需要调整它们的具体超参数。

2. 在Auto-sklearn使用,随机搜索

如果想要使用随机搜索算法而不是SMAC(一种贝叶斯优化方法)算法,那么应该将smac_scenario_args中的runcount_limit参数设置为None。这样做可以使autosklearn选择随机搜索算法来进行超参数搜索。

以下是使用随机搜索算法调用autosklearn时的示例代码:

import autosklearn.classification as asc
cls = asc.AutoSklearnClassifier(time_left_for_this_task=360,
                                 per_run_time_limit=30,
                                 include_estimators=['random_forest', 'adaboost', 'libsvm_svc'],
                                 resampling_strategy='cv',
                                 resampling_strategy_arguments={'folds': 5},
                                 n_jobs=4,
                                 ensemble_size=0,
                                 initial_configurations_via_metalearning=0,
                                 smac_scenario_args={'runcount_limit': None},
                                 random_state=42)
cls.fit(X_train, y_train, dataset_name='classification', metric=accuracy, optimize_metric=True)

在这个例子中,我们将smac_scenario_args中的runcount_limit参数设置为None,表示使用随机搜索算法来进行超参数搜索。

3. 在Auto-sklearn使用,贝叶斯优化

贝叶斯优化:贝叶斯优化算法通过先验概率和观测数据来不断更新超参数空间的后验概率分布,并选择其中最有可能获得最佳性能的超参数组合进行评估。
在autosklearn中,如果不指定任何超参数搜索算法,则默认使用贝叶斯优化算法。但是,我们可以通过调整以下参数来指定特定的搜索算法:

import autosklearn.classification as asc
cls = asc.AutoSklearnClassifier(time_left_for_this_task=360,
                                 per_run_time_limit=30,
                                 include_estimators=['random_forest', 'adaboost', 'libsvm_svc'],
                                 resampling_strategy='cv',
                                 resampling_strategy_arguments={'folds': 5},
                                 n_jobs=4,
                                 ensemble_size=0,
                                 initial_configurations_via_metalearning=0,
                                 smac_scenario_args={'runcount_limit': 10},
                                 random_state=42)
cls.fit(X_train, y_train, dataset_name='classification', metric=accuracy, optimize_metric=True)

在这个例子中,我们可以通过设置smac_scenario_args参数来选择使用SMAC算法(一种贝叶斯优化方法)或随机搜索算法。如果将runcount_limit设置为正整数,则表示使用SMAC算法;如果将其设置为None,则表示使用随机搜索算法。

4. runcount_limit参数设置作用

runcount_limit参数指定了SMAC算法运行的最大步数(即评估超参数组合的次数)。当runcount_limit设置为较小的值时,SMAC算法只能进行有限的尝试,可能无法找到最佳超参数组合。相反,如果将该参数设置得非常大,则SMAC算法将花费更长时间来搜索空间,并且可能会找到更好的超参数组合。

在autosklearn中,如果将runcount_limit参数设置为正整数,则SMAC算法将被用于搜索超参数空间。因此,runcount_limit参数的大小会影响超参数搜索的准确性和速度。如果您的数据集比较小,那么可以将runcount_limit设置为较小的值,例如10或20,这样可以节省时间并快速获得一组相对较好的超参数。如果您的数据集很大,建议将runcount_limit设置为较大的值,例如50或100,以便能够更全面地搜索超参数空间。

需要注意的是,使用随机搜索算法可以避免SMAC算法中的这些问题,因为随机搜索算法会直接在超参数空间中随机抽样超参数组合进行评估,而不是使用启发式算法。因此,如果您的任务比较简单,可以使用随机搜索算法代替SMAC算法。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

机器学习之网格搜索技术,如何在Auto-sklearn中应用网格搜索技术 的相关文章

  • (discord.py) 尝试更改成员角色时,“用户”对象没有属性“角色”

    因此 我正在尝试编写一个机器人 让某人在命令中指定的主持人指定的一段时间内暂停角色 我知道该变量称为 小时 即使它目前以秒为单位 我稍后会解决这个问题 基本上 它是由主持人在消息 暂停 personmention numberofhours
  • 使用 openCV 对图像中的子图像进行通用检测

    免责声明 我是计算机视觉菜鸟 我看过很多关于如何在较大图像中查找特定子图像的堆栈溢出帖子 我的用例有点不同 因为我不希望它是具体的 而且我不确定如何做到这一点 如果可能的话 但我感觉应该如此 我有大量图像数据集 有时 其中一些图像是数据集的
  • Pycharm Python 控制台不打印输出

    我有一个从 Pycharm python 控制台调用的函数 但没有显示输出 In 2 def problem1 6 for i in range 1 101 2 print i end In 3 problem1 6 In 4 另一方面 像
  • 如何打印没有类型的defaultdict变量?

    在下面的代码中 from collections import defaultdict confusion proba dict defaultdict float for i in xrange 10 confusion proba di
  • 如何在Windows上模拟socket.socketpair

    标准Python函数套接字 套接字对 https docs python org 3 library socket html socket socketpair不幸的是 它在 Windows 上不可用 从 Python 3 4 1 开始 我
  • 如何使用包含代码的“asyncio.sleep()”进行单元测试?

    我在编写 asyncio sleep 包含的单元测试时遇到问题 我要等待实际的睡眠时间吗 I used freezegun到嘲笑时间 当我尝试使用普通可调用对象运行测试时 这个库非常有用 但我找不到运行包含 asyncio sleep 的测
  • Spark的distinct()函数是否仅对每个分区中的不同元组进行洗牌

    据我了解 distinct 哈希分区 RDD 来识别唯一键 但它是否针对仅移动每个分区的不同元组进行了优化 想象一个具有以下分区的 RDD 1 2 2 1 4 2 2 1 3 3 5 4 5 5 5 在此 RDD 上的不同键上 所有重复键
  • feedparser 在脚本运行期间失败,但无法在交互式 python 控制台中重现

    当我运行 eclipse 或在 iPython 中运行脚本时 它失败了 ascii codec can t decode byte 0xe2 in position 32 ordinal not in range 128 我不知道为什么 但
  • python 集合可以包含的值的数量是否有限制?

    我正在尝试使用 python 设置作为 mysql 表中 ids 的过滤器 python集存储了所有要过滤的id 现在大约有30000个 这个数字会随着时间的推移慢慢增长 我担心python集的最大容量 它可以包含的元素数量有限制吗 您最大
  • 当玩家触摸屏幕一侧时,如何让 pygame 发出警告?

    我使用 pygame 创建了一个游戏 当玩家触摸屏幕一侧时 我想让 pygame 给出类似 你不能触摸屏幕两侧 的错误 我尝试在互联网上搜索 但没有找到任何好的结果 我想过在屏幕外添加一个方块 当玩家触摸该方块时 它会发出警告 但这花了很长
  • HTTPS 代理不适用于 Python 的 requests 模块

    我对 Python 还很陌生 我一直在使用他们的 requests 模块作为 PHP 的 cURL 库的替代品 我的代码如下 import requests import json import os import urllib impor
  • 如何将 numpy.matrix 提高到非整数幂?

    The 运算符为numpy matrix不支持非整数幂 gt gt gt m matrix 1 0 0 5 0 5 gt gt gt m 2 5 TypeError exponent must be an integer 我想要的是 oct
  • Numpy 优化

    我有一个根据条件分配值的函数 我的数据集大小通常在 30 50k 范围内 我不确定这是否是使用 numpy 的正确方法 但是当数字超过 5k 时 它会变得非常慢 有没有更好的方法让它更快 import numpy as np N 5000
  • Python 3 中“map”类型的对象没有 len()

    我在使用 Python 3 时遇到问题 我得到了 Python 2 7 代码 目前我正在尝试更新它 我收到错误 类型错误 map 类型的对象没有 len 在这部分 str len seed candidates 在我像这样初始化它之前 se
  • 从 pygame 获取 numpy 数组

    我想通过 python 访问我的网络摄像头 不幸的是 由于网络摄像头的原因 openCV 无法工作 Pygame camera 使用以下代码就像魅力一样 from pygame import camera display camera in
  • Nuitka 未使用 nuitka --recurse-all hello.py [错误] 编译 exe

    我正在尝试通过 nuitka 创建一个简单的 exe 这样我就可以在我的笔记本电脑上运行它 而无需安装 Python 我在 Windows 10 上并使用 Anaconda Python 3 我输入 nuitka recurse all h
  • 如何将 PIL 图像转换为 NumPy 数组?

    如何转换 PILImage来回转换为 NumPy 数组 这样我就可以比 PIL 进行更快的像素级转换PixelAccess允许 我可以通过以下方式将其转换为 NumPy 数组 pic Image open foo jpg pix numpy
  • 如何在 Django 中使用并发进程记录到单个文件而不使用独占锁

    给定一个在多个服务器上同时执行的 Django 应用程序 该应用程序如何记录到单个共享日志文件 在网络共享中 而不保持该文件以独占模式永久打开 当您想要利用日志流时 这种情况适用于 Windows Azure 网站上托管的 Django 应
  • 设置 torch.gather(...) 调用的结果

    我有一个形状为 n x m 的 2D pytorch 张量 我想使用索引列表来索引第二个维度 可以使用 torch gather 完成 然后然后还设置新值到索引的结果 Example data torch tensor 0 1 2 3 4
  • 如何使用google colab在jupyter笔记本中显示GIF?

    我正在使用 google colab 想嵌入一个 gif 有谁知道如何做到这一点 我正在使用下面的代码 它并没有在笔记本中为 gif 制作动画 我希望笔记本是交互式的 这样人们就可以看到代码的动画效果 而无需运行它 我发现很多方法在 Goo

随机推荐

  • MA模型简介及其相关性质

    文章目录 1 概述 1 1 定义 1 2 限制条件 1 3 中心化 M A q
  • 课时 16 自测题

    以下说法错误的是 单选题 A etcd 是一个商业软件 B etcd 使用 go 语言编写 C etcd 是一个分布式系统 通常由多个 server 组成一个集群 关于 etcd 重要时间节点 以下说法错误的是 单选题 A etcd 最初由
  • react 初级基础

    react基本使用 项目创建 项目的创建命令 npx create react app react basic 创建一个基本元素进行渲染 1 导入react 和 react dom import React from react impor
  • 竞赛 交通目标检测-行人车辆检测流量计数 - 竞赛

    文章目录 0 前言 1 目标检测概况 1 1 什么是目标检测 1 2 发展阶段 2 行人检测 2 1 行人检测简介 2 2 行人检测技术难点 2 3 行人检测实现效果 2 4 关键代码 训练过程 最后 0 前言 优质竞赛项目系列 今天要分享
  • 函数或变量 x 无法识别。_这个变量陷阱,连高手都躲不开

    点击上方 Python小白集训营 选 星标 公众号 重磅干货 第一时间送达 图 Pexels 日期 2021 1 2 你可能会好奇 是什么样的陷阱 连高段位的python选手也会频繁踩坑 讲这个topic前 先来讲一个例子 这是我前几个月在
  • 【华为OD考试真题】报数游戏(Python实现)

    前言 考试题目大同小异 练习真题是通过考试的捷径 思路仅供参考 如果有更好的思路 欢迎一起交流学习 创作不易 文章若对你有帮助 点个关注 谢谢 题目描述 100个人围成一圈 每个人有一个编码 编号从1开始到100 他们从1开始依次报数 报到
  • 利用LSB算法隐藏图片信息的MATLAB实现

    前一篇博客中介绍了利用LSB算法隐藏文字信息的MATLAB实现 http blog csdn net csdn moming article details 50936687 在此基础上 下面介绍利用LSB算法隐藏图片信息的MATLAB实现
  • Mt2015 lfsr

    Taken from 2015 midterm question 5 See also the first part of this question mt2015 muxdff Write the Verilog code for thi
  • python:使用unquote对url解码

    参考 python之urlencode quote 及unquote wf592523813的博客 CSDN博客 python unquote
  • 企业架构成功之道读书笔记

    企业架构成功之道读书笔记 原文 https www leanix net en enterprise architecture 企业架构成功之道 理解下一代企业架构的价值 降低成本 应用合理化 速赢 10 软件授权优化 项目合理化 应用下线
  • 图形视图(17):【类】QGraphicsWidget[官翻]

    文章目录 详述 公共类型 enum anonymous 属性 autoFillBackground bool focusPolicy Qt FocusPolicy font QFont geometry QRectF layout QGra
  • final定义类、方法、属性以及多态性

    1 在Java中final称为终结期 在java里面可以使用 不能有子类 2 使用final定义的方法不能被子类覆写 3 使用final定义的变量就成了常量 常量必须在定义的时候设置 多态性基本概念以及相关的使用限制 多态性的依赖 转载于
  • Sentinel客户端调用并发控制

    前言 当链路中某个应用出现不稳定 导致整个链路调用变慢 如果不加控制可能导致雪崩 这种情况如何处理呢 一 慢调用现象分析 在分布式链路中调用中 调用关系如下 methodA1与methodA2在同一个应用中 链路标号 调用链 链路1 met
  • luajit官方性能优化指南和注解

    luajit官方性能优化指南和注解 luajit是目前最快的脚本语言之一 不过深入使用就很快会发现 要把这个语言用到像宣称那样高性能 并不是那么容易 实际使用的时候往往会发现 刚开始写的一些小test case性能非常好 经常毫秒级就算完
  • 怎么解决Greenplum中用pg

    基本思路是为ns1 table1设置分布策略 root登陆master host切换到Greenplum的管理员用户 比如gpadmin su gpadmin使用psql连接数据库 psql databasename设置随机分布策略alte
  • 超好玩地铁跑酷游戏,内涵源代码

    直接上代码 include
  • java的多重循环和程序调试

    java的多重循环和程序调试 一 掌握Java二重循环 多重 嵌套 注意 1 外层循环控制行 内层循环控制列 每行打印的内容 2 外层循环执行一次 内层循环执行一遍 3 一般多重循环值的就是二重循环 二 使用跳转语句控制程序的流程 retu
  • 在线Plist文件格式转Json文件格式

    Plist文件是一种用于存储应用程序配置信息的文件格式 其中包含应用程序的各种设置和数据 在过去 Plist文件通常是以 plist 格式存储的 然而 随着时间的推移 人们开始使用 JSON 格式来存储更复杂的数据结构和数据 如果您需要将
  • 国人自研开源项目,一款简单易用的 GitLab 替代品

    公众号关注 GitHubDaily 设为 星标 每天带你逛 GitHub 今天跟大家介绍一个国人自研项目 可用做 GitLab 替代品 PS 本文来自作者本人投稿 OneDev 是一个开源的一体化的 DevOps 平台 目前项目在 GitH
  • 机器学习之网格搜索技术,如何在Auto-sklearn中应用网格搜索技术

    文章目录 一 机器学习中的网格搜索技术是怎么回事 二 通俗解释 三 在一般情况下使用网格搜索技术 四 GridSearchCV网格搜索技术的原理 五 如何在Auto sklearn中使用网格搜索技术 1 Auto sklearn实际用应用中