机器学习框架Ray -- 2.4 基于Ray的Fashion Minst数据集识别

2023-11-08

1. 概述

使用 Ray 和 Ray Train,可以在多个 worker 上分发训练任务,从而加速整个训练过程。每个 worker 都在独立的数据子集上训练相同的神经网络结构。在训练过程中,所有 worker 共享并更新同一个神经网络的参数。

这里的并行计算并不是用于比较不同神经网络结构的训练效果,而是用于加速单个神经网络结构的训练。通过在多个 worker 上分发训练任务,可以更快地完成整个训练过程。这种方法特别适用于大型数据集和复杂模型,因为这些情况下单个设备(例如单个 CPU 或 GPU)可能会受到计算能力和内存限制。

下面将使用Ray Train 分布式训练 PyTorch 模型的Fashion MNIST 图像分类问题。

代码的主要组成部分:

  1. 导入所需库和模块
  2. 下载 Fashion MNIST 训练数据和测试数据
  3. 定义一个神经网络模型(NeuralNetwork 类)
  4. 定义训练和验证函数(train_epoch 和 validate_epoch)
  5. 实现一个训练循环函数(train_func),它将用于每个 Ray Train worker
  6. 定义一个 train_fashion_mnist 函数,用于设置和运行分布式训练。这个函数创建一个 TorchTrainer 实例,并使用给定的配置参数(如工作器数量和是否使用 GPU)来初始化训练
  7. 最后,使用 argparse 模块处理命令行参数,并在 __main__ 块中启动 Ray 和分布式训练

目标是设置并运行一个分布式训练任务,使用 Ray Train 在 Fashion MNIST 数据集上训练一个 PyTorch 神经网络模型。这个示例展示了如何使用 Ray Train 轻松地扩展训练任务,使其可以在多个工作器上并行运行。

2. 环境构建

训练环境为《机器学习框架Ray -- 1.4 Ray RLlib的基本使用》中创建的RayRLlib环境。

本案例中除了安装Ray以外,还需要安装pytorch。

Anaconda中环境创建如下:

conda create -n RayRLlib python=3.7 
conda activate RayRLlib 
conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch
pip install ipykernel -i https://pypi.tuna.tsinghua.edu.cn/simple
pip install pyarrow
pip install gputil
pip install "ray[rllib]" -i https://pypi.tuna.tsinghua.edu.cn/simple 

进入RayRLlib环境中,导入所需包。

import argparse
from typing import Dict
from ray.air import session

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

import ray.train as train
from ray.train.torch import TorchTrainer
from ray.air.config import ScalingConfig

3. 数据产生、定义神经网络

下载Fashion MNIST数据集,将自动联网下载数据集,位置在Linux系统根目录的data文件夹内。

# Download training data from open datasets.
# 下载训练数据
training_data = datasets.FashionMNIST(
    root="~/data",
    train=True,
    download=True,
    transform=ToTensor(),
)

# Download test data from open datasets.
# 下载测试数据
test_data = datasets.FashionMNIST(
    root="~/data",
    train=False,
    download=True,
    transform=ToTensor(),
)

定义神经网络模型。由于Fashion MNIST 数据集每张图片均为28x28大小,神经网络维度需要对应设置28x28。神经网络为若干全连接层:

nn.Linear(28 * 28,512),
nn.ReLU(),
nn.Linear(512,128),
nn.ReLU(),
nn.Linear(128,64),
nn.ReLU(),
nn.Linear(64,10),
nn.LogSoftmax(dim=1)

具体定义神经网络与训练函数。epochs原始代码为4,本文改为50。

本案例中,若使用GPU训练,--num-workers需要设置为GPU数量;--num-workers在仅CPU训练时最大可以设置为略小于CPU的线程数。

# Define model
# 定义神经网络模型
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            # nn.Linear(28 * 28, 512),
            # nn.ReLU(),
            # nn.Linear(512, 512),
            # nn.ReLU(),
            # nn.Linear(512, 10),
            # nn.ReLU(),
            nn.Linear(28 * 28,512),
            nn.ReLU(),
            nn.Linear(512,128),
            nn.ReLU(),
            nn.Linear(128,64),
            nn.ReLU(),
            nn.Linear(64,10),
            nn.LogSoftmax(dim=1),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

# 定义训练函数
def train_epoch(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset) // session.get_world_size()
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

# 定义验证函数
def validate_epoch(dataloader, model, loss_fn):
    size = len(dataloader.dataset) // session.get_world_size()
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(
        f"Test Error: \n "
        f"Accuracy: {(100 * correct):>0.1f}%, "
        f"Avg loss: {test_loss:>8f} \n"
    )
    return test_loss

# 定义 Ray Train 工作函数
def train_func(config: Dict):
    batch_size = config["batch_size"]
    lr = config["lr"]
    epochs = config["epochs"]

    worker_batch_size = batch_size // session.get_world_size()

    # Create data loaders.
    train_dataloader = DataLoader(training_data, batch_size=worker_batch_size)
    test_dataloader = DataLoader(test_data, batch_size=worker_batch_size)

    train_dataloader = train.torch.prepare_data_loader(train_dataloader)
    test_dataloader = train.torch.prepare_data_loader(test_dataloader)

    # Create model.
    model = NeuralNetwork()
    model = train.torch.prepare_model(model)

    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)

    for _ in range(epochs):
        train_epoch(train_dataloader, model, loss_fn, optimizer)
        loss = validate_epoch(test_dataloader, model, loss_fn)
        session.report(dict(loss=loss))

# fashion mnist训练函数
def train_fashion_mnist(num_workers=1, use_gpu=True):
    trainer = TorchTrainer(
        train_loop_per_worker=train_func,
        train_loop_config={"lr": 1e-3, "batch_size": 64, "epochs": 50}, 
        scaling_config=ScalingConfig(num_workers=num_workers, use_gpu=use_gpu),
    )
    result = trainer.fit()
    print(f"Last result: {result.metrics}")
  • 如果使用CPU训练模型,可参考修改对应代码段(以20线程CPU为例):
...
trainer = TorchTrainer(
        train_loop_per_worker=train_func,
        train_loop_config={"lr": 1e-3, "batch_size": 64, "epochs": 50},   
        scaling_config=ScalingConfig(num_workers=num_workers, use_gpu=use_gpu),
    )

...
parser.add_argument(
        "--num-workers",
        "-n",
        type=int,
        default=20,  
        help="Sets number of workers for training.",
)
parser.add_argument(
        "--use-gpu", action="store_true", default=False, help="Enables GPU training"
    )
...
  • 如果使用单GPU计算,可参考修改对应代码段:
...
trainer = TorchTrainer(
        train_loop_per_worker=train_func,
        train_loop_config={"lr": 1e-3, "batch_size": 64, "epochs": 50},   
        scaling_config=ScalingConfig(num_workers=num_workers, use_gpu=use_gpu),
    )

...
parser.add_argument(
        "--num-workers",
        "-n",
        type=int,
        default=1,  # 若使用GPU,此处为GPU的数量对应,否则会卡在pending
        help="Sets number of workers for training.",
    )
    parser.add_argument(
        "--use-gpu", action="store_true", default=True, help="Enables GPU training"
    )
...

4.训练模型

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--address", required=False, type=str, help="the address to use for Ray"
    )
    parser.add_argument(
        "--num-workers",
        "-n",
        type=int,
        default=20,
        help="Sets number of workers for training.",
    )
    parser.add_argument(
        "--use-gpu", action="store_true", default=False, help="Enables GPU training"
    )
    parser.add_argument(
        "--smoke-test",
        action="store_true",
        default=False,
        help="Finish quickly for testing.",
    )

    args, _ = parser.parse_known_args()

    import ray

    if args.smoke_test:
        # 2 workers + 1 for trainer.
        ray.init(num_cpus=3)
        train_fashion_mnist() 
    else:
        ray.init(address=args.address)
        train_fashion_mnist(num_workers=args.num_workers, use_gpu=args.use_gpu)

使用GPU时,显示如下: 

以CPU与GPU分别以相同的学习率训练50epochs,最终准确率都在80%以上。

后续补充如何修改超参数,以提高预测精度。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

机器学习框架Ray -- 2.4 基于Ray的Fashion Minst数据集识别 的相关文章

随机推荐

  • 为云服务器添加python web环境

    为云服务器添加python web环境 自用不喜勿喷 当前配置 阿里云win10云服务器 anaconda配置的python环境 操作步骤 1 参照教程配置python及Django 2 pyCharm如何运行Django https ww
  • vue标签属性及其用法

    一 Vue的特点 1 采用组件化模式 提高代码的复用率 且让代码更好维护 2 声明编码 让编码人员无需直接操作DOM 提高开发效率 3 使用学你DOM 优秀的Diff算法 尽量服用DOM节点 二 Vue模板语法有两大类 1 插值语法 功能
  • Springboot修改内置Tomcat版本

    背景 Tomcat的安全漏洞需要升级版本进行解决 如 9 0 63 gt 9 0 75 1 Pom文件Springboot的依赖配置项 2 Ctrl 右键点击红色框选 3 全局搜索 修改 修改数值 启动测试
  • C++实现——string的所有操作

    C 中string的操作 Constructors 构造函数 用于字符串初始化 Operators 操作符 用于字符串比较和赋值 append 在字符串的末尾添加文本 assign 为字符串赋新值 at 按给定索引值返回字符 begin 返
  • 基于python实现的CS通信和P2P通信

    实验要求 C S通信实现要求 两台计算机分别模拟服务器 客户端 通过编程实现服务器端 客户端程序Socket Client 服务器端程序监听客户端向服务器端发出的请求 并返回数据给客户端 不采用方式 自定义通信协议 传输文件要足够大 例如
  • Python GUI: PyCahrm结合Pyqt5开发图形化界面 详细步骤 踩坑!

    1 下载安装pythonPython官网下载地址 注意 1 1 Python版本选择并不是越新越好 后面会提到 我安装的版本是 V3 5 4 64位 1 2 安装的时候一定要勾选pip和add python to path 自动添加到环境变
  • 清华大佬耗时36个小时,终于整理出来了一份Python自学计划,学不会退出IT界

    在人工智能的风口 Python这门胶水语言越来越火 很多小伙伴也开始学习Python 但是没有一份合适的学习规划怎么能行 今天特意为大家整理了一份Python自学计划 希望可以帮助到处在迷茫期的你们 文末获取o 这份自学计划是我精心整理的
  • React+AntDesign开发完整的考勤系统前端页面(一)

    一 项目准备工作 1 开发环境准备 准备好Visual Studio Code前端开发工具 下载并安装Node js 2 项目准备 本次项目使用umi脚手架的方式创建 1 打开开发工具打开项目文件夹并新建终端在终端里面输入命令 npm i
  • 【概率论与数理统计】猴博士 笔记 p17-20 一、二维连续型:已知F,求f;已知f,求f

    一维连续型已知F 求f 题型 步骤 f是F的导数 对F求导即可得到f 例1 解 例2 解 一维连续型已知f 求f 题型 已知f x 求f y 步骤 注意 要满足要求 Y g X 满足单增或单减才能用公式法 看起来有点抽象 我们看一道例题 此
  • GAMES101 作业3(附三角形重心坐标,Blinn-Phong光照模型及法线贴图推导)

    目录 写在前面 第一题 三角形重心坐标 第二题 Blinn Phong光照模型 第三题 纹理贴图 第四题 凹凸贴图实现及法线贴图推导 第五题 位移贴图 写在前面 main 函数中 std function
  • FedDG: Federated Domain Generalization on Medical Image Segmentation via Episodic Learning

    FedDG Federated Domain Generalization on Medical Image Segmentation via Episodic Learning in Continuous Frequency Space
  • 数据库锁表的查询和处理

    如果遇到数据库锁表 通常需要用如下方法处理 查看表的partnum情况 oncheck pt shjz mzb baf01 grep i partnum 这个里面包含了多个分区的partnum 也包含了索引的partnum Partitio
  • Windows安装和配置VCenter

    Vcenter的环境搭建和配置 Vcenter简介 Vcenter一般指 VMware vCenter Server VMware vCenterServer 提供了一个可伸缩 可扩展的平台 为 虚拟化管理奠定了基础 VMware vCen
  • 射线与AABB型包围盒相交算法

    基础知识 AABB包围盒 也叫轴对称包围盒 意思就是它的六个面总是分别平行XYZ三个轴的 相交计算原理 计算射线与包围盒每个面的平面的交点 计算这个点是否在包围盒面的范围 在就是相交 不在就是没有相交 图解 用个2D图形简单讲解一下 首先从
  • C++ 一些学习笔记(十一)类和对象-继承

    C 一些学习笔记 十一 类和对象 继承 主要是针对之前学习C的时候一些知识点的遗漏的补充 还有一些我自己觉得比较重要的地方 本文章的主要内容是关于继承 1 继承的基本语法 2 继承方式 3 继承中的对象模型 4 继承中的构造和析构顺序 5
  • 虚拟化技术调研

    虚拟化技术调研 容器 虚拟化技术 容器是一种轻量级虚拟化技术 它可以在一台宿主机上共享内核 并且在运行应用程序时具有独立的文件系统空间 网络空间 进程空间和用户空间 常见的容器技术有Docker和LXC KVM虚拟化技术 KVM是一种全虚拟
  • Nginx反向代理服务器搭建(超详细)

    一 简介 Nginx engine x 是一个高性能的Web服务器和反向代理服务器 也可以作为邮件代理服务器 反向代理 Reverse Proxy 方式是指以代理服务器来接受internet上的连接请求 然后将请求转发给内部网络上的服务器
  • 前端map传给后端接收

    前端let map new Map map set 1 1 map set 2 2 map set 3 3 map转obj let obj Object create null for let k v of map obj k v ajax
  • sql实现多字段去重

    sql实现多字段去重 且返回所有字段 1 主要思想 根据需求去重的字段进行分组 获取id 在联合查询 2 主要代码 SELECT from table A where id in SELECT max id from table A gro
  • 机器学习框架Ray -- 2.4 基于Ray的Fashion Minst数据集识别

    1 概述 使用 Ray 和 Ray Train 可以在多个 worker 上分发训练任务 从而加速整个训练过程 每个 worker 都在独立的数据子集上训练相同的神经网络结构 在训练过程中 所有 worker 共享并更新同一个神经网络的参数