我有一个特殊的算法,作为最后的步骤之一,我需要执行 3-D 数组与 2-D 数组的乘法,以便 3-D 数组的每个矩阵切片与 3-D 数组的每一列相乘。二维数组。换句话说,如果说A
is an N x N x N
矩阵和B
is an N x N
矩阵,我需要计算一个矩阵C
尺寸的N x N
where C(:,i) = A(:,:,i)*B(:,i);
.
实现这一点的简单方法是循环,即
C = zeros(N,N);
for i = 1:N
C(:,i) = A(:,:,i)*B(:,i);
end
然而,循环并不是 Matlab 中最快的,应该避免。我正在寻找更快的方法来做到这一点。现在,我所做的就是利用以下事实(现在 Mathjax 会很棒!):
[A1 b1, A2 b2, ..., AN bN] = [A1, A2, ..., AN]*blkdiag(b1,b2,...,bN)
这允许摆脱循环,但是,我们必须创建一个大小为的块对角矩阵N^2 x N
。我正在通过sparse
为了高效,即像这样:
A_long = reshape(A,N,N^2);
b_cell = mat2cell(B,N,ones(1,N)); % convert matrix to cell array of vectors
b_cell{1} = sparse(b_cell{1}); % make first element sparse, this is enough to trigger blkdiag into sparse mode
B_blk = blkdiag(b_cell{:});
C = A_long*B_blk;
根据我的基准测试,尽管进行了必要的准备(仅乘法就比循环快 3 到 4 倍),但这种方法比循环快两倍左右(对于大 N)。
这是我所做的一个快速基准测试,改变了问题的大小N
并测量循环和替代方法的时间(有或没有准备步骤)。对于大型N
加速比约为 2...2.5。
不过,这对我来说看起来非常复杂。有没有更简单或更好的方法来实现这一目标?这看起来像是一个非常通用/标准的问题,所以我可以想象解决方案就在身边,我只是不知道真正要搜索什么。
P.S.: blkdiag(A1,...,AN)*B
是一个明显的替代方案,但这里块对角线已经是N^2 x N^2
所以我认为它不会比我所做的更好。
edit: 谢谢大家的评论!我在 Matlab R2016b 上进行了新的基准测试。不幸的是,我在同一台计算机上没有这两个版本,因此我们无法比较绝对数字,但相对比较仍然很有趣,因为它发生了一些变化。这里是:
这是高 N 区域的放大图:
一些观察结果:
- SumRepDot是Divakar提出的解决方案,即使用
squeeze(sum(bsxfun(@times,A,permute(B,[3,1,2])),2))
在 R2016b 上简化为squeeze(sum(A.*permute(B,[3,1,2]),2))
。它比高循环更快N
大约为 1.2...1.4 倍。
- 从某种意义上说,循环仍然“慢”,因为与稀疏块对角矩阵的乘法要快得多。
- 对于后者,准备开销似乎可以忽略不计
N
这使得它总体上比循环快 3...4 倍。这是一个很好的结果。