为什么在云上训练时出现“IndexError:列表索引超出范围”?

2024-02-27

我求助于使用云培训工作流程。鉴于我得到的产品,我本希望直接放入与其他 tflite 模型一起使用的代码中,但云生成的模型不起作用。询问时我得到“索引超出范围”interpreter.get_tensor参数。

这是我的代码,基本上是一个修改后的示例,我可以在其中提取视频并生成带有结果的视频。

import argparse
import cv2
import numpy as np
import sys
import importlib.util



# Define and parse input arguments
parser = argparse.ArgumentParser()
parser.add_argument('--modeldir', help='Folder the .tflite file is located in',
                    required=True)
parser.add_argument('--graph', help='Name of the .tflite file, if different than detect.tflite',
                    default='model.tflite')
#                    default='/tmp/detect.tflite')
parser.add_argument('--labels', help='Name of the labelmap file, if different than labelmap.txt',
            default='dict.txt')
#                    default='/tmp/coco_labels.txt')
parser.add_argument('--threshold', help='Minimum confidence threshold for displaying detected objects',
                    default=0.5)
parser.add_argument('--video', help='Name of the video file',
                    default='test.mp4')
parser.add_argument('--edgetpu', help='Use Coral Edge TPU Accelerator to speed up detection',
                    action='store_true')

args = parser.parse_args()

MODEL_NAME = args.modeldir
GRAPH_NAME = args.graph
LABELMAP_NAME = args.labels
VIDEO_NAME = args.video
min_conf_threshold = float(args.threshold)
use_TPU = args.edgetpu

# Import TensorFlow libraries
# If tensorflow is not installed, import interpreter from tflite_runtime, else import from regular tensorflow
# If using Coral Edge TPU, import the load_delegate library
pkg = importlib.util.find_spec('tensorflow')
pkg = True
if pkg is None:
    from tflite_runtime.interpreter import Interpreter
    if use_TPU:
        from tflite_runtime.interpreter import load_delegate
else:
    from tensorflow.lite.python.interpreter import Interpreter
    if use_TPU:
        from tensorflow.lite.python.interpreter import load_delegate

# If using Edge TPU, assign filename for Edge TPU model
if use_TPU:
    # If user has specified the name of the .tflite file, use that name, otherwise use default 'edgetpu.tflite'
    if (GRAPH_NAME == 'detect.tflite'):
        GRAPH_NAME = 'edgetpu.tflite'   

# Get path to current working directory
CWD_PATH = os.getcwd()

# Path to video file
VIDEO_PATH = os.path.join(CWD_PATH,VIDEO_NAME)

# Path to .tflite file, which contains the model that is used for object detection
PATH_TO_CKPT = os.path.join(CWD_PATH,MODEL_NAME,GRAPH_NAME)

# Path to label map file
PATH_TO_LABELS = os.path.join(CWD_PATH,MODEL_NAME,LABELMAP_NAME)

# Load the label map
with open(PATH_TO_LABELS, 'r') as f:
    labels = [line.strip() for line in f.readlines()]

# Have to do a weird fix for label map if using the COCO "starter model" from
# https://www.tensorflow.org/lite/models/object_detection/overview
# First label is '???', which has to be removed.
if labels[0] == '???':
    del(labels[0])

# Load the Tensorflow Lite model.
# If using Edge TPU, use special load_delegate argument
if use_TPU:
    interpreter = Interpreter(model_path=PATH_TO_CKPT,
                              experimental_delegates=[load_delegate('libedgetpu.so.1.0')])
    print(PATH_TO_CKPT)
else:
    interpreter = Interpreter(model_path=PATH_TO_CKPT)

interpreter.allocate_tensors()

# Get model details
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
height = input_details[0]['shape'][1]
width = input_details[0]['shape'][2]

floating_model = (input_details[0]['dtype'] == np.float32)

input_mean = 127.5
input_std = 127.5

# Open video file
video = cv2.VideoCapture(VIDEO_PATH)
imW = video.get(cv2.CAP_PROP_FRAME_WIDTH)
imH = video.get(cv2.CAP_PROP_FRAME_HEIGHT)
out = cv2.VideoWriter('output.avi', cv2.VideoWriter_fourcc(
        'M', 'J', 'P', 'G'), 10, (1920, 1080))
while(video.isOpened()):

    # Acquire frame and resize to expected shape [1xHxWx3]
    ret, frame = video.read()
    frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    frame_resized = cv2.resize(frame_rgb, (width, height))
    input_data = np.expand_dims(frame_resized, axis=0)

    # Normalize pixel values if using a floating model (i.e. if model is non-quantized)
    if floating_model:
        input_data = (np.float32(input_data) - input_mean) / input_std

    # Perform the actual detection by running the model with the image as input
    interpreter.set_tensor(input_details[0]['index'],input_data)
    interpreter.invoke()

    # Retrieve detection results
    boxes = interpreter.get_tensor(output_details[0]['index'])[0] # Bounding box coordinates of detected objects
    classes = interpreter.get_tensor(output_details[1]['index'])[0] # Class index of detected objects
    scores = interpreter.get_tensor(output_details[2]['index'])[0] # Confidence of detected objects
    print (boxes)
    print (classes)
    print (scores)
    #num = interpreter.get_tensor(output_details[3]['index'])[0]  # Total number of detected objects (inaccurate and not needed)

    # Loop over all detections and draw detection box if confidence is above minimum threshold
    for i in range(len(scores)):
        if ((scores[i] > min_conf_threshold) and (scores[i] <= 1.0)):

            # Get bounding box coordinates and draw box
            # Interpreter can return coordinates that are outside of image dimensions, need to force them to be within image using max() and min()
            ymin = int(max(1,(boxes[i][0] * imH)))
            xmin = int(max(1,(boxes[i][1] * imW)))
            ymax = int(min(imH,(boxes[i][2] * imH)))
            xmax = int(min(imW,(boxes[i][3] * imW)))

            cv2.rectangle(frame, (xmin,ymin), (xmax,ymax), (10, 255, 0), 4)

            # Draw label
            object_name = labels[int(classes[i])] # Look up object name from "labels" array using class index
            label = '%s: %d%%' % (object_name, int(scores[i]*100)) # Example: 'person: 72%'
            labelSize, baseLine = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.7, 2) # Get font size
            label_ymin = max(ymin, labelSize[1] + 10) # Make sure not to draw label too close to top of window
            cv2.rectangle(frame, (xmin, label_ymin-labelSize[1]-10), (xmin+labelSize[0], 
label_ymin+baseLine-10), (255, 255, 255), cv2.FILLED) # Draw white box to put label text in
            cv2.putText(frame, label, (xmin, label_ymin-7), cv2.FONT_HERSHEY_SIMPLEX, 
0.7, (0, 0, 0), 2) # Draw label text

    # All the results have been drawn on the frame, so it's time to display it.
    cv2.imshow('Object detector', frame)
    #output_rgb = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
    out.write(frame)
    # Press 'q' to quit
    if cv2.waitKey(1) == ord('q'):
        break

# Clean up
video.release()
out.release()
cv2.destroyAllWindows()

使用封装的 tflite 模型时,打印语句应如下所示:

[32. 76. 56. 76.  0. 61. 74.  0.  0.  0.]
[0.609375   0.48828125 0.44921875 0.44921875 0.4140625  0.40234375
 0.37890625 0.3125     0.3125     0.3125    ]
[[-0.01923192  0.17330796  0.747546    0.8384144 ]
 [ 0.01866053  0.5023282   0.39603746  0.6143299 ]
 [ 0.01673795  0.47382414  0.34407628  0.5580931 ]
 [ 0.11588445  0.78543806  0.8778869   1.0039229 ]
 [ 0.8106107   0.70675755  1.0080075   0.89248717]
 [ 0.84941524  0.06391776  1.0006479   0.28792098]
 [ 0.05543692  0.53557926  0.40413857  0.62823087]
 [ 0.07051808 -0.00938512  0.8822515   0.28100258]
 [ 0.68205094  0.33990026  0.9940187   0.6020821 ]
 [ 0.08010477  0.01998334  0.6011186   0.26135433]]

这是使用云创建的模型时出现的错误:

File "tflite_vid.py", line 124, in <module>
    classes = interpreter.get_tensor(output_details[1]['index'])[0] # Class index of detected objects
IndexError: list index out of range

因此,我恳请有人解释如何使用 Python 和 TF2 开发 TFlite 模型,或者如何让云生成可用的 TFlite 模型。拜托哦拜托不要给我指出一个需要通过互联网示例来思考的方向,除非它们是如何做到这一点的真正福音。,


In output_details[1], it is [1]

更多关于Python代码的使用请参考https://www.tensorflow.org/lite/guide/inference#load_and_run_a_model_in_python https://www.tensorflow.org/lite/guide/inference#load_and_run_a_model_in_python以获得指导。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

为什么在云上训练时出现“IndexError:列表索引超出范围”? 的相关文章

  • Django 代理模型的继承和多态性

    我正在开发一个我没有启动的 Django 项目 我面临着一个问题遗产 我有一个大模型 在示例中简化 称为MyModel这应该代表不同种类的物品 的所有实例对象MyModel应该具有相同的字段 但方法的行为根据项目类型的不同而有很大差异 到目
  • Django 管理员在模型编辑时间歇性返回 404

    我们使用 Django Admin 来维护导出到我们的一些站点的一些数据 有时 当单击标准更改列表视图来获取模型编辑表单而不是路由到正确的页面时 我们会得到 Django 404 页面 模板 它是偶尔发生的 我们可以通过重新加载三次来重现它
  • 将 Matplotlib 误差线放置在不位于条形中心的位置

    我正在 Matplotlib 中生成带有错误栏的堆积条形图 不幸的是 某些层相对较小且数据多样 因此多个层的错误条可能重叠 从而使它们难以或无法读取 Example 有没有办法设置每个误差条的位置 即沿 x 轴移动它 以便重叠的线显示在彼此
  • Python(Selenium):如何通过登录重定向/组织登录登录网站

    我不是专业程序员 所以请原谅任何愚蠢的错误 我正在做一些研究 我正在尝试使用 Selenium 登录数据库来搜索大约 1000 个术语 我有两个问题 1 重定向到组织登录页面后如何使用 Selenium 登录 2 如何检索数据库 在我解决
  • PyUSB 1.0:NotImplementedError:此平台不支持或未实现操作

    我刚刚开始使用 pyusb 基本上我正在玩示例代码here https github com walac pyusb blob master docs tutorial rst 我使用的是 Windows 7 64 位 并从以下地址下载 z
  • python 相当于 R 中的 get() (= 使用字符串检索符号的值)

    在 R 中 get s 函数检索名称存储在字符变量 向量 中的符号的值s e g X lt 10 r lt XVI s lt substr r 1 1 X get s 10 取罗马数字的第一个符号r并将其转换为其等效整数 尽管花了一些时间翻
  • 如何从网页中嵌入的 Tableau 图表中抓取工具提示值

    我试图弄清楚是否有一种方法以及如何使用 python 从网页中的 Tableau 嵌入图形中抓取工具提示值 以下是当用户将鼠标悬停在条形上时带有工具提示的图表示例 我从要从中抓取的原始网页中获取了此网址 https covid19 colo
  • ubuntu 20.04 上无法获取卷积算法错误~tensorflow-gpu

    我有一个 NVIDIA 2070 RTX GPU 我的操作系统是 Ubuntu20 04 我已经使用 conda 安装了tensorflow gpu 包 我有not安装了 CUDA toolkit 我相信它还会安装 CUDA toolkit
  • 测试 python Counter 是否包含在另一个 Counter 中

    如何测试是否是pythonCounter https docs python org 2 library collections html collections Counter is 包含在另一个中使用以下定义 柜台a包含在计数器中b当且
  • 以编程方式停止Python脚本的执行? [复制]

    这个问题在这里已经有答案了 是否可以使用命令在任意行停止执行 python 脚本 Like some code quit quit at this point some more code that s not executed sys e
  • Python pickle:腌制对象不等于源对象

    我认为这是预期的行为 但想检查一下 也许找出原因 因为我所做的研究结果是空白 我有一个函数可以提取数据 创建自定义类的新实例 然后将其附加到列表中 该类仅包含变量 然后 我使用协议 2 作为二进制文件将该列表腌制到文件中 稍后我重新运行脚本
  • AWS EMR Spark Python 日志记录

    我正在 AWS EMR 上运行一个非常简单的 Spark 作业 但似乎无法从我的脚本中获取任何日志输出 我尝试过打印到 stderr from pyspark import SparkContext import sys if name m
  • 如何使用Python创建历史时间线

    So I ve seen a few answers on here that helped a bit but my dataset is larger than the ones that have been answered prev
  • 如何在seaborn displot中使用hist_kws

    我想在同一图中用不同的颜色绘制直方图和 kde 线 我想为直方图设置绿色 为 kde 线设置蓝色 我设法弄清楚使用 line kws 来更改 kde 线条颜色 但 hist kws 不适用于显示 我尝试过使用 histplot 但我无法为
  • 每个 X 具有多个 Y 值的 Python 散点图

    我正在尝试使用 Python 创建一个散点图 其中包含两个 X 类别 cat1 cat2 每个类别都有多个 Y 值 如果每个 X 值的 Y 值的数量相同 我可以使用以下代码使其工作 import numpy as np import mat
  • 为字典中的一个键附加多个值[重复]

    这个问题在这里已经有答案了 我是 python 新手 我有每年的年份和值列表 我想要做的是检查字典中是否已存在该年份 如果存在 则将该值附加到特定键的值列表中 例如 我有一个年份列表 并且每年都有一个值 2010 2 2009 4 1989
  • 类型错误:预期单个张量时的张量列表 - 将 const 与 tf.random_normal 一起使用时

    我有以下 TensorFlow 代码 tf constant tf random normal time step batch size 1 1 我正进入 状态TypeError List of Tensors when single Te
  • Scrapy:如何使用元在方法之间传递项目

    我是 scrapy 和 python 的新手 我试图将 parse quotes 中的项目 item author 传递给下一个解析方法 parse bio 我尝试了 request meta 和 response meta 方法 如 sc
  • 导入错误:没有名为 site 的模块 - mac

    我已经有这个问题几个月了 每次我想获取一个新的 python 包并使用它时 我都会在终端中收到此错误 ImportError No module named site 我不知道为什么会出现这个错误 实际上 我无法使用任何新软件包 因为每次我
  • 如何将输入读取为数字?

    这个问题的答案是社区努力 help privileges edit community wiki 编辑现有答案以改进这篇文章 目前不接受新的答案或互动 Why are x and y下面的代码中使用字符串而不是整数 注意 在Python 2

随机推荐