基于pytorch的手势识别

2023-11-13

本次实验主要是使用pytorch完成手势识别。网络包含两个隐藏层,第一层隐藏层有576个节点,第二层隐藏层有144个节点,输入784个节点(图片大小为28×28),输出10个节点(10种手势)。

目录

1. 数据集处理

2. 神经网络的建立

3. 神经网络的训练

4. 神经网络的测试


1. 数据集处理

本次实验所用数据集为自建数据集,首先预览了解数据,确保数据能够被正常载入。

import pandas
from torch.utils.data import Dataset

import torch
import matplotlib.pyplot as plt


class GestureDataset(Dataset):
    def __init__(self, csv_file):
        self.dataset = pandas.read_csv(csv_file, header=0)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        # 图像标签
        label = self.dataset.iloc[index, 0]
        target = torch.zeros(10)  # 神经网络预期输出
        target[label] = 1.0
        # 图像数据,取值范围是0~255,标准化为0~1
        image_value = torch.FloatTensor(self.dataset.iloc[index, 1:].values) / 255
        # 返回标签、图像数据张量以及目标张量
        return label, image_value, target

    def plot_image(self, index):
        arr = self.dataset.iloc[index, 1:].values.reshape(28, 28)
        plt.title("label = " + str(self.dataset.iloc[index, 0]))
        plt.imshow(arr, interpolation='none', cmap='gray')
        plt.show()


# 查看图片
gesture_dataset = GestureDataset('train.csv')
gesture_dataset.plot_image(9)
print(gesture_dataset[100])
print(len(gesture_dataset))

以上代码中各函数含义如下:

__len__() 函数的作用是返回DataFrame的大小。

__getitem__()函数索引获取数据集中的第 n 项,数据集中的第index项中提取一个标签(label)。返回值中的 target 表示神经网络的预期输出。除了与标签相对应的位置是1之外,其他值皆为0。比如手势 2 的 target 应该表示为[0, 0, 1, 0, 0, 0, 0, 0, 0, 0]。

运行以上代码,结果如下:

2. 神经网络的建立

import torch
import torch.nn as nn

import pandas, numpy
import matplotlib.pyplot as plt


class Classifier(nn.Module):
    def __init__(self):
        # 初始化pytorch父类
        super().__init__()
        # 定义神经网络
        self.model = nn.Sequential(
            nn.Linear(784, 576),
            nn.LeakyReLU(0.02),

            nn.LayerNorm(576),

            nn.Linear(576, 144),
            nn.LeakyReLU(0.02),

            nn.LayerNorm(144),

            nn.Linear(144, 10),
            nn.Sigmoid()
        )

        # 创建损失函数
        self.Loss_function = nn.BCELoss()
        # 优化器
        self.optimiser = torch.optim.Adam(self.parameters(),
                                          lr=0.0001)
        # 记录训练进展的计数器和列表
        self.counter = 0
        self.process = []

    def forward(self, inputs):
        # 直接运行模型
        return self.model(inputs)

    def cnn_train(self, inputs, targets):
        # 计算网络的输出值
        outputs = self.forward(inputs)
        # 计算损失值
        loss = self.Loss_function(outputs, targets)
        # 梯度归零,反向传播,并更新权重
        self.optimiser.zero_grad()  # 梯度全部归零
        loss.backward()
        self.optimiser.step()  # 使用梯度更新可学习参数

        # 每隔10个训练样本增加一次计数器的值,并将损失值添加进列表的末尾,共36080张图片
        self.counter += 1
        if self.counter % 10 == 0:
            self.process.append(loss.item())

        # 在每10000次训练后打印计数器的值,了解训练进展的快慢
        if self.counter % 10000 == 0:
            print("counter=", self.counter)

    # 绘制训练过程的损失值
    def plot_progress(self):
        df = pandas.DataFrame(self.process, columns=["loss"])
        df.plot(ylim=(0, 1.0), figsize=(16, 8), alpha=0.1, marker='.',
                grid=True, yticks=(0, 0.25, 0.5))
        plt.show()

3. 神经网络的训练

数据集处理部分代码保存为gesture_dataset.py,神经网络的建立部分代码保存为gesture_cnn.py,在创建了网络后,需要使用数据集训练网络,并保存网络参数以便后续使用。

import torch

from gesture_dataset import GestureDataset
from gesture_cnn import Classifier

# 创建神经网络
C = Classifier()
gesture_dataset = GestureDataset('train.csv')
# 在数据集训练神经网络
epochs = 3
for i in range(epochs):
    print('training epoch', i + 1, "of", epochs)
    for label, image_data_tensor, target_tensor in gesture_dataset:
        C.cnn_train(image_data_tensor, target_tensor)
    pass
pass

# 绘制分类器损失值
C.plot_progress()
# 保存网络
torch.save(C.model, 'gesture_cnn_model.pkl')

4. 神经网络的测试

from gesture_dataset import GestureDataset

import torch
import pandas
import numpy as np
import matplotlib.pyplot as plt

# 加载测试集数据
test_dataset = GestureDataset('test.csv')
record = 19
test_dataset.plot_image(record)

image_data = test_dataset[record][1]
# 调用训练后的神经网络
cnn_model = torch.load('gesture_cnn_model.pkl')
output = cnn_model(image_data)
# 绘制输出张量
# pandas.DataFrame(output.detach().numpy()).plot(kind='bar',
#                                                legend=False, ylim=(0, 1))
# plt.show()

predict = output.detach().numpy()
print(np.where(predict == np.max(predict))[0][0])

# 测试正确率
T_test = 0
counter_test = 0

for label, image_data_tensor, target_tensor in test_dataset:
    predict = cnn_model(image_data_tensor).detach().numpy()
    if np.where(predict == np.max(predict))[0][0] == label:
        T_test += 1
    pass
    counter_test += 1
    if counter_test % 100 == 0:
        print('counter_test = ', counter_test)
pass

test_accuracy = T_test/len(test_dataset)
print('Test Accuracy = ', test_accuracy)

# 训练集正确率
train_dataset = GestureDataset('train.csv')

T_train = 0
counter_train = 0

for label, image_data_tensor, target_tensor in train_dataset:
    predict = cnn_model(image_data_tensor).detach().numpy()
    if np.where(predict == np.max(predict))[0][0] == label:
        T_train += 1
    pass
    counter_train += 1
    if counter_train % 1000 == 0:
        print('counter_train = ', counter_train)
pass

train_accuracy = T_train/len(train_dataset)
print('Train Accuracy = ', train_accuracy)
print(len(train_dataset))

测试结果如下:

网络最终在训练集上的正确率约为99.83%,在测试集上的正确率约为97.50%,测试结果表明网络性能较好,训练结果较好,最终的手势识别效果较好。

代码注释详细,作者能力有限,如有发现问题欢迎评论提出。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

基于pytorch的手势识别 的相关文章

  • Python BigQuery 存储。并行读取多个流

    我有以下玩具代码 import pandas as pd from google cloud import bigquery storage v1beta1 import os import google auth os environ G
  • 下载 PyQt6 的 Qt Designer 并使用 pyuic6 将 .ui 文件转换为 .py 文件

    如何下载 PyQt6 的 QtDesigner 如果没有适用于 PyQt6 的 QtDesigner 我也可以使用 PyQt5 的 QtDesigner 但是如何将此 ui 文件转换为使用 PyQt6 库而不是 PyQt5 的 py 文件
  • 如何在刻度标签和轴之间添加空间

    我已成功增加刻度标签的字体 但现在它们距离轴太近了 我想在刻度标签和轴之间添加一点呼吸空间 如果您不想全局更改间距 通过编辑 rcParams 并且想要更简洁的方法 请尝试以下操作 ax tick params axis both whic
  • 如何打印没有类型的defaultdict变量?

    在下面的代码中 from collections import defaultdict confusion proba dict defaultdict float for i in xrange 10 confusion proba di
  • 如何在Windows上模拟socket.socketpair

    标准Python函数套接字 套接字对 https docs python org 3 library socket html socket socketpair不幸的是 它在 Windows 上不可用 从 Python 3 4 1 开始 我
  • 如何等到 Excel 计算公式后再继续 win32com

    我有一个 win32com Python 脚本 它将多个 Excel 文件合并到电子表格中并将其另存为 PDF 现在的工作原理是输出几乎都是 NAME 因为文件是在计算 Excel 文件内容之前输出的 这可能需要一分钟 如何强制工作簿计算值
  • 打破嵌套循环[重复]

    这个问题在这里已经有答案了 有没有比抛出异常更简单的方法来打破嵌套循环 在Perl https en wikipedia org wiki Perl 您可以为每个循环指定标签 并且至少继续一个外循环 for x in range 10 fo
  • 安装后 Anaconda 提示损坏

    我刚刚安装张量流GPU创建单独的后环境按照以下指示here https github com antoniosehk keras tensorflow windows installation 但是 安装后当我关闭提示窗口并打开新航站楼弹出
  • 从 scikit-learn 导入 make_blobs [重复]

    这个问题在这里已经有答案了 我收到下一个警告 D Programming Python ML venv lib site packages sklearn utils deprecation py 77 DeprecationWarning
  • 如何使用装饰器禁用某些功能的中间件?

    我想模仿的行为csrf exempt see here https docs djangoproject com en 1 11 ref csrf django views decorators csrf csrf exempt and h
  • 在 NumPy 中获取 ndarray 的索引和值

    我有一个 ndarrayA任意维数N 我想创建一个数组B元组 数组或列表 其中第一个N每个元组中的元素是索引 最后一个元素是该索引的值A 例如 A array 1 2 3 4 5 6 Then B 0 0 1 0 1 2 0 2 3 1 0
  • Pandas Dataframe 中 bool 值的条件前向填充

    问题 如何转发 fill boolTruepandas 数据框中的值 如果是当天的第一个条目 True 到一天结束时 请参阅以下示例和所需的输出 Data import pandas as pd import numpy as np df
  • 表达式中的 Python 'in' 关键字与 for 循环中的比较 [重复]

    这个问题在这里已经有答案了 我明白什么是in运算符在此代码中执行的操作 some list 1 2 3 4 5 print 2 in some list 我也明白i将采用此代码中列表的每个值 for i in 1 2 3 4 5 print
  • 循环中断打破tqdm

    下面的简单代码使用tqdm https github com tqdm tqdm在循环迭代时显示进度条 import tqdm for f in tqdm tqdm range 100000000 if f gt 100000000 4 b
  • 如何改变Python中特定打印字母的颜色?

    我正在尝试做一个简短的测验 并且想将错误答案显示为红色 欢迎来到我的测验 您想开始吗 是的 祝你好运 法国的首都是哪里 法国 随机答案不正确的答案 我正在尝试将其显示为红色 我的代码是 print Welcome to my Quiz be
  • 在Python中重置生成器对象

    我有一个由多个yield 返回的生成器对象 准备调用该生成器是相当耗时的操作 这就是为什么我想多次重复使用生成器 y FunctionWithYield for x in y print x here must be something t
  • 检查所有值是否作为字典中的键存在

    我有一个值列表和一本字典 我想确保列表中的每个值都作为字典中的键存在 目前我正在使用两组来确定字典中是否存在任何值 unmapped set foo set bar keys 有没有更Pythonic的方法来测试这个 感觉有点像黑客 您的方
  • 对输入求 Keras 模型的导数返回全零

    所以我有一个 Keras 模型 我想将模型的梯度应用于其输入 这就是我所做的 import tensorflow as tf from keras models import Sequential from keras layers imp
  • 从 Python 中的类元信息对 __init__ 函数进行类型提示

    我想做的是复制什么SQLAlchemy确实 以其DeclarativeMeta班级 有了这段代码 from sqlalchemy import Column Integer String from sqlalchemy ext declar
  • 协方差矩阵的对角元素不是 1 pandas/numpy

    我有以下数据框 A B 0 1 5 1 2 6 2 3 7 3 4 8 我想计算协方差 a df iloc 0 values b df iloc 1 values 使用 numpy 作为 cov numpy cov a b I get ar

随机推荐

  • 一起学nRF51xx 10 -  rng

    前言 随机数产生器 RNG 的结构 随机数发生器 RNG 根据内部热产生真实的非确定性随机数噪音 RNG通过触发START任务启动 并通过触发STOP任务停止 当随机数已经生成 它会产生一个VALRDY事件 同时把随机数存入VALUE寄存器
  • 智慧城市领域大单,巨头占尽优势

    智慧城市领域 哪个公司做的比较好 一 前言 二 智慧城市中标大单 清单 三 中标厂商分析 1 华为 2 科大讯飞 3 腾讯 4 阿里 5 中国电科 6 中国电子 7 百度 8 数字广东 四 获取 智慧城市等全套最新解决方案合集 一 前言 在
  • python eclipse+pydev(An error has occurred when creating this preference page)

    Eclipse 安装pydev Help gt Install New Software gt add gt Location http pydev org updates 点击pydev左边的小三角勾选pydev for eclipse
  • Shell init Ubuntu

    echo HISTFILESIZE 99999 gt gt bashrc echo HISTSIZE 99999 gt gt bashrc echo HISTTIMEFORMAT F T gt gt bashrc echo PROMPT C
  • Thrift原理简析(JAVA)

    Apache Thrift是一个跨语言的服务框架 本质上为RPC 同时具有序列化 反序列化机制 当我们开发的service需要开放出去的时候 就会遇到跨语言调用的问题 JAVA语言开发了一个UserService用来提供获取用户信息的服务
  • CUDA编程 基础与实践 学习笔记(十)

    线程束 warp 一个GPU由多个SM组成 一个SM上可以放多个线程块 不同线程块之间并行或顺序执行 一个线程块分为多个线程束 一个线程束由32个线程 有连续的线程号 组成 从更细粒度来看 一个SM以一个线程束为单位产生 管理 调度 执行线
  • Java面向对象 - 封装、继承和多态

    第1关 什么是封装 如何使用封装 相关知识 为了完成本关任务 你需要掌握 1 什么是封装 2 封装的意义 3 实现Java封装的步骤 package case1 public class TestPersonDemo public stat
  • GoLang之”奇怪用法“实践总结

    2013 11 23 wcdj 0 摘要 本文通过对A Tour of Go的实践 总结Go语言的基础用法 1 Go语言 奇怪用法 有哪些 1 go的变量声明顺序是 先写变量名 再写类型名 此与C C 的语法孰优孰劣 可见下文解释 http
  • 销售心理学

    销售中的心理学 影响你一生的销售心理学书籍 要想钓到鱼 就要像鱼一样思考 在生活中 如果想钓到鱼 你就得像鱼那样思考 而不是像渔夫那样思考 当你对鱼了解得越多 你也就越来越会钓鱼了 这样的想法用在销售中同样适用 要知道 销售的过程其实就是销
  • 【Redis17】Redis进阶:管道

    Redis进阶 管道 管道是啥 我们做开发的同学们经常会在 Linux 环境中用到管道命令 比如 ps ef grep php 在之前学习 Laravel框架时的 Laravel6 4 管道过滤器https mp weixin qq com
  • Latex使用

    问题 在使用latex的过程中插入图片 在某些条件下 图片可能会出现越过后续的文字出现在下一页的页首 解决办法 在该tex文件首部加上 usepackage stfloats 然后参数设置成H如下 begin figure H center
  • 使用frp 实现内网穿透 & 将私人电脑变成一个服务器

    使用frp 实现内网穿透 frp 是什么 frp 是一个可用于内网穿透的高性能的反向代理应用 支持 tcp udp 协议 为 http 和 https 应用协议提供了额外的能力 且尝试性支持了点对点穿透 作用 比如你需要用到云服务器部署你的
  • 阅读GFS论文

    GFS论文发表距今已经十几年了 据之开源的hdfs也已经在业界得到了广泛应用 为了取得分布式系统的真经 拜读一下这篇经典论文 重要假设 软硬件失败乃家常便饭 我们写大文件 不屑小文件 文件改动的主流是追加新数据 随机写是非主流 一旦写完 仅
  • Neon Instruction C支持的向量运算

    转载请标明出处 https blog csdn net u013752202 article details 92008843 文章目的 快速索引到需要的向量运算 vadd gt ri ai bi 1 Vector add 正常指令 r a
  • pagehelper使用方法及参数说明

    pagehelper使用方法及参数说明 使用方法 Override public PageInfo
  • spring源码--10--IOC高级特性--autowiring实现原理

    spring源码 10 IOC高级特性 autowiring实现原理 1 Spring IoC容器提供了2种方式 管理Bean的依赖关系 1 1 显式管理 通过BeanDefinition的属性值和构造方法实现Bean依赖关系管理 1 2
  • vue学习笔记:在vscode中使用@提示路径

    在vscode中输入 后如果可以智能提示路径 可以有效防止路径名称输入错误 减少不必要的麻烦 效果如下图所示 安装 Path Autocomplete 插件后可以实现路径的智能提示 步骤如下 1 在vscode中查找Path Autocom
  • 关于shell运行python文件中的错误——shell脚本换行

    问题 https ask csdn net questions 7900411 spm 1001 2014 3001 5505 问题由来 由于工程需要在本地window中写 当需要比较少的算力时在本地跑 当需要比较大的算力时就需要在auto
  • K8S调用GPU资源配置指南

    06 09 K8S调用GPU资源配置指南 时间 版本号 修改描述 修改人 2022年6月9日15 33 12 V0 1 新建K8S调用GPU资源配置指南 编写了Nvidia驱动安装过程 2022年6月10日11 16 52 V0 2 添加K
  • 基于pytorch的手势识别

    本次实验主要是使用pytorch完成手势识别 网络包含两个隐藏层 第一层隐藏层有576个节点 第二层隐藏层有144个节点 输入784个节点 图片大小为28 28 输出10个节点 10种手势 目录 1 数据集处理 2 神经网络的建立 3 神经