PyTorch - 参数不变

2024-03-21

为了了解 pytorch 的工作原理,我尝试对多元正态分布中的一些参数进行最大似然估计。然而,它似乎不适用于任何协方差相关的参数。

所以我的问题是:为什么这段代码不起作用?

import torch


def make_covariance_matrix(sigma, rho):
    return torch.tensor([[sigma[0]**2, rho * torch.prod(sigma)],
                         [rho * torch.prod(sigma), sigma[1]**2]])


mu_true = torch.randn(2)
rho_true = torch.rand(1)
sigma_true = torch.exp(torch.rand(2))

cov_true = make_covariance_matrix(sigma_true, rho_true)
dist_true = torch.distributions.MultivariateNormal(mu_true, cov_true)

samples = dist_true.sample((1_000,))

mu = torch.zeros(2, requires_grad=True)
log_sigma = torch.zeros(2, requires_grad=True)
atanh_rho = torch.zeros(1, requires_grad=True)

lbfgs = torch.optim.LBFGS([mu, log_sigma, atanh_rho])


def closure():
    lbfgs.zero_grad()
    sigma = torch.exp(log_sigma)
    rho = torch.tanh(atanh_rho)
    cov = make_covariance_matrix(sigma, rho)
    dist = torch.distributions.MultivariateNormal(mu, cov)
    loss = -torch.mean(dist.log_prob(samples))
    loss.backward()
    return loss


lbfgs.step(closure)

print("mu: {}, mu_hat: {}".format(mu_true, mu))
print("sigma: {}, sigma_hat: {}".format(sigma_true, torch.exp(log_sigma)))
print("rho: {}, rho_hat: {}".format(rho_true, torch.tanh(atanh_rho)))

output:

mu: tensor([0.4168, 0.1580]), mu_hat: tensor([0.4127, 0.1454], requires_grad=True)
sigma: tensor([1.1917, 1.7290]), sigma_hat: tensor([1., 1.], grad_fn=<ExpBackward>)
rho: tensor([0.3589]), rho_hat: tensor([0.], grad_fn=<TanhBackward>)

>>> torch.__version__
'1.0.0.dev20181127'

换句话说,为什么有这样的估计log_sigma and atanh_rho没有改变它们的初始值?


创建协方差矩阵的方式不是后验概率:

def make_covariance_matrix(sigma, rho):
    return torch.tensor([[sigma[0]**2, rho * torch.prod(sigma)],
                         [rho * torch.prod(sigma), sigma[1]**2]])

从(多个)张量创建新张量时,仅保留输入张量的值。来自输入张量的所有附加信息都被剥离,因此所有图连接您的参数从此时开始被切断,因此反向传播无法通过。

这是一个简短的例子来说明这一点:

import torch

param1 = torch.rand(1, requires_grad=True)
param2 = torch.rand(1, requires_grad=True)
tensor_from_params = torch.tensor([param1, param2])

print('Original parameter 1:')
print(param1, param1.requires_grad)
print('Original parameter 2:')
print(param2, param2.requires_grad)
print('New tensor form params:')
print(tensor_from_params, tensor_from_params.requires_grad)

Output:

Original parameter 1:
tensor([ 0.8913]) True
Original parameter 2:
tensor([ 0.4785]) True
New tensor form params:
tensor([ 0.8913,  0.4785]) False

正如您所看到的,根据参数创建的张量param1 and param2,不跟踪梯度param1 and param2.

因此,您可以使用此代码来保留图形连接 and is 后验概率:

def make_covariance_matrix(sigma, rho):
    conv = torch.cat([(sigma[0]**2).view(-1), rho * torch.prod(sigma), rho * torch.prod(sigma), (sigma[1]**2).view(-1)])
    return conv.view(2, 2)

使用以下方法将这些值连接到一个平面张量torch.cat。然后使用将它们调整为正确的形状view().
这会产生与函数中相同的矩阵输出,但它保持与参数的连接log_sigma and atanh_rho.

这是更改后的步骤之前和之后的输出make_covariance_matrix。如您所见,现在您可以优化参数,并且值确实会发生变化:

Before:
mu: tensor([ 0.1191,  0.7215]), mu_hat: tensor([ 0.,  0.])
sigma: tensor([ 1.4222,  1.0949]), sigma_hat: tensor([ 1.,  1.])
rho: tensor([ 0.2558]), rho_hat: tensor([ 0.])

After:
mu: tensor([ 0.1191,  0.7215]), mu_hat: tensor([ 0.0712,  0.7781])
sigma: tensor([ 1.4222,  1.0949]), sigma_hat: tensor([ 1.4410,  1.0807])
rho: tensor([ 0.2558]), rho_hat: tensor([ 0.2235])

希望这可以帮助!

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

PyTorch - 参数不变 的相关文章

随机推荐

  • Entity Framework 4.1 - Code First:多对多关系

    我想建立这样的关系 一个区域位于 x 个其他区域的附近 public class Zone public string Id get set public string Name get set public virtual ICollec
  • 在 Java 中使用 ENUMS 验证值组合的最佳方法是什么?

    我通过如下定义 ENUM 来验证从数据库检索的记录的状态 public enum RecordStatusEnum CREATED CREATED INSERTED INSERTED FAILED FAILED private String
  • 在Linux中使用自定义规则在多个端口上运行的SSH服务[关闭]

    Closed 这个问题不符合堆栈溢出指南 help closed questions 目前不接受答案 我正在努力设置一台在多个端口上运行 SSH 服务的服务器 例如端口 22 和 5522 这些端口应该具有一组不同的规则 即 我们为端口 2
  • 在 C# 中如何将字符串转换为 ascii 二进制?

    不久前 高中一年级 我请一位非常优秀的大三 C 程序员制作一个简单的应用程序 将字符串转换为二进制 他给了我以下代码示例 void ToBinary char str char tempstr int k 0 tempstr new cha
  • 列表未添加 C# 中的所有值

    我尝试了下面的代码来创建 json 代码 代码工作正常 我从数据库加载值 但只有最后一个值我得到了输出 剩余值未添加 DataTable dt new DataTable var objectToSerialize new RootObje
  • 解除PDF密码保护,知道密码[关闭]

    Closed 这个问题不符合堆栈溢出指南 help closed questions 目前不接受答案 我有一堆 pdf 文件 我想从中删除密码 请注意 我知道密码 因此无需暴力破解 我正在 Mac 上工作 所以我想制作一个应用程序来删除这些
  • Git合并分支到master

    我有一个主分支和一个工作分支branch 1 我想 动 一下branch 1正是如此master 所以我想要这样的东西 git checkout master git merge branch 1 I don t know what is
  • symfony2 - twig - 如何从树枝模板内部渲染树枝模板

    我有一个 xxx html twig 文件 它显示一个页面 但是当我想用不同的数据刷新页面并用新数据更新它时 我有一个选择和一个提交按钮 问题是我不知道如何在控制器中调用一个动作 我从我的树枝传递参数并调用新数据 然后我用新参数再次渲染相同
  • Python:单击按钮[重复]

    这个问题在这里已经有答案了 我在单击此按钮时遇到问题 该按钮的 HTML 代码如下所示
  • Eventbug 的实际工作原理

    Eventbug http getfirebug com wiki index php Firebug Extensions Eventbug是 Firebug 的一个附加组件 是的 附加组件的附加组件 其目的是跟踪分配给 DOM 元素的所
  • ld:架构armv7的871个重复符号,clang:错误:链接器命令失败,退出代码1(使用-v查看调用)

    我在 iPhone 应用程序中使用 FastPDFKit 来显示 PDF 当我在模拟器上运行该项目时 它工作正常 但是 当我在 iPhone 上运行该项目时 出现以下错误 duplicate symbol value map in User
  • 如何多次查询并最后关闭连接?

    我想打开与 mysql 数据库的连接并使用不同的查询检索数据 我是否需要在每次获取数据时关闭连接 或者是否有更好的方法可以多次查询并仅在最后关闭连接 目前我这样做 db dbConnect MySQL user root password
  • 我们可以导出 Kibana 中的所有搜索结果数据吗?

    我正在尝试导出 Kibana 5 中的所有搜索结果数据 但它仅导出结果的计数 有没有办法将所有数据导出为 CSV 格式 在基巴纳 到目前为止尝试过 单击搜索结果底部的符号 可视化 尝试使用 原始 和 格式化 选项 数据以 CSV 格式导出
  • symfony:如何设置不同环境的配置参数文件?

    如何为每个环境设置不同的配置参数文件 目前参数在parameters yml两者都使用dev and prod环境 但我需要不同的参数才能在产品中部署我的应用程序 您可以将所有使用的参数放入dev环境在一个app config parame
  • Postgresql计数+排序性能

    我使用 postgresql 和 psycopg2 构建了一个小型库存系统 一切都很好 除了当我想创建内容的聚合摘要 报告时 由于 count 和排序 我的性能非常糟糕 数据库架构如下 CREATE TABLE hosts id SERIA
  • 如何更新 Kubernetes 中的 api 版本列表

    我尝试在我的配置中使用 autoscaling v2beta2 apiVersion 如下本教程 https kubernetes io docs tasks run application horizontal pod autoscale
  • Perl 中的简单并行处理

    我在某个对象的函数内有一些代码块 它们可以并行运行并加快速度 我尝试使用subs parallel通过以下方式 所有这些都在函数体内 my is a done parallelize block a do some work return
  • 意外的 T_ENCAPSED_AND_WHITESPACE,期待 T_STRING 或 T_VARIABLE 或 T_NUM_STRING 错误 [重复]

    这个问题在这里已经有答案了 我对这个错误一直茫然 似乎不知道问题是什么 当我运行查询时 我收到此错误 意外的 T ENCAPSED AND WHITESPACE 需要 T STRING 或 T VARIABLE 或 T NUM STRING
  • 带 Bootstrap 的 Google 地图没有响应

    我正在使用 bootstrap 并嵌入了 Google Maps API 3 map canvas没有反应 它是固定宽度 另外 如果我使用height auto and width auto地图未显示在页面中 Why div class c
  • PyTorch - 参数不变

    为了了解 pytorch 的工作原理 我尝试对多元正态分布中的一些参数进行最大似然估计 然而 它似乎不适用于任何协方差相关的参数 所以我的问题是 为什么这段代码不起作用 import torch def make covariance ma