Pytorch多进程Queue通信产生Segmentation fault (core dumped)——解决方案及代码规范

2023-11-03

最近在做一个强化学习的项目,运用多进程分布训练时遇到了段错误的问题,这里记录下解决的过程思路和方案。

由于智能体与环境交互的过程涉及到了第三方的程序以及大量的文件读写操作,使得整个实验过程非常慢,为了解决交互部分的速度瓶颈,采用Ape-X( Distributed Prioritized Experience Replay)的分布式训练思路,即多个actor负责与环境交互,得到的交互数据存储到公共replay memory中,一个leaner负责从memory中抽样训练更新网络。

由于Pytorch在多进程方面的封装较好,我采用torch.multiprocessing包来实现多进程,并通过其中的Queue队列来实现进程间通信,也就是actor将交互数据发送给learner。主要代码结构简化如下:

def actor(q):
    # 创建环境
    ...
    while True:
        # 获取交互数据 batch 类型为Tensor
        ...
        q.put(batch)

def learner(q)
    # 创建memory
    memory = Memory()
    ...
    while True:
        batch = q.get() # <--- *** 产生 SegFault的地方 ***
        memory.push(batch) 
        update_model()

if __name__ == '__main__':
    # 创建模型、优化器等
    model = DQN()
    model.share_memory()
    ...

    q = torch.multiprocessing.Queue() # 创建 队列q
    processes = []
    for id in range(actor_num):
        p = torch.multiprocessing.Process(target=actor, args=(q))
        processes.append(p)
    processes.append(torch.multiprocessing.Process(target=learn, args=(q)))

    for p in processes:
        p.start()
    for p in processes:
        p.join()
        

一开始程序运行正常,但循环到一定时候,learn进程直接就消失了,连报错都没有(多进程下,子进程出错是没有提示的)。后来将learn函数移到主进程运行,得到了错误时候的提示:

Segmentation fault (core dumped)

 意思是段错误(核心转储)。这种通常是比较严重的运行错误了,导致进程直接结束,因此也得不到python解释器发送的error。

网上查询之后得知,引发该错误的原因基本都和内存相关。经过print调试法,最终将引发错误的语句定位到了 q.get() 这一句,在百度和google上都搜索了一遍,完全找不到相关的解答。于是我开始从官方文档中寻求思路 MULTIPROCESSING BEST PRACTICES

我注意到了这么一句话:

Reuse buffers passed through a Queue

Remember that each time you put a Tensor into a multiprocessing.Queue, it has to be moved into shared memory. If it’s already shared, it is a no-op, otherwise it will incur an additional memory copy that can slow down the whole process. Even if you have a pool of processes sending data to a single one, make it send the buffers back - this is nearly free and will let you avoid a copy when sending next batch.

意思是,任何被放入Queue队列中的Tensor类型数据,都会被移入到共享内存中。因此我推测,是因为从Queue中取出的数据被直接使用,导致它们始终储存在共享内存中,最后爆内存了。经过验证,确实是类似的原因。解决方案如下

def learner(q)
    # 创建memory
    memory = Memory()
    ...
    while True:
        batch = q.get() 
        batch_local = batch.clone() # *** 新增代码 *** 创建一个属于本进程的数据副本
        del batch                   # *** 新增代码 *** 释放共享内存
        memory.push(batch_local)    # *** 修改代码 *** 存储数据副本而不是直接获取的数据
        update_model()

其实在在一个文档中,官方也提供了相应的代码规范,MULTIPROCESSING PACKAGE - TORCH.MULTIPROCESSING 尤其是涉及到GPU cuda Tensor数据的多进程共享,具体包括

  1. 尽快从消费者进程中释放内存
  2. 保持生产者进程的运行状态,直到所有的消费者进程结束。可以防止生产者进程释放消费者仍在使用的内存的情况
  3. 不要直接传递接收来的tensor变量

 总结:官方文档,永远滴神

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Pytorch多进程Queue通信产生Segmentation fault (core dumped)——解决方案及代码规范 的相关文章

随机推荐