深度学习之基于Xception实现四种动物识别

2023-11-09

本次实验类似于猫狗大战,只不过将两种动物识别变为了四种动物识别。
本文的重点是卷积神经网络Xception的实践,在之前的学习中,我们已经实验过其他几种比较常用的网络模型,但是Xception网络并未实践过。在弄本科毕设的时候,一个好朋友的毕设就是基于Xception实现海洋垃圾的识别,最终的实验效果达到了99%左右,由此可见Xception的模型性能还是不错的。
本次实验基于Xception实现动物识别,最终的模型准确率在95%左右。

1.导入库

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import os,pathlib,PIL

# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号

2.数据加载

data_dir = "E:/tmp/.keras/datasets/animal_photos"
data_dir = pathlib.Path(data_dir)
img_count = len(list(data_dir.glob('*/*')))

共4000张图片

all_images_paths = list(data_dir.glob('*'))
all_images_paths = [str(path) for path in all_images_paths]
all_label_names = [path.split("\\")[5].split(".")[0] for path in all_images_paths]
分为四类: ['cat', 'chook', 'dog', 'horse']

超参数的设置:

height = 224
width = 224
epochs =10
batch_size = 128

图像增强:
一共分为4类,每一类有1000张图片,数据并不是很多,因此对原数据进行数据加强。并按照8:2的比例划分训练集与测试集。

train_data_gen = tf.keras.preprocessing.image.ImageDataGenerator(
    rescale=1./255,
    rotation_range=45,
    shear_range=0.2,
    zoom_range=0.2,
    validation_split=0.2,
    horizontal_flip=True
)
train_ds = train_data_gen.flow_from_directory(
    directory=data_dir,
    target_size=(height,width),
    batch_size=batch_size,
    shuffle=True,
    class_mode='categorical',
    subset='training'
)
test_ds = train_data_gen.flow_from_directory(
    directory=data_dir,
    target_size=(height,width),
    batch_size=batch_size,
    shuffle=True,
    class_mode='categorical',
    subset='validation'
)

显示图像:

plt.figure(figsize=(15, 10))  # 图形的宽为15高为10

for images, labels in train_ds:
    for i in range(8):
        ax = plt.subplot(5, 8, i + 1)
        plt.imshow(images[i])
        plt.title(all_label_names[np.argmax(labels[i])])
        plt.axis("off")
    break
plt.show()

在这里插入图片描述

3.Xception模型

Xception是Inception的改进版本,创新点便是 深度可分离卷积

深度可分离卷积 = 深度卷积+逐点卷积。具体步骤如下所示:

第一步:Depthwise 卷积,对输入的每个channel,分别进行 3 × 3 卷积操作,并将结果 concat
第二步:Pointwise 卷积,对 Depthwise 卷积中的 concat 结果,进行 1 × 1 卷积操作。
在这里插入图片描述
标准卷积与深度可分离卷积的对比如下所示:图片来源
在这里插入图片描述
既然最终的结果是一样的,那为什么深度可分离卷积方式更优呢?
原因就是利用深度可分离卷积,参数更少,从而在迭代更新的过程中,计算量就更小

本次实验利用迁移学习采用官方模型进行训练

base_model = tf.keras.applications.Xception(weights = 'imagenet',include_top = False,pooling = 'max',input_shape = (height,width,3))
base_model.trainable = False#前面的参数设置为不可训练
input = base_model.input
x = tf.keras.layers.Dense(256,activation='relu')(base_model.output)
x = tf.keras.layers.Dense(128,activation='relu')(x)
output = tf.keras.layers.Dense(4,activation='sigmoid')(x)
model = tf.keras.models.Model(inputs = input,outputs = output)

优化器的设置:

# 设置初始学习率
initial_learning_rate = 1e-4

lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate,
        decay_steps=300,
        decay_rate=0.96,
        staircase=True)

# 将指数衰减学习率送入优化器
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)

网络编译&&训练

model.compile(
    optimizer = optimizer,
    loss = "categorical_crossentropy",
    metrics = ['accuracy']
)

history = model.fit(
    train_ds,
    validation_data = test_ds,
    epochs = epochs
)

Accuracy与Loss图如下所示:
在这里插入图片描述
模型准确率比较高,在95%左右。

4.预测&&混淆矩阵

模型保存:

model.save("E:/Users/yqx/PycharmProjects/animal_rec/model.h5")

模型加载:

model = tf.keras.models.load_model("E:/Users/yqx/PycharmProjects/animal_rec/model.h5")

预测:

plt.figure(figsize=(50,50))
num = 0
for images,labels in test_ds:
    for i in range(64):
        ax = plt.subplot(8,8,i+1)
        plt.imshow(images[i])
        img_array = tf.expand_dims(images[i],0)

        pre = model.predict(img_array)
        if np.argmax(pre) == np.argmax(labels[i]):
            plt.title(all_label_names[np.argmax(pre)])
        else:
            plt.title("False :"+str(all_label_names[np.argmax(pre)]))
        if np.argmax(pre) == np.argmax(labels[i]):
            num += 1
        plt.axis("off")
    break
plt.suptitle("The Acc rating is:{}".format(num / 64))
plt.show()

在这里插入图片描述
混淆矩阵的绘制:

from sklearn.metrics import confusion_matrix
import seaborn as sns
import pandas as pd

#绘制混淆矩阵
def plot_cm(labels,pre):
    conf_numpy = confusion_matrix(labels,pre)#根据实际值和预测值绘制混淆矩阵
    conf_df = pd.DataFrame(conf_numpy,index=all_label_names,columns=all_label_names)#将data和all_label_names制成DataFrame
    plt.figure(figsize=(8,7))

    sns.heatmap(conf_df,annot=True,fmt="d",cmap="BuPu")#将data绘制为混淆矩阵
    plt.title('混淆矩阵',fontsize = 15)
    plt.ylabel('真实值',fontsize = 14)
    plt.xlabel('预测值',fontsize = 14)
    plt.show()
test_pre = []
test_label = []
num = 0
for images,labels in test_ds:
    num = num + 1
    for image,label in zip(images,labels):
        img_array = tf.expand_dims(image,0)#增加一个维度
        pre = model.predict(img_array)#预测结果
        test_pre.append(all_label_names[np.argmax(pre)])#将预测结果传入列表
        test_label.append(all_label_names[np.argmax(label)])#将真实结果传入列表
    if num == 3:#由于硬件问题,只测试了3个batch_size
        break
plot_cm(test_label,test_pre)#绘制混淆矩阵

在这里插入图片描述
努力加油a啊

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

深度学习之基于Xception实现四种动物识别 的相关文章

随机推荐

  • python爬取英雄联盟所有皮肤

    import jsonpath import requests import json import os import time 程序开始时间 start time time from tqdm import tqdm from time
  • chrome应用商店打不开,怎么下载vue-devtools并安装呢?

    相信很多朋友曾经像我一样 安装vue devtools时总会从各种渠道最后综合转到chrome应用商店的网址 而国内chrome网页是打不开的 肿么办 一 下载 1 本地建立文件夹 自由命名 比如我的为了区分自己的和网上下载的 起名为vue
  • TypeScript 基础类型 —— void

    声明为 void 类型表示没有任何类型 当一个函数没有返回值时 通常其返回值会声明为 void 类型 function gretter void console log 123 编译成js function gretter console
  • 使用Python实现K均值聚类算法

    使用Python实现K均值聚类算法 K均值聚类算法是一种经典的无监督学习算法 它将数据集分为K个簇 每个簇中的数据点与同一簇中心点的距离最小 不同簇的数据点之间的距离较大 该算法常用于数据挖掘 图像处理等领域 以下是其优缺点和Python实
  • Electron+Vue入门(二)vue-cli3.0+electron项目初始化

    第一步 用vue cli3 0创建一个项目 打开命令行工具 vue create demo 选择默认 安装完成 第二步 安装vue cli plugin electron builder 进入项目 cd demo 进入vue项目管理器 vu
  • 怎么样理解同步清零和异步清零?

    DA专业论坛 通用设计 求助 大家是怎么样理解 同步清零和 异步清零的 查看完整版本 求助 大家是怎么样理解同步清零和异步清零的 mxflying 2005 4 20 03 45 求助 大家是怎么样理解 同步清零和 异步清零的 本人对 同步
  • ROS-kinetic中Gazebo中的机械臂仿真报错解决

    1 警告 其实是错误 但也要解决 WARN 1682069601 434351 0 000000 Controller Spawner couldn t find the expected controller manager ROS in
  • 有哪些因素影响服务器的访问速度

    在网络环境下 根据服务器提供的服务类型不同 分为文件服务器 数据库服务器 应用程序服务器 WEB服务器等 一些对服务器的了解不够深入的朋友 会认为服务器的配置越高 服务器的访问速度就会越快 这句话有一定的道理 但是服务器的配置高低只是影响服
  • 计算机视觉项目实战-图像特征检测harris、sift、特征匹配

    欢迎来到本博客 本次博客内容将继续讲解关于OpenCV的相关知识 作者简介 目前计算机研究生在读 主要研究方向是人工智能和群智能算法方向 目前熟悉python网页爬虫 机器学习 计算机视觉 OpenCV 群智能算法 然后正在学习深度学习的相
  • android中下拉菜单的制作(详解)

    在我们的android中下拉菜单的制作有两种的方法 1 一种的方式就是通过我们的布局文件的方法制作 2 第二种方式就是通过我们的java代码的方式制作 第一种方式
  • deepin 20.2版本亮度调节问题暂时解决方案

    可在设置 gt 键盘和语言 gt 快捷键 中设置自己需要的快捷键 建议alt 1和alt 2这两个 与现有快捷键没有冲突 使用原来的快捷键会提示冲突 如果覆盖了设置可能会使原来的快捷键失效 分别添加下面的命令 降低亮度 echo your
  • Anaconda 换源与更新

    参考 Windows下Anaconda安装 换源与更新 里面很详细介绍了 conda 的更新 与 Anaconda 的更新
  • Node.js入门笔记(一)——环境问题和版本号问题

    Node js入门笔记 一 1 node js的版本管理工具 nvm 2 npm全局安装和局部安装 3 开发环境安装与生产环境安装 4 其他常用的npm语法 5 版本号里面的讲究 6 npm上传包 其实就是寒假比较无聊搭了这个自己的博客网站
  • Visual Studio+VAssistX自动添加注释

    1 增加函数头注释 右击函数名 然后依次点击 Refacto gt Document Method 这个时候函数头注释就会蹦出来 不过这个注释的格式是默认的 想修改注释格式 可以通过以下方法 点击 VAssistX gt Visual VA
  • IE下载文件时,中文文件名乱码问题

    经排查 Content Disposition中的filename进行了两次URL转码 以汉字漫为例 第一次转码 漫变为 E6 BC AB 第二次转码 E6 BC AB变为 25E6 25BC 25AB 第二次转码时 因为 是特殊字符 所以
  • word2vector学习笔记(一)

    word2vector学习笔记 一 最近研究了一下google的开源项目word2vector http code google com p word2vec 其实这玩意算是神经网络在文本挖掘的一项成功应用 本文是看了论文 Distribu
  • Spring中如何在一个Bean中注入一个内部Bean呢?

    转自 Spring中如何在一个Bean中注入一个内部Bean呢 在日常开发中 有些实体类的定义 一个类中包含了另一个类 那么在Spring Bean中 同样也有此种操作 下文将讲述使用xml配置文件的方式注入内部bean的方法 实现思路 使
  • 经典面试题 为什么要用 Docker

    经典面试题 为什么要用 Docker 解决面试题 斩获心仪的 Offer 文章目录 经典面试题 为什么要用 Docker 一 Docker是什么 二 Docker 的优势 1 更高效的利用系统资源 2 更快速的启动时间 3 一致的运行环境
  • T检验python实现

    from scipy import stats import numpy as np 方差齐性检验 方差反映了一组数据与其平均值的偏离程度 方差齐性检验用以检验两组或多组数据与其均值偏离程度是否存在差异 也是很多检验和算法的先决条件 np
  • 深度学习之基于Xception实现四种动物识别

    本次实验类似于猫狗大战 只不过将两种动物识别变为了四种动物识别 本文的重点是卷积神经网络Xception的实践 在之前的学习中 我们已经实验过其他几种比较常用的网络模型 但是Xception网络并未实践过 在弄本科毕设的时候 一个好朋友的毕