使用TensorFlow Object Detection API进行图像物体检测

2023-05-16

参考 https://github.com/tensorflow/models/tree/master/object_detection

使用TensorFlow Object Detection API进行图像物体检测

准备

  1. 安装TensorFlow

    参考 https://www.tensorflow.org/install/

    如在Ubuntu下安装TensorFlow with GPU support, python 2.7版本

    wget https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.2.0-cp27-none-linux_x86_64.whl
    pip install tensorflow_gpu-1.2.0-cp27-none-linux_x86_64.whl
  2. 配置TensorFlow Models

    • 下载TensorFlow Models
    git clone https://github.com/tensorflow/models.git
    • 编译protobuf
    
    # From tensorflow/models/
    
    protoc object_detection/protos/*.proto --python_out=.

    生成若干py文件在object_detection/protos/

    • 添加PYTHONPATH
    
    # From tensorflow/models/
    
    export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim
    • 测试
    
    # From tensorflow/models/
    
    python object_detection/builders/model_builder_test.py

    若成功,显示OK

  3. 准备数据

    参考 https://github.com/tensorflow/models/blob/master/object_detection/g3doc/preparing_inputs.md

    这里以PASCAL VOC 2012为例。

    • 下载并解压
    
    # From tensorflow/models
    
    wget http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
    tar -xvf VOCtrainval_11-May-2012.tar
    • 生成TFRecord
    
    # From tensorflow/models
    
    mkdir VOC2012
    python object_detection/create_pascal_tf_record.py \
        --label_map_path=object_detection/data/pascal_label_map.pbtxt \
        --data_dir=VOCdevkit --year=VOC2012 --set=train \
        --output_path=VOC2012/pascal_train.record
    python object_detection/create_pascal_tf_record.py \
        --label_map_path=object_detection/data/pascal_label_map.pbtxt \
        --data_dir=VOCdevkit --year=VOC2012 --set=val \
        --output_path=VOC2012/pascal_val.record

    得到pascal_train.recordpascal_val.record

    如果需要用自己的数据,则参考create_pascal_tf_record.py编写处理数据生成TFRecord的脚本。可参考 https://github.com/tensorflow/models/blob/master/object_detection/g3doc/using_your_own_dataset.md

  4. (可选)下载模型

    官方提供了不少预训练模型( https://github.com/tensorflow/models/blob/master/object_detection/g3doc/detection_model_zoo.md ),这里以ssd_mobilenet_v1_coco以例。

    
    # From tensorflow/models
    
    wget http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_coco_11_06_2017.tar.gz
    tar zxf ssd_mobilenet_v1_coco_11_06_2017.tar.gz

训练

如果使用现有模型进行预测则不需要训练。

  1. 文件结构

    为了方便查看文件,使用以下文件结构。

    models
    ├── object_detection
    │   ├── VOC2012
    │   │   ├── ssd_mobilenet_train_logs
    │   │   ├── ssd_mobilenet_val_logs
    │   │   ├── ssd_mobilenet_v1_voc2012.config
    │   │   ├── pascal_label_map.pbtxt
    │   │   ├── pascal_train.record
    │   │   └── pascal_val.record
    │   ├── infer.py
    │   └── create_pascal_tf_record.py
    ├── eval_voc2012.sh
    └── train_voc2012.sh
  2. 配置

    参考 https://github.com/tensorflow/models/blob/master/object_detection/g3doc/configuring_jobs.md

    这里使用SSD w/MobileNet,把object_detection/samples/configs/ssd_mobilenet_v1_pets.config复制到object_detection/VOC2012/ssd_mobilenet_v1_voc2012.config

    修改第9行为num_classes: 20

    修改第158行为fine_tune_checkpoint: "object_detection/ssd_mobilenet_v1_coco_11_06_2017/model.ckpt"

    修改第177行为input_path: "object_detection/VOC2012/pascal_train.record"

    修改第179行和193行为label_map_path: "object_detection/data/pascal_label_map.pbtxt"

    修改第191行为input_path: "object_detection/VOC2012/pascal_val.record"

  3. 训练

    新建tensorflow/models/train_voc2012.sh,内容以下:

    python object_detection/train.py \
        --logtostderr \
        --pipeline_config_path=object_detection/VOC2012/ssd_mobilenet_v1_voc2012.config \
        --train_dir=object_detection/VOC2012/ssd_mobilenet_train_logs \
        2>&1 | tee object_detection/VOC2012/ssd_mobilenet_train_logs.txt &

    进入tensorflow/models/,运行./train_voc2012.sh即可训练。

  4. 验证

    可一边训练一边验证,注意使用其它的GPU或合理分配显存。

    新建tensorflow/models/eval_voc2012.sh,内容以下:

    python object_detection/eval.py \
        --logtostderr \
        --pipeline_config_path=object_detection/VOC2012/ssd_mobilenet_v1_voc2012.config \
        --checkpoint_dir=object_detection/VOC2012/ssd_mobilenet_train_logs \
        --eval_dir=object_detection/VOC2012/ssd_mobilenet_val_logs &

    进入tensorflow/models/,运行CUDA_VISIBLE_DEVICES="1" ./train_voc2012.sh即可验证(这里指定了第二个GPU)。

  5. 可视化log

    可一边训练一边可视化训练的log,可看到Loss趋势。

    tensorboard --logdir ssd_mobilenet_train_logs/

    可视化验证的log,可看到Precision/mAP@0.5IOU的趋势以及具体image的预测结果。

    tensorboard --logdir ssd_mobilenet_val_logs/ --port 6007

测试

  1. 导出模型

    训练完成后得到一些checkpoint文件在ssd_mobilenet_train_logs中,如:

    • graph.pbtxt
    • model.ckpt-200000.data-00000-of-00001
    • model.ckpt-200000.info
    • model.ckpt-200000.meta

    其中meta保存了graph和metadata,ckpt保存了网络的weights。

    而进行预测时只需模型和权重,不需要metadata,故可使用官方提供的脚本生成推导图。

    python object_detection/export_inference_graph.py \
        --input_type image_tensor \
        --pipeline_config_path object_detection/VOC2012/ssd_mobilenet_v1_voc2012.config \
        --trained_checkpoint_prefix object_detection/VOC2012/ssd_mobilenet_train_logs/model.ckpt-200000 \
        --output_directory object_detection/VOC2012
  2. 测试图片

    • 运行object_detection_tutorial.ipynb并修改其中的各种路径即可。

    • 或自写编译inference脚本,如tensorflow/models/object_detection/infer.py

      import sys
      sys.path.append('..')
      import os
      import time
      import tensorflow as tf
      import numpy as np
      from PIL import Image
      from matplotlib import pyplot as plt
      
      from utils import label_map_util
      from utils import visualization_utils as vis_util
      
      PATH_TEST_IMAGE = sys.argv[1]
      PATH_TO_CKPT = 'VOC2012/frozen_inference_graph.pb'
      PATH_TO_LABELS = 'VOC2012/pascal_label_map.pbtxt'
      NUM_CLASSES = 21
      IMAGE_SIZE = (18, 12)
      
      label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
      categories = label_map_util.convert_label_map_to_categories(
          label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
      category_index = label_map_util.create_category_index(categories)
      
      detection_graph = tf.Graph()
      with detection_graph.as_default():
          od_graph_def = tf.GraphDef()
          with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
              serialized_graph = fid.read()
              od_graph_def.ParseFromString(serialized_graph)
              tf.import_graph_def(od_graph_def, name='')
      
      config = tf.ConfigProto()
      config.gpu_options.allow_growth = True
      
      with detection_graph.as_default():
          with tf.Session(graph=detection_graph, config=config) as sess:
              start_time = time.time()
              print(time.ctime())
              image = Image.open(PATH_TEST_IMAGE)
              image_np = np.array(image).astype(np.uint8)
              image_np_expanded = np.expand_dims(image_np, axis=0)
              image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
              boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
              scores = detection_graph.get_tensor_by_name('detection_scores:0')
              classes = detection_graph.get_tensor_by_name('detection_classes:0')
              num_detections = detection_graph.get_tensor_by_name('num_detections:0')
              (boxes, scores, classes, num_detections) = sess.run(
                  [boxes, scores, classes, num_detections],
                  feed_dict={image_tensor: image_np_expanded})
              print('{} elapsed time: {:.3f}s'.format(time.ctime(), time.time() - start_time))
              vis_util.visualize_boxes_and_labels_on_image_array(
                  image_np, np.squeeze(boxes), np.squeeze(classes).astype(np.int32), np.squeeze(scores),
                  category_index, use_normalized_coordinates=True, line_thickness=8)
              plt.figure(figsize=IMAGE_SIZE)
              plt.imshow(image_np)

      运行infer.py test_images/image1.jpg即可

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

使用TensorFlow Object Detection API进行图像物体检测 的相关文章

  • 关于访问说明符

    我定义了一个类基 class Base private int i Base int i this i i 所以基类的对象可以访问私有变量 class BaseDemo public static void main String args
  • 如何将未知字段类型的数据解组为 JSON

    我有这些 结构 type Results struct Gender string json gender Name struct First string json first Last string json last json nam
  • 尝试在 Facebook 中注册成就时出现 OAuthException 2500(未知路径组件)

    我正在尝试为应用程序注册 Facebook 开放图谱成就 我获取应用程序访问令牌并使用开放图 API 资源管理器发布以下请求 请注意 上面的应用程序 ID 和访问令牌不是真实的 但是 我得到以下答复 error message Unknow
  • 在Java中,当对象实例化失败时会发生什么?

    我有 C 背景 我发现自己经常在 Java 中这样做 SomeClass sc new SomeClass if null sc sc doSomething 我想知道的是 如果构造函数由于某种原因失败 比如可能没有足够的内存 变量 sc
  • 从字符串列表创建 TfRecords 并在解码后在张量流中提供图形

    目的是创建 TfRecords 数据库 给定 我有 23 个文件夹 每个文件夹包含 7500 个图像 以及 23 个文本文件 每个文件有 7500 行描述单独文件夹中 7500 个图像的特征 我通过以下代码创建了数据库 import ten
  • RESTful API:仅用于验证的方法/标头组合

    我希望我的 API 有一个仅验证请求 例如 如果我有一个 URL 例如 http api somesite com users 12345 用户正在客户端上填写一份信息表单 我最终会将其修补 放置 发布到该资源 当用户填写表单时 我可能希望
  • 将 Pytorch 模型 .pth 转换为 onnx 模型

    我有一个预训练的模型 其格式为 pth 扩展名 我想将其转换为 Tensorflow protobuf 但我没有找到任何方法来做到这一点 我见过 onnx 可以将模型从 pytorch 转换为 onnx 然后从 onnx 转换为 Tenso
  • Java控制台显示对象的地址而不是实际值[重复]

    这个问题在这里已经有答案了 好的 我正在用 Java 处理一个简单的数组 问题是 当我运行程序时 我得到的是对象的地址而不是实际值 我还发现循环 数组有问题 它应该显示房屋 3 5 和 7 但底部显示的是 3 4 和 5 我哪里出错了 请参
  • IB Java API:提取多个合约的股票数据(实时柱)

    我正在对算法交易和 IB API 进行一些自学和实验 我决定使用 Java 但我愿意切换到 C 我浏览了一个在线教程 该教程将引导您完成下面所示的代码 但我想知道是否可以将其扩展到一只股票之外 我想浏览所有 SP500 股票并检查股票数据
  • Tensorflow 到 ONNX 的转换

    我目前正在尝试转换我使用本教程创建的已保存 且正在工作 的 pb 文件 https github com thtrieu darkflow https github com thtrieu darkflow 到 onnx 文件中 我目前正在
  • 创建动态多维对象/数组

    我正在尝试使用 JS 创建一个多维数组 以便我可以通过 Ajax 调用 PHP 来发布一些数据 这可能很简单 但我对 JS 的了解很少关于这个具体的事情 这是带有代码的 JSFiddle http jsfiddle net k5Q3p 我想
  • 使用 Tkinter 显示 numpy 数组中的图像

    我对 Python 缺乏经验 第一次使用 Tkinter 制作一个 UI 显示我的数字分类程序与 mnist 数据集的结果 当图像来自 numpy 数组而不是我的 PC 上的文件路径时 我有一个关于在 Tkinter 中显示图像的问题 我为
  • Spring @RequestMapping 带有可选参数

    我的控制器在请求映射中存在可选参数的问题 请查看下面的控制器 GetMapping produces MediaType APPLICATION JSON VALUE public ResponseEntity
  • neo4j - python 驱动程序,服务不可用

    我对 neo4j 非常陌生 我正在尝试建立从 python3 6 到 neo4j 的连接 我已经安装了驱动程序 并且刚刚开始执行第一步 导入请求 导入操作系统 导入时间 导入urllib 从 neo4j v1 导入 GraphDatabas
  • NotImplementedError:无法将符号张量 (lstm_2/strided_slice:0) 转换为 numpy 数组。时间

    张量流版本 2 3 1 numpy 版本 1 20 在代码下面 define model model Sequential model add LSTM 50 activation relu input shape n steps n fe
  • Android REST API 连接

    我有点傻 对此感到抱歉 我编写了一个 API 它返回一些 JSON 我的目标是从 Android 应用程序使用此 API 我已经尝试过使用 AsyncTask 但失败了 我想像这样使用它 调用该类 告知 URL 和结果的类型 哪个json
  • 从 swift 数组创建张量

    这工作正常 import TensorFlow var t Tensor
  • Ruby 的 Faraday - 多次包含相同的参数

    我正在使用一个 API 该 API 迫使我多次发送相同的参数名称以级联不同的过滤条件 因此 示例 api GET 调用如下所示 GET http api site com search a b1 a b2 a b3 a c2 我使用 Far
  • 过滤条件的查询字符串与资源路径

    背景 我有2个资源 courses and professors A course具有以下属性 ID topic 学期号 年 部分 教授 id A professor具有以下属性 ID 学院 超级用户 名 姓 所以 你可以说一门课程有一位教
  • 通过 PayPal REST API 示例获得折扣?

    PayPal GURUS 我需要帮助 如何插入折扣 我使用 REST API 可能是某个 可以显示代码示例吗 有什么方法可以使用 PHP REST API 发送折扣金额吗 目前 REST 支付 API 不支持折扣 您最好的选择是计算您端的折

随机推荐

  • 上传文件超过限制,造成长时间无响应的解决方案

    在上传大文件 xff0c 造成长时间没有响应的情况的解决方案 xff1a 上传大文件时 xff0c 因为http协议的响应问题 xff0c 造成长时间不能向客户端发送响应请求头 解决方案 xff1a 1 向服务器发送上传大文件的reques
  • checkbox的jsTree的一个调用

    lt DOCTYPE HTML PUBLIC 34 W3C DTD HTML 4 01 Transitional EN 34 gt lt html gt lt head gt lt meta http equiv 61 34 Content
  • 灵活使用递归算法,生成Excel文件中的复合表头

    最近 xff0c 在开发中 xff0c 需要导出数据到excel文件 xff0c 文件的表头的格式是不一致的 有复合表头 xff0c 也有单表头 xff0c 那么如何灵活地生成excel文件中的复合表头 首先有一个JSON字符串格式的字段描
  • 在 ibm http server 和 websphere 之间配置 ssl

    在WebSphere的环境中 xff0c 配置SSL xff0c 有一些细节需要注意 xff1a 1 最好是先安装 ibm http server7 32bit xff0c websphere7 再安装插件 2 http server 需要
  • Ext4使用总结(二)简单的hbox布局

    布局的合理利用 xff1a 如图 xff1a xtype 39 container 39 margins 39 5 0 0 0 39 layout align 39 stretch 39 type 39 hbox 39
  • 软件开发者的精力管理(一)

    精力管理对于软件开发者来讲是非常重要的 不希望自己被长周期的项目拖垮 xff0c 不希望被连续的加班所累 我个人认为泛义的时间管理是涉及到多个方面的 而心理学 精力管理则是非常重要的 作为一名从事了多年软件开发的从业者 xff0c 我的一个
  • 如何高效能地学习和使用"工具"?

    在软件开发中 xff0c 应该注意工具的合理使用 xff0c 使得自己变得高效起来 1 工具也是产品 xff0c 有许多的工具是产品化的 既然是产品 xff0c 就很多的服务 xff0c 例如帮助文档 xff0c 论坛 xff0c 咨询人员
  • Ext4使用总结(十二) 采用 CellEditing 方式的Grid,如何取得修改的单元格数据值

    使用cellediting方式编辑数据的grid在保存数据时 xff0c 需要进行数据的处理 xff0c 所以数据处理的方式需要特别注意 cellEditing 插件的事件 listeners edit function editor e
  • 「Ubuntu」Ubuntu中的python终端配置(修改终端默认python配置,软连接,不同版本python环境配置)

    前言 通过这篇博客 xff08 Ubuntu安装Python xff09 安装完Python后 xff0c 想要在终端直接启动想启动的python版本 此时直接在终端输入python2或者python3 xff0c 发现系统已经配置好了py
  • [解题报告] CSDN竞赛第15期

    CSDN编程竞赛报名地址 xff1a https edu csdn net contest detail 29 1 求并集 题目 由小到大输出两个单向有序链表的并集 如链表 A 1 gt 2 gt 5 gt 7 链表 B 3 gt 5 gt
  • JSP开发技术四——————EL表达式

    EL xff08 Expression Language xff09 表达式 xff0c 即正则表达式 用来操作字符串 用一些特定的字符来表示一些代码操作 xff0c 这样简化代码书写 学习正则表达式 xff0c 就是学习一些特殊符号的实用
  • [解题报告] CSDN竞赛第17期

    CSDN编程竞赛报名地址 xff1a https edu csdn net contest detail 31 1 判断胜负 题目 已知两个字符串A B 连续进行读入n次 每次读入的字符串都为A B 输出读入次数最多的字符串 解题报告 模拟
  • [解题报告] CSDN竞赛第18期

    CSDN编程竞赛报名地址 xff1a https edu csdn net contest detail 32 1 单链表排序 题目 单链表的节点定义如下 xff08 C 43 43 xff09 xff1a class Node publi
  • [解题报告] CSDN竞赛第22期

    CSDN编程竞赛报名地址 xff1a https edu csdn net contest detail 36 1 c 43 43 难题 大数加法 题目 大数一直是一个c语言的一个难题 现在我们需要你手动模拟出大数加法过程 请你给出两个大整
  • [解题报告] CSDN竞赛第23期

    CSDN编程竞赛报名地址 xff1a https edu csdn net contest detail 37 1 排查网络故障 题目 A地跟B地的网络中间有n个节点 xff08 不包括A地和B地 xff09 xff0c 相邻的两个节点是通
  • CSDN竞赛第24期

    CSDN编程竞赛报名地址 xff1a https edu csdn net contest detail 38 这次写完第一道题时遇到一个奇怪的情况 xff1a 一直在 运行中 xff0c 然后发现每道题输入做任意代码都出现一直运行中 跟小
  • [Python开发] 使用python读取图片的EXIF

    使用python读取图片的EXIF 方法 使用PIL Image读取图片的EXIF 使用https pypi python org pypi ExifRead 读取图片的EXIF xff0c 得到EXIF标签 xff08 dict类型 xf
  • Partial Least Squares Regression 偏最小二乘法回归

    介绍 定义 偏最小二乘回归 多元线性回归分析 43 典型相关分析 43 主成分分析 输入 xff1a n m 的预测矩阵 X n p 的响应矩阵 Y 输出 X 和 Y 的投影 分数 矩阵 T U R n l 目标 xff1a 最大化 cor
  • 使用TensorFlow-Slim进行图像分类

    参考 https github com tensorflow models tree master slim 使用TensorFlow Slim进行图像分类 准备 安装TensorFlow 参考 https www tensorflow o
  • 使用TensorFlow Object Detection API进行图像物体检测

    参考 https github com tensorflow models tree master object detection 使用TensorFlow Object Detection API进行图像物体检测 准备 安装Tensor