PyTorch: 训练分类CIFAR10

2023-11-11

# !/usr/bin/env python
# -- coding: utf-8 --
# @Author zengxiaohui
# Datatime:8/13/2021 11:20 AM
# @File:train_cifar10
import os
import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn as nn
from tqdm import tqdm

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


def shufflenet_v2_x0_5(nc, pretrained):
    model_ft = torchvision.models.shufflenet_v2_x0_5(pretrained=pretrained)
    num_ftrs = model_ft.fc.in_features
    model_ft.fc = nn.Linear(num_ftrs, nc)
    return model_ft


if __name__ == '__main__':
    os.environ['CUDA_VISIBLE_DEVICES'] = '1'
    epochs = 5
    batch_size = 256
    num_workers = 8
    classes = 10

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers,
                                              pin_memory=True)

    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    model = shufflenet_v2_x0_5(classes, True)
    model.cuda()
    model.train()

    criterion = nn.CrossEntropyLoss()
    # SGD with momentum
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

    for epoch in range(epochs):
        running_loss = 0.0
        for i, (inputs, labels) in tqdm(enumerate(trainloader)):
            inputs, labels = inputs.cuda(), labels.cuda()

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward
            outputs = model(inputs)
            # loss
            loss = criterion(outputs, labels)
            # backward
            loss.backward()
            # update weights
            optimizer.step()

            # print statistics
            running_loss += loss
        print('%d/%d loss: %.3f' % (epochs, epoch + 1, running_loss / len(trainset)))

    correct = 0
    model.eval()
    for j, (images, labels) in tqdm(enumerate(testloader)):
        outputs = model(images.cuda())
        _, predicted = torch.max(outputs.data, 1)
        correct += (predicted.cpu() == labels).sum()
    print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / len(testset)))

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

PyTorch: 训练分类CIFAR10 的相关文章

  • 图神经网络入门推荐好文(附GNN大佬资料包下载福利)

    今天想和大家分享的是图卷积神经网络 随着人工智能发展 很多人都听说过机器学习 深度学习 卷积神经网络这些概念 但图卷积神经网络 却不多人提起 那什么是图卷积神经网络呢 简单的来说就是其研究的对象是图数据 Graph 研究的模型是卷积神经网络
  • 【机器学习】Q-Learning详细介绍

    Q learning Q learning 是一种机器学习方法 它使模型能够通过采取正确的操作来迭代学习和改进 Q learning属于强化学习的算法 通过强化学习 可以训练机器学习模型来模仿动物或儿童的学习方式 好的行为会得到奖励或加强
  • CodeGeex AI代码提示插件使用

    这里写自定义目录标题 下载插件 Jetbrains IDEA安装示例 下载插件 可在官网按照指引安装不同编译器的插件 目前支持VS Code 和Jetbrains全家桶 官网地址 https codegeex cn zh CN Jetbra
  • 基于自然语言处理技术的智能化自然语言生成技术应用于智能写作工具开发

    文章目录 基于自然语言处理技术的智能化自然语言生成技术应用于智能写作工具开发 1 引言 2 技术原理及概念 2 1 基本概念解释 2 2 技术原理介绍 算法原理 操作步骤 数学公式等 2 2 1 语音识别 2 2 2 自然语言理解 2 2
  • ChatGPT Prompting开发实战(五)

    一 如何编写有效的prompt 对于大语言模型来说 编写出有效的prompt能够帮助模型更好地理解用户的意图 intents 生成针对用户提问来说是有效的答案 避免用户与模型之间来来回回对话多次但是用户不能从LLM那里得到有意义的反馈 本文
  • 产业AI公开课正式开播!60分钟解读AI对金融科技的全新破局

    京东数科 产业AI公开课 第一季第一期 重 磅 开 播 行业热门话题 实力业内大咖 深度解读 经典对话 绝对让你这1个小时的时间欲罢不能 干货满满 从SARS到这次新冠肺炎 黑天鹅 事件对资本市场造成极大影响 不同时期的应对之道有何不同 疫
  • Unity3D研究院之游戏开发中的人工智能AI

    人工智能这个东西在游戏中是非常重要的 人工智能说简单了就是根据随机的数字让敌人执行一些动作或逻辑 说难了TA需要一个非常复杂的算法 本文我主要说说Unity3D中人工智能的脚本如何来编写 首先你应该搞清楚的一点AI脚本属于一个工具类脚本 工
  • 解决报错ImportError: IProgress not found. Please update jupyter and ipywidgets

    在终端 pip install ipywidgets 然后重启jupyter notebook即可
  • ChatGPT 类 AI 软件供应链的安全及合规风险

    AIGC将成为重要的软件供应链 近日 OpenAI推出的ChatGPT通过强大的AIGC 人工智能生产内容 能力让不少人认为AI的颠覆性拐点即将到来 基于AI将带来全新的软件产品体验 而AI也将会成为未来软件供应链中非常重要的一环 在Ope
  • 你是否看到过如此有趣的AI网站?

    1 营销文案 CopyAI Create Marketing Copy In Seconds 2 美化ppt设计 https www beautiful ai 3 图片修改 https hotpot ai 4 照片变视频 https www
  • 【VSCode】推荐一款Microsoft Visual Studio Code能在编辑器内智能补全代码的插件 - Tabnine AI

    Tabnine AI Autocomplete for Javascript Python Typescript PHP Go Java Ruby more Tabnine是一个AI代码补全插件 支持JavaScript Python Ja
  • 当我们谈人工智能 我们在谈论什么

    我们对一个事物的认识模糊往往是因为宣传过剩冲淡了理论的真实 我们陷在狂欢里 暂时忘记为什么要狂欢 如何踏上这趟飞速发展的列车成为越来越多人心心念念的事情 人工智能的浪潮更像是新闻舆论炒起来的话题 城外的人想进去 城内的人也不想出来 当我们谈
  • 大学四年,因为这8个网站,我成为同学眼中的学霸

    作者简介 CSDN top100 阿里云博客专家 华为云享专家 网络安全领域优质创作者 推荐专栏 对网络安全感兴趣的小伙伴可以关注专栏 网络安全入门到精通 大学期间 几乎每一个教过我的老师都反应 我的学习态度不好 我上课很少仔细听老师在讲什
  • 【AI之路】使用huggingface_hub优雅解决huggingface大模型下载问题

    文章目录 前言 一 Hugging face是什么 二 准备工作 三 下载整个仓库或单个大模型文件 1 下载整个仓库 2 下载单个大模型文件 总结 附录 前言 Hugging face 资源很不错 可是国内下载速度很慢 动则GB的大模型 下
  • 使用SVM对随机生成数据集进行分类 (线性可分 硬间隔)

    具体数学原理参考 统计学习方法 在学习过程中有疑惑如下 一直想不明白为什么式7 11中的分子没有用并且可以被当作常数 下面的解释是当w与b同比例变换时 函数间隔 即分子 亦会同比例变换 的确是这样 自己纸上写一下就好 但是为什么w和b一定要
  • 武汉大学空间智能化处理复习

    空间数据处理智能化的重要性 提高地理信息处理的效率 减轻人在地理信息处理中的劳动量 使一般的地理信息用户也能让专家一样解决问题 大型的空间决策服务需要归纳 分析多种方案 智能化处理方法的来源 常常来自于人工智能学科的研究成果 如 知识工程
  • #挑战Open AI!马斯克宣布成立xAI,你怎么看?# 马斯克的xAI:充满困难与希望

    文章目录 1 什么是xAI公司 2 xAI公司的图标 3 反AI斗士 马斯克进军AI 期待与挑战并存 3 1 关于马斯克 3 2 这位 反AI斗士 3 3 我的看法 3 4 可能会遇到的困难与优势 3 5 蓄谋已久的马斯克 3 6 xAI
  • EasyRecovery易恢复2024最新免费版电脑数据恢复软件功能介绍

    EasyRecovery从 易恢复2024 支持恢复不同存储介质数据 在Windows中恢复受损和删除文件 以及能检索数据格式化或损坏卷 甚至还可以从初始化磁盘 同时 你只需要最简单的操作就可以恢复数据文件 如 硬盘 光盘 U盘 移动硬盘
  • 由于人工智能和自动化,2030 年将不存在的 6 个科技工作岗位

    我们都知道人工智能和自动化已经存在 并且有很多关于它们将如何扰乱日常业务实践以及支撑它们的专业角色的讨论 虽然预测某些工作岗位将彻底消失似乎很戏剧性 但对未来可能发生的情况保持现实态度是明智的 以便为接下来发生的事情做好准备 因此 考虑到这
  • 用对AI工具,工作效率嘎嘎提高

    随着人工智能 AI 技术的飞速发展 AI软件已经深入到我们生活的方方面面 为我们的工作和生活带来了前所未有的便利 本文将为您介绍几款具有代表性的AI软件 让您了解这一强大技术引擎的魅力所在 一 AI软件介绍 1 悦音配音 这是一款基于AI人

随机推荐

  • Android中获取View宽高方法

    今天遇到一个问题 就是view获取宽度 高度都为0的问题 其实这个大家都遇到过 这里转载别人的 大家好共同学习 本文转载于 http www jianshu com p f56c92e29dea Android开发中经常需要获取控件的宽高
  • FileZilla的下载与安装

    FileZilla的下载与安装 为什么要使用FileZilla进行文件互传呢 Windows下 FileZilla客户端下载与安装 1 FileZilla的下载 1 FileZilla的安装 1 双击运行安装包 点击 i agree 2 n
  • Shader中的一些专业用语的解释

    Shader中的一些专业用语的解释 此文章收录于我主页顶置的 Unity Shader入门精要文章目录 点击即可跳转 一 什么是OPenGL DirectX 简单的来说 就是图像应用编程的接口 这些接口用语渲染二维和三维的图形 架起了上层应
  • 【毕业设计】基于单片机的桌面炫酷律动灯条 -物联网 嵌入式 单片机

    文章目录 0 前言 1 简介 2 主要器件 3 实现效果 4 设计原理 5 部分核心代码 6 最后 0 前言 这两年开始毕业设计和毕业答辩的要求和难度不断提升 传统的毕设题目缺少创新和亮点 往往达不到毕业答辩的要求 这两年不断有学弟学妹告诉
  • 公办幼儿园教师要涨工资了???

    终于盼到这一天了 已在市区公办园上班3年多却一直没有编制的季馨 听说从明年开始要涨工资了 高兴坏了 记者从日前召开的全市学前教育工作会议上获悉 从2012年起 确保市区公办幼儿园中具有国家教师资格的聘用教师最低工资水平不得低于当地最低工资标
  • 蓝桥杯 问题 1083: Hello, world!(C/C++ vector实现)

    问题 1083 Hello world 时间限制 1Sec 内存限制 64MB 提交 944 解决 476 题目描述 This is the first problem for test Since all we know the ASCI
  • 《一周搞定模电》—功率放大器

    系列文章目录 文章目录 系列文章目录 前言 一 功率放大电路三极管的工作模式 二 功率放大器内部结构 前言 功率放大器指一种以输出较大功率为目的的放大电路 特点 输出电压大 输出电流大 放大电路的输出电阻与负载匹配 电压放大器和功率放大器的
  • 三子棋创作(c语言)

    我们写三子棋之前首先要思考一下三子棋的实现逻辑 一 1 游戏菜单 是选择开始游戏还是结束游戏 2 打印一个棋盘出来 并且进行棋盘的初始化 即没有旗子的棋盘 3 玩家下棋 用 表示 4 电脑下棋 用 表示 5 判断胜负 电脑和玩家下完棋之后
  • java使用lambda表达式对List集合进行操作(Java8)

    import java util ArrayList import java util List import java util function Predicate import java util stream Collectors
  • token会被截取吗_OAuth2 为什么要用 code 换 token

    先简单介绍下 OAuth2 再用一个例子说明下为什么要用 code 换 token OAuth2 简单介绍 4 个角色 resource owner 可以授权访问被保护资源的实体 如果是人的话 即是最终用户 resource server
  • h2数据库优缺点

    h2数据库是嵌入式的内存型数据库 也可以存储在磁盘上 效率比通过socket调用的redis执行的要快 纯java编写就一个jar h2数据库的缺点是不适合大数据量高并发的操作
  • centos 安装防火墙,并开启对应端口号

    1 查看防火墙状态 命令 systemctl status firewalld service 开启防火墙时 提示没有安装防火墙 root localhost systemctl start firewalld service Failed
  • 关于锁的面试题

    1 synchronized和ReentranctLock有什么区别 底层实现 synchronized是jvm层面的锁 通过monitor对象完成 对象只能在同步代码块和同步方法中调用wait notify方法 ReentranctLoc
  • Java多线程——线程的sleep方法、中断线程的睡眠

    一 关于Sleep方法的应用 public static void sleep long millis throws InterruptedException 让当前正在执行的线程进入休眠 暂时停止执行 指定的毫秒数 静态方法 Thread
  • 数字媒体技术专业方向

    现在是大三下 这篇文章是大一时 整理知乎青岛大学 某学姐的高赞回答 咱这个专业 你可以根据你的学校进行选择 学校好 按部就班的学 以下几个方向都走得通 学校不好 很普通 那么大概率也不学了什么 普通本科院校的学风啊 教学质量啊 与其都学个皮
  • C++11/14/17中提供的mutex系列区别

    C 11 14 17中提供的mutex系列类型如下 互斥量 C 版本 作用 mutex C 11 基本的互斥量 timed mutex C 11 timed mutex带超时功能 在规定的等待时间内 没有获取锁 线程不会一直阻塞 代码会继续
  • 监听小程序切换到后台

    注意要写在app js里面 onHide wx onAppHide
  • 图像处理学习笔记(三):基于匹配的目标识别

    Matlab图像处理学习笔记 三 基于匹配的目标识别 如果要在一幅图像中寻找已知物体 最常用且最简单的方法之一就是匹配 在目标识别的方法中 匹配属于基于决策理论方法的识别 匹配方法可以是最小距离分类器 相关匹配 本文code是基于最小距离分
  • 三进制计算机_数学糖果S10:N进制

    不同进制各有各的特点 二进制更为基础 十进制匹配人体手指数量 十二进制之基数12所含因数多 十六进制之基数16易被多次二分 六十进制结合了五进制与十二进制 世界可能是由概率控制的 现实世界中十进制被选中 计算机世界中二进制被选中 N 进 制
  • PyTorch: 训练分类CIFAR10

    usr bin env python coding utf 8 Author zengxiaohui Datatime 8 13 2021 11 20 AM File train cifar10 import os import torch