使用TensorFlow Object Detection API进行图像物体检测


参考 https://github.com/tensorflow/models/tree/master/object_detection

  1. 安装TensorFlow

    参考 https://www.tensorflow.org/install/

    如在Ubuntu下安装TensorFlow with GPU support, python 2.7版本

    wget https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.2.0-cp27-none-linux_x86_64.whl
    pip install tensorflow_gpu-1.2.0-cp27-none-linux_x86_64.whl
  2. 配置TensorFlow Models

    • 下载TensorFlow Models
    git clone https://github.com/tensorflow/models.git
    • 编译protobuf
    # From tensorflow/models/
    protoc object_detection/protos/*.proto --python_out=.


    # From tensorflow/models/
    export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim
    • 测试
    # From tensorflow/models/
    python object_detection/builders/model_builder_test.py


  3. 准备数据

    参考 https://github.com/tensorflow/models/blob/master/object_detection/g3doc/preparing_inputs.md

    这里以PASCAL VOC 2012为例。

    • 下载并解压
    # From tensorflow/models
    wget http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
    tar -xvf VOCtrainval_11-May-2012.tar
    • 生成TFRecord
    # From tensorflow/models
    mkdir VOC2012
    python object_detection/create_pascal_tf_record.py \
        --label_map_path=object_detection/data/pascal_label_map.pbtxt \
        --data_dir=VOCdevkit --year=VOC2012 --set=train \
    python object_detection/create_pascal_tf_record.py \
        --label_map_path=object_detection/data/pascal_label_map.pbtxt \
        --data_dir=VOCdevkit --year=VOC2012 --set=val \


    如果需要用自己的数据,则参考create_pascal_tf_record.py编写处理数据生成TFRecord的脚本。可参考 https://github.com/tensorflow/models/blob/master/object_detection/g3doc/using_your_own_dataset.md

  4. (可选)下载模型

    官方提供了不少预训练模型( https://github.com/tensorflow/models/blob/master/object_detection/g3doc/detection_model_zoo.md ),这里以ssd_mobilenet_v1_coco以例。

    # From tensorflow/models
    wget http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_coco_11_06_2017.tar.gz
    tar zxf ssd_mobilenet_v1_coco_11_06_2017.tar.gz



  1. 文件结构


    ├── object_detection
    │   ├── VOC2012
    │   │   ├── ssd_mobilenet_train_logs
    │   │   ├── ssd_mobilenet_val_logs
    │   │   ├── ssd_mobilenet_v1_voc2012.config
    │   │   ├── pascal_label_map.pbtxt
    │   │   ├── pascal_train.record
    │   │   └── pascal_val.record
    │   ├── infer.py
    │   └── create_pascal_tf_record.py
    ├── eval_voc2012.sh
    └── train_voc2012.sh
  2. 配置

    参考 https://github.com/tensorflow/models/blob/master/object_detection/g3doc/configuring_jobs.md

    这里使用SSD w/MobileNet,把object_detection/samples/configs/ssd_mobilenet_v1_pets.config复制到object_detection/VOC2012/ssd_mobilenet_v1_voc2012.config

    修改第9行为num_classes: 20

    修改第158行为fine_tune_checkpoint: "object_detection/ssd_mobilenet_v1_coco_11_06_2017/model.ckpt"

    修改第177行为input_path: "object_detection/VOC2012/pascal_train.record"

    修改第179行和193行为label_map_path: "object_detection/data/pascal_label_map.pbtxt"

    修改第191行为input_path: "object_detection/VOC2012/pascal_val.record"

  3. 训练


    python object_detection/train.py \
        --logtostderr \
        --pipeline_config_path=object_detection/VOC2012/ssd_mobilenet_v1_voc2012.config \
        --train_dir=object_detection/VOC2012/ssd_mobilenet_train_logs \
        2>&1 | tee object_detection/VOC2012/ssd_mobilenet_train_logs.txt &


  4. 验证



    python object_detection/eval.py \
        --logtostderr \
        --pipeline_config_path=object_detection/VOC2012/ssd_mobilenet_v1_voc2012.config \
        --checkpoint_dir=object_detection/VOC2012/ssd_mobilenet_train_logs \
        --eval_dir=object_detection/VOC2012/ssd_mobilenet_val_logs &

    进入tensorflow/models/,运行CUDA_VISIBLE_DEVICES="1" ./train_voc2012.sh即可验证(这里指定了第二个GPU)。

  5. 可视化log


    tensorboard --logdir ssd_mobilenet_train_logs/


    tensorboard --logdir ssd_mobilenet_val_logs/ --port 6007


  1. 导出模型


    • graph.pbtxt
    • model.ckpt-200000.data-00000-of-00001
    • model.ckpt-200000.info
    • model.ckpt-200000.meta



    python object_detection/export_inference_graph.py \
        --input_type image_tensor \
        --pipeline_config_path object_detection/VOC2012/ssd_mobilenet_v1_voc2012.config \
        --trained_checkpoint_prefix object_detection/VOC2012/ssd_mobilenet_train_logs/model.ckpt-200000 \
        --output_directory object_detection/VOC2012
  2. 测试图片

    • 运行object_detection_tutorial.ipynb并修改其中的各种路径即可。

    • 或自写编译inference脚本,如tensorflow/models/object_detection/infer.py

      import sys
      import os
      import time
      import tensorflow as tf
      import numpy as np
      from PIL import Image
      from matplotlib import pyplot as plt
      from utils import label_map_util
      from utils import visualization_utils as vis_util
      PATH_TEST_IMAGE = sys.argv[1]
      PATH_TO_CKPT = 'VOC2012/frozen_inference_graph.pb'
      PATH_TO_LABELS = 'VOC2012/pascal_label_map.pbtxt'
      NUM_CLASSES = 21
      IMAGE_SIZE = (18, 12)
      label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
      categories = label_map_util.convert_label_map_to_categories(
          label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
      category_index = label_map_util.create_category_index(categories)
      detection_graph = tf.Graph()
      with detection_graph.as_default():
          od_graph_def = tf.GraphDef()
          with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
              serialized_graph = fid.read()
              tf.import_graph_def(od_graph_def, name='')
      config = tf.ConfigProto()
      config.gpu_options.allow_growth = True
      with detection_graph.as_default():
          with tf.Session(graph=detection_graph, config=config) as sess:
              start_time = time.time()
              image = Image.open(PATH_TEST_IMAGE)
              image_np = np.array(image).astype(np.uint8)
              image_np_expanded = np.expand_dims(image_np, axis=0)
              image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
              boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
              scores = detection_graph.get_tensor_by_name('detection_scores:0')
              classes = detection_graph.get_tensor_by_name('detection_classes:0')
              num_detections = detection_graph.get_tensor_by_name('num_detections:0')
              (boxes, scores, classes, num_detections) = sess.run(
                  [boxes, scores, classes, num_detections],
                  feed_dict={image_tensor: image_np_expanded})
              print('{} elapsed time: {:.3f}s'.format(time.ctime(), time.time() - start_time))
                  image_np, np.squeeze(boxes), np.squeeze(classes).astype(np.int32), np.squeeze(scores),
                  category_index, use_normalized_coordinates=True, line_thickness=8)

      运行infer.py test_images/image1.jpg即可


