实践

2023-11-13

在这里插入图片描述
在这里插入图片描述
图像分类是根据图像的语义信息将不同类别图像区分开来,是计算机视觉中重要的基本问题

猫狗分类属于图像分类中的粗粒度分类问题

Step1:准备数据

自定义数据集
(1)数据集介绍

我们使用CIFAR10数据集。CIFAR10数据集包含60,000张32x32的彩色图片,10个类别,每个类包含6,0000张。其中50,000张图片作为训练集,10000张作为验证集。这次我们只对其中的猫和狗两类进行预测。

在这里插入图片描述
(2)train_dataset和eval_dataset

自定义读取器处理训练集和测试集

paddle.reader.shuffle()表示每次缓存BUF_SIZE个数据项,并进行打乱

paddle.batch()表示每BATCH_SIZE组成一个batch

# 导入需要的包
import paddle
# import os 
# import numpy as np
# from PIL import Image
# import matplotlib.pyplot as plt
# import sys
# import pickle
# from paddle.vision.transforms import ToTensor
import paddle.nn as nn
import paddle.nn.functional as F
print("本教程基于Paddle的版本号为:"+paddle.__version__)
'''
参数配置
'''
train_parameters = {
    "input_size": [3, 32, 32],                           #输入图片的shape
    "src_path":"/home/aistudio/data/data9154/cifar-10-python.tar.gz",       #原始数据集路径
    "target_path":"/home/aistudio/cifar-10-batches-py",        #要解压的路径 
    "num_epochs": 1,                                    #训练轮数
    "train_batch_size": 64,                             #批次的大小
    "learning_strategy": {                              #优化函数相关的配置
        "lr": 0.001                                     #超参数学习率
    } 
}
def unzip_data(src_path,target_path):

    '''
    解压原始数据集,将src_path路径下的zip包解压至/home/aistudio/目录下
    '''

    if(not os.path.isdir(target_path)):    
        import tarfile
        tar = tarfile.open(src_path,'r')
        tar.extractall(PATH=target_path)
        tar.close()
    else:
        print("文件已解压")
'''
参数初始化
'''
src_path=train_parameters['src_path']
target_path=train_parameters['target_path']
batch_size=train_parameters['train_batch_size']
image_size=train_parameters['input_size']
epoch_num=train_parameters['num_epochs']
lr=train_parameters['learning_strategy']['lr']
'''
解压原始数据到指定路径
'''
unzip_data(src_path,target_path)
#定义数据序列化函数
def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

print(unpickle("cifar-10-batches-py/data_batch_1").keys())
print(unpickle("cifar-10-batches-py/test_batch").keys())

dict_keys([b’batch_label’, b’labels’, b’data’, b’filenames’])
dict_keys([b’batch_label’, b’labels’, b’data’, b’filenames’])

自定义数据集

'''
自定义数据集
'''
from paddle.io import Dataset
class MyDataset(paddle.io.Dataset):
    """
    步骤一:继承paddle.io.Dataset类
    """
    def __init__(self, mode='train'):
        """
        步骤二:实现构造函数,定义数据集大小
        """
        super(MyDataset, self).__init__()
        if mode == 'train':
            xs=[]
            ys=[]
            self.data = []
            self.label = []
            #批量读入数据
            for i in range(1,6):
                train_dict=unpickle("cifar-10-batches-py/data_batch_%d" % (i,))
                xs.append(train_dict[b'data'])
                ys.append(train_dict[b'labels'])
            #拼接数据
            Xtr = np.concatenate(xs)
            Ytr = np.concatenate(ys)
            #数据归一化处理
            for (x,y) in zip(Xtr,Ytr):  
                x= x.flatten().astype('float32')/255.0
                x= x.reshape(image_size)
                #将数据同一添加到data和label中
                self.data.append(x)
                self.label.append(np.array(y).astype('int64'))
        else:
            self.data = []
            self.label = []
            #读入数据
            test_dict=unpickle("cifar-10-batches-py/test_batch")
            X=test_dict[b'data']
            Y=test_dict[b'labels']
            for (x,y) in zip(X,Y):  
                #数据归一化处理
                x= x.flatten().astype('float32')/255.0
                x= x.reshape(image_size)
                #将数据同一添加到data和label中
                self.data.append(x)
                self.label.append(np.array(y).astype('int64'))
    def __getitem__(self, index):
        """
        步骤三:实现__getitem__方法,定义指定index时如何获取数据,并返回单条数据(训练数据,对应的标签)
        """
        #返回单一数据和标签
        data = self.data[index]
        label = self.label[index]
        #注:返回标签数据时必须是int64
        return data, np.array(label, dtype='int64')
    def __len__(self):
        """
        步骤四:实现__len__方法,返回数据集总数目
        """
        #返回数据总数
        return len(self.data)

# 测试定义的数据集
train_dataset = MyDataset(mode='train')
eval_dataset = MyDataset(mode='val')
print('=============train_dataset =============')
#输出数据集的形状和标签
print(train_dataset.__getitem__(1)[0].shape,train_dataset.__getitem__(1)[1])
#输出数据集的长度
print(train_dataset.__len__())
print('=============eval_dataset =============')
#输出数据集的形状和标签
for data, label in eval_dataset:
    print(data.shape, label)
    break
#输出数据集的长度
print(eval_dataset.__len__())

=============train_dataset =============
(3, 32, 32) 9
50000
=============eval_dataset =============
(3, 32, 32) 3
10000

Step2.网络配置

(1)RESNET网络模型

在这里插入图片描述
本示例直接调用飞桨API内置网络,resnet18进行训练

(2)飞桨内置网络

print('飞桨内置网络:', paddle.vision.models.__all__)
飞桨内置网络: ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'VGG', 'vgg11', 'vgg13', 'vgg16', 'vgg19', 'MobileNetV1', 'mobilenet_v1', 'MobileNetV2', 'mobilenet_v2', 'LeNet']
model = paddle.vision.models.resnet18()
paddle.summary(model,(1,3,32,32))

Step3.模型训练

方式1:基于基础API,完成模型的训练与预测

模型配置

接下来,用一个循环来进行模型的训练,将会:
使用 paddle.optimizer.Adam 优化器来进行优化。
使用 F.cross_entropy 来计算损失值。
使用 paddle.io.DataLoader 来加载数据并组建batch。

print('start training ... ')
# turn into training mode
model.train()

opt = paddle.optimizer.Adam(learning_rate=lr,
                            parameters=model.parameters())

train_loader = paddle.io.DataLoader(train_dataset,
                                    shuffle=True,
                                    batch_size=batch_size)

valid_loader = paddle.io.DataLoader(eval_dataset, batch_size=batch_size)

for epoch in range(epoch_num):
    for batch_id, data in enumerate(train_loader()):
        x_data = data[0]
        y_data = paddle.to_tensor(data[1])
        y_data = paddle.unsqueeze(y_data, 1)

        logits = model(x_data)
        loss = F.cross_entropy(logits, y_data)
        acc = paddle.metric.accuracy(logits,y_data)#计算精度
        if batch_id!=0 and batch_id%100==0:
            Batch = Batch + 100 
            Batchs.append(Batch)
            all_train_loss.append(loss.numpy()[0])
            all_train_accs.append(acc.numpy()[0])
            print("train_pass:{},batch_id:{},train_loss:{},train_acc:{}".format(epoch,batch_id,loss.numpy(),acc.numpy()))
        loss.backward()
        opt.step()
        opt.clear_grad() #opt.clear_grad()来重置梯度
paddle.save(model.state_dict(),'resnet18')#保存模型
draw_train_acc(Batchs,all_train_accs)
draw_train_loss(Batchs,all_train_loss)

模型验证

训练完成后,需要验证模型的效果,此时,加载测试数据集,然后用训练好的模对测试集进行预测,计算损失与精度。

# 图片预处理
def load_image(file):
        '''
        预测图片预处理
        '''
        #打开图片
        im = Image.open(file)
        #将图片调整为跟训练数据一样的大小  32*32,设定ANTIALIAS,即抗锯齿.resize是缩放
        im = im.resize((32, 32), Image.ANTIALIAS)
        #建立图片矩阵 类型为float32
        im = np.array(im).astype(np.float32)
        #矩阵转置 
        im = im.transpose((2, 0, 1))                               
        #将像素值从【0-255】转换为【0-1】
        im = im / 255.0
        #print(im)       
        im = np.expand_dims(im, axis=0)
        # 保持和之前输入image维度一致
        print('im_shape的维度:',im.shape)
        return im
'''
模型预测
'''
para_state_dict = paddle.load("resnet18")
model = paddle.vision.models.resnet18()
model.set_state_dict(para_state_dict) #加载模型参数
model.eval() #训练模式

#展示预测图片
infer_path='/home/aistudio/data/data7940/dog.png'
img = Image.open(infer_path)
plt.imshow(img)          #根据数组绘制图像
plt.show()               #显示图像

#对预测图片进行预处理
infer_img = load_image(infer_path)
infer_img = infer_img.reshape(3,32,32)

#定义标签列表
label_list = [ "airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse","ship", "truck"]

data = infer_img
dy_x_data = np.array(data).astype('float32')
dy_x_data=dy_x_data[np.newaxis,:, : ,:]
img = paddle.to_tensor (dy_x_data)
out = model(img)
lab = np.argmax(out.numpy())  #argmax():返回最大数的索引
print(label_list[lab])

在这里插入图片描述

方式2:基于高层API,完成模型的训练与预测

模型配置

#step3:训练模型
# 用Model封装模型
model = paddle.Model(model)
# 定义损失函数
model.prepare(optimizer=paddle.optimizer.Adam(parameters=model.parameters()),
                    loss=paddle.nn.CrossEntropyLoss(),
                    metrics=paddle.metric.Accuracy())
# 训练可视化VisualDL工具的回调函数
visualdl = paddle.callbacks.VisualDL(log_dir='visualdl_log')
# 启动模型全流程训练
model.fit(train_dataset,            # 训练数据集
           eval_dataset,            # 评估数据集
          epochs=epoch_num,            # 总的训练轮次
          batch_size = batch_size,    # 批次计算的样本量大小
          shuffle=True,             # 是否打乱样本集
          verbose=1,                # 日志展示格式
          save_dir='./chk_points/', # 分阶段的训练模型存储路径
          callbacks=[visualdl])     # 回调函数使用
#保存模型
model.save('model_save_dir')

模型验证

model.evaluate(eval_dataset, batch_size=batch_size, verbose=1)

模型预测

#定义标签列表
label_list = [ "airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse","ship", "truck"]
#读入测试图片并展示
infer_path='/home/aistudio/data/data7940/dog.png'
img = Image.open(infer_path)
plt.imshow(img)   
plt.show()    

#载入要预测的图片
infer_img = load_image(infer_path)
infer_img = infer_img.reshape(1,1,3,32,32)
#将图片变为数组
# infer_img=np.array(infer_img).astype('float32')
#进行预测
result = model.predict(infer_img)
# 输出预测结果
# print('results',result)
print("infer results: %s" % label_list[np.argmax(result[0][0])])  

在这里插入图片描述
给我整笑了、、

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

实践 的相关文章

  • 具有多处理功能的 Python 代码无法在 Windows 上运行

    以下简单的绝对初学者代码在 Ubuntu 14 04 Python 2 7 6 和 Cygwin Python 2 7 8 上运行 100 但在 Windows 64 位 Python 2 7 8 上挂起 我使用另一个片段观察到了同样的情况
  • Python Nose 导入错误

    我似乎无法理解鼻子测试框架 https nose readthedocs org en latest 识别文件结构中测试脚本下方的模块 我已经设置了演示该问题的最简单的示例 下面我会解释一下 这是包文件结构 init py foo py t
  • Python 在 chroot 中运行时出现错误

    我尝试在 chroot 中运行一些 Python 程序 但出现以下错误 Could not find platform independent libraries
  • 将 API 数据存储到 DataFrame 中

    我正在运行 Python 脚本来从 Interactive Brokers API 收集金融市场数据 连接到API后 终端打印出请求的历史数据 如何将数据保存到数据帧中而不是在终端中流式传输 from ibapi wrapper impor
  • Python sqlite3游标没有属性commit

    当我运行这段代码时 path Scripts wallpapers single png conn sqlite3 connect Users Heaven Library Application Support Dock desktopp
  • conda 无法从 yml 创建环境

    我尝试运行下面的代码来从 YAML 文件创建虚拟 Python 环境 我在 Ubuntu 服务器上的命令行中运行代码 虚拟环境名为 py36 当我运行下面的代码时 我收到下面的消息 环境也没有被创建 这个问题是因为我有几个必须使用 pip
  • 如何在 ReportLab 段落中插入回车符?

    有没有办法在 ReportLab 的段落中插入回车符 我试图将 n 连接到我的段落字符串 但这不起作用 Title Paragraph Title n Page myStyle 我想要这样做 因为我将名称放入单元格中 并且想要控制单元格中的
  • 用Python中的嵌套for循环替换重复的if语句?

    在我编写的下面的代码中 n 4 所以有五个 if 语句 所以如果我想将 n 增加到 比如说 10 那么就会有很多 if 语句 因此我的问题是 如何用更优雅的东西替换所有 if 语句 n p 4 5 number of trials prob
  • 在 matplotlib 中使用 yscale('log') 时缺少误差线

    在某些情况下 当使用对数刻度时 matplotlib 会错误地显示带有误差条的图 假设这些数据 例如在 pylab 内 s 19 0 20 0 21 0 22 0 24 0 v 36 5 66 814250000000001 130 177
  • Python 视频框架

    我正在寻找一个 Python 框架 它将使我能够播放视频并在该视频上绘图 用于标记目的 我尝试过 Pyglet 但这似乎效果不是特别好 在现有视频上绘图时 会出现闪烁 即使使用双缓冲和所有这些好东西 而且似乎没有办法在每帧回调期间获取视频中
  • Python正则表达式从字符串中获取浮点数

    我正在使用正则表达式来解析字符串中的浮点数 re findall a zA Z d d t 是我使用的代码 这段代码有问题 如果数字和任何字符之间没有空格 则不会解析该数字 例如 0 1 2 3 4 5 6 7 8 9 的预期输出为 0 1
  • 一起使用 Flask 和 Tornado?

    我是以下的忠实粉丝Flask 部分是因为它很简单 部分是因为它有很多扩展 http flask pocoo org extensions 然而 Flask 是为了在 WSGI 环境中使用而设计的 而 WSGI 不是非阻塞的 所以 我相信 它
  • 求解不等式系统时“多项式错误:仅允许使用单变量多项式”

    我想找到以下两个常数的区间cons1 and cons2我写了下面的代码 from sympy import Poly from sympy import Abs from sympy solvers inequalities import
  • 为什么“return self”返回 None ? [复制]

    这个问题在这里已经有答案了 我正在尝试获取链的顶部节点getTopParent 当我打印出来时self name 它确实打印出了父实例的名称 然而 当我回来时self 它返回 None 为什么是这样 class A def init sel
  • 从 IMDbPy 结果中的片目中获取电影 ID

    我正在尝试创建一个数据集 允许我根据 Python IMDb API 中的演员 ID 和电影 ID 加入演员和电影 现在 我正在尝试从演员的电影作品中提取电影 ID 列表 但无法做到 例如 我知道 Rodney Dangerfield 在
  • Spark中的count和collect函数抛出IllegalArgumentException

    当我使用时抛出此异常时 我尝试在本地 Spark 上加载一个小数据集count 在 PySpark 中 take 似乎有效 我试图搜索这个问题 但没有找到原因 看来RDD的分区有问题 有任何想法吗 先感谢您 sc stop sc Spark
  • Spyder 如何在同一线程的后台运行 asyncio 事件循环(或者确实如此?)

    我已经研究 asyncio 模块 功能几天了 因为我想将它用于我的应用程序的 IO 绑定部分 并且我认为我现在对它的工作原理有一个合理的理解 或者在至少我认为我已经理解了以下内容 任一时刻 任一线程中只能运行一个异步事件循环 一旦一切都设置
  • 如何设置 matplotlib 表中列的背景颜色

    我在一个目录中有多个 txt 文件 例如 d memdump 0 txt 1 txt 10 txt 示例文本文件如下 Applications Memory Usage kB Uptime 7857410 Realtime 7857410
  • 最小硬币找零问题——回溯

    我正在尝试用最少数量的硬币解决硬币找零问题 采用回溯法 我实际上已经完成了它 但我想添加一些选项 按其单位打印硬币数量 而不仅仅是总数 这是我下面的Python代码 def minimum coins coin list change mi
  • 在Python中从日期时间中减去秒

    我有一个 int 变量 它实际上是秒 让我们调用这个秒数X 我需要得到当前日期和时间 以日期时间格式 减去的结果X秒 Example If X是 65 当前日期是2014 06 03 15 45 00 那么我需要得到结果2014 06 03

随机推荐

  • Uboot启动参数说明

    29 Uboot 启动参数说明 bootcmd cp b 0xc4200000 0x7fc0 0x200000 bootm 倒计时到 0 以后 自动执行的指令 bootdelay 2 baudrate 38400 串口波特率 一般使用 38
  • Springboot实现MQTT通信

    目录 一 MQTT简介 1 MQTT协议 2 MQTT协议特点 二 MQTT服务器搭建 三 使用Springboot整合MQTT协议 1 在父工程下创建一个Springboot项目作为消息的提供者 1 1 导入依赖包 1 2 修改配置文件
  • vue3 的 ref、 toRef 、 toRefs

    1 ref 对原始数据进行拷贝 当修改 ref 响应式数据的时候 模版中引用 ref 响应式数据的视图处会发生改变 但原始数据不会发生改变
  • 同行评审

    在IBM 微软等很多公司都有一个很好的实践 那就是代码复审 这种代码审查的过程 不是将代码发给某一个人或某几个人去看 而是强调程序员自己定期走上台 向其他人讲解自己源程序的活动 因为要向大家讲解自己的程序 程序员会极其重视自己的工作进度 代
  • SeleniumLibrary4.5.0 关键字详解(九)

    SeleniumLibrary4 5 0 关键字详解 九 库版本 4 5 0 库范围 全局 命名参数 受支持 简介 SeleniumLibrary是Robot Framework的Web测试库 本文档说明了如何使用SeleniumLibra
  • linux安装rz、sz上传下载文件工具

    在centos版本linux系统中执行如下命令 yum install lrzsz 如下图所以即可安装成功
  • windows 7编辑启动选项

    问题 开机之后 提示编辑启动选项 路径 windows system32 winload exe 分区 1 硬盘 f3c3f39 NOEXECUTE OPTIN 如图 解决步骤 1 按回车键 进入操作系统之后 查看启动项配置 msconfi
  • 自定义ZoomRecyclerView可缩放可点击

    可直接使用喔 public class PinchRecyclerView extends RecyclerView implements View OnTouchListener private static final int INVA
  • html网页效果跳动的心

    跳动的心代码 用到了css的轮廓 动画效果
  • 使用Eclipse Babel语言包汉化eclipse

    eclipse下载下来是默认是英文版的 在eclipse的设置里似乎不能直接更改eclipse的语言文字 我想把eclipse改成中文版 我发现在官网上有个叫Eclispe Babel的可以更改Eclipse的语言 这是一个多国语言包 可以
  • 2.7-3 Android Studio 的Gradle一点理解, 查看gradle 版本和android 插件的版本

    参考 https developer android com studio releases gradle plugin html gradle 最大的优点就是对依赖管理的强力支持 查看gradle 版本和android 插件的版本 Fil
  • Kubernetes 101,第一部分,基础知识

    已经有一段时间了 我想花点时间坐下来写写关于Kubernetes 的文章 时机已到 简而言之 Kubernetes是一个用于自动化和管理容器化应用程序的开源系统 Kubernetes 就是关于容器的 如果你对什么是容器不太了解 请先参考我的
  • 函数模板,重载函数模板,模板的显式具体化,实例化

    目录 一 函数模板应用场景 二 函数模板 1 直白理解函数模板 函数模板就是建立一个通用的函数 其参数类型和返回类型不具体指定 用一个虚拟的类型来代表 2 函数模板的声明 3 函数模板的代码 三 重载的模板 1 为什么要使用重载模板 2 代
  • H5静态页面跳转微信小程序;从外部浏览器,点击H5链接跳转打开微信小程序;以及在微信内直接点击H5链接打开微信小程序;

    参考链接 需求 从外部浏览器 点击H5链接跳转打开微信小程序 以及在微信内直接点击H5链接打开微信小程序 步骤1 小程序开发需要使用云开发创建项目 使用云开发生成的项目会自带云函数文件夹 步骤2 项目开启云开发 步骤3 下载官方的H5静态h
  • 详解AdaBoost

    Boosting 顾名思义 这是一个增强算法 而它增强的对象 就是机器学习中我们所熟知的学习器 在Valiant引入的PAC Probably Approximately Correct 又称可能近似正确 中 学习器可被分为强学习器和弱学习
  • 虚拟机同时连接内网 (通过网线连接到开发板) 和外网 (连接至Internet) 的一种解决办法

    因为嵌入式实验需要搭建开发环境 因此需要将虚拟机通过网线连接到开发板 同时因为更新的需要 也要将虚拟机连接至Internet 所以写了一下自己的解决方法 注 我的虚拟机为VMware 装的是Linux系统 Ubuntu 目录 虚拟机同时连接
  • Visual Studio 2015开发Android App启动调试始终无法完成应用部署的解决方案

    创建一个Android App项目后 直接启动调试发现Visual Studio Emulator for Android已成功运行 但应用始终处于Build中 等待时间超过1小时 并未如预期通过adb部署到模拟器中 将应用直接导出apk
  • 机器学习算法竞赛实战--1,初见竞赛

    目录 竞赛流程 1 2 思考练习 之所以强烈推荐用竞赛作为积极学习适当的重要方式是因为他实在是一个快速入门 积极学习的极佳方式 对于初学者来说 他们的水平并不足以支撑他们直接进到企业接触实际的应用场景 而从书里得来的知识终究有些浅薄 在时代
  • 面试经典150题(1)

    文章目录 前言 除自身以外数组的乘积 要求 思路 代码 跳跃游戏 要求 题解 代码 跳跃游戏 要求 题解 代码 前言 今天开始我将陆续为大家更新面试经典150题中较难理解的题目 今天我为大家分享的是 除自身以外数组的乘积 跳跃游戏 和 跳跃
  • 实践

    CNN到ResNet Step1 准备数据 自定义数据集 Step2 网络配置 1 RESNET网络模型 2 飞桨内置网络 Step3 模型训练 方式1 基于基础API 完成模型的训练与预测 模型配置 模型验证 方式2 基于高层API 完成