在Caffe中调用TensorRT提供的MNIST model

2023-11-17

在TensorRT 2.1.2中提供了MNIST的model,这里拿来用Caffe的代码调用实现,原始的mnist_mean.binaryproto文件调整为了纯二进制文件mnist_tensorrt_mean.binary,测试结果与使用TensorRT调用(http://blog.csdn.net/fengbingchun/article/details/78552908)一致:

测试代码如下:

#include "funset.hpp"
#include <memory>
#include <fstream>
#include <tuple>
#include "common.hpp"

int mnist_tensorrt_predict()
{
#ifdef CPU_ONLY
	caffe::Caffe::set_mode(caffe::Caffe::CPU);
#else
	caffe::Caffe::set_mode(caffe::Caffe::GPU);
#endif

	const std::string deploy_file{ "E:/GitCode/Caffe_Test/test_data/model/mnist/mnist_tensorrt.prototxt" };
	const std::string model_filename{ "E:/GitCode/Caffe_Test/test_data/model/mnist/mnist_tensorrt.caffemodel" };
	const std::string mean_file{ "E:/GitCode/Caffe_Test/test_data/model/mnist/mnist_tensorrt_mean.binary" };
	const std::string image_path{ "E:/GitCode/Caffe_Test/test_data/images/handwritten_digits/" };

	caffe::Net<float> caffe_net(deploy_file, caffe::TEST);
	caffe_net.CopyTrainedLayersFrom(model_filename);

	// print net info
	fprintf(stdout, "input blob num: %d, output blob num: %d\n", caffe_net.num_inputs(), caffe_net.num_outputs());
	const boost::shared_ptr<caffe::Blob<float> > blob_by_name = caffe_net.blob_by_name("data");
	int image_num = blob_by_name->num();
	int image_channel = blob_by_name->channels();
	int image_height = blob_by_name->height();
	int image_width = blob_by_name->width();
	fprintf(stdout, "inpub blob shape(num, channels, height, width): %d, %d, %d, %d\n",
		image_num, image_channel, image_height, image_width);

	fprintf(stdout, "layer names: ");
	for (int i = 0; i < caffe_net.layer_names().size(); ++i) {
		fprintf(stdout, "  %s  ", caffe_net.layer_names()[i].c_str());
	}
	fprintf(stdout, "\nblob names: ");
	for (int i = 0; i < caffe_net.blob_names().size(); ++i) {
		fprintf(stdout, "  %s  ", caffe_net.blob_names()[i].c_str());
	}
	fprintf(stdout, "\nlayer types: ");
	for (int i = 0; i < caffe_net.layers().size(); ++i) {
		fprintf(stdout, "  %s  ", caffe_net.layers()[i]->type());
	}
	const std::vector<caffe::Blob<float>*> output_blobs = caffe_net.output_blobs();
	fprintf(stdout, "\noutput blobs num: %d, blob(num, channel, heihgt, width): %d, %d, %d, %d\n",
		output_blobs.size(), output_blobs[0]->num(), output_blobs[0]->channels(), output_blobs[0]->height(), output_blobs[0]->width());

	const int image_size{ image_num * image_channel * image_height * image_width };
	std::unique_ptr<float[]> mean_values(new float[image_size]);
	std::ifstream in(mean_file.c_str(), std::ios::in | std::ios::binary);
	if (!in.is_open()) {
		fprintf(stderr, "read mean file fail: %s\n", mean_file.c_str());
		return -1;
	}
	in.read((char*)mean_values.get(), image_size * sizeof(float));
	in.close();

	const std::vector<int> target{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 };
	typedef std::tuple<int, float> result;
	std::vector<result> results;

	for (const auto& num : target) {
		std::string str = std::to_string(num);
		str += ".png";
		str = image_path + str;

		cv::Mat mat = cv::imread(str.c_str(), 0);
		if (!mat.data) {
			fprintf(stderr, "load image error: %s\n", str.c_str());
			return -1;
		}

		cv::resize(mat, mat, cv::Size(image_width, image_height));
		mat.convertTo(mat, CV_32FC1);

		float* p = (float*)mat.data;
		for (int i = 0; i < image_size; ++i) {
			p[i] -= mean_values.get()[i];
		}

		const std::vector<caffe::Blob<float>*>& blob_input = caffe_net.input_blobs();
		blob_input[0]->set_cpu_data((float*)mat.data);

		const std::vector<caffe::Blob<float>*>& output_blob_ = caffe_net.Forward(nullptr);
		const float* output = output_blob_[0]->cpu_data();

		float tmp{ -1.f };
		int pos{ -1 };

		for (int j = 0; j < output_blobs[0]->count(); j++) {
			if (tmp < output[j]) {
				pos = j;
				tmp = output[j];
			}
		}

		result ret = std::make_tuple(pos, tmp);
		results.push_back(ret);
	}

	for (auto i = 0; i < target.size(); i++)
		fprintf(stdout, "actual digit is: %d, result digit is: %d, probability: %f\n",
			target[i], std::get<0>(results[i]), std::get<1>(results[i]));

	fprintf(stdout, "predict finish\n");

	return 0;
}
测试图像如下:


执行结果如下:

GitHub: https://github.com/fengbingchun/Caffe_Test

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

在Caffe中调用TensorRT提供的MNIST model 的相关文章

随机推荐

  • 使用QGraphicsItem绘制微信消息文本框

    微信消息框如下 使用QGraphicsItem绘制 怎么绘制呢 先不考虑头像 那文本框就是由一个菱形矩形加一个小箭头组成的 所以很简单就能画出来了 void PopoItem paint QPainter painter const QSt
  • 彻底解决Python(win)导包from import错误问题

    1 一句话 一句话 关键是os sys path这个目录 这个目录有 就from import没问题 没有 就报错 解决办法就是千方百计加进去即可 例如 import os print os sys path import dd from
  • 单链表中求倒数第几个节点

    问题描述 在单链表中求出倒数第K个节点 要求快速 方法一 利用链表的长度 不推荐 此方法必须事先知道链表的长度 在有长度的信息链表中 此方法可行 比如我之前的链表是这样的实现 参考博文 http blog csdn net dawn aft
  • 机器学习之梯度提升决策树(GBDT)

    1 GBDT算法简介 GBDT Gradient Boosting Decision Tree 是一种迭代的决策树算法 由多棵决策树组成 所有树的结论累加起来作为最终答案 我们根据其名字 Gradient Boosting Decision
  • SpringAOP来监控service层中每个方法的执行时间

    使用AOP来说 太方便了 并且特别适合这类场景 代码如下 这里是将要统计的信息写到log文件中 也可以设计成写入表中 package com ecsoft interceptor import org aspectj lang Procee
  • linux版本的发行版和内核版是什么意思

    linux内核版本的分类 Linux内核版本有两种 稳定版和开发版 Linux内核版本号由3组数字组成 第一个组数字 第二组数字 第三组数字 第一个组数字 目前发布的内核主版本 第二个组数字 偶数表示稳定版本 奇数表示开发中版本 第三个组数
  • Linux扫盲篇:CentOS、Ubuntu、Gentoo

    http www williamlong info info archives 197 html Linux最早由Linus Benedict Torvalds在1991年开始编写 在这之前 Richard Stallman创建了Free
  • DirectX在VS2017环境配置

    提示 此方法是解决DirectX9在windows环境下的配置问题 原文 https xygeng cn post 249 html 具体方法 1 问题 无法打开包括文件 stdlib h 解决办法 视图 gt 属性管理器 点击 user属
  • VMware Workstation 16 安装教程

    哈喽 大家好 今天一起学习的是VMware Workstation 16的安装 vm虚拟机是小编非常喜欢的生产力软件 小编之前发布的测试教程钧在vm上进行的实验 VMware Workstation是一款功能强大的桌面虚拟计算机软件 它能够
  • K8s微服务从0到1入门及命令实战

    写在前面 本文主要介绍k8s的核心概念 基础语法 常用命令和常用操作 Kubernetes介绍 Kubernetes是一种流行的开源容器编排和管理系统 它的目标是简化部署 扩展和管理容器化应用程序 Kubernetes最初由Google开发
  • 如何让女人满意?多个心眼爱女人

    别以为只有男人甜言蜜语地哄骗女人 女人有时也会设下甜蜜的陷阱让男人钻 如果有一天 你那个素来刁蛮的小女人突然变得乖巧柔顺 温温柔柔地抱着你的胳膊说 亲爱的 我今天心情特别好 给你一分钟的时间诉诉苦苦吧 平时我有哪些缺点令你敢怒不敢言的 尽管
  • python学习笔记第一天

    一 Python的基本语法元素 Python程序从默认的第一条语句开始 按顺序依次执行各条语句 代码块可视为复合语句 Python使用严格的缩进 空格 来表示代码块 连续的多条具有相同缩进量的语句为一个代码块 注释用于为程序添加说明性的文字
  • Deep learning Reading List

    Following is a growing list of some of the materials i found on the web for Deep Learning beginners Free Online Books De
  • 简单HTML+css太极图

  • Tauri 应用中发送 http 请求

    最近基于 Tauri 和 React 开发一个用于 http https 接口测试的工具 Get Tools 其中使用了 tauri 提供的 fetch API 在开发调试过程中遇到了一些权限和参数问题 在此记录下来 权限配置 在 taur
  • vue中的input输入框按回车键自动搜索

    vue中的input输入框按回车键自动搜索 在input标签内部增加 keyup enter事件即可 事件名为按钮点击名称
  • python之文件夹拷贝(亲测可用)

    效果 import os import shutil def copy dir src path dst path source path os path abspath src path target path os path abspa
  • centos7 mysql 机器重启后pid文件丢失导致mysql 服务无法重启

    1 首先执行命令vim etc my cnf 查看pid存储的路径 pid file xxxxxx 2 到对应的路径下查看发现已经丢失了 mysqld pid创建在系统的run目录下 该目录是运行在内存中的 因此服务器重启后文件不存在 3
  • CentOS 7 下 minikube 部署 && 配置

    CentOS 7 下 minikube 部署 配置 文章目录 CentOS 7 下 minikube 部署 配置 下载 安装 下载安装脚本 安装 minikube 启动 minikube 环境 安装 kubectl 工具 启动 miniku
  • 在Caffe中调用TensorRT提供的MNIST model

    在TensorRT 2 1 2中提供了MNIST的model 这里拿来用Caffe的代码调用实现 原始的mnist mean binaryproto文件调整为了纯二进制文件mnist tensorrt mean binary 测试结果与使用