是的,使用 Protobuf Python API 非常简单:
编辑管道.py:
import argparse
import tensorflow as tf
from google.protobuf import text_format
from object_detection.protos import pipeline_pb2
def parse_arguments():
parser = argparse.ArgumentParser(description='')
parser.add_argument('pipeline')
parser.add_argument('output')
return parser.parse_args()
def main():
args = parse_arguments()
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
with tf.gfile.GFile(args.pipeline, "r") as f:
proto_str = f.read()
text_format.Merge(proto_str, pipeline_config)
pipeline_config.model.ssd.image_resizer.fixed_shape_resizer.height = 300
pipeline_config.model.ssd.image_resizer.fixed_shape_resizer.width = 300
config_text = text_format.MessageToString(pipeline_config)
with tf.gfile.Open(args.output, "wb") as f:
f.write(config_text)
if __name__ == '__main__':
main()
我调用脚本的方式:
TOOL_DIR=tool/tf-models/research
(
cd $TOOL_DIR
protoc object_detection/protos/*.proto --python_out=.
)
export PYTHONPATH=$PYTHONPATH:$TOOL_DIR:$TOOL_DIR/slim
python3 edit_pipeline.py pipeline.config pipeline_new.config
复合字段
如果存在重复字段,则必须将它们视为数组(例如使用extend()
, append()
方法):
pipeline_config.train_input_reader.tf_record_input_reader.input_path[0] = '/tensorflow/models/data/train100.record'
Eval 输入读取器错误
这是尝试编辑复合字段时的常见错误。 (如果是 eval_input_reader,则“未找到属性 tf_record_input_reader” )
下面@latida 的回答中提到了这一点。
通过将其设置为数组字段来解决此问题。
pipeline_config.eval_input_reader[0].label_map_path = label_map_full_path
pipeline_config.eval_input_reader[0].tf_record_input_reader.input_path[0] = val_record_path