求解线性回归的梯度下降法和正规方程法给出了不同的解

2023-12-28

我正在研究机器学习问题,并希望使用线性回归作为学习算法。我实现了两种不同的方法来查找参数theta线性回归模型:梯度(最速)下降和正态方程。对于相同的数据,他们应该给出大致相等的theta向量。然而他们没有。

Both theta除了第一个元素之外,所有元素上的向量都非常相似。这是用于将所有 1 与数据相乘的向量。

以下是如何thetas 看起来像(第一列是梯度下降的输出,第二列是正规方程的输出):

Grad desc Norm eq
-237.7752 -4.6736
-5.8471   -5.8467
9.9174    9.9178
2.1135    2.1134
-1.5001   -1.5003
-37.8558  -37.8505
-1.1024   -1.1116
-19.2969  -19.2956
66.6423   66.6447
297.3666  296.7604
-741.9281 -744.1541
296.4649  296.3494
146.0304  144.4158
-2.9978   -2.9976
-0.8190   -0.8189

什么可能导致差异theta(1, 1)通过梯度下降返回与theta(1, 1)由正规方程返回?我的代码有错误吗?

这是我在 Matlab 中实现的正规方程:

function theta = normalEque(X, y)
    [m, n] = size(X);
    X = [ones(m, 1), X];
    theta = pinv(X'*X)*X'*y;
end

这是梯度下降的代码:

function theta = gradientDesc(X, y)
    options = optimset('GradObj', 'on', 'MaxIter',  9999);
    [theta, ~, ~] = fminunc(@(t)(cost(t, X, y)),...
                    zeros(size(X, 2), 1), options);
end

function [J, grad] = cost(theta, X, y)
    m = size(X, 1);
    X = [ones(m, 1), X];
    J = sum((X * theta - y) .^ 2) ./ (2*m);
    for i = 1:size(theta, 1)
        grad(i, 1) = sum((X * theta - y) .* X(:, i)) ./ m;
    end
end

我传递完全相同的数据X and y对于这两个函数(我没有标准化X).

Edit 1:

根据答案和评论,我检查了一些代码并运行了一些测试。

首先,我想检查问题是否可能是由 X 接近单数引起的,如建议的那样@用户1489497的回答 https://stackoverflow.com/a/11271797/311865。所以我用 inv 替换了 pinv - 当运行它时我真的收到了警告Matrix is close to singular or badly scaled.。为了确保这不是问题,我获得了更大的数据集并使用这个新数据集运行测试。这次inv(X)没有显示警告并使用pinv and inv给出了相同的结果。所以我希望X不再接近奇异.

Then 我变了normalEque按照建议的代码 by 木屑 https://stackoverflow.com/users/85109/woodchips所以现在看起来像:

function theta = normalEque(X, y)
    X = [ones(size(X, 1), 1), X];
    theta = pinv(X)*y;
end

然而问题仍然存在. New normalEque不接近奇异的新数据上的函数给出了不同的theta as gradientDesc.

为了找出哪个算法有问题,我对相同的数据运行了数据挖掘软件 Weka 的线性回归算法。 Weka 计算的 theta 与输出非常相似normalEque但与输出不同gradientDesc。所以我猜normalEque是正确的并且有一个错误gradientDesc.

这是比较thetas 由 Weka 计算,normalEque and GradientDesc:

Weka(correct) normalEque    gradientDesc
779.8229      779.8163      302.7994
  1.6571        1.6571        1.7064
  1.8430        1.8431        2.3809
 -1.5945       -1.5945       -1.5964
  3.8190        3.8195        5.7486
 -4.8265       -4.8284      -11.1071
 -6.9000       -6.9006      -11.8924
-15.6956      -15.6958      -13.5411
 43.5561       43.5571       31.5036
-44.5380      -44.5386      -26.5137
  0.9935        0.9926        1.2153
 -3.1556       -3.1576       -1.8517
 -0.1927       -0.1919       -0.6583
  2.9207        2.9227        1.5632
  1.1713        1.1710        1.1622
  0.1091        0.1093        0.0084
  1.5768        1.5762        1.6318
 -1.3968       -1.3958       -2.1131
  0.6966        0.6963        0.5630
  0.1990        0.1990       -0.2521
  0.4624        0.4624        0.2921
-12.6013      -12.6014      -12.2014
 -0.1328       -0.1328       -0.1359

我还按照建议计算了错误贾斯汀·皮尔的回答 https://stackoverflow.com/a/11271748/311865。输出normalEque给出稍小的平方误差,但差异很小。更当我计算成本梯度时theta使用功能cost(与使用的相同gradientDesc)我的梯度接近于零。对输出执行同样的操作gradientDesc不给出接近于零的梯度。我的意思是:

>> [J_gd, grad_gd] = cost(theta_gd, X, y, size(X, 1));
>> [J_ne, grad_ne] = cost(theta_ne, X, y, size(X, 1));
>> disp([J_gd, J_ne])
  120.9932  119.1469
>> disp([grad_gd, grad_ne])
  -0.005172856743846  -0.000000000908598
  -0.026126463200876  -0.000000135414602
  -0.008365136595272  -0.000000140327001
  -0.094516503056041  -0.000000169627717
  -0.028805977931093  -0.000000045136985
  -0.004761477661464  -0.000000005065103
  -0.007389474786628  -0.000000005010731
   0.065544198835505  -0.000000046847073
   0.044205371015018  -0.000000046169012
   0.089237705611538  -0.000000046081288
  -0.042549228192766  -0.000000051458654
   0.016339232547159  -0.000000037654965
  -0.043200042729041  -0.000000051748545
   0.013669010209370  -0.000000037399261
  -0.036586854750176  -0.000000027931617
  -0.004761447097231  -0.000000027168798
   0.017311225027280  -0.000000039099380
   0.005650124339593  -0.000000037005759
   0.016225097484138  -0.000000039060168
  -0.009176443862037  -0.000000012831350
   0.055653840638386  -0.000000020855391
  -0.002834810081935  -0.000000006540702
   0.002794661393905  -0.000000032878097

这表明梯度下降根本没有收敛到全局最小值......但事实并非如此,因为我运行了数千次迭代。那么bug在哪里呢?


我终于有时间回到这个话题了。不存在“错误”。

如果矩阵是奇异的,则有无穷多个解。您可以从该组中选择任何解决方案,并获得同样好的答案。 pinv(X)*y 解是一个很好的解,很多人都喜欢,因为它是最小范数解。

永远没有充分的理由使用 inv(X)*y。更糟糕的是,对正规方程使用逆函数,因此 inv(X'*X)*X'*y 简直就是数值垃圾。我不在乎谁告诉你使用它,他们正在引导你到错误的地方。 (是的,它对于条件良好的问题来说是可以接受的,但大多数时候你不知道它什么时候会给你带来麻烦。那么为什么要使用它呢?)

即使您正在解决正则化问题,正规方程通常也是一件坏事。有一些方法可以避免对系统的条件数进行平方,尽管除非被问到,否则我不会解释它们,因为这个答案已经足够长了。

X\y 也会产生合理的结果。

绝对没有充分的理由对问题使用无约束的优化器,因为这会产生不稳定的结果,完全取决于您的起始值。

作为一个例子,我将从一个单一的问题开始。

X = repmat([1 2],5,1);
y = rand(5,1);

>> X\y
Warning: Rank deficient, rank = 1, tol =  2.220446e-15. 
ans =
                         0
         0.258777984694222

>> pinv(X)*y
ans =
         0.103511193877689
         0.207022387755377

pinv 和反斜杠返回略有不同的解决方案。事实证明,有一个基本解决方案,我们可以为 X 的行空间添加任意数量的零空间向量。

null(X)
ans =
         0.894427190999916
        -0.447213595499958

pinv 生成最小范数解。在可能产生的所有解决方案中,这个解决方案具有最小 2 范数。

相反,反斜杠生成的解决方案将一个或多个变量设置为零。

但如果您使用无约束优化器,它将生成完全依赖于您的起始值的解决方案。同样,可以将任意数量的零向量添加到您的解决方案中,并且您仍然拥有一个完全有效的解决方案,具有相同的误差平方和值。

请注意,即使没有返回奇点警告,但这并不意味着您的矩阵不接近奇异点。您对问题的更改很少,因此它仍然很接近,只是不足以触发警告。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

求解线性回归的梯度下降法和正规方程法给出了不同的解 的相关文章

  • Matlab - 如果值包含xxx,则删除元胞数组中的行

    在 Matlab 中 如何删除包含变量字符串的元胞数组中的元胞 假设我的元胞数组是 C svnTrunk RadarLib radarlb utilities scatteredInterpolant m C svnTrunk RadarL
  • 更改随机森林分类器的阈值

    我需要开发一个没有 或接近没有 假阴性值的模型 为此 我绘制了召回率 精度曲线 并确定阈值应设置为 0 11 我的问题是 如何定义模型训练时的阈值 稍后在评估时定义它是没有意义的 因为它不会反映新数据 X train X test y tr
  • 在 matlab 中求 3d 峰的体积

    现在我有一个带有峰值的 3D 散点图 我需要找到其体积 我的数据来自图像 因此 x 和 y 值表示 xy 平面上的像素位置 z 值是每个像素的像素值 这是我的散点图 scatter3 x y z 20 z filled 我试图找到数据峰值的
  • Matlab:保存后翻转图例顺序和图例重叠图

    我正在尝试根据以下内容反转我的图例条目顺序matlab条形图中图例颜色的逆序 https stackoverflow com questions 31178005 reverse ordering of legend colors in m
  • 如何使用WordNet或与wordnet相关的类别来实现基于类别的文本标记?

    如何使用wordnet按单词类别标记文本 java作为接口 Example 考虑以下句子 1 计算机需要键盘 显示器 CPU才能工作 2 汽车使用齿轮和离合器 现在我的目标是 例句必须标记为 第 1 句话 电脑 电子键盘 电子中央处理器 电
  • 如何在 Matlab 中对数组应用低通或高通滤波器?

    有没有一种简单的方法可以将低通或高通滤波器应用于 MATLAB 中的数组 我对 MATLAB 的强大功能 或数学的复杂性 有点不知所措 需要一个简单的函数或一些指导 因为我无法从文档或网络搜索中找到答案 看着那 这filter http w
  • 使用 MATLAB 进行线路跟踪

    我有一个图像 我想将其转换为逻辑图像 包括线条为黑色 背景为白色 当然 可以使用阈值方法来实现这一点 但我不想使用这种方式来做到这一点 我想通过使用线路跟踪方法或类似的方法来检测它 这是关于视网膜血管检测的 我找到了一个article ht
  • MATLAB 特征函数

    我很好奇哪里可以找到完整的描述FEATURE功能 它接受哪些论点 没有找到文档 我只听说过memstats and getpid 还要别的吗 gt gt which feature built in undocumented 注意 更完整的
  • Tensorflow推荐的系统规格?

    我开始在我的 RHEL 6 5 机器上安装 Tensorflow 但事实证明 Tensorflow 需要 glibc gt 2 17 而 rhel 6 5 上默认的 glibc 是 2 12 我想知道是否有人可以帮助我了解张量流的最低 推荐
  • MATLAB:具有复数的 printmat

    我想使用 MATLAB 的printmat显示带有标签的矩阵 但这不适用于复数 N 5 x rand N 1 y rand N 1 z x 1i y printmat x y z fftdemo N 1 2 3 4 5 x y x iy O
  • 使用符号求解器仅求解某些变量

    我正在尝试在 MATLAB 中求解包含 3 个变量和 5 个常量的方程组 是否可以使用solve求解三个变量 同时保持常量为符号而不用数值替换它们 当您使用SOLVE http www mathworks com access helpde
  • 如何告诉 mex 链接到 /usr/lib 中的 libstdc++.so.6 而不是 MATLAB 目录中的 libstdc++.so.6?

    现在 MATLAB 2012a 中的 mex 仅正式支持 gcc 4 4 6 但我想使用 gcc 4 7 风险自负 现在如果我直接用 mex 编译一些东西 它会抱怨 usr lib gcc i686 linux gnu 4 7 cc1plu
  • 池化与随时间池化

    我从概念上理解最大 总和池中发生的情况作为 CNN 层操作 但我看到这个术语 随时间变化的最大池 或 随时间变化的总和池 例如 用于句子分类的卷积神经网络 https arxiv org pdf 1408 5882 pdfYoon Kim
  • 括号中的波形符字符

    在 MATLAB 中 以下代码执行什么操作 m func returning matrix 波浪号运算符 的作用是什么 在 Matlab 中 这意味着不要将函数中相应的输出参数分配到赋值的右侧 因此 如果func returning mat
  • 在 Matlab 中高效获取像素坐标

    我想在 Matlab 中创建一个函数 给定一个图像 该函数将允许人们通过单击图像中的像素来选择该像素并返回该像素的坐标 理想情况下 人们能够连续单击图像中的多个像素 并且该函数会将所有相应的坐标存储在一个矩阵中 有没有办法在Matlab中做
  • PyTorch 中的标签平滑

    我正在建造一个ResNet 18分类模型为斯坦福汽车使用迁移学习的数据集 我想实施标签平滑 https arxiv org pdf 1701 06548 pdf惩罚过度自信的预测并提高泛化能力 TensorFlow有一个简单的关键字参数Cr
  • 如何在sklearn决策树中显示特征名称?

    我目前有一个决策树 将功能名称显示为X index i e X 0 X 1 X 2 etc from sklearn import tree from sklearn tree import DecisionTreeClassifier d
  • 提高SVM分类器准确率的技术

    我正在尝试使用 UCI 数据集构建一个分类器来预测乳腺癌 我正在使用支持向量机 尽管我尽最大努力提高分类器的准确性 但仍无法超过 97 062 我尝试过以下方法 1 Finding the most optimal C and gamma
  • MATLAB - 冲浪图数据结构

    我用两种不同的方法进行了计算 对于这些计算 我改变了 2 个参数 x 和 y 最后 我计算了每种变体的两种方法之间的 误差 现在我想根据结果创建 3D 曲面图 x gt on x axis y gt on y axis Error gt o
  • 访问图像的 Windows“标签”元数据字段

    我正在尝试进行一些图像处理 所以现在我正在尝试读取图像 exif 数据 有 2 个内置函数可用于读取图像的 exif 数据 问题是我想读取图像标签 exifread and imfinfo这两个函数都不显示图像标签 Is there any

随机推荐