如何正确使用 scikit-learn 的高斯过程进行 2D 输入、1D 输出回归?

2023-12-24

在发帖之前我做了很多搜索并发现这个问题 https://stackoverflow.com/questions/21320964/how-to-make-a-2d-gaussian-process-using-gpml-matlab-for-regression这可能正是我的问题。但是,我尝试了答案中提出的内容,但不幸的是这并没有解决它,而且我无法添加评论来请求进一步的解释,因为我是这里的新成员。

无论如何,我想在 Python 中将高斯过程与 scikit-learn 结合使用,从一个简单但真实的案例开始(使用 scikit-learn 文档中提供的示例)。我有一个 2D 输入集(8 对 2 个参数),称为X。我有 8 个相应的输出,聚集在一维数组中y.

#  Inputs: 8 points 
X = np.array([[p1, q1],[p2, q2],[p3, q3],[p4, q4],[p5, q5],[p6, q6],[p7, q7],[p8, q8]])

# Observations: 8 couples
y = np.array([r1,r2,r3,r4,r5,r6,r7,r8])

我定义了一个输入测试空间x:

# Input space
x1 = np.linspace(x1min, x1max) #p
x2 = np.linspace(x2min, x2max) #q
x = (np.array([x1, x2])).T

然后我实例化 GP 模型,将其拟合到我的训练数据中(X,y),并进行一维预测y_pred在我的输入空间上x:

from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, ConstantKernel as C

kernel = C(1.0, (1e-3, 1e3)) * RBF([5,5], (1e-2, 1e2))
gp = GaussianProcessRegressor(kernel=kernel, n_restarts_optimizer=15)
gp.fit(X, y)
y_pred, MSE = gp.predict(x, return_std=True)

然后我制作了一个 3D 绘图:

fig = pl.figure()
ax = fig.add_subplot(111, projection='3d')
Xp, Yp = np.meshgrid(x1, x2)
Zp = np.reshape(y_pred,50)

surf = ax.plot_surface(Xp, Yp, Zp, rstride=1, cstride=1, cmap=cm.jet,
linewidth=0, antialiased=False)
pl.show()

这是我得到的:

当我修改内核参数时,我得到这样的信息,类似于我上面提到的海报得到的信息:

这些图甚至与原始训练点的观察结果不匹配([65.1,37] 获得较低响应,[92.3,54] 获得最高响应)。

我对 2D GP 相当陌生(不久前也开始使用 Python),所以我想我在这里遗漏了一些东西......任何答案都会有帮助并且非常感谢,谢谢!


您正在使用两个特征来预测第三个特征。而不是像这样的 3D 绘图plot_surface,如果您使用能够显示有关第三维信息的二维图,通常会更清楚,例如hist2d or pcolormesh。这是一个使用与问题中类似的数据/代码的完整示例:

from itertools import product
import numpy as np
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, ConstantKernel as C

X = np.array([[0,0],[2,0],[4,0],[6,0],[8,0],[10,0],[12,0],[14,0],[16,0],[0,2],
                    [2,2],[4,2],[6,2],[8,2],[10,2],[12,2],[14,2],[16,2]])

y = np.array([-54,-60,-62,-64,-66,-68,-70,-72,-74,-60,-62,-64,-66,
                    -68,-70,-72,-74,-76])

# Input space
x1 = np.linspace(X[:,0].min(), X[:,0].max()) #p
x2 = np.linspace(X[:,1].min(), X[:,1].max()) #q
x = (np.array([x1, x2])).T

kernel = C(1.0, (1e-3, 1e3)) * RBF([5,5], (1e-2, 1e2))
gp = GaussianProcessRegressor(kernel=kernel, n_restarts_optimizer=15)

gp.fit(X, y)

x1x2 = np.array(list(product(x1, x2)))
y_pred, MSE = gp.predict(x1x2, return_std=True)

X0p, X1p = x1x2[:,0].reshape(50,50), x1x2[:,1].reshape(50,50)
Zp = np.reshape(y_pred,(50,50))

# alternative way to generate equivalent X0p, X1p, Zp
# X0p, X1p = np.meshgrid(x1, x2)
# Zp = [gp.predict([(X0p[i, j], X1p[i, j]) for i in range(X0p.shape[0])]) for j in range(X0p.shape[1])]
# Zp = np.array(Zp).T

fig = plt.figure(figsize=(10,8))
ax = fig.add_subplot(111)
ax.pcolormesh(X0p, X1p, Zp)

plt.show()

Output:

看起来有点普通,但我的示例数据也是如此。一般来说,您不应期望通过这几个数据点得到特别有趣的结果。

另外,如果你确实想要曲面图,你可以直接替换pcolormesh与您最初拥有的一致(或多或少):

ax = fig.add_subplot(111, projection='3d')            
surf = ax.plot_surface(X0p, X1p, Zp, rstride=1, cstride=1, cmap='jet', linewidth=0, antialiased=False)

Output:

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何正确使用 scikit-learn 的高斯过程进行 2D 输入、1D 输出回归? 的相关文章

随机推荐

  • TypeError:使用 pytest 固定装置时缺少 1 个必需的位置参数 [重复]

    这个问题在这里已经有答案了 我已在文件中编写了测试类 并且正在尝试使用pytest 装置 https docs pytest org en 6 2 x fixture html这样我就不必在每个测试函数中创建相同的输入数据 下面是最小的工作
  • 如何在代码中找到点和抛物线之间的距离

    我试图为 DirectX 像素着色器找到抛物线上距离 2d 中任意点最近的点 大量的谷歌搜索告诉我 这是一个常见的微积分预科作业问题 不幸的是 数百个相关答案都说 一旦你有了这个方程 使用图形计算器的最小函数 它会告诉你答案是 6 我承认我
  • 在运行时以编程方式注册 HttpModule

    我正在编写一个应用程序 第三方供应商可以在其中编写插件 DLL 并将它们放入 Web 应用程序的 bin 目录中 我希望这些插件能够在必要时注册自己的 HttpModule 无论如何 我是否可以在运行时在管道中添加或删除 HttpModul
  • 填充形状边缘的间隙

    是否有一种算法在填充样本图像上的孔洞方面表现良好 膨胀效果不佳 因为在我最终设法连接这些曲线的末端之前 曲线变得非常厚 我想避免加粗线条 感谢您的任何帮助 是的 图像中可以有任何字母或形状带有类似的孔 另一种更简单的方法可能会更好地转化为O
  • 如何在php中比较不区分大小写的两个字符串

    我需要比较两个不区分大小写的字符串 这是我的代码 if strcasecmp genderseek both 0 gender2 Ugender MALE Ugender FEMALE else gender2 genderseek Uge
  • 我无法在不避免出现令人不安的噪音的情况下更改音乐文件的速度

    我正在尝试更改音频文件的速度 如果我使用无符号值进行操作 一切都可以 但是一旦我开始使用双值 事情就会变得混乱 例如 我的代码适用于所有 x 5 数字但它不与任何其他带小数的数字 就我而言 我想将速度提高 1 3 点 但我得到的只是一个文件
  • 尝试生成预签名 url 链接以便用户可以下载 Amazon S3 对象,但收到无效请求

    我目前正在使用 Ruby aws sdk 版本 2 gem 以及服务器端客户提供的加密密钥 SSE C 我可以毫无问题地将对象从 Rails 表单上传到 Amazon S3 def s3 Aws S3 Object new bucket n
  • Gradle/Eclipse:使用相等时德语“Umlaute”的不同行为?

    在使用 Java 的相等性检查 直接或间接 时 我遇到了德语 Umlaute 的奇怪行为 从 Eclipse 运行 调试或测试时 一切都按预期工作 并且包含 Umlaute 的输入被视为相等或不按预期处理 然而 当我使用 Spring Bo
  • 如何使用R的pixmap包提取像素数据?

    如何使用R的pixmap包提取像素数据 所以我使用以下方法读取图像文件 图片 如何将像素数据提取到某个矩阵中 您可以通过以下方式获取灰度图像的 2 D 矩阵数据或彩色图像的 3 D 数组数据 getChannels gt x lt read
  • 使用 querySelector 获取包含某个类的所有元素

    为了改变我正在使用的类中的一些样式querySelector el querySelector fa fa car style display none 这对于一个元素来说效果很好 但如果有更多元素包含此类 并且我想将所有元素的显示更改为无
  • WooCommerce 无法从产品类别访问购物车

    我有一个自定义的 WooCommerce 产品类型 我需要从其中访问购物车 URL 看起来很简单 class WC Product My Product extends WC Product Simple public function s
  • 使用 PIL 更改 OpenCV Python 中的字体系列

    上面的答案没有解决我的问题 我在用cv2 putText 将文本放在视频上 这按预期工作 但我正在尝试使用不同的字体 在 OpenCV 中不可用 据我了解 OpenCV 仅限于cv2 FONT HERSHEY字体 所以我使用 PIL 和 O
  • 在渲染中传递参数 - Rails 3

    我看到了几个关于此的问题 但无法解决 我试图在渲染部分时传递参数 类似于domainname com memory books new fbookupload yes 现在 我使用这一行 在部分中 我尝试使用以下方式获取 fbookuplo
  • 如何使用 axios 重定向后获取登陆页面 URL

    使用 NodeJS 如果我使用 maxRedirects 5 的 axios 如果我输入将重定向到另一个 URL 的 URL 如何从最终登陆页面获取 URL 在 HTTP 标头中 当存在 HTTP 200 代码时 就没有着陆页的标头字段 示
  • Django:在保存之前修改模型的字段

    我有一个模型 course 与ImageField and FileField所以我想在用户每次创建课程时创建一个文件夹 我想我可以在保存模型之前执行此操作 所以这是我的问题 如何在方法中访问模型的字段 模型 py Class Course
  • 在 Objective C 中复制整数数组最有效的方法是什么?

    在 Objective C 中将 1000 个整数的数组从一个数组复制到另一个数组的最有效方法是什么 这将在 iPhone 上运行一些绘图代码 因此尽可能高效很重要 Thanks 如果关心的是效率 我假设这是一个 C 整数数组 如果是这样
  • 如何使用ConfigurationManager解析app.config?

    我正在使用某种方法来解析我的 app config 文件 然后我被告知使用 ConfigurationManager 更好 更简单 但问题是我不知道如何使用 ConfigurationManager 来做到这一点 我原来的代码是这样的 Xm
  • Angular 2 - 所有组件的全局变量

    我的 angular2 应用程序在许多不同的组件中使用我的后端 Laravel API 我一直在想 将来我需要更改 API URL 这意味着我必须在我对 API 使用 http get post 方法的所有地方 在所有组件中 更改我的 AP
  • SSL - 如何以及何时使用它

    我有一个客户需要 SSL 来保护在线捐赠 但我对于如何 何时使用 SSL 的经验有限 据我所知 在购买证书时 我将该证书分配给整个域 实际上是 IP 地址 有没有办法将加密隔离到网站的单个页面 或者我应该继续保护整个网站 即使只有一个页面需
  • 如何正确使用 scikit-learn 的高斯过程进行 2D 输入、1D 输出回归?

    在发帖之前我做了很多搜索并发现这个问题 https stackoverflow com questions 21320964 how to make a 2d gaussian process using gpml matlab for r