子类化 sklearn LinearSVC 以用作 sklearn GridSearchCV 的估计器

2023-12-01

我正在尝试创建一个子类sklearn.svm.LinearSVC用作估计器sklearn.model_selection.GridSearchCV。子类有一个额外的函数,在本例中不执行任何操作。然而,当我运行这个时,我最终遇到了一个我似乎无法调试的错误。如果您复制粘贴代码并运行,它应该会重现以以下结尾的完整错误ValueError: Input contains NaN, infinity or a value too large for dtype('float64')

一旦我让他工作,我希望为该方法添加更多功能transform_this().

有人可以告诉我哪里出了问题吗?基于this我首先认为这是由于我的数据存在一些问题。然而,由于我使用 sklearn 内置数据集重现了它,所以情况似乎并非如此。另外,我相信我根据我对上一个问题的回答正确地对此进行了子类化here。另外,我了解到 GridSearchCV 似乎没有以不同的方式初始化估计器(不知何故,它首先使用默认参数,正如我从这个帖子)

from sklearn.datasets import load_breast_cancer
from sklearn.svm import LinearSVC
from sklearn.model_selection import GridSearchCV

RANDOM_STATE = 123


class LinearSVCSub(LinearSVC):
    def __init__(self, penalty='l2', loss='squared_hinge', additional_parameter1=1, additional_parameter2=100,
                 dual=True, tol=0.0001, C=1.0, multi_class='ovr', fit_intercept=True, intercept_scaling=1,
                 class_weight=None, verbose=0, random_state=None, max_iter=1000):
        super(LinearSVCSub, self).__init__(penalty=penalty, loss=loss, dual=dual, tol=tol,
                                           C=C, multi_class=multi_class, fit_intercept=fit_intercept,
                                           intercept_scaling=intercept_scaling, class_weight=class_weight,
                                           verbose=verbose, random_state=random_state, max_iter=max_iter)

        self.additional_parameter1 = additional_parameter1
        self.additional_parameter2 = additional_parameter2

    def fit(self, X, y, sample_weight=None):
        X = self.transform_this(X)
        super(LinearSVCSub, self).fit(X, y, sample_weight)

    def predict(self, X):
        X = self.transform_this(X)
        super(LinearSVCSub, self).predict(X)

    def score(self, X, y, sample_weight=None):
        X = self.transform_this(X)
        super(LinearSVCSub, self).score(X, y, sample_weight)

    def decision_function(self, X):
        X = self.transform_this(X)
        super(LinearSVCSub, self).decision_function(X)

    def transform_this(self, X):
        return X


if __name__ == '__main__':
    data = load_breast_cancer()
    X, y = data.data, data.target

    # Parameter tuning with custom LinearSVC
    param_grid = {'C': [0.00001, 0.0001, 0.0005],
                      'dual': (True, False), 'random_state': [RANDOM_STATE],
                      'additional_parameter1': [0.90, 0.80, 0.60, 0.30],
                      'additional_parameter2': [20, 30]}

    gs_model = GridSearchCV(estimator=LinearSVCSub(), verbose=1, param_grid=param_grid,
                            scoring='roc_auc', n_jobs=-1)
    gs_model.fit(X, y)

你有几个问题:

  1. 定义的方法没有 return 语句
  2. 您选择的数据集不收敛LinearSVC

一旦您纠正了这些问题,您就可以开始:

from sklearn.datasets import make_classification
from sklearn.svm import LinearSVC
from sklearn.model_selection import GridSearchCV

RANDOM_STATE = 123


class LinearSVCSub(LinearSVC):
    def __init__(self, penalty='l2', loss='squared_hinge', additional_parameter1=1, additional_parameter2=100,
                 dual=True, tol=0.0001, C=1.0, multi_class='ovr', fit_intercept=True, intercept_scaling=1,
                 class_weight=None, verbose=0, random_state=None, max_iter=100000):
        super(LinearSVCSub, self).__init__(penalty=penalty, loss=loss, dual=dual, tol=tol,
                                           C=C, multi_class=multi_class, fit_intercept=fit_intercept,
                                           intercept_scaling=intercept_scaling, class_weight=class_weight,
                                           verbose=verbose, random_state=random_state, max_iter=max_iter)

        self.additional_parameter1 = additional_parameter1
        self.additional_parameter2 = additional_parameter2

    def fit(self, X, y, sample_weight=None):
        X = self.transform_this(X)
        super(LinearSVCSub, self).fit(X, y, sample_weight)
        return self

    def predict(self, X):
        X = self.transform_this(X)
        return super(LinearSVCSub, self).predict(X)

    def score(self, X, y, sample_weight=None):
        X = self.transform_this(X)
        return super(LinearSVCSub, self).score(X, y, sample_weight)

    def decision_function(self, X):
        X = self.transform_this(X)
        return super(LinearSVCSub, self).decision_function(X)

    def transform_this(self, X):
        return X


X, y = make_classification()

# Parameter tuning with custom LinearSVC
param_grid = {'C': [0.00001, 0.0001, 0.0005],
                  'dual': (True, False), 'random_state': [RANDOM_STATE],
                  'additional_parameter1': [0.90, 0.80, 0.60, 0.30],
                  'additional_parameter2': [20, 30]
             }

gs_model = GridSearchCV(estimator=LinearSVCSub(), verbose=1, param_grid=param_grid,
                        scoring='roc_auc', n_jobs=1)

gs_model.fit(X, y)
Fitting 5 folds for each of 48 candidates, totalling 240 fits
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done 240 out of 240 | elapsed:    0.9s finished
GridSearchCV(estimator=LinearSVCSub(), n_jobs=1,
             param_grid={'C': [1e-05, 0.0001, 0.0005],
                         'additional_parameter1': [0.9, 0.8, 0.6, 0.3],
                         'additional_parameter2': [20, 30],
                         'dual': (True, False), 'random_state': [123]},
             scoring='roc_auc', verbose=1)

gs_model.predict(X)
array([0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1,
       1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1,
       1, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1,
       0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0,
       0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1])
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

子类化 sklearn LinearSVC 以用作 sklearn GridSearchCV 的估计器 的相关文章

随机推荐

  • 如何使用 IResourceChangeListener 检测文件重命名并动态设置 EditorPart 名称?

    IResourceChangeListener监听项目工作区中的更改 例如编辑器零件文件名是否已更改 我想知道如何访问该特定的EditorPart并相应地更改其标题名称 例如 setPartName 或者刷新编辑器以便它自动显示新名称 理想
  • 在 Highchart 样条图上的最后一点显示指标

    我知道如何在最后一点显示标记 例如this 当数据是动态的时候 不知道如何标记最后一个点 plotOptions column stacking normal spline marker enabled true 当您动态添加新点时 您可以
  • 在缓冲区对象上运行并通过着色器更改其数据? [关闭]

    就目前情况而言 这个问题不太适合我们的问答形式 我们希望答案得到事实 参考资料或专业知识的支持 但这个问题可能会引发辩论 争论 民意调查或扩展讨论 如果您觉得这个问题可以改进并可能重新开放 访问帮助中心以获得指导 有没有办法在缓冲区对象上运
  • MSSQL 选择“垂直”-其中

    除了 垂直位置 之外 我真的不知道如何解释 想象一下下表 TAGID PRODUCTID SHOP ID 59 3418 7 38 61 3418 7 38 60 4227 4 38 61 4227 4 38 现在我想返回与标签 ID 相关
  • 实例成员不能用于类型

    我有以下课程 class ReportView NSView var categoriesPerPage Int var numPages Int return categoriesPerPage count 编译失败并显示消息 实例成员
  • 链接共享 C 库时 Android NDK 错误

    我正在尝试将一些 C 文件链接到我正在处理的 NDK 项目 并设置我的CMakeLists txt像下面这样归档 cmake minimum required VERSION 3 4 1 set CMAKE C FLAGS CMAKE C
  • 润滑 as_date 和。 as_datetime 行为差异

    我有一个数字向量 表示自 1970 年 1 月 1 日以来的毫秒数 我想使用以下方法将它们转换为日期时间对象lubridate 数据示例如下 raw times lt c 1139689917479 1139667123031 114036
  • 使用浏览器控制台使用 Javascript 在 Facebook 中发送聊天消息

    我尝试使用 Javascript 在 Facebook 中发送聊天消息 但不断收到错误消息 要么是TypeError Object
  • 我可以对 Linux 进程的地址空间中的每个页面进行写保护吗?

    我想知道是否有一种方法可以对 Linux 中的每个页面进行写保护 进程的地址空间 从进程本身的内部 通过mprotect 我所说的 每一页 实际上是指该网站的每一页 进程的地址空间可以被普通进程写入 程序在用户模式下运行 所以 程序文本 常
  • ServiceStack Javascript JsonServiceClient 缺少属性

    我正在尝试使用 Servicestack JsonServiceClient 连接到经过 JWT 身份验证的服务 但是文档仅描述了如何使用 C 客户端执行此操作 http docs servicestack net jwt authprov
  • 计时器不包含在 Xamarin.Forms 的 System.Threading 中

    I used System Threading Timer in Xamarin Android 我如何在中使用同一个类Xamarin Forms 我想从 Xamarin Forms 中的 Xamarin Android 转移我的项目 pu
  • 单击按钮更改颜色在重新加载或重新启动页面后保持不变

    我创建了锚标记 其中使用心形图标 单击后会更改颜色 但我想在重新加载或重新启动页面后保持相同的颜色 当我重新启动或重新加载页面时 它会恢复默认颜色 var btnvar document getElementById favorite fu
  • 如何从 C# 调用 MongoDb 中存储的 JavaScript

    我正在评估将 SQL Server 数据库移植到 MongoDb 问题是移动存储过程 我读到了有关 MongoDb 存储 JavaScript 的内容 我想在 Net 中进行一些测试 我已经安装了 MongoDb 驱动程序 2 4 0 并在
  • 搜索数组中的连续值

    在数组中搜索连续值的最佳方法是什么 例如 搜索array a b in array x a b c 会产生1 因为这些值首先连续出现在该索引处 还没有测试过这个 但类似这样的事情应该可以 function consecutive value
  • 使用 PHP 接收 JSON POST

    我尝试在支付接口网站上接收 JSON POST 但无法对其进行解码 当我打印时 echo POST I get Array 当我尝试这个时我什么也没得到 if POST foreach POST as key gt value echo l
  • 圆与圆的交点

    如何计算两个圆的交点 我希望在所有情况下都会有两个 一个或没有交点 我有中心点的 x 和 y 坐标以及每个圆的半径 python 中的答案是首选 但任何工作算法都是可以接受的 两个圆的交点 保罗 伯克 编剧 The following no
  • Linq to SQL 是如何工作的?

    我在项目中使用 Linq to SQL 我使用它从 SQL 存储过程中获取数据 它工作完美 但我不明白 LINQ SQL 内部如何与 SQL Server 通信 它在获取数据后将数据存储在哪里 它从哪里获取连接字符串 提前致谢 更好读 ht
  • 为什么使用不带 lambda 的内联

    我试图了解如何使用inline修改正确 我了解一般情况 当我们内联 lambda 以防止过度分配时 如中所述docs 我正在检查 kotlin stdlib 并发现 Strings kt下面这段代码 kotlin internal Inli
  • 在vBulletin中使用curl登录网站

    我一直在尝试登录某个网站 www siamchart 论坛 按照此链接上的说明进行操作 使用 PHP cURL 登录远程站点 我无法通过登录 运行以下脚本后 它将我重定向到相同的登录页面 www siamchart forum 但没有成功登
  • 子类化 sklearn LinearSVC 以用作 sklearn GridSearchCV 的估计器

    我正在尝试创建一个子类sklearn svm LinearSVC用作估计器sklearn model selection GridSearchCV 子类有一个额外的函数 在本例中不执行任何操作 然而 当我运行这个时 我最终遇到了一个我似乎无