Mac上使用GPU加速训练模型

2023-11-12

文章目录

前言

上一篇文章中我介绍了使用pytorch的一个完整模型训练套路,其中没有使用gpu,如果要使用gpu的话,win上我们可以使用cuda,mac上可以使用mps,而我自己是mac电脑,需要进行如下修改。

使用GPU

import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from model2 import *
import time

# 创建数据集
train_data = torchvision.datasets.CIFAR10("./source", train=True,
                                          transform=torchvision.transforms.ToTensor(), download=True)
test_data = torchvision.datasets.CIFAR10("./source", train=False,
                                          transform=torchvision.transforms.ToTensor(), download=True)

# 加载数据集
train_loader = DataLoader(train_data, batch_size=64)
test_loader = DataLoader(test_data, batch_size=64)

# 查看数据集长度
train_data_size = len(train_data)
test_data_size = len(test_data)
print(f"训练数据集的大小为{train_data_size}")
print(f"测试数据集的大小为{test_data_size}")

# 创建网络模型 搭建神经网络
# class Aniu(nn.Module):
#     def __init__(self):
#         super(Aniu, self).__init__()
#         self.model = nn.Sequential(
#             nn.Conv2d(3, 32, 5, 1, 2),
#             nn.MaxPool2d(2),
#             nn.Conv2d(32, 32, 5, 1, 2),
#             nn.MaxPool2d(2),
#             nn.Conv2d(32, 64, 5, 1, 2),
#             nn.MaxPool2d(2),
#             nn.Flatten(),
#             nn.Linear(64 * 4 * 4, 64),
#             nn.Linear(64, 10)
#         )
#
#     def forward(self, x):
#         x = self.model(x)
#         return x

# 定义训练的设备
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
# 创建神经网络模型
aniu = Aniu()
aniu = aniu.to(device)

# 定义损失函数
loss_fn = nn.CrossEntropyLoss()
loss_fn = loss_fn.to(device)
# 定义优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(aniu.parameters(), lr=learning_rate)

# 训练网络:

# 设置训练网络的一些参数:
# 总共训练次数
total_train_step = 0
# 总共测试次数
total_test_step = 0
# 总轮次
epoch = 50

# 添加 tensorboard 以便观察
writer = SummaryWriter("./log_train2")
start_time = time.time()
for i in range(epoch):
    print(f"------------第{i+1}轮训练开始------------")

    # 训练开始
    aniu.train()
    for data in train_loader:
        imgs, targets = data
        imgs = imgs.to(device)
        targets = targets.to(device)
        output = aniu(imgs)
        loss = loss_fn(output, targets)

        # 优化器优化模型
        optimizer.zero_grad() # 优化器梯度清零
        loss.backward()
        optimizer.step()

        total_train_step += 1
        if total_train_step % 100 == 0 :
            end_time = time.time()
            print(end_time - start_time)
            print(f"训练次数{total_train_step},Loss:{loss.item()}")
            writer.add_scalar("train_loss", loss.item(), total_train_step)

    # 测试步骤开始:
    aniu.eval()
    total_test_loss = 0

    # 整体正确的个数
    total_accuracy = 0
    with torch.no_grad():
        for data in test_loader:
            imgs, targets = data
            imgs = imgs.to(device)
            targets = targets.to(device)
            output = aniu(imgs)
            loss = loss_fn(output, targets)
            total_test_loss += loss.item()
            accuracy = (output.argmax(1) == targets).sum()
            total_accuracy = total_accuracy + accuracy

    print(f"整体测试集上的Loss为{total_test_loss}")
    print(f"整体测试集上的正确率:{total_accuracy / test_data_size}")

    writer.add_scalar("test_loss", total_test_loss, total_test_step)
    writer.add_scalar("test_accuracy", total_accuracy / test_data_size, total_test_step)

    total_test_step += 1

    torch.save(aniu, f"aniu_{i}.pth")
    print("模型已保存")

writer.close()

总的来说就是添加了几行代码:

# 定义训练的设备
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

aniu = aniu.to(device)

loss_fn = loss_fn.to(device)

output = aniu(imgs)

loss = loss_fn(output, targets)

速度大概快了10几倍。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Mac上使用GPU加速训练模型 的相关文章

随机推荐

  • javascript高级 --- 惰性函数

    一 介绍 惰性函数表示在函数执行的过程中 函数会在首次被成功调用的时候覆盖当前函数 成功后的逻辑不会被执行 二 案例 因为浏览器行为之间的差异 我们在处理一些差异的同时 必须考虑其兼容性问题 例如 addEventListener remo
  • 微服务swagger公共模块(SpringBoot 2.7.7 Swagger 3.0.0)

    一 SpringBoot和Swagger版本 SpringBoot
  • 电脑安装多个版本Java如何进行快速切换

    安装好Java之后 首先查看环境变量 在Path栏中寻找地址值为 C盘 java bin 之类的值 删除 然后找到该目录 删除具有java exe javaw exe javaws exe的文件夹 我们可以运用批处理脚本 进行快速的Java
  • javaweb——jsp(学习总结,javaweb必备技能)

    javaweb jsp 1 jsp简介 2 jsp的生命周期 3 jsp的三种语法 3 1 头部的page 指令 3 1 1 page指令的相关属性 3 2 表达式脚本 3 3 jsp注释 4 jsp的九大隐含对象 内置对象 5 jsp的四
  • mysql迁移kingbase缺少is_ipv4函数自定义实现

    mysql迁移kingbase缺少is ipv4函数自定义实现 TOC mysql is ipv4函数 is ipv4函数判断传入的字符串是否是一个ipv4地址 IS IPV4 expr Returns 1 if the argument
  • Elasticsearch学习

    0 带着问题上路 ES是如何产生的 1 思考 大规模数据如何检索 如 当系统数据量上了10亿 100亿条的时候 我们在做系统架构的时候通常会从以下角度去考虑问题 1 用什么数据库好 mysql sybase oracle 达梦 神通 mon
  • configure: error: C preprocessor "/lib/cpp" fails

    2019独角兽企业重金招聘Python工程师标准 gt gt gt 错误代码 root localhost libevent 2 0 21 stable configure make make install checking for a
  • ARTS挑战打卡第十八周

    Algorithm 一周至少一道算法题 Review 阅读并点评至少一篇英文技术文章 Tip 学习至少一个技术技巧 总结和归纳在日常工作中所遇到的知识点 Share 分享一篇有观点和思考的技术文章 01 Algorthm https lee
  • 推荐一个开源虚拟化及云管管理平台

    能解决哪些问题 将几台物理服务器虚拟化成一个私有云平台 需要一个紧凑而且功能相对完整的物理机全生命周期管理工具 将 VMware vSphere 虚拟化集群转换为一个可以自服务的私有云平台 存在使用多云场景 能够在一个界面管理私有云和公有云
  • IDEA配置.gitignore不生效

    问题描述 我在 gitignore里面添加了日志文件不进行追踪 但是每次还是都上传到了云端 gitignore并没有生效 原因 gitignore只能忽略未被track的文件 而git本地缓存 如果某些文件已经被纳入了版本管理中 则修改 g
  • C/C++ 程序自删除

    文章目录 前言 一 代码 二 部分代码解释 前言 一般病毒之内的可能都带有自删除功能 而目前可进行完美自删除的方法并不多 其中一种较好的解决方法就是利用批处理文件 批处理文件一个优点就是 即使自身在运行的情况下也可以删除自己 所以实现的逻辑
  • 【Opencv&Cpp】12 像素统计:最大/小值、平均值、标准差

    minMaxLoc 找到全局最小和最大值 meanStdDev 计算矩阵的均值和标准偏差 找到全局最小和最大值 minmaxloc minMaxLoc InputArray src double minVal double maxVal 0
  • 大家来讨论怎么写概要设计

    http blog csdn net sunwill chen article details 7864904 笔者声明 本文讲述笔者浅薄的观点 意在抛砖引玉 望网友一起发表观点共同切磋 目前网络上的概要设计格式繁多 质量也是参差不齐 许多
  • 单链表C语言代码实现

    一 代码 include
  • sqli-labs-master靶场搭建以及报错解决

    一 前提准备 1 下载 sqli labs master mirrors audi 1 sqli labs GitCode 2 安装PHP study Windows版phpstudy下载 小皮面板 phpstudy xp cn 二 搭建靶
  • 华为OD机试 - 最小传输时延(Java)

    题目描述 某通信网络中有N个网络结点 用1到N进行标识 网络通过一个有向无环图表示 其中图的边的值表示结点之间的消息传递时延 现给定相连节点之间的时延列表times i u v w 其中u表示源结点 v表示目的结点 w表示u和v之间的消息传
  • Pycharm中修改注释文本的颜色(详细设置步骤)

    下面是在Pycharm中设置注释文本颜色的详细步骤 下面是修改前后对比 修改前注释行的颜色 修改后注释行的颜色 以上就是Pycharm中修改注释文本颜色的详细步骤 希望能帮到你
  • 小程序真机调试连接本地服务器进行调试

    小程序连接本地服务器 开发小程序时经常会遇到需要连接本地服务器进行调试的时候 但是总是连接不上 这里就说一下本菜鸟连接本地服务器的方法 第一步 把下图红框的地方勾选住 好多方法都得选这一步 第二步 设置里面代理按图中勾选 第三步是连接的方法
  • JavaScript避免使用return跳出多重循环从而继续执行函数;使用break跳出多重for循环

    一 先来看一下使用break仅跳出一层for循环的用法 const foo function for let i 1 i lt 3 i for let j 1 j lt 3 j if i 2 break console log 输出j的值
  • Mac上使用GPU加速训练模型

    文章目录 前言 使用GPU 前言 上一篇文章中我介绍了使用pytorch的一个完整模型训练套路 其中没有使用gpu 如果要使用gpu的话 win上我们可以使用cuda mac上可以使用mps 而我自己是mac电脑 需要进行如下修改 使用GP