dmvnorm MVN 密度 - RcppArmadillo 实现比 R 包慢,包括一些 Fortran

2024-05-13

The solution现已上线RCPP画廊 http://gallery.rcpp.org/articles/dmvnorm_arma/


我从 RcppArmadillo 中的 mvtnorm 包重新实现了 dmvnorm。我有点喜欢犰狳,但我想它也可以在普通的 Rcpp 中工作。 dmvnorm 的方法基于马哈拉诺比斯距离,因此我有一个函数,然后是多元正态密度函数。

让我向你展示我的代码:

#include <RcppArmadillo.h>
#include <Rcpp.h>

// [[Rcpp::depends("RcppArmadillo")]]

// [[Rcpp::export]]
arma::vec mahalanobis_arma( arma::mat x ,  arma::mat mu, arma::mat sigma ){

  int n = x.n_rows;
  arma::vec md(n);
    for (int i=0; i<n; i++){
        arma::mat x_i = x.row(i) - mu;
        arma::mat Y = arma::solve( sigma, arma::trans(x_i) );
        md(i) = arma::as_scalar(x_i * Y);
    }
    return md;

    }



// [[Rcpp::export]]
arma::vec dmvnorm ( arma::mat x,  arma::mat mean,  arma::mat sigma, bool log){ 

arma::vec distval = mahalanobis_arma(x,  mean, sigma);

    double logdet = sum(arma::log(arma::eig_sym(sigma)));
    double log2pi = 1.8378770664093454835606594728112352797227949472755668;
    arma::vec logretval = -( (x.n_cols * log2pi + logdet + distval)/2  ) ;

       if(log){ 
         return(logretval);

       }else { 
       return(exp(logretval));
         }
}

所以,并没有让我大失望:

模拟一些数据

sigma <- matrix(c(4,2,2,3), ncol=2)
x <- rmvnorm(n=5000000, mean=c(1,2), sigma=sigma, method="chol")

和基准

system.time(mvtnorm::dmvnorm(x,t(1:2),.2+diag(2),F))
   user  system elapsed 
   0.05    0.02    0.06 

system.time(dmvnorm(x,t(1:2),.2+diag(2),F))
   user  system elapsed 
   0.12    0.02    0.14 

不!!!!!! :-(

[EDIT]

The 问题是: 1) 为什么 RcppArmadillo 实现比普通 R 实现慢? 2) 如何创建击败 R 实现的 Rcpp/RcppArmadillo 实现?

[EDIT 2]

我将 mahalanobis_arma 放入 mvtnorm::dmvnorm 函数中,它也会减慢速度。


如果您想要更快地实现马哈拉诺比斯距离,您只需重写算法并模仿 R 使用的算法即可。这非常简单

我稍微修改了你的功能mahalanobis_arma转动mu to a rowvec.

基本上我只是将 R 代码翻译为 RcppArmadillo

mahalanobis
function (x, center, cov, inverted = FALSE, ...) 
{
    x <- if (is.vector(x)) 
        matrix(x, ncol = length(x))
    else as.matrix(x)
    x <- sweep(x, 2, center)
    if (!inverted) 
        cov <- solve(cov, ...)
    setNames(rowSums((x %*% cov) * x), rownames(x))
}
<bytecode: 0x6e5b408>
<environment: namespace:stats>

这里是

#include <RcppArmadillo.h>
#include <Rcpp.h>

// [[Rcpp::depends("RcppArmadillo")]]
// [[Rcpp::export]]
arma::vec Mahalanobis(arma::mat x, arma::rowvec center, arma::mat cov){
    int n = x.n_rows;
    arma::mat x_cen;
    x_cen.copy_size(x);
    for (int i=0; i < n; i++) {
        x_cen.row(i) = x.row(i) - center;
    }
    return sum((x_cen * cov.i()) % x_cen, 1);    
}


// [[Rcpp::export]]
arma::vec mahalanobis_arma( arma::mat x ,  arma::rowvec mu, arma::mat sigma ){

  int n = x.n_rows;
  arma::vec md(n);
    for (int i=0; i<n; i++){
        arma::mat x_i = x.row(i) - mu;
        arma::mat Y = arma::solve( sigma, arma::trans(x_i) );
        md(i) = arma::as_scalar(x_i * Y);
    }
    return md;

    }

现在,让我们来比较一下这个新的犰狳版本(Mahalanobis),你的第一个版本(mahalanobis_arma)和 R 实现(mahalanobis).

我将此 Cpp 代码保存为mahalanobis.cpp

require(RcppArmadillo)
sourceCpp("mahalanobis.cpp")

set.seed(1)
x <- matrix(rnorm(10000 * 10), ncol = 10)
Sx <- cov(x)


all.equal(c(Mahalanobis(x, colMeans(x), Sx))
          ,mahalanobis(x, colMeans(x), Sx))
## [1] TRUE

all.equal(mahalanobis_arma(x, colMeans(x), Sx)
          ,Mahalanobis(x, colMeans(x), Sx))
## [1] TRUE


require(rbenchmark)
benchmark(Mahalanobis(x, colMeans(x), Sx),
          mahalanobis(x, colMeans(x), Sx),
          mahalanobis_arma(x, colMeans(x), Sx),
          order = "elapsed")


##                                   test replications elapsed
## 1      Mahalanobis(x, colMeans(x), Sx)          100   0.124
## 2      mahalanobis(x, colMeans(x), Sx)          100   0.741
## 3 mahalanobis_arma(x, colMeans(x), Sx)          100   4.509
##   relative user.self sys.self user.child sys.child
## 1    1.000     0.173    0.077          0         0
## 2    5.976     0.804    0.670          0         0
## 3   36.363     4.386    4.626          0         0

正如您所看到的,新的实现比 R 的实现更快。 我非常确定,通过使用乔列斯基分解来求解协方差矩阵或使用其他矩阵分解,我们可以做得更好。

最后,我们可以将其插入Mahalanobis功能进入你的dmvnorm并测试它:

require(mvtnorm)
set.seed(1)
sigma <- matrix(c(4, 2, 2, 3), ncol = 2)
x <- rmvnorm(n = 5000000, mean = c(1, 2), sigma = sigma, method = "chol")


all.equal(mvtnorm::dmvnorm(x, t(1:2), .2 + diag(2), FALSE),
          c(dmvnorm(x, t(1:2), .2+diag(2), FALSE)))
## [1] TRUE

benchmark(mvtnorm::dmvnorm(x, t(1:2), .2 + diag(2), FALSE),
          dmvnorm(x, t(1:2), .2+diag(2), FALSE),
          order = "elapsed")

##                                                test replications
## 2          dmvnorm(x, t(1:2), 0.2 + diag(2), FALSE)          100
## 1 mvtnorm::dmvnorm(x, t(1:2), 0.2 + diag(2), FALSE)          100
##   elapsed relative user.self sys.self user.child sys.child
## 2  35.366    1.000    31.117    4.193          0         0
## 1  60.770    1.718    56.666   13.236          0         0

现在几乎快了一倍。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

dmvnorm MVN 密度 - RcppArmadillo 实现比 R 包慢,包括一些 Fortran 的相关文章

  • 为什么在 data.frame 中预先指定类型会比较慢?

    我预先分配了一个大 data frame 以便稍后填写 我通常这样做NA是这样的 n lt 1e6 a lt data frame c1 1 n c2 NA c3 NA 我想知道如果我预先指定数据类型是否会让事情变得更快 所以我测试了 f1
  • 从每小时中提取/子集分钟值

    我的数据框包含以下格式的日期值YYYY MM DD HH MM SS跨越 125000 多行 按分钟细分 每行代表一分钟 1 2018 01 01 00 04 00 2 2018 01 01 00 05 00 3 2018 01 01 00
  • 无效的命令名称“tk_chooseDirectory”错误

    我使用 bioconductor 进行 WES 管道 并使用 tk choose dir 选择用户存储输入文件的目录 并将其存储以供进一步使用 这里是命令行 library tcltk dataDir lt dirname tk choos
  • 关于子组的新列和另一列中的百分比范围

    我有一个如下所示的示例 df df test lt data frame Group Name c Group1 Group2 Group1 Group2 Group2 Group2 Group1 Sub group name c A A
  • R中无法连接odbc数据库

    我一直在尝试使用以下命令将我公司的 DMS 连接到 RodbcConnect命令 但收到以下消息 myConn lt odbcConnect NZSQL uid cejacobson pwd password Warning message
  • mlogit:需要 TRUE/FALSE 时缺少值

    我有来自离散选择实验 DCE 的数据 该实验研究了来自不同行业的个人的招聘偏好 我已经格式化为长格式 我想使用 mlogit 进行建模 我已导出数据 并且可以使用 asclogit 命令在 Stata 中成功运行模型 但在 R 中运行时遇到
  • 在 R 中使用深度网络和 MNIST 数据读取手写数字第 3 部分

    我尝试编写一个基于深度网络的程序来读取手写数字 我在 Youtube 上找到了一个代码 https www youtube com watch v 5bso 5X7Zu4 https www youtube com watch v 5bso
  • 根据 R 中的字符串模式选择行

    假设我有以下数据 df lt data frame name c TO for Turnover for people HC people Hello world beenie man apple pears TO is number c
  • R中使用余弦距离的层次聚类

    我想通过使用余弦相似度与 R 编程语言对文档语料库进行层次聚类 但出现以下错误 if is na n n gt 65536L stop 大小不能为 NA 或 超过 65536 需要 TRUE FALSE 时缺少值 我应该怎么办 为了重现它
  • R正则表达式获取第二个下划线之前的所有文本

    s lt 1 343 43Hello 2 323 14 fdh 99H 在 R 中 我想使用正则表达式来获取第二个下划线之前的子字符串 如何使用一个正则表达式来完成此操作 另一种方法是用 分割 然后粘贴前两个 一些东西 paste sapp
  • R 中的点图每行有多个值

    我有以下 R 输入文件 car 1 car 2 car 3 car2 1 car2 2 car2 3 然后 我使用以下命令来绘制图表 autos data 点图 autos data V2 autos data V1 但这将每个汽车和 ca
  • 跨类别和列自动化卡方

    我有一个调查数据框 其中包含几个问题 列 编码为 1 同意 0 不同意 受访者 行 根据 年龄 年轻 中年 老年 地区 东 中 西 等指标进行分类 大约有30个类别总共 3个年龄 3个地区 2个性别 11个职业等 在每个指标中 类别不重叠且
  • 构造奎因(自我复制功能)

    有没有人构建过 quine 生成自己源文本的副本作为其完整输出的程序 http www nyx net gthompso quine htm http www nyx net gthompso quine htm 在 R 中 quine 标
  • 基本 dyplr 函数给出错误:“check_dots_used”

    试图找出为什么我会收到此错误 以前从未见过 谷歌没有帮助 check dots used action warn 中的错误 未使用参数 action warn 我在下面的非常基本的试验中收到错误 而且在 group by count 中也收
  • R中具有特定条件的多列变异

    我有这个数据 M1 M2 M3 UCL 1 2 3 1 5 我想在这种情况下创建新列 如果M1大于UCL MM1将为 UP 否则为 NULL 如果M2大于UCL MM2将为 UP 否则为 NULL 如果M3大于UCL MM3将为 UP 否则
  • 如何使用 Facet R 添加线条[重复]

    这个问题在这里已经有答案了 所以我有一个多面图 我希望能够向其中添加随每个面而变化的线 这是代码 p lt ggplot mtcars aes x wt geom histogram bins 20 aes fill factor cyl
  • ggplot2 + 使用比例 X 的日期结构

    我真的需要帮助 因为我已经迷路了 我正在尝试创建一个折线图 显示几个团队一年来的表现 我将一年分为几个季度 2012 年 1 月 1 日 2012 年 4 月 1 日 2012 年 8 月 1 日 12 1 12 并将 csv 数据帧加载到
  • ggplot 按因子和梯度颜色

    我正在尝试绘制一个对两个变量 一个因子和一个强度 进行着色的图 我希望每个因素都是不同的颜色 并且我希望强度是白色和该颜色之间的渐变 到目前为止 我已经使用了诸如对因子进行分面等技术 将颜色设置为两个变量之间的相互作用 并将颜色设置为因子并
  • SparkR 和 Sparklyr 之间导入 parquet 文件所需的时间差异

    我正在使用 databricks 导入镶木地板文件SparkR and sparklyr data1 SparkR read df dbfs data202007 source parquet header TRUE inferSchema
  • 从 leafletProxy() 返回渲染的传单地图

    是否可以在渲染后在 Shiny 中检索传单地图 下面是一个代码示例 展示了如何生成地图leaflet 与返回的不同leafletProxy 即使它们在渲染时看起来完全相同 是否有一个功能可能不同于leafletProxy 获取实际的 htm

随机推荐

  • 如何在嵌入式tomcat中配置valve?

    我需要在嵌入式tomcat中配置valvehttp tomcat apache org tomcat 8 0 doc config valve html Remote IP Valve http tomcat apache org tomc
  • MongoDB 在仅返回 _id 时使用 COLLSCAN

    我想返回 MongoDB 集合中的所有 ID 我使用了以下代码 db coll find id 1 但MongoDB扫描整个集合而不是从默认读取信息index id 1 从日志中 find collection filter project
  • 每第 n 个字符分割一个字符串

    在 JavaScript 中 这就是我们如何在每 3 个字符处分割一个字符串 foobarspam match 1 3 g 我正在尝试弄清楚如何在 Java 中做到这一点 有什么指点吗 你可以这样做 String s 1234567890
  • 在 Delphi 2007 中将具有透明度的位图保存为 PNG

    我有一个包含透明度信息的 Delphi 位图 32 位 我需要将其转换并保存为 PNG 文件 同时保留透明度 我目前拥有的工具是graphics32 Library GR32 PNG 由Christian Budde 提供 和PNGImag
  • 并行启动服务

    我有一个脚本可以检查不同服务器上的某些服务是否已启动 如果没有启动 该脚本应该启动该服务 问题是 它不会并行启动服务 而是等待每个服务启动 Code server list Get Content path D Path list of s
  • Google G-Suite API 控制台未显示启用 G Suite 域范围委派

    我正在与客户合作设置服务帐户凭据 以便通过 API 读取 G Suite 目录信息 我之前已经这样做了十几次 没有任何问题 现在我遇到了一个问题 设置没有向客户端显示 下面的图片显示了我通常会看到的内容 阅读中圈出的区域是启用域范围委派的能
  • 使用 VNext 构建后,TFS tbl_Content 开始快速增长

    直到一个月前我们一直在使用旧样式 XAML 构建 然后开始使用 vNext 构建 之后我注意到 TFS 数据库中的 tbl Content 表开始快速增长 例如 在过去 8 小时内 它增长了 10 GB 但我不明白为什么会这样做 有谁知道它
  • 我可以通过链接分享我的私人 GitHub 存储库吗?

    我在 GitHub 上的私人存储库中有一个 Java 应用程序 我想与没有帐户的人共享它 我在网站上没有找到任何与此相关的选项 有没有办法做到这一点 协作者只能是 GitHub 用户 无法在非 Github 用户之间共享私有存储库 您需要
  • 使用 XPath 3.1 fn:serialize 进行 JSON 序列化

    我在 Saxon HE 9 8 中使用 XSLT 3 0 并且希望将 JSON 文档用作链接数据JSON LD https json ld org 在 JSON LD 中 完整的 HTTP URI 通常显示为值 当我使用 XPath 3 1
  • 插入并发问题-多线程环境

    我有一个问题 即使用完全相同的参数在完全相同的时间调用相同的存储过程 存储过程的目的是获取记录 如果存在 或创建并获取记录 如果不存在 问题是两个线程都在检查记录是否存在并报告错误 然后都插入新记录 在数据库中创建重复记录 我尝试将操作保留
  • 钛金 Android 导航组

    您好 我是钛合金新手 它允许开发人员创建跨平台应用程序 我需要创建一个适用于 Android 和 iOS 的导航组 有没有明确的解决方案 因为 Ti UI iPhone createNavigationGrou 仅适用于 iphone 谢谢
  • Itunes Connect 测试飞行公共链接有效性

    苹果最近为试飞版本启用了公共链接功能 我们可以与任何人共享此链接 他可以使用此公共链接安装应用程序 此公共链接背后的构建有效期为 90 天 我的问题是 与用户共享公共链接后 我们可以增加构建的到期时间吗 这样公共链接的有效性就会增加 我们不
  • 将颜色映射到plotly go.饼图中的标签

    我正在使用 make subplots 和 go Pie 绘制一系列 3 个饼图 我想最终将它们放入破折号应用程序中 用户可以在其中过滤数据并且数字将更新 如何将特定颜色映射到变量 以便男性始终为蓝色 女性始终为粉红色 等等 您可以使用 c
  • 使用 Terraform 管理访问 RDS 数据库的凭据时出现问题

    我通过 Terraform 创建了一个秘密 该秘密用于访问也在 Terraform 中定义的 RDS 数据库 并且在秘密中 我不想包含username and password 因此我创建了一个空密钥 然后在 AWS 控制台中手动添加凭证
  • 在继承的ctypes.Structure类的构造函数中调用from_buffer_copy

    我有以下代码 class MyStruct ctypes Structure fields id ctypes uint perm ctypes uint 定义类后 我可以直接从缓冲区复制数据到我的字段上 例如 ms MyStruct fr
  • 一个新的 JavaScript 数组长度是否无法使用? [复制]

    这个问题在这里已经有答案了 根据MDN 文档new Array length https developer mozilla org en US docs Web JavaScript Reference Global Objects Ar
  • 将数值数据更改为分类数据 - Pandas [重复]

    这个问题在这里已经有答案了 我有一个 pandas 数据框 其中有一个数字列 金额 金额从 0 到 20000 不等 我想将其更改为定义范围的分类变量 因此 分类变量将是 0 1000 之间 1000 2000 美元之间 依此类推 直到 1
  • 多个where条件codeigniter

    如何将此查询转换为活动记录 UPDATE table user SET email email last ip last ip where username username and status status 我尝试将上面的查询转换为 d
  • JavascriptCore:在 JSExport 中将 javascript 函数作为参数传递

    JavascriptCore是iOS7中支持的新框架 我们可以使用 JSExport 协议将 objc 类的部分内容公开给 JavaScript 在javascript中 我尝试将函数作为参数传递 像这样 function getJsonC
  • dmvnorm MVN 密度 - RcppArmadillo 实现比 R 包慢,包括一些 Fortran

    The solution现已上线RCPP画廊 http gallery rcpp org articles dmvnorm arma 我从 RcppArmadillo 中的 mvtnorm 包重新实现了 dmvnorm 我有点喜欢犰狳 但我