如何在 CUDA 中执行多个矩阵乘法?

2024-04-27

我有一个方阵数组int *M[10];以便M[i]定位第一个元素i-th 矩阵。我想将所有矩阵相乘M[i]通过另一个矩阵N,这样我就收到了方阵数组int *P[10]作为输出。

我看到有不同的可能性:

  1. 分配不同元素的计算M[i]到不同的线程;例如,我有10矩阵,4x4大小,以便涉及的线程数为160;如何使用CUDA来实现这种方法?
  2. 在上面例子的框架中,创建一个复合矩阵大小40x40(即收集10, 4x4大小矩阵在一起)并使用40x40线程;但这种方法似乎需要更多时间;我正在尝试使用矩阵数组,但我认为我做错了;我怎样才能使用这种方法10矩阵?如何在内核函数中编写它?

这就是我正在尝试的;

void GPU_Multi(int *M[2], int *N, int *P[2], size_t width)
{

    int *devM[2];
    int *devN[2];
    int *devP[2];
    size_t allocasize =sizeof(int) *width*width;

    for(int i = 0 ; i < 10 ; i ++ ) 
    {
        cudaMalloc((void**)&devM[ i ], allocasize );
        cudaMalloc((void**)&devP[ i ], allocasize ); 
    }

    cudaMalloc((void**)&devN, allocasize );

    for(int i = 0 ; i < 10 ; i ++ ) {

        cudaMemcpy(devM[ i ],M[ i ], allocasize , cudaMemcpyHostToDevice);
        cudaMemcpy(devN, N, allocasize , cudaMemcpyHostToDevice);
        dim3 block(width*2, width*2);
        dim3 grid(1,1,1);
        Kernel_Function<<<grid, block>>>  (devM[2], devN, devP[2],width);

        for(int i = 0 ; i < 10 ; i ++ ) 
        {
            cudaMemcpy(P[ i ], P[ i ], allocatesize, cudaMemcpyDeviceToHost);
            cudaFree(devM[ i ]);   
            cudaFree(devP[ i ]);
        }

    }

我认为使用以下方法可能会实现最快的性能CUBLAS批量gemm函数 http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-gemmbatched它是专门为此目的而设计的(执行大量“相对较小”的矩阵乘法运算)。

Even though you want to multiply your array of matrices (M[]) by a single matrix (N), the batch gemm function will require you to pass also an array of matrices for N (i.e. N[]), which will all be the same in your case.

EDIT:现在我已经完成了一个示例,对我来说很清楚,通过对下面的示例进行修改,我们可以传递一个N矩阵并有GPU_Multi函数只需发送单个N矩阵到设备,同时传递一个指针数组N, i.e. d_Narray在下面的示例中,所有指针都指向同一个N设备上的矩阵。

这是一个完整的批量 GEMM 示例:

#include <stdio.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <assert.h>

#define ROWM 4
#define COLM 3
#define COLN 5

#define cudaCheckErrors(msg) \
    do { \
        cudaError_t __err = cudaGetLastError(); \
        if (__err != cudaSuccess) { \
            fprintf(stderr, "Fatal error: %s (%s at %s:%d)\n", \
                msg, cudaGetErrorString(__err), \
                __FILE__, __LINE__); \
            fprintf(stderr, "*** FAILED - ABORTING\n"); \
            exit(1); \
        } \
    } while (0)


typedef float mytype;
// Pi = Mi x Ni
// pr = P rows = M rows
// pc = P cols = N cols
// mc = M cols = N rows
void GPU_Multi(mytype **M, mytype **N, mytype **P
  , size_t pr, size_t pc, size_t mc
  , size_t num_mat, mytype alpha, mytype beta)
{

    mytype *devM[num_mat];
    mytype *devN[num_mat];
    mytype *devP[num_mat];
    size_t p_size =sizeof(mytype) *pr*pc;
    size_t m_size =sizeof(mytype) *pr*mc;
    size_t n_size =sizeof(mytype) *mc*pc;
    const mytype **d_Marray, **d_Narray;
    mytype **d_Parray;
    cublasHandle_t myhandle;
    cublasStatus_t cublas_result;

    for(int i = 0 ; i < num_mat; i ++ )
    {
        cudaMalloc((void**)&devM[ i ], m_size );
        cudaMalloc((void**)&devN[ i ], n_size );
        cudaMalloc((void**)&devP[ i ], p_size );
    }
    cudaMalloc((void**)&d_Marray, num_mat*sizeof(mytype *));
    cudaMalloc((void**)&d_Narray, num_mat*sizeof(mytype *));
    cudaMalloc((void**)&d_Parray, num_mat*sizeof(mytype *));
    cudaCheckErrors("cudaMalloc fail");
    for(int i = 0 ; i < num_mat; i ++ ) {

        cudaMemcpy(devM[i], M[i], m_size , cudaMemcpyHostToDevice);
        cudaMemcpy(devN[i], N[i], n_size , cudaMemcpyHostToDevice);
        cudaMemcpy(devP[i], P[i], p_size , cudaMemcpyHostToDevice);
    }
    cudaMemcpy(d_Marray, devM, num_mat*sizeof(mytype *), cudaMemcpyHostToDevice);
    cudaMemcpy(d_Narray, devN, num_mat*sizeof(mytype *), cudaMemcpyHostToDevice);
    cudaMemcpy(d_Parray, devP, num_mat*sizeof(mytype *), cudaMemcpyHostToDevice);
    cudaCheckErrors("cudaMemcpy H2D fail");
    cublas_result = cublasCreate(&myhandle);
    assert(cublas_result == CUBLAS_STATUS_SUCCESS);
    // change to    cublasDgemmBatched for double
    cublas_result = cublasSgemmBatched(myhandle, CUBLAS_OP_N, CUBLAS_OP_N
      , pr, pc, mc
      , &alpha, d_Marray, pr, d_Narray, mc
      , &beta, d_Parray, pr
      , num_mat);
    assert(cublas_result == CUBLAS_STATUS_SUCCESS);

    for(int i = 0 ; i < num_mat ; i ++ )
    {
        cudaMemcpy(P[i], devP[i], p_size, cudaMemcpyDeviceToHost);
        cudaFree(devM[i]);
        cudaFree(devN[i]);
        cudaFree(devP[i]);
    }
    cudaFree(d_Marray);
    cudaFree(d_Narray);
    cudaFree(d_Parray);
    cudaCheckErrors("cudaMemcpy D2H fail");

}

int main(){

  mytype h_M1[ROWM][COLM], h_M2[ROWM][COLM];
  mytype h_N1[COLM][COLN], h_N2[COLM][COLN];
  mytype h_P1[ROWM][COLN], h_P2[ROWM][COLN];
  mytype *h_Marray[2], *h_Narray[2], *h_Parray[2];
  for (int i = 0; i < ROWM; i++)
    for (int j = 0; j < COLM; j++){
      h_M1[i][j] = 1.0f; h_M2[i][j] = 2.0f;}
  for (int i = 0; i < COLM; i++)
    for (int j = 0; j < COLN; j++){
      h_N1[i][j] = 1.0f; h_N2[i][j] = 1.0f;}
  for (int i = 0; i < ROWM; i++)
    for (int j = 0; j < COLN; j++){
      h_P1[i][j] = 0.0f; h_P2[i][j] = 0.0f;}

  h_Marray[0] = &(h_M1[0][0]);
  h_Marray[1] = &(h_M2[0][0]);
  h_Narray[0] = &(h_N1[0][0]);
  h_Narray[1] = &(h_N2[0][0]);
  h_Parray[0] = &(h_P1[0][0]);
  h_Parray[1] = &(h_P2[0][0]);

  GPU_Multi(h_Marray, h_Narray, h_Parray, ROWM, COLN, COLM, 2, 1.0f, 0.0f);
  for (int i = 0; i < ROWM; i++)
    for (int j = 0; j < COLN; j++){
      if (h_P1[i][j] != COLM*1.0f)
      {
        printf("h_P1 mismatch at %d,%d was: %f should be: %f\n"
          , i, j, h_P1[i][j], COLM*1.0f); return 1;
      }
      if (h_P2[i][j] != COLM*2.0f)
      {
        printf("h_P2 mismatch at %d,%d was: %f should be: %f\n"
          , i, j, h_P2[i][j], COLM*2.0f); return 1;
      }
    }
  printf("Success!\n");
  return 0;
}
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何在 CUDA 中执行多个矩阵乘法? 的相关文章

  • 在模板类中声明模板友元类时出现编译器错误

    我一直在尝试实现我自己的链表类以用于教学目的 我在迭代器声明中指定了 List 类作为友元 但它似乎无法编译 这些是我使用过的 3 个类的接口 Node h define null Node
  • 在一个数据访问层中处理多个连接字符串

    我有一个有趣的困境 我目前有一个数据访问层 它必须与多个域一起使用 并且每个域都有多个数据库存储库 具体取决于所调用的存储过程 目前 我只需使用 SWITCH 语句来确定应用程序正在运行的计算机 并从 Web config 返回适当的连接字
  • 机器Epsilon精度差异

    我正在尝试计算 C 中双精度数和浮点数的机器 epsilon 值 作为学校作业的一部分 我在 Windows 7 64 位中使用 Cygwin 代码如下 include
  • 如何在 C# 中打开 Internet Explorer 属性窗口

    我正在开发一个 Windows 应用程序 我必须向用户提供一种通过打开 IE 设置窗口来更改代理设置的方法 Google Chrome 使用相同的方法 当您尝试更改 Chrome 中的代理设置时 它将打开 Internet Explorer
  • 对类 static constexpr 结构的未定义引用,g++ 与 clang

    这是我的代码 a cp p struct int2 int x y struct Foo static constexpr int bar1 1 static constexpr int2 bar2 1 2 int foo1 return
  • C++ 多行字符串原始文字[重复]

    这个问题在这里已经有答案了 我们可以像这样定义一个多行字符串 const char text1 part 1 part 2 part 3 part 4 const char text2 part 1 part 2 part 3 part 4
  • 需要帮助优化算法 - 两百万以下所有素数的总和

    我正在尝试做一个欧拉计划 http projecteuler net问题 我正在寻找 2 000 000 以下所有素数的总和 这就是我所拥有的 int main int argc char argv unsigned long int su
  • 重载 (c)begin/(c)end

    我试图超载 c begin c end类的函数 以便能够调用 C 11 基于范围的 for 循环 它在大多数情况下都有效 但我无法理解和解决其中一个问题 for auto const point fProjectData gt getPoi
  • WcfSvcHost 的跨域异常

    对于另一个跨域问题 我深表歉意 我一整天都在与这个问题作斗争 现在已经到了沸腾的地步 我有一个 Silverlight 应用程序项目 SLApp1 一个用于托管 Silverlight SLApp1 Web 的 Web 项目和 WCF 项目
  • 结构体的内存大小不同?

    为什么第一种情况不是12 测试环境 最新版本的 gcc 和 clang 64 位 Linux struct desc int parts int nr sizeof desc Output 16 struct desc int parts
  • 为什么 C# 2.0 之后没有 ISO 或 ECMA 标准化?

    我已经开始学习 C 并正在寻找标准规范 但发现大于 2 0 的 C 版本并未由 ISO 或 ECMA 标准化 或者是我从 Wikipedia 收集到的 这有什么原因吗 因为编写 审查 验证 发布 处理反馈 修订 重新发布等复杂的规范文档需要
  • C 编程:带有数组的函数

    我正在尝试编写一个函数 该函数查找行为 4 列为 4 的二维数组中的最大值 其中二维数组填充有用户输入 我知道我的主要错误是函数中的数组 但我不确定它是什么 如果有人能够找到我出错的地方而不是编写新代码 我将不胜感激 除非我刚去南方 我的尝
  • 如何实例化 ODataQueryOptions

    我有一个工作 简化 ODataController用下面的方法 public class MyTypeController ODataController HttpGet EnableQuery ODataRoute myTypes pub
  • 如何在 Linq to SQL 中使用distinct 和 group by

    我正在尝试将以下 sql 转换为 Linq 2 SQL select groupId count distinct userId from processroundissueinstance group by groupId 这是我的代码
  • 使用特定参数从 SQL 数据库填充组合框

    我在使用参数从 sql server 获取特定值时遇到问题 任何人都可以解释一下为什么它在 winfom 上工作但在 wpf 上不起作用以及我如何修复它 我的代码 private void UpdateItems COMBOBOX1 Ite
  • 当文件流没有新数据时如何防止fgets阻塞

    我有一个popen 执行的函数tail f sometextfile 只要文件流中有数据显然我就可以通过fgets 现在 如果没有新数据来自尾部 fgets 挂起 我试过ferror and feof 无济于事 我怎样才能确定fgets 当
  • 将数据从 GPU 复制到 CPU - CUDA

    我在将数据从 GPU 复制到 CPU 时遇到问题 一开始我在 GPU 空间中创建变量 device float gpu array 在此 GPU 函数中 我想将数据从 od fS gi 值 0 43 复制到 gpu array global
  • C# 使用“?” if else 语句设置值这叫什么

    嘿 我刚刚看到以下声明 return name null name NA 我只是想知道这在 NET 中叫什么 是吗 代表即然后执行此操作 这是一个俗称的 条件运算符 三元运算符 http en wikipedia org wiki Tern
  • 如何确定 CultureInfo 实例是否支持拉丁字符

    是否可以确定是否CultureInfo http msdn microsoft com en us library system globalization cultureinfo aspx我正在使用的实例是否基于拉丁字符集 我相信你可以使
  • 使用 WGL 创建现代 OpenGL 上下文?

    我正在尝试使用 Windows 函数创建 OpenGL 上下文 现代版本 基本上代码就是 创建窗口类 注册班级 创建一个窗口 choose PIXELFORMATDESCRIPTOR并设置它 创建旧版 OpenGL 上下文 使上下文成为当前

随机推荐

  • accept() 创建一个新套接字是什么意思?

    我的问题基于以下理解 套接字由 ip port 定义 服务器和客户端都有自己的套接字 Socket连接由五组server ip server port client ip client port protocol定义 套接字描述符是标识套接
  • 如何将带有嵌套节点(父/子关系)的 XML 导入 Access?

    我正在尝试将 XML 文件导入 Access 但它创建了 3 个不相关的表 也就是说 子记录被导入到子表中 但无法知道哪些子记录属于哪个父记录 如何导入数据来维护父子节点 记录 之间的关系 以下是 XML 数据的示例
  • 将目录从 Assets 复制到本地目录

    我正在尝试使用资产文件夹中的目录并将其作为File 是否可以访问 Assets 目录中的某些内容File 如果没有 如何将 Assets 文件夹中的目录复制到应用程序的本地目录 我会像这样复制一个文件 try InputStream str
  • Tkinter 嵌套主循环

    我正在写一个视频播放器tkinter python 所以基本上我有一个可以播放视频的 GUI 现在 我想实现一个停止按钮 这意味着我将有一个mainloop 对于 GUI 还有另一个嵌套mainloop 播放 停止视频并返回 GUI 启动窗
  • JyNI Eclipse 设置

    我在 Eclipse 中有以下 Java 文件 package java python tutorial import org python core PyInstance import org python util PythonInte
  • 仅使用 NumPy einsum 处理上三角元素

    我使用 numpy einsum 来计算形状为 3 N 的列向量 pts 数组与其自身的点积 从而得到形状为 N N 的矩阵 dotps 与所有点积 这是我使用的代码 dotps np einsum ij ik gt jk pts pts
  • 为什么 Ruby 解析文件时常量不像局部变量那样被初始化?

    在 Ruby 中 我知道我可以做这样的事情 if false var Hello end puts var 应用程序不会崩溃 并且var只需设置为nil 我读到 这种情况的发生是由于 Ruby 解析器的工作方式造成的 为什么同样的方法不适用
  • 在 MVC 5 中,如何在单个 Ajax POST 请求中发送 ViewModel 和文件?

    我有一个 ASP NET MVC 5 应用程序 我正在尝试发送带有模型数据的 POST 请求 并且还包括用户选择的文件 这是我的 ViewModel 为了清晰起见进行了简化 public class Model public string
  • 给GAC,还是不给GAC?

    我有一个用 ASP NET 3 5 编写的数据访问层 DAL 并使用 Microsoft 模式和实践库 以下简称 P P 来完成其数据访问 我安装了 P P 它驻留在我的 GAC 中 因此 从逻辑上讲 我的 DAL 在 GAC 中引用它 因
  • `checkout` = `reset` + `symbolic ref`?

    Suppose a branch是一个现有分支 指向与之前不同的提交HEAD指着 HEAD可能直接或通过某些方式指向提交branch 以下命令等效吗 git checkout a branch and git symbolic ref HE
  • 分布式张量流中的并行进程

    我有带有训练参数的张量流神经网络 它是代理的 策略 网络正在核心程序的主张量流会话的训练循环中进行更新 在每个训练周期结束时 我需要将该网络传递给几个并行进程 工作人员 这些进程将使用它来从代理策略与环境的交互中收集样本 我需要并行执行 因
  • 没有传输安全性的 WCF 可靠会话不会按时发生故障事件

    我遇到了可靠会话的一个非常有趣的行为 我使用的是netTcp绑定 双工通道 可靠会话 当我尝试侦听channel faulted时 如果安全模式设置为transport 则当客户端断开连接时 故障事件将立即触发 但是 当我将绑定的安全模式设
  • 在实体框架中附加集合

    使用实体框架 我可以使用附加单个对象 entity Attach 但是 我没有看到任何方法允许我将多个对象的集合 数组添加到实体 我必须循环遍历集合中的每个项目并调用entity Attach 每一次 是的 您必须循环遍历子集合并Attac
  • 在 MySQL 中存储 IPv6 地址

    正如 需要支持 ipv6 的 inet aton 和 inet ntoa 函数 http bugs mysql com bug php id 34037 目前没有用于存储 IPv6 地址的 MySQL 函数 用于存储 插入的推荐数据类型 函
  • 如何在 CSS 中用 SVG 图标替换 Web 字体(Font Awesome)?

    我注意到在我的 CSS 文件中 有一些使用 Font Awesome Web 字体的规则 如下所示 ul fancy li before category page ul li before display none font style
  • 删除URL参数而不刷新页面

    我试图删除 之后的所有内容在文档准备好的浏览器 URL 中 这是我正在尝试的 jQuery document ready function var url window location href url url split 0 我可以做到
  • toLocaleLowerCase() 和 toLowerCase() 之间的区别[重复]

    这个问题在这里已经有答案了 我试图fiddle http jsfiddle net xameeramir kr33b0aL with toLocaleLowerCase http www w3schools com jsref jsref
  • 如何退出 Instagram API?

    Instagram API 身份验证页面没有任何有关如何注销用户的信息 在使用 API 的 iOS 应用程序上 我该如何允许用户注销 要注销用户 您只需删除令牌即可 如果用户不希望您的应用访问他们的数据 他们将取消您的应用访问权限 如果您想
  • 编写无 BOM 的 UTF-8

    这段代码 OutputStream out new FileOutputStream new File C file test txt out write A getBytes 和这个 OutputStream out new FileOu
  • 如何在 CUDA 中执行多个矩阵乘法?

    我有一个方阵数组int M 10 以便M i 定位第一个元素i th 矩阵 我想将所有矩阵相乘M i 通过另一个矩阵N 这样我就收到了方阵数组int P 10 作为输出 我看到有不同的可能性 分配不同元素的计算M i 到不同的线程 例如 我