如果您想要更快地实现马哈拉诺比斯距离,您只需重写算法并模仿 R 使用的算法即可。这非常简单
我稍微修改了你的功能mahalanobis_arma
转动mu
to a rowvec
.
基本上我只是将 R 代码翻译为 RcppArmadillo
mahalanobis
function (x, center, cov, inverted = FALSE, ...)
{
x <- if (is.vector(x))
matrix(x, ncol = length(x))
else as.matrix(x)
x <- sweep(x, 2, center)
if (!inverted)
cov <- solve(cov, ...)
setNames(rowSums((x %*% cov) * x), rownames(x))
}
<bytecode: 0x6e5b408>
<environment: namespace:stats>
这里是
#include <RcppArmadillo.h>
#include <Rcpp.h>
// [[Rcpp::depends("RcppArmadillo")]]
// [[Rcpp::export]]
arma::vec Mahalanobis(arma::mat x, arma::rowvec center, arma::mat cov){
int n = x.n_rows;
arma::mat x_cen;
x_cen.copy_size(x);
for (int i=0; i < n; i++) {
x_cen.row(i) = x.row(i) - center;
}
return sum((x_cen * cov.i()) % x_cen, 1);
}
// [[Rcpp::export]]
arma::vec mahalanobis_arma( arma::mat x , arma::rowvec mu, arma::mat sigma ){
int n = x.n_rows;
arma::vec md(n);
for (int i=0; i<n; i++){
arma::mat x_i = x.row(i) - mu;
arma::mat Y = arma::solve( sigma, arma::trans(x_i) );
md(i) = arma::as_scalar(x_i * Y);
}
return md;
}
现在,让我们来比较一下这个新的犰狳版本(Mahalanobis
),你的第一个版本(mahalanobis_arma
)和 R 实现(mahalanobis
).
我将此 Cpp 代码保存为mahalanobis.cpp
require(RcppArmadillo)
sourceCpp("mahalanobis.cpp")
set.seed(1)
x <- matrix(rnorm(10000 * 10), ncol = 10)
Sx <- cov(x)
all.equal(c(Mahalanobis(x, colMeans(x), Sx))
,mahalanobis(x, colMeans(x), Sx))
## [1] TRUE
all.equal(mahalanobis_arma(x, colMeans(x), Sx)
,Mahalanobis(x, colMeans(x), Sx))
## [1] TRUE
require(rbenchmark)
benchmark(Mahalanobis(x, colMeans(x), Sx),
mahalanobis(x, colMeans(x), Sx),
mahalanobis_arma(x, colMeans(x), Sx),
order = "elapsed")
## test replications elapsed
## 1 Mahalanobis(x, colMeans(x), Sx) 100 0.124
## 2 mahalanobis(x, colMeans(x), Sx) 100 0.741
## 3 mahalanobis_arma(x, colMeans(x), Sx) 100 4.509
## relative user.self sys.self user.child sys.child
## 1 1.000 0.173 0.077 0 0
## 2 5.976 0.804 0.670 0 0
## 3 36.363 4.386 4.626 0 0
正如您所看到的,新的实现比 R 的实现更快。
我非常确定,通过使用乔列斯基分解来求解协方差矩阵或使用其他矩阵分解,我们可以做得更好。
最后,我们可以将其插入Mahalanobis
功能进入你的dmvnorm
并测试它:
require(mvtnorm)
set.seed(1)
sigma <- matrix(c(4, 2, 2, 3), ncol = 2)
x <- rmvnorm(n = 5000000, mean = c(1, 2), sigma = sigma, method = "chol")
all.equal(mvtnorm::dmvnorm(x, t(1:2), .2 + diag(2), FALSE),
c(dmvnorm(x, t(1:2), .2+diag(2), FALSE)))
## [1] TRUE
benchmark(mvtnorm::dmvnorm(x, t(1:2), .2 + diag(2), FALSE),
dmvnorm(x, t(1:2), .2+diag(2), FALSE),
order = "elapsed")
## test replications
## 2 dmvnorm(x, t(1:2), 0.2 + diag(2), FALSE) 100
## 1 mvtnorm::dmvnorm(x, t(1:2), 0.2 + diag(2), FALSE) 100
## elapsed relative user.self sys.self user.child sys.child
## 2 35.366 1.000 31.117 4.193 0 0
## 1 60.770 1.718 56.666 13.236 0 0
现在几乎快了一倍。