Pytorch 多GPU训练

2023-11-02

Pytorch 多GPU训练



1 导入库

import torch#深度学习的pytoch平台
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset

 


2 指定GPU

2.1 单GPU声明

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

2.2 多GPU声明

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3,4,5' #指定GPU编号
device = torch.device("cuda") #创建GPU对象

 

3 数据放到GPU

x_train = Variable(train,requires_grad=True).to(device=device,dtype=torch.float32) #把训练变量放到GPU

 


4 把模型网络放到GPU 【重要】

net = DNN(layers)
net = nn.DataParallel(net)
net.to(device=device)

重要:nn.DataParallel

net = nn.DataParallel(net)
net.to(device=device)

  1. 使用 nn.DataParallel 打包模型
  2. 然后用 nn.DataParallel 的 model.to(device) 把模型传送到多块GPU中进行运算

torch.nn.DataParallel(DP)

DataParallel(DP)中的参数:

  • module即表示你定义的模型
  • device_ids表示你训练时用到的gpu device
  • output_device这个参数表示输出结果的device,默认就是在第一块卡上,因此第一块卡的显存会占用的比其他卡要更多一些。

当调用nn.DataParallel的时候,input数据是并行的,但是output loss却不是这样的,每次都会在output_device上相加计算
===> 这就造成了第一块GPU的负载远远大于剩余其他的显卡。

DP的优势是实现简单,不涉及多进程,核心在于使用nn.DataParallel将模型wrap一下,代码其他地方不需要做任何更改。

例子:
在这里插入图片描述


 

5 其他:多GPU并行

加个判断:

 if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)
model = Model(input_size, output_size)  # 实例化模型对象
if torch.cuda.device_count() > 1:  # 检查电脑是否有多块GPU
    print(f"Let's use {torch.cuda.device_count()} GPUs!")
    model = nn.DataParallel(model)  # 将模型对象转变为多GPU并行运算的模型

model.to(device)  # 把并行的模型移动到GPU上
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Pytorch 多GPU训练 的相关文章

  • 如何在python中读取多个文件中的文本

    我的文件夹中有许多文本文件 大约有 3000 个文件 每个文件中第 193 行是唯一包含重要信息的行 我如何使用 python 将所有这些文件读入 1 个文本文件 os 模块中有一个名为 list dir 的函数 该函数返回给定目录中所有文
  • InterfaceError:连接已关闭(使用 django + celery + Scrapy)

    当我在 Celery 任务中使用 Scrapy 解析函数 有时可能需要 10 分钟 时 我得到了这个信息 我用 姜戈 1 6 5 django celery 3 1 16 芹菜 3 1 16 psycopg2 2 5 5 我也使用了psyc
  • Pycharm Python 控制台不打印输出

    我有一个从 Pycharm python 控制台调用的函数 但没有显示输出 In 2 def problem1 6 for i in range 1 101 2 print i end In 3 problem1 6 In 4 另一方面 像
  • SQL Alchemy 中的 NULL 安全不等式比较?

    目前 我知道如何表达 NULL 安全的唯一方法 SQL Alchemy 中的比较 其中与 NULL 条目的比较计算结果为 True 而不是 NULL 是 or field None field value 有没有办法在 SQL Alchem
  • __del__ 真的是析构函数吗?

    我主要用 C 做事情 其中 析构函数方法实际上是为了销毁所获取的资源 最近我开始使用python 这真的很有趣而且很棒 我开始了解到它有像java一样的GC 因此 没有过分强调对象所有权 构造和销毁 据我所知 init 方法对我来说在 py
  • 在pyyaml中表示具有相同基类的不同类的实例

    我有一些单元测试集 希望将每个测试运行的结果存储为 YAML 文件以供进一步分析 YAML 格式的转储数据在几个方面满足我的需求 但测试属于不同的套装 结果有不同的父类 这是我所拥有的示例 gt gt gt rz shorthand for
  • python 集合可以包含的值的数量是否有限制?

    我正在尝试使用 python 设置作为 mysql 表中 ids 的过滤器 python集存储了所有要过滤的id 现在大约有30000个 这个数字会随着时间的推移慢慢增长 我担心python集的最大容量 它可以包含的元素数量有限制吗 您最大
  • Python:字符串不会转换为浮点数[重复]

    这个问题在这里已经有答案了 我几个小时前写了这个程序 while True print What would you like me to double line raw input gt if line done break else f
  • Geopandas 设置几何图形:MultiPolygon“等于 len 键和值”的 ValueError

    我有 2 个带有几何列的地理数据框 我将一些几何图形从 1 个复制到另一个 这对于多边形效果很好 但对于任何 有效 多多边形都会返回 ValueError 请指教如何解决这个问题 我不知道是否 如何 为什么应该更改 MultiPolygon
  • 如何将 numpy.matrix 提高到非整数幂?

    The 运算符为numpy matrix不支持非整数幂 gt gt gt m matrix 1 0 0 5 0 5 gt gt gt m 2 5 TypeError exponent must be an integer 我想要的是 oct
  • Numpy 优化

    我有一个根据条件分配值的函数 我的数据集大小通常在 30 50k 范围内 我不确定这是否是使用 numpy 的正确方法 但是当数字超过 5k 时 它会变得非常慢 有没有更好的方法让它更快 import numpy as np N 5000
  • 如何改变Python中特定打印字母的颜色?

    我正在尝试做一个简短的测验 并且想将错误答案显示为红色 欢迎来到我的测验 您想开始吗 是的 祝你好运 法国的首都是哪里 法国 随机答案不正确的答案 我正在尝试将其显示为红色 我的代码是 print Welcome to my Quiz be
  • VSCode:调试配置中的 Python 路径无效

    对 Python 和 VSCode 以及 stackoverflow 非常陌生 直到最近 我已经使用了大约 3 个月 一切都很好 当尝试在调试器中运行任何基本的 Python 程序时 弹出窗口The Python path in your
  • 在 Pandas DataFrame Python 中添加新列[重复]

    这个问题在这里已经有答案了 例如 我在 Pandas 中有数据框 Col1 Col2 A 1 B 2 C 3 现在 如果我想再添加一个名为 Col3 的列 并且该值基于 Col2 式中 如果Col2 gt 1 则Col3为0 否则为1 所以
  • 使用基于正则表达式的部分匹配来选择 Pandas 数据帧的子数据帧

    我有一个 Pandas 数据框 它有两列 一列 进程参数 列 包含字符串 另一列 值 列 包含相应的浮点值 我需要过滤出部分匹配列 过程参数 中的一组键的子数据帧 并提取与这些键匹配的数据帧的两列 df pd DataFrame Proce
  • 在 Python 类中动态定义实例字段

    我是 Python 新手 主要从事 Java 编程 我目前正在思考Python中的类是如何实例化的 我明白那个 init 就像Java中的构造函数 然而 有时 python 类没有 init 方法 在这种情况下我假设有一个默认构造函数 就像
  • 您可以在 Python 类型注释中指定方差吗?

    你能发现下面代码中的错误吗 米皮不能 from typing import Dict Any def add items d Dict str Any gt None d foo 5 d Dict str str add items d f
  • Spark.read 在 Databricks 中给出 KrbException

    我正在尝试从 databricks 笔记本连接到 SQL 数据库 以下是我的代码 jdbcDF spark read format com microsoft sqlserver jdbc spark option url jdbc sql
  • 改变字典的哈希函数

    按照此question https stackoverflow com questions 37100390 towards understanding dictionaries 我们知道两个不同的字典 dict 1 and dict 2例
  • Pandas 与 Numpy 数据帧

    看这几行代码 df2 df copy df2 1 df 1 df 1 values 1 df2 ix 0 0 我们的教练说我们需要使用 values属性来访问底层的 numpy 数组 否则我们的代码将无法工作 我知道 pandas Data

随机推荐