(一)pytorch单任务图像分类

2023-10-29

深度学习主要由:数据读取、网络模型、损失函数、优化器这四个部分构成

最开始不应该纠结于这些细节,应该先让代码跑起来再去研究代码是怎么写的

下面的代码只是训练部分的代码,并加上验证模型准确率的功能。

1.项目分布:创建一个文件夹my_data1,在my_data1里面创建train和valid这个文件夹

(文件夹名称固定,train和valid不要写错,不然代码跑不起来)

  train是训练集的图片,valid是验证集的图片。在train这个文件夹里面,你训练多少个类别就创建    多少个文件夹(比如我只训练两类就只创建两个文件夹cat和dog,文件夹名称不固定

  valid文件夹  的格式和train的格式一样。

2.代码参数介绍:如果你训练的类别为3,就把代码里面的num_classes=2改成num_classes=3

   代码默认只训练100轮,想训练200轮的话就把代码里面的Epoches=100改成Epoches=200

   代码默认Batch_size为4,设置多少与显卡有关,显卡越好可以设的值就越大。

   代码默认将图片resize成【224,224】再进行训练,想改的话可以对Image_Size进行修改

3.训练代码train.py:里面用到的模型是resnet18,并加载预训练模型进行训练,然后冻结前30层

import torch
from torch.autograd import Variable
import torchvision
from torchvision import datasets, transforms, models
import os
import matplotlib.pyplot as plt
import time
import torch.optim as optim

from PIL import Image, ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"


def train():
    running_loss = 0
    for batch_idx, (data, target) in enumerate(train_data):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        out = net(data)
        loss = criterion(out, target)
        running_loss += loss.item()
        loss.backward()
        optimizer.step()

    return running_loss


def test():
    correct, total = 0, 0
    with torch.no_grad():
        for _, (data, target) in enumerate(val_data):
            data, target = data.to(device), target.to(device)
            out = net(data)
            prediction = torch.max(out.data, dim=1)[1]
            total += target.size(0)
            correct += (prediction == target).sum().item()
        print('Accuracy on test set: (%d/%d)=%d %%' % (correct, total, 100 * correct / total))


if __name__ == '__main__':
    loss_list = []
    Epoches = 100
    Batch_Size = 4
    Image_Size = [224, 224]

    # 1.数据加载
    data_dir = r'D:\Code\python\完整项目放置\classify_project\multi_classification\my_dataset1'
    # 1.1 定义要对数据进行的处理
    data_transform = {x: transforms.Compose([transforms.Resize(Image_Size), transforms.ToTensor()]) for x in
                      ["train", "valid"]}
    image_datasets = {x: datasets.ImageFolder(root=os.path.join(data_dir, x), transform=data_transform[x]) for x in
                      ["train", "valid"]}
    dataloader = {x: torch.utils.data.DataLoader(dataset=image_datasets[x], batch_size=Batch_Size, shuffle=True) for x in
                  ["train", "valid"]}
    train_data, val_data = dataloader["train"], dataloader["valid"]

    index_classes = image_datasets["train"].class_to_idx
    print(index_classes)
    example_classes = image_datasets["train"].classes
    print(example_classes)

    # 2.数据预览, 在训练的时候可以注释掉
    # X_example, y_example = next(iter(dataloader["train"]))
    # img = torchvision.utils.make_grid(X_example)
    # img = img.numpy().transpose([1, 2, 0])
    # for i in range(len(y_example)):
    #     index = y_example[i]
    #     print(example_classes[index], end='   ')
    #     if (i+1)%8 == 0:
    #         print()
    # plt.imshow(img)
    # plt.show()

    # 3.模型加载, 并对模型进行微调
    net = models.resnet18(pretrained=True)
    fc_features = net.fc.in_features

    # 设置训练的类别个数,我这里只有两类所以写2
    num_classes = 2
    net.fc = torch.nn.Linear(fc_features, num_classes)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # 4.pytorch fine tune 微调(冻结一部分层)。这里是冻结网络前30层参数进行训练。
    for i, param in enumerate(net.parameters()):
        if i < 30:
            param.requires_grad = False
    net.to(device)

    # 5.定义损失函数,以及优化器
    LR = 1e-3
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=LR)

    for epoch in range(Epoches):
        loss = train()
        loss_list.append(loss)
        print("第%d轮的loss为:%5f:" % (epoch, loss))
        test()

        # net.state_dict只保存模型的参数
        # torch.save(net.state_dict(), 'Model2.pth')

        # 保存整个模型
        torch.save(net, "my_model.pth")

    plt.title("Graph")
    plt.plot(range(Epoches), loss_list)
    plt.ylabel("loss")
    plt.xlabel("epoch")
    plt.show()

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

(一)pytorch单任务图像分类 的相关文章

随机推荐

  • 编译mono-debugger-2.4出错

    usr bin ld cannot find ltermcapcollect2 ld 返回 1make fileman 错误 1 echo PKG CONFIG PATH To set the PKG CONFIG PATH value u
  • 使用jprofiler分析dump文件一个实例

    1 jstat 命令先分析一下 一次fullgc之后 old 老年代使用比例 只降低2 应该有什么大的对象常驻内存 2 可以使用jmap 命令查看对象大小 这里后面使用jprofiler 就没用这个命令 jmap histo live 72
  • 如何使用Python读写JSON文件

    1 读取JSON文件 假设我们有一个名为 data json 的文件 其内容如下 name Alice age 30 city New York 我们可以使用Python中的json模块来读取该文件并将其存储为Python对象 以下是一个读
  • NGINX proxy服务器

    1 代理原理 正向代理 内网客户机通过代理访问互联网 通常要设置代理服务器地址和端口 反向代理 外网用户通过代理访问内网服务器 内网服务器无感知 正向代理与反向代理的区别是 正向代理即是客户端代理 代理客户端 服务端不知道实际发起请求的客户
  • 西工大图书馆分拣经历的数学建模角度思考

    今天下午没课于是乎去图书馆做志愿 志愿内容简单来讲就是分拣书籍 装箱子 运走的三部曲 工作需要我们的耐心和细致 同时也要求很好的体力 做的时候我还在思考这样的一个问题 就是这件事情从数学建模角度能不能分析分析 我们所需要做出的模型假设 仅供
  • 外贸业务员专用的18个英文学习网站!

    今天 我收集了一些非常实用的英语网站 包括信息 翻译和口语等方面练习 01英语学习网站 1 https www businessenglishsite com 这个网站是由在商业领域拥有丰富经验的专业人士创建的 他们每天都使用商业英语 因此
  • WebShell工具特征流量分析合集

    目录 中国蚁剑流量抓包分析 配置代理 数据包分析 特征 中国菜刀流量抓包分析 数据包分析 特征 冰蝎流量抓包分析 配置代理 自带PhpWebshell分析 base64编码 数据包分析 弱特征 强特征 哥斯拉流量抓包分析 配置代理 生成we
  • SpringBoot异常处理

    我们在实际开发中 会因为各种问题而导致无法正常访问网址 网站的对象是群众 如果出现各种的报错信息 对于用户的体验是非常的不好的 所以我们需要对项目的内部进行异常处理 保证用户的体验舒爽 目录 1 异常处理一 默认异常处理机制 1 导入前端模
  • OneNet平台对接记录

    手头有一台支持中移动的OneNet平台的接口的烟感设备 刚好可以用来了解一下移动搭建的这套开放平台 OneNet平台简介 OneNet平台是中国移动物联网公司推出的物联网解决方案平台 对于集成了移动的物联网模块 NB IOT模块的设备 目前
  • Linux内核编译+Busybox文件系统制作(基础)

    本人小白纯属爱好折腾了好久 希望分享对小白有所帮助 linux 5 15 1 5 14 14版本都可以 编译linux 4 9 229 出错提示 cc1 error fcf protection is not compatible with
  • 十大C++实战项目,你会几个?【高薪必备】

    市面上有很多C 的实战项目 从简单到进阶 学习每个项目都可以掌握相应的知识点 如果你还是C 新手的话 那么这个C 的项目列表你可以拿去练手实战开发 毕竟学编程动手实践是少不了的 如果你不知道C 可以用来做哪些项目 可以应用在哪些地方 那么
  • 解决临时表空间不足

    第一种方法 数据库服务器切换到 oracle的根目录执行 su oracle oracle edzxbsdb source bash profile oracle edzxbsdb sqlplus as sysdba 进入sql SQL g
  • bat脚本-卸载并重新安装apk,强制关闭app并重新启动app

    卸载并重新安装apk echo off echo echo Get devices adb devices gt devices txt echo echo restartApp for f skip 1 tokens 1 delims i
  • 头条号如何快速涨100W+粉丝?

    最近一些做头条的朋友和我反映 最近头条的流量很不错 给账号的扶持很大 劝诫我们要抓住这次机会 01 提高爆文产出率 粗看是句废话 但其中藏有奥妙 依靠爆款优质内容涨粉看似 低效 但始终是最根本的途径 由此吸引的粉丝 忠诚度极高 小易这头条号
  • Go【gin和gorm框架】实现紧急事件登记的接口

    简单来说 就是接受前端微信小程序发来的数据保存到数据库 这是我写的第二个接口 相比前一个要稍微简单一些 而且因为前端页面也是我写的 参数类型自然是无缝对接 前端页面大概长这个样子 先用apifox模拟发送请求测试 apifox可以直接复制J
  • python 字符串长度

    Python是一种高级编程语言 它具有简单易学 可读性强 功能强大等特点 因此在各个领域都有广泛的应用 在Python中 字符串是一种非常重要的数据类型 它可以用来存储文本信息 比如说一段话 一篇文章等等 字符串的长度是指其中字符的个数 可
  • mysql查询 多门课程的平均成绩_数据分析中级 MySQL 任务6 总结复习

    0 入门 0 1 MySQL安装 Navicat安装 0 2 MySQL设置 Nacicat设置 包括链接点 unicode 8 0 3 创建表格 student course score teacher 1 简单查询 1 1 查询姓 猴
  • 2023年7月婴幼儿辅食市场数据分析(京东商品数据)

    随着人们对婴幼儿饮食健康的关注不断增加 市场对高品质 安全 营养丰富的辅食需求也日益旺盛 婴幼儿辅食市场增长放缓 但整体仍保持上升态势 鲸参谋数据显示 今年7月份 京东平台婴幼儿辅食市场的销量为1000万 同比增长约8 本月的总销额为3 7
  • ORM如何处理many -to -many的关系

    表之间的关联可以形成一张非常复杂的graph 但是我们对其进行抽象就会发现两个有关系的表之间只有两种可能 one to many 或者many to many many to many 时会加入一个关联表 所以这里讲述的是如何处理关联表映射
  • (一)pytorch单任务图像分类

    深度学习主要由 数据读取 网络模型 损失函数 优化器这四个部分构成 最开始不应该纠结于这些细节 应该先让代码跑起来再去研究代码是怎么写的 下面的代码只是训练部分的代码 并加上验证模型准确率的功能 1 项目分布 创建一个文件夹my data1