干货|Pytorch弹性训练极简实现( 附源码)

2023-11-16

点击上方“视学算法”,选择加"星标"或“置顶

重磅干货,第一时间送达db11525605175d78744d802e14d3b387.png

作者丨颜挺帅@知乎(已授权)

来源丨https://zhuanlan.zhihu.com/p/489892744

编辑丨极市平台

导读

 

作者将以往抽象的分布式训练的概念以代码的形式展现出来,并保证每个代码可执行、可验证、可复现,并贡献出来源码让大家相互交流。本例中会先在Node0上启动4 GPU的worker group ,等其训练一段时间后,会在Node1上再启动4 GPU的workers,并与Node1上的workers构成一个新的worker group,最终构成一个2机8卡的分布式训练。

由于工作需要,最近在补充分布式训练方面的知识。经过一番理论学习后仍觉得意犹未尽,很多知识点无法准确get到(例如:分布式原语scatter、all reduce等代码层面应该是什么样的,ring all reduce 算法在梯度同步时是怎么使用的,parameter server参数是如何部分更新的)。

著名物理学家,诺贝尔奖得主Richard Feynman办公室的黑板上写了:"What I cannot create, I do not understand."。在程序员界也经常有"show me the code"的口号。因此,我打算写一系列的分布式训练的文章,将以往抽象的分布式训练的概念以代码的形式展现出来,并保证每个代码可执行、可验证、可复现,并贡献出来源码让大家相互交流。

经过调研发现pytorch对于分布式训练做好很好的抽象且接口完善,因此本系列文章将以pytorch为主要框架进行,文章中的例子很多都来自pytorch的文档,并在此基础上进行了调试和扩充。

最后,由于分布式训练的理论介绍网络上已经很多了,理论部分的介绍不会是本系列文章的重点,我会将重点放在代码层面的介绍上面。

Pytorch - 分布式训练极简体验:https://zhuanlan.zhihu.com/p/477073906

Pytorch - 分布式通信原语(附源码):https://zhuanlan.zhihu.com/p/478953028

Pytorch - 手写allreduce分布式训练(附源码):https://zhuanlan.zhihu.com/p/482557067

Pytorch - 算子间并行极简实现(附源码):https://zhuanlan.zhihu.com/p/483640235

Pytorch - 多机多卡极简实现(附源码):https://zhuanlan.zhihu.com/p/486130584

1. 介绍

Pytorch在1.9.0引入了torchrun,用其替代1.9.0以前版本的torch.distributed.launch。torchrun在torch.distributed.launch 功能的基础上主要新增了两个功能:

  • Failover: 当worker训练失败时,会自动重新启动所有worker继续进行训练;

  • Elastic: 可以动态增加或或删除node节点,本文将通过一个例子说明Elastic Training应该如何使用;

本例中会先在Node0上启动4 GPU的worker group ,等其训练一段时间后,会在Node1上再启动4 GPU的workers,并与Node1上的workers构成一个新的worker group,最终构成一个2机8卡的分布式训练。

6d315545e97f4ca70ab2adccce1e6189.png

2. 模型构建

一个简单的全连接模型神经网络模型

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))

3. checkpoint 处理

由于再每次增加或删除node时,会将所有worker kill掉,然后再重新启动所有worker进行训练。因此,在训练代码中要对训练的状态进行保存,以保证重启后能接着上次的状态继续训练。

需要保存的信息一般有如下内容:

  • model :模型的参数信息

  • optimizer :优化器的参数信心

  • epoch:当前执行到第几个epoch

save和load的代码如下所示

  • torch.save:利用python的pickle将python的object 进行序列化,并保存到本地文件;

  • torch.load : 将torch.save后的本地文件进行反序列化,并加载到内存中;

  • model.state_dict(): 存储了model 每个layer和其对应的param信息

  • optimizer.state_dict():存储了优化器的参数信信息

def save_checkpoint(epoch, model, optimizer, path):
    torch.save({
    "epoch": epoch,
    "model_state_dict": model.state_dict(),
    "optimize_state_dict": optimizer.state_dict(),
}, path)

def load_checkpoint(path):
    checkpoint = torch.load(path)
    return checkpoint

4. 训练代码

初始化逻辑如下:

  • 1~3行: 输出当前worker的关键环境变量,用于后面的结果展示

  • 5~8行:创建模型、优化器和损失函数

  • 10~12行:初始化参数信息

  • 14~19行:如果存在checkpoint,则加载checkpoint,并赋值给model、optimizer和firt_epoch

local_rank = int(os.environ["LOCAL_RANK"])
    rank = int(os.environ["RANK"])
    print(f"[{os.getpid()}] (rank = {rank}, local_rank = {local_rank}) train worker starting...")
    
    model = ToyModel().cuda(local_rank)
    ddp_model = DDP(model, [local_rank])
    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
    optimizer.zero_grad()
    max_epoch = 100
    first_epoch = 0
    ckp_path = "checkpoint.pt"
    
    if os.path.exists(ckp_path):
        print(f"load checkpoint from {ckp_path}")
        checkpoint = load_checkpoint(ckp_path)
        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimize_state_dict"])
        first_epoch = checkpoint["epoch"]

训练逻辑:

  • 1行:epoch执行的次数为first_epoch到max_epoch,以便能够在worker被重启后继续原有的epoch继续训练;

  • 2行:为了展示动态添加node效果,这里添加sleep函数来降低训练的速度;

  • 3~8行:模型训练流程;

  • 9行:为了简单,文本每个epoch进行一次checkpoint保存;将当前的epoch,model和optimizer保存到checkpoint中;

for i in range(first_epoch, max_epoch):
        time.sleep(1) # 为了展示动态添加node效果,这里添加sleep函数来降低训练的速度
        outputs = ddp_model(torch.randn(20, 10).to(local_rank))
        labels = torch.randn(20, 5).to(local_rank)
        loss = loss_fn(outputs, labels)
        loss.backward()
        print(f"[{os.getpid()}] epoch {i} (rank = {rank}, local_rank = {local_rank}) loss = {loss.item()}\n")
        optimizer.step()
        save_checkpoint(i, model, optimizer, ckp_path)

5. 启动方式

由于我们使用torchrun来启动多机多卡任务,无需使用spawn接口来启动多个进程(torchrun会负责将我们的python script启动为一个process),因此直接调用上文编写的train函数,并在前后分别添加DistributedDataParallel的初始化和效果函数即可。

下面代码描述了上文train接口的调用。

def run():
    env_dict = {
        key: os.environ[key]
        for key in ("MASTER_ADDR", "MASTER_PORT", "WORLD_SIZE", "LOCAL_WORLD_SIZE")
    }
    print(f"[{os.getpid()}] Initializing process group with: {env_dict}")
    dist.init_process_group(backend="nccl")
    train()
    dist.destroy_process_group()


if __name__ == "__main__":
    run()

本例中使用torchrun来执行多机多卡的分布式训练任务(注:torch.distributed.launch 已经被pytorch淘汰了,尽量不要再使用)。启动脚本描述如下(注:node0和node1均通过该脚本进行启动)

  • --nnodes=1:3 :表示当前训练任务接受最少1个node,最多3个node参与分布式训练;

  • --nproc_per_node=4:表示每个node上节点有4个process

  • --max_restarts=3: worker group最大的重启次数;这里需要注意的是,node fail、node scale down和node scale up都会导致restart;

  • --rdzv_id=1:一个unique的job id,所有node均使用同一个job id;

  • --rdzv_backend: rendezvous的backend实现,默认支持c10d和etcd两种;rendezvous用于多个node之间的通信和协调;

  • --rdzv_endpoint:rendezvous的地址,应该为一个node的host ip和port;

torchrun \
    --nnodes=1:3\
    --nproc_per_node=4\
    --max_restarts=3\
    --rdzv_id=1\
    --rdzv_backend=c10d\
    --rdzv_endpoint="192.0.0.1:1234"\
    train_elastic.py

6. 结果分析

代码:BetterDL - train_elastic.py:https://github.com/tingshua-yts/BetterDL/blob/master/test/pytorch/DDP/train_elastic.py

运行环境: 2台4卡 v100机器

image: pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtime

gpu: v100

先在node0上执行执行启动脚本

torchrun \
    --nnodes=1:3\
    --nproc_per_node=4\
    --max_restarts=3\
    --rdzv_id=1\
    --rdzv_backend=c10d\
    --rdzv_endpoint="192.0.0.1:1234"\
    train_elastic.py

得到如下结果

  • 2~5行:当前启动的是单机4卡的训练任务,因此WORLD_SIZE为4, LOCAL_WORKD_SIZE也为4

  • 6~9行:共有4个rank参与了分布式训练,rank0~rank3

  • 10~18行: rank0~rank3 均从epoch=0开始训练

r/workspace/DDP# sh run_elastic.sh
[4031] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '44901', 'WORLD_SIZE': '4', 'LOCAL_WORLD_SIZE': '4'}
[4029] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '44901', 'WORLD_SIZE': '4', 'LOCAL_WORLD_SIZE': '4'}
[4030] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '44901', 'WORLD_SIZE': '4', 'LOCAL_WORLD_SIZE': '4'}
[4032] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '44901', 'WORLD_SIZE': '4', 'LOCAL_WORLD_SIZE': '4'}
[4029] (rank = 0, local_rank = 0) train worker starting...
[4030] (rank = 1, local_rank = 1) train worker starting...
[4032] (rank = 3, local_rank = 3) train worker starting...
[4031] (rank = 2, local_rank = 2) train worker starting...
[4101] epoch 0 (rank = 1, local_rank = 1) loss = 0.9288564920425415
[4103] epoch 0 (rank = 3, local_rank = 3) loss = 0.9711472988128662
[4102] epoch 0 (rank = 2, local_rank = 2) loss = 1.0727070569992065
[4100] epoch 0 (rank = 0, local_rank = 0) loss = 0.9402943253517151
[4100] epoch 1 (rank = 0, local_rank = 0) loss = 1.0327017307281494
[4101] epoch 1 (rank = 1, local_rank = 1) loss = 1.4485043287277222
[4103] epoch 1 (rank = 3, local_rank = 3) loss = 1.0959293842315674
[4102] epoch 1 (rank = 2, local_rank = 2) loss = 1.0669530630111694
...

在node1上执行与上面相同的脚本

torchrun \
    --nnodes=1:3\
    --nproc_per_node=4\
    --max_restarts=3\
    --rdzv_id=1\
    --rdzv_backend=c10d\
    --rdzv_endpoint="192.0.0.1:1234"\
    train_elastic.py

node1上结果如下:

  • 2~5行:由于添加node1,当前执行的是2机8卡的分布式训练任务,因此WORLD_SIZE=8, LOCAL_WORLD_SIZE=4

  • 6~9行:当前node1上workers的rank为rank4 ~rank7

  • 13~20行: 由于node1是在node0上work训练到epoch35的时候加入的,因此其接着epoch 35开始训练

/workspace/DDP# sh run_elastic.sh
[696] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '42913', 'WORLD_SIZE': '8', 'LOCAL_WORLD_SIZE': '4'}
[697] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '42913', 'WORLD_SIZE': '8', 'LOCAL_WORLD_SIZE': '4'}
[695] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '42913', 'WORLD_SIZE': '8', 'LOCAL_WORLD_SIZE': '4'}
[694] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '42913', 'WORLD_SIZE': '8', 'LOCAL_WORLD_SIZE': '4'}
[697] (rank = 7, local_rank = 3) train worker starting...
[695] (rank = 5, local_rank = 1) train worker starting...
[694] (rank = 4, local_rank = 0) train worker starting...
[696] (rank = 6, local_rank = 2) train worker starting...
load checkpoint from checkpoint.ptload checkpoint from checkpoint.pt
load checkpoint from checkpoint.pt
load checkpoint from checkpoint.pt
[697] epoch 35 (rank = 7, local_rank = 3) loss = 1.1888569593429565
[694] epoch 35 (rank = 4, local_rank = 0) loss = 0.8916441202163696
[695] epoch 35 (rank = 5, local_rank = 1) loss = 1.5685604810714722
[696] epoch 35 (rank = 6, local_rank = 2) loss = 1.11683189868927
[696] epoch 36 (rank = 6, local_rank = 2) loss = 1.3724170923233032
[694] epoch 36 (rank = 4, local_rank = 0) loss = 1.061527967453003
[695] epoch 36 (rank = 5, local_rank = 1) loss = 0.96876460313797
[697] epoch 36 (rank = 7, local_rank = 3) loss = 0.8060566782951355
...

node0上结果如下:

  • 6~9行: node0上的works在执行到epoch 35时,node1上执行了训练脚本,请求加入到训练任务中

  • 10~13行:所有workers重新启动,由于添加了node1,当前执行的是2机8卡的分布式训练任务,因此WORLD_SIZE=8, LOCAL_WORLD_SIZE=4

  • 14~17行:当前node1上works的rank为rank0~rank3

  • 18~21行:加载checkpoint

  • 22~30行:接着checkpoint中的model、optimizer和epoch继续训练

...
[4100] epoch 35 (rank = 0, local_rank = 0) loss = 1.0746158361434937
[4101] epoch 35 (rank = 1, local_rank = 1) loss = 1.1712706089019775
[4103] epoch 35 (rank = 3, local_rank = 3) loss = 1.1774182319641113
[4102] epoch 35 (rank = 2, local_rank = 2) loss = 1.0898035764694214
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 4100 closing signal SIGTERM
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 4101 closing signal SIGTERM
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 4102 closing signal SIGTERM
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 4103 closing signal SIGTERM
[4164] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '42913', 'WORLD_SIZE': '8', 'LOCAL_WORLD_SIZE': '4'}
[4165] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '42913', 'WORLD_SIZE': '8', 'LOCAL_WORLD_SIZE': '4'}
[4162] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '42913', 'WORLD_SIZE': '8', 'LOCAL_WORLD_SIZE': '4'}
[4163] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '42913', 'WORLD_SIZE': '8', 'LOCAL_WORLD_SIZE': '4'}
[4162] (rank = 0, local_rank = 0) train worker starting...
[4163] (rank = 1, local_rank = 1) train worker starting...
[4164] (rank = 2, local_rank = 2) train worker starting...
[4165] (rank = 3, local_rank = 3) train worker starting...
load checkpoint from checkpoint.pt
load checkpoint from checkpoint.pt
load checkpoint from checkpoint.pt
load checkpoint from checkpoint.pt
[4165] epoch 35 (rank = 3, local_rank = 3) loss = 1.3437936305999756
[4162] epoch 35 (rank = 0, local_rank = 0) loss = 1.5693414211273193
[4163] epoch 35 (rank = 1, local_rank = 1) loss = 1.199862003326416
[4164] epoch 35 (rank = 2, local_rank = 2) loss = 1.0465545654296875
[4163] epoch 36 (rank = 1, local_rank = 1) loss = 0.9741991758346558
[4162] epoch 36 (rank = 0, local_rank = 0) loss = 1.3609280586242676
[4164] epoch 36 (rank = 2, local_rank = 2) loss = 0.9585908055305481
[4165] epoch 36 (rank = 3, local_rank = 3) loss = 0.9169824123382568
...

f2906302c44c0794a3998a0d403e388c.png

outside_default.png

点个在看 paper不断!

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

干货|Pytorch弹性训练极简实现( 附源码) 的相关文章

随机推荐

  • Self-study Python Fish-C Note-5 P20-P26 (part2)

    python 中的列表 Part 2 本文主要讲解了python中列表的使用 本文为自学B站上鱼C的python课程随手做的笔记 如有问题 欢迎大家批评指正 原视频链接 https www bilibili com video BV1c44
  • LeetCode算法心得——和可被 K 整除的子数组(前缀和+HashMap)

    大家好 我是晴天学长 同余定理的应用 需要的小伙伴可以关注支持一下哦 后续会继续更新的 1 和可被 K 整除的子数组 题目描述 给定一个整数数组 A 返回其中元素之和可被 K 整除的 连续 非空 子数组的数目 示例 输入 A 4 5 0 2
  • 【极简数据结构】快速了解并实现顺序表,速通玩家的最爱

    顺序表目录 前言 一 线性表 二 顺序表 1 顺序表的概念 2 接口函数 顺序表 初始化 顺序表 尾插 顺序表 打印 顺序表 销毁 顺序表 尾删 顺序表 头插 和 顺序表 扩容 优化顺序表 尾删 顺序表 头删 顺序表 查找 顺序表 任意po
  • Android如何安全的关闭线程

    正常情况下 当线程中的run方法执行完毕后 线程是会自动关闭 不需要我们手动去关闭的 如 new Thread new Runnable Override public void run 执行操作 start 该线程在run方法中的操作执行
  • C语言经典100例题(22)--两个乒乓球队进行比赛,各出三人。甲队为a, b, c三人,乙队为x, y, z三人。已抽签决定//比赛名单,有人向队员打听比赛的名单.a说他不和x比,c说他不和x, z

    目录 题目 问题分析 代码 运行结果 题目 两个乒乓球队进行比赛 各出三人 甲队为a b c三人 乙队为x y z三人 已抽签决定比赛名单 有人向队员打听比赛的名单 a说他不和x比 c说他不和x z比 请编程序找出三队赛手的名单 问题分析
  • linux屏蔽海外流量的两种方法

    方法一 使用大神的开源脚本 屏蔽指定国家地区的IP访问 wget https raw githubusercontent com iiiiiii1 Block IPs from countries master block ips sh s
  • RSA的C++语言描述简单实现

    文章目录 前言 代码仓库 代码特点 大 素 数讨论 部分资料 作者理解 代码 rsa h rsa cpp main cpp 结果 总结 参考资料 作者的话 前言 网络安全中RSA的C 语言描述简单实现 代码仓库 yezhening Prog
  • excel 生成sql

    参考文章 https blog csdn net m0 67695717 article details 127406830 新增语句 INSERT INTO table name column1 column2 VALUES A2 D2
  • 判断一个文件是否为CSV文件的Python代码

    在Python中 我们可以使用os模块的path splitext 函数来获取文件扩展名 然后判断这个扩展名是否为 csv 以下是一个示例代码 import os def is csv file file path file extensi
  • WSL2和Docker for Windows

    文章目录 一 Docker和WSL2概述 二 WSL安装使用 三 基于Docker导入任意WSL分发 参考资料 一 Docker和WSL2概述 Docker 是一个开源的应用容器引擎 让开发者可以打包他们的应用以及依赖包到一个可移植的容器中
  • git中出现“interactive rebase in progress; onto 11dde1e”错误分析与解决方案

    出错原因分析 进行提交前 需提前拉取远程仓库的代码 拉取之后 需要重新add commit 避免仓库的数据被修改 但是再次提交之后会出现上图的错误 原因 是因为你现在正在编辑的提交将要覆盖在 11ddele commited 之前使用过gi
  • 472-I/O阻塞和非阻塞,同步和异步

    阻塞 非阻塞 同步 异步 典型的一次I O的两个阶段是什么 数据准备 和 数据读写 我们作为服务器 接收客户端的请求 得先监听客户端有没有数据过来 这是一个状态 还有就是数据过来了该怎么去读写 这又是一个状态 实际上 阻塞 非阻塞 同步 异
  • 基于51单片机的羽毛球计分器proteus仿真程序设计

    硬件设计 末尾附文件 方案 在像羽毛球这样的竞技比赛中 计分器占着很大的作用 如果我们就只在心里记着双方的比分 显然是不实际的 而且在现在的乒乓球比赛中采用的都是21分制 因此我们不能再用传统的计分方式了 本次课题采用单片机设计了一个羽毛球
  • JetBrains IDE Support 调试工具(页面自动刷新)

    1 谷歌浏览器安装JetBrains IDE Support 插件 2 更改端口 webstorm 和 JetBrains IDE Support 端口不对插件图标会黑 3 debug 4 同步
  • 求定制闲鱼爬虫获取最新发布商品

    闲鱼采集及监控下单软件开发 1 点击宝贝右键打开网页 2 点击宝贝右键拉黑卖家 3 点击宝贝右键清空列表 4 点击宝贝左键显示二维码和主图 5 软件页面显示宝贝二维码 6 软件页面显示宝贝主图 7 软件页面显示检测搜索词及下单宝贝 8 每次
  • 介绍计算机方队,方阵

    f ng zh n 方阵 语音 编辑 锁定 讨论 上传视频 方阵是古代军队作战时采用的一种队形 是把军队在野外开阔地上排列成方形阵式 远古方阵由前军 中军和后军相互嵌套排列而成 方阵平面呈现 回 字形状 反映出远古观念中的一种政治地理结构
  • Java丨JVM虚拟机与类加载器

    一丶JVM 虚拟机介绍 Sun HotSpot VM 这个目前看起来 血统纯正 的虚拟机在最初并非由Sun公司开发 而是由一家名为 Longview Technologies 的小公 司设计的 甚至这个虚拟机最初并非是为Java语言而开发的
  • FPGA图像处理基础----直方图统计

    直方图统计的原理 直方图统计从数学上来说 是对图像中的像素点进行统计 图像直方图统计常用于统计灰度图像 表示图像中各个灰度级出现的次数或者概率 统计直方图的实现采用C C 或者其他高级语言实现十分简单 单采用FPGA来实现直方图的统计就稍显
  • Tensorflow分布式训练

    Tensorflow分布式训练 一 分布式训练模式 1 模型并行 In graph 2 数据并行 Between graph 二 异步 同步训练 1 异步训练 2 同步训练 三 同步更新和异步更新的优缺点 四 分布式机器类型 TODO 1
  • 干货|Pytorch弹性训练极简实现( 附源码)

    点击上方 视学算法 选择加 星标 或 置顶 重磅干货 第一时间送达 作者丨颜挺帅 知乎 已授权 来源丨https zhuanlan zhihu com p 489892744 编辑丨极市平台 导读 作者将以往抽象的分布式训练的概念以代码的形