ROS与机器学习(三)——手写数字识别

2023-05-16

ROS与机器学习(三)——手写数字识别

目录

    • 1、理论基础
    • 2、TensorFlow中的MNIST例程
      • 2.1 创建模型
      • 2.2 训练模型
      • 2.3 评估模型
    • 3、基于ROS实现MNIST
      • 3.1 初始化ROS节点
      • 3.2 设置ROS参数
      • 3.3 加入Subscriber和Publisher
      • 3.4 加入回调函数处理图像
      • 3.5 发布识别结果

1、理论基础

MNIST的下载链接:http://yann.lecun.com/exdb/mnist/。
MNIST是一个包含数字0~9的手写体图片数据集,图片已归一化为以手写数 字为中心的28*28规格的图片。
MNIST 由训练集与测试集两个部分组成,各部分的规模如下:

   训练集:60,000个手写体图片及对应标签 

   测试集:10,000个手写体图片及对应标签

在这里插入图片描述MNIST 是一个很有名的手写数字识别数据集,对于每张图片,存储的方式是一个 28 * 28 的矩阵,但是我们在导入数据进行使用的时候会自动展平成 1 * 784(28 * 28)的向量,这在TensorFlow导入很方便,在使用命令下载数据之后,可以看到有四个数据集:
在这里插入图片描述

2、TensorFlow中的MNIST例程

MNIST 是TensorFlow中的入门例程。先用原生MNIST例程的代码实现。

#!/usr/bin/env python3  
# -*- coding: utf-8 -*-  
  
import input_data  
import tensorflow as tf  
  
#MNIST数据输入  
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)  
  
x = tf.placeholder(tf.float32,[None, 784]) #图像输入向量  
W = tf.Variable(tf.zeros([784,10]))        #权重,初始化值为全零  
b = tf.Variable(tf.zeros([10]))            #偏置,初始化值为全零  
  
#进行模型计算,y是预测,y_ 是实际  
y = tf.nn.softmax(tf.matmul(x,W) + b)  
  
y_ = tf.placeholder("float", [None,10])  
  
#计算交叉熵  
cross_entropy = -tf.reduce_sum(y_*tf.log(y))  
#接下来使用BP算法来进行微调,以0.01的学习速率  
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)  
  
#上面设置好了模型,添加初始化创建变量的操作  
init = tf.global_variables_initializer()  
#启动创建的模型,并初始化变量  
sess = tf.Session()  
sess.run(init)  

#开始训练模型,循环训练1000次  
for i in range(1000):  
    #随机抓取训练数据中的100个批处理数据点  
    batch_xs, batch_ys = mnist.train.next_batch(100)  
    sess.run(train_step, feed_dict={x:batch_xs,y_:batch_ys})  
      
''''' 进行模型评估 '''  
#判断预测标签和实际标签是否匹配  
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(y_,1))   
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))  
#计算所学习到的模型在测试数据集上面的正确率  
print( sess.run(accuracy, feed_dict={x:mnist.test.images, y_:mnist.test.labels}) )  

2.1 创建模型

x = tf.placeholder(tf.float32,[None, 784])

其中x不是一个特定的值,而是一个占位符placeholder,在TensorFlow运行计算时再输入这个值。我们希望能够输入任意数量的MNIST图像,每张图都可以展开为784维的向量。用二维的浮点数张量来表示这些图,这个张量的形状是[None,784],其中None表示此张量的第一个维度可以是任何长度。

W = tf.Variable(tf.zeros([784,10]))  
b = tf.Variable(tf.zeros([10]))

模型也需要权重值和偏置量。我们赋予tf.Variable不同的初值来创建不同的Variable:这里用全为零的张量来初始化W和b。
现在,可以实现模型了,只需要一行代码:

y = tf.nn.softmax(tf.matmul(x,W) + b)  

用tf.matmul(x,W)表示x乘以W,对应模型中的,这里x是一个二维张量,拥有多个输入;然后再加上b,把两者的和输入tf.softmax 函数中。

2.2 训练模型

为训练模型,我们需要定义一个指标来评估这个模型,也就是代价函数。常见的代价函数是“交叉熵”(cross-entropy)。交叉熵产生于信息压缩编码技术,但是后来演变成为从博弈论到机器学习等其他领域里的重要技术手段。
为计算交叉熵,首先需要添加一个用于输入真实值的占位符:

y_ = tf.placeholder("float", [None,10])

然后计算交叉熵:

cross_entropy = -tf.reduce_sum(y_*tf.log(y))

这里的交叉熵不仅用来衡量一对预测和真实值,也是所有100幅图片交叉熵的总和。相比单一数据点预测,对于100个数据点的预测表现能更好地描述模型性能。
TensorFlow在后台为计算图增加了一系列新的计算操作单元,用于实现反向传播算法和梯度下降算法,然后返回一个单一操作。当运行这个操作时,将用梯度下降算法训练模型,微调变量,不断减少函数值。

train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy) 

这里我们要求TensorFlow用梯度下降算法以0.01的学习效率最小化交叉熵。当然TensorFlow也提供了许多其他种类的优化算法,只要简单调整这一行代码即可更换。
在运行计算之前,需要初始化创建的变量:

init = tf.global_variables_initializer()

现在可以通过Session启动模型,并且初始化变量:

sess = tf.Session()  
sess.run(init)

然后开始训练模型,这里让模型循环训练1000次。

for i in range(1000):    
    batch_xs, batch_ys = mnist.train.next_batch(100)  
    sess.run(train_step, feed_dict={x:batch_xs,y_:batch_ys})

该循环的每个步骤都会随机抓取训练数据中的100个批处理数据点,然后用这些数据点作为参数替换之前的占位符来运行train_step。
理想情况下,我们希望用所有数据进行每一步训练,从而实现更好的训练结果,但这显然需要很大的计算量。所以,每一次训练可以使用不同的数据子集,这样既可以减少计算量,又可以最大化地学习到数据集的总体特性。

2.3 评估模型

首先找出预测正确的标签。tf.argmax() 是一个非常有用的函数,它能给出某个tensor对象在某一维上数据最大值所在的索引值。由于标签向量由0、1组成,因此最大值1所在的索引位置就是类别标签,比如tf.argmax(y,1)返回的是模型对于任意输入x预测到的标签值,而tf.argmax(y_,1)代表正确的标签,可以用tf.equal()来检测预测是否与真实标签匹配(索引位置一样表示匹配)。

correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(y_,1)) 

以上代码会得到一组布尔值。为了确定正确预测项的比例,可以把布尔值转换成浮点数,然后取平均值。例如[True,False,True,True]会变成[1.0,0.0,1.0,1.0],取平均值后得到0.75。

accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

最后,计算所学习到的模型在测试数据集上的正确率。

print( sess.run(accuracy, feed_dict={x:mnist.test.images, y_:mnist.test.labels}) )

执行程序,从运行效果可以看出最终结果应该是在91%左右,这个结果并不算太好,因为我们仅使用了一个非常简单的模型。如果进一步优化模型,就可以得到97%以上的正确率,最好的模型甚至可以获得超过99.7%的正确率。

3、基于ROS实现MNIST

结合ROS,利用MNIST识别输入图像中的手写数字,并且将识别结果发布出去。

#!/usr/bin/env python 
# -*- coding: utf-8 -*-
 
import rospy
from sensor_msgs.msg import Image
from std_msgs.msg import Int16
from cv_bridge import CvBridge
import cv2
import numpy as np
import input_data  
import tensorflow as tf

class MNIST():
    def __init__(self):
        image_topic = rospy.get_param("~image_topic", "")

        self._cv_bridge = CvBridge()

        #MNIST数据输入  
        self.mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)  
          
        self.x = tf.placeholder(tf.float32,[None, 784]) #图像输入向量  
        self.W = tf.Variable(tf.zeros([784,10]))        #权重,初始化值为全零  
        self.b = tf.Variable(tf.zeros([10]))            #偏置,初始化值为全零  
          
        #进行模型计算,y是预测,y_ 是实际  
        self.y = tf.nn.softmax(tf.matmul(self.x, self.W) + self.b)  
          
        self.y_ = tf.placeholder("float", [None,10])  
          
        #计算交叉熵  
        self.cross_entropy = -tf.reduce_sum( self.y_*tf.log(self.y))  
        #接下来使用BP算法来进行微调,以0.01的学习速率  
        self.train_step = tf.train.GradientDescentOptimizer(0.01).minimize(self.cross_entropy)  
          
        #上面设置好了模型,添加初始化创建变量的操作  
        self.init = tf.global_variables_initializer()  
        #启动创建的模型,并初始化变量  
        self.sess = tf.Session()  
        self.sess.run(self.init)  

        #开始训练模型,循环训练1000次  
        for i in range(1000):  
            #随机抓取训练数据中的100个批处理数据点  
            batch_xs, batch_ys = self.mnist.train.next_batch(100)  
            self.sess.run(self.train_step, feed_dict={self.x:batch_xs, self.y_:batch_ys})  

        ''''' 进行模型评估 '''  
        #判断预测标签和实际标签是否匹配  
        correct_prediction = tf.equal(tf.argmax(self.y,1),tf.argmax(self.y_,1))   
        self.accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))  
       
        #计算所学习到的模型在测试数据集上面的正确率  
        print( "The predict accuracy with test data set: \n")
        print( self.sess.run(self.accuracy, feed_dict={self.x:self.mnist.test.images, self.y_:self.mnist.test.labels}) )  

        self._sub = rospy.Subscriber(image_topic, Image, self.callback, queue_size=1)
        self._pub = rospy.Publisher('result', Int16, queue_size=1)

    def callback(self, image_msg):
        #预处理接收到的图像数据
        cv_image = self._cv_bridge.imgmsg_to_cv2(image_msg, "bgr8")
        cv_image_gray = cv2.cvtColor(cv_image, cv2.COLOR_RGB2GRAY)
        ret,cv_image_binary = cv2.threshold(cv_image_gray,128,255,cv2.THRESH_BINARY_INV)
        cv_image_28 = cv2.resize(cv_image_binary,(28,28))
        
        #转换输入数据shape,以便于用于网络中
        np_image = np.reshape(cv_image_28, (1, 784))

        predict_num = self.sess.run(self.y, feed_dict={self.x:np_image, self.y_:self.mnist.test.labels})
        
        #找到概率最大值
        answer = np.argmax(predict_num, 1)
        
        #发布识别结果
        rospy.loginfo('%d' % answer)
        self._pub.publish(answer)
        rospy.sleep(1) 

    def main(self):
        rospy.spin()

if __name__ == '__main__':
    rospy.init_node('ros_tensorflow_mnist')
    tensor = MNIST()
    rospy.loginfo("ros_tensorflow_mnist has started.")
    tensor.main()

在MNIST的基础上进行一些简单修改,使之融入ROS中。

3.1 初始化ROS节点

封装ROS节点的第一步是加入ROS节点的初始化,代码如下:

rospy.init_node('ros_tensorflow_mnist')

3.2 设置ROS参数

将图像话题名作为参数传入节点中,便于灵活设置,代码如下:

image_topic = rospy.get_param("~image_topic", "")

3.3 加入Subscriber和Publisher

创建订阅图像消息的Subscriber和发布最终识别结果的Publisher,代码如下:

self._sub = rospy.Subscriber(image_topic, Image, self.callback, queue_size=1)
self._pub = rospy.Publisher('result', Int16, queue_size=1)

3.4 加入回调函数处理图像

接收到图像后进入回调函数,然后使用cv_bridge将ROS图像转换成OpenCV的图像格式,进行识别处理,代码如下:

def callback(self, image_msg):
        cv_image = self._cv_bridge.imgmsg_to_cv2(image_msg, "bgr8")
        ......

3.5 发布识别结果

图像处理完成后,发布识别结果,并且稍作延时,等待下一次识别,代码如下:

rospy.loginfo('%d' % answer)
self._pub.publish(answer)
rospy.sleep(1) 

使用ROS命令运行基于ROS的MNIST。启动后可以看到摄像头所拍摄的图像,通过命令查看识别结果。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

ROS与机器学习(三)——手写数字识别 的相关文章

  • ros+arduino学习(六):重构ros_lib库文件

    前言 ros lib是arduino程序和ros连接的库文件 通过使用这些库文件和相关函数 可以在arduino上通过编程使得arduino硬件开ros节点程序 这样arduino硬件就可以与上位机通过话题进行通讯 从而把arduino从传
  • 无人飞行器智能感知竞赛--模拟器安装

    开发环境 win11 wsl2 注意事项 请配合视频使用 如果不看视频会对下面的配置过程迷惑 因为一开始我是想安装在ubuntu18 04的 中途发现ubuntu18 04没有ros noetic 所以转入ubuntu20 04配置 视频链
  • 关于相机与激光雷达数据采集与标定

    最近在做一个关于车路协同的项目 需要做一个路侧系统 传感器有摄像头和激光雷达 相机和激光雷达联合标定费了老半天劲 在此记录一下 雷达时间戳不对 导致摄像头和雷达的数据无法对齐 解决办法 修改雷达驱动发布点云消息时的时间戳 相机内参标定可以使
  • Ubuntu下vscode配置ROS环境

    摘要 最近准备放弃用clion开发ROS使用更主流的vscode 整理一下在ubuntu18 04下的VSCode安装和ROS环境配置流程 安装 方法一 软件商店安装 个人还是推荐使用ubuntu软件下载vscode 简单不容易出错 方法二
  • ROS noetic tf demo错误处理及python版本切换

    文章目录 报错描述及解决 ubuntu20 04下python版本切换 报错描述及解决 ubuntu版本 20 04 ROS版本 noetic roslaunch turtle tf turtle tf demo launch 报错信息 t
  • 解决ros安装 使用roscore命令测试问题

    本人安装教程完成ROS的安装后 在进行测试如图1命令 出现 解决办法输入完命令1后要输入命令2才行 即可测试成功 测试成功的界面如下
  • rosprofiler 安装和使用

    rosprofiler wiki 页面 http wiki ros org rosprofiler rosprofiler package 下载rosprofiler和ros statistics msgs 放到工程目录下编译 https
  • ubuntu18.04命令安装ros2

    ROS2官方文档 本教程为apt get命令安装方式 官网教程有点问题 借鉴一下大佬的安装方式 文章目录 1 安装ROS2 1 1 安装秘钥相关指令 1 2 授权秘钥 1 3 添加ROS2软件源 1 4 安装 2 设置环境 可选但是推荐 2
  • ROS1 ROS2学习

    ROS1 ROS2学习 安装 ROS ROS1 ROS2 命令行界面 ROS2 功能包相关指令 ROS 命令行工具 ROS1 CLI工具 ROS2 CLI工具 ROS 通信核心概念 节点 Node 节点相关的CLI 话题 Topic 编写发
  • roslaunch error: ERROR: cannot launch node of type

    今天在因为github上有个之前的包更新了 重新git clone后出现了一个问题 ERROR cannot launch node of type crazyflie demo controller py can t locate nod
  • 如何将从 rospy.Subscriber 数据获得的数据输入到变量中?

    我写了一个示例订阅者 我想将从 rospy Subscriber 获得的数据提供给另一个变量 以便稍后在程序中使用它进行处理 目前 我可以看到订阅者正在运行 因为当我使用 rospy loginfo 函数时 我可以看到打印的订阅值 虽然我不
  • 进入 docker 容器,exec 丢失 PATH 环境变量

    这是我的 Dockerfile FROM ros kinetic ros core xenial CMD bash 如果我跑docker build t ros docker run it ros 然后从容器内echo PATH 我去拿 o
  • 将 CUDA 添加到 ROS 包

    我想在 ros 包中使用 cuda 有人给我一个简单的例子吗 我尝试使用 cuda 函数构建一个静态库并将该库添加到我的包中 但总是出现链接错误 未定义的引用 cuda 我已经构建了一个可执行文件而不是库并且它可以工作 请帮忙 我自己找到了
  • 错误状态:平台不允许不安全的 HTTP:http://0.0.0.0:9090

    我正在尝试从我的 flutter 应用程序连接到 ws local host 9090 使用 rosbridge 运行 的 Ros WebSocket 服务 但我在 Flutter 中收到以下错误 错误状态 平台不允许不安全的 HTTP h
  • 我的代码的 Boost 更新问题

    我最近将 boost 更新到 1 59 并安装在 usr local 中 我的系统默认安装在 usr 并且是1 46 我使用的是ubuntu 12 04 我的代码库使用 ROS Hydro 机器人操作系统 我有一个相当大的代码库 在更新之前
  • Caught exception in launch(see debug for traceback)

    Caught exception in launch see debug for traceback Caught exception when trying to load file of format xml Caught except
  • catkin_make 编译报错 Unable to find either executable ‘empy‘ or Python module ‘em‘...

    文章目录 写在前面 一 问题描述 二 解决方法 参考链接 写在前面 自己的测试环境 Ubuntu20 04 一 问题描述 自己安装完 anaconda 后 再次执行 catkin make 遇到如下问题 CMake Error at opt
  • 无法在 Ubuntu 20.04 上安装 ROS Melodic

    我正在尝试使用这些命令在 Ubuntu 20 04 上安装 ROS Melodic sudo sh c echo deb http packages ros org ros ubuntu lsb release sc main gt etc
  • 如何使用一个凉亭同时创建两个地图?

    如下图所示 现在我的gazebo正在运行2个slam gmapping包 首先是 turtlebot slam gmapping 发布到 map 主题 第二个是 slam gmapping 发布到与第一个相同的 map 主题 我想创建一个新
  • 在 Google Colaboratory 上运行gym-gazebo

    我正在尝试在 Google Colaboratory 上运行gym gazebo 在Colab上运行gazebo服务器 没有gui的gazebo 有问题 显示警告 Unable to create X window Rendering wi

随机推荐