tflite热帖: Tensorflow量化步骤及生成量化的tflite(2)

2023-10-30

https://blog.csdn.net/angela_12/article/details/85000072

      版权声明:本文为博主原创文章,未经博主允许不得转载。          https://blog.csdn.net/angela_12/article/details/85000072        </div>
        <link rel="stylesheet" href="https://csdnimg.cn/release/phoenix/template/css/ck_htmledit_views-f57960eb32.css">
                          <div id="content_views" class="markdown_views prism-atom-one-dark">
        <!-- flowchart 箭头图标 勿删 -->
        <svg xmlns="http://www.w3.org/2000/svg" style="display: none;">
          <path stroke-linecap="round" d="M5,0 0,2.5 5,5z" id="raphael-marker-block" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></path>
        </svg>
        <h1><a name="t0"></a><a id="Tensorflowtflite_0" target="_blank"></a>第二部分:Tensorflow伪量化训练操作(生成量化的tflite)</h1>

序:

(2018.12.24更新:鉴于评论中有人误解,我这里写清楚一点,后来者也可以先去看看评论区的问题然后再决定要不要看这篇博客~
我当初在添加伪量化代码执行完全量化训练这部分,只找到一篇有用的文章,所以决定给我这样的迷茫小白在这一部分添加一份“小白启发指南”(希望有用)。
这是第一次写博客,我会从评论中不断改进滴~)
1.普通量化步骤及使用详见“第一部分”:
2019年2月21日更新:https://blog.csdn.net/angela_12/article/details/84999473
2.伪量化训练,新出现的博客:
[min,max]的部分可以看看:TensorFlow量化训练
3.伪量化含义理解:
伪量化是完全量化的第一步,它只是模拟了量化的过程,并没有实现量化,只是在训练过程中添加了伪量化节点(也就是后面要说的两句话,创建训练图推理图分别添加进去),计算过程还是用float32计算。
然后训练得出.pb文件,放到toco指令里去实现第二步完整的量化,最后生成tflite,实现int8计算。

我的环境:

2018.12.24更新:
1.系统:Ubuntu16.04
2. Python版本:3.5.2
3. TensorFlow Lite(tflite)版本:Release 0.1.7
4. TF版本: 源码安装时版本为1.9.0,后来用pip升级为1.11.0。TF版本查看方法:
命令行输入python ,进入Python环境,输入:

import tensorflow as tf
tf.__version__  

 
 
 
 
  • 1
  • 2

前面先说一点概述性的东西,也简单记录下我做这个的心路历程,想直接看干货的就翻到下面去吧。
做了几个月,主要是找不到人问,自己基础差又解决不了 ,很是浪费时间,这个方面的网上的博客千篇一律,写的能用的简直太少了(可能大神们都没时间写基础的东西吧…),现在终于有点成效,所以就想好好整理一下这个实现的过程,给像我一样的小白们一点启发吧……

正文:

量化训练不用工具都是手写代码来实现各种功能的话,感觉不用源码安装用pip来安装就好了(参见量化安装,还没写完,这部分别人写了很多,我也是照着他们的做的,先把链接给在这吧),然后主要是写代码来加入量化节点,也就是官方给出的两个关键语句
create_training_graphcreate_eval_graph
这两句话放到一起,也就是放在一个.py文件里面写我没弄出来,感觉也不适合放在一起。所以,为了逻辑清晰,也为了减少麻烦(吾等小白只能乖乖按照官方的例子来,试过各种写法全部失败……),还是建议分开写,像官方的两个例子那样:
分为3个.py文件:模型定义(所用网络的定义)、train .py(加入create_training_graph) 和 eval .py (或freeze .py:加入create_eval_graph )。

两个例子分别是:slim写的mobilenet_v1_train.py、mobilenet_v1_eval.py和mobilenet_v1.py,位置的话在这里
这个例子我最开始看的,里面的格式跟官网给出的“定点量化”文章很像,但是total_loss这里我总出问题,还有learning_rate配着GradientDescentOptimizer用也有问题,然后找不到答案,就跑去社区看问题,发现了后面给的链接中Zongjun大神的回复,然后就问了我的问题,没想到遇到一个又耐心又会讲解的大神,真是这几个月以来最幸福的时刻,再次感谢大神~
有兴趣的可以去看看那个问题,在参考1.这里,基本上就是我做出这个伪量化的整个过程了,里面遇到的问题也基本都有问到(貌似用了将近20天的时间,好慢……)。

#############好啦,现在就直接写过程吧##############

上面那个例子没有speech_commands简单清晰,所以主要看这个例子就好了。
参考的代码是:models.pytrain.pyfreeze.py

我的主要目的是生成量化的tflite,用8bit来计算,之后可以用在手机端来测试性能啥的,所以看了官方文档之后发现,这个只有一种方法就是:quantization-aware training

  1. 在训练图中加入create_training_graph作为fake quantization nodes,然后训练,生成ckpt和pbtxt文件(参考train.py)。
  2. 创建推理图(freeze.py里的create_inference_graph()),就是重新调用训练时用的网络模型(例子里就是models.py 的create_model),然后给出一个输出,也就是logits = models.create_model()和tf.nn.softmax(logits, name=‘labels_softmax’)这两句话,前面的都是数据处理,没什么用。
  3. 将上面的输出(训练好的数据)传给create_eval_graph(),就是models.load_variables_from_checkpoint(),这里的eval把这张推理图数据转换成tflite能够识别并量化的格式。
  4. 后面的frozen_graph_def把ckpt和pbtxt固化,freeze 成.pb文件,这个图中就是带有fake quantization nodes的了。可以用tensorboard或者Netron来查看图。
  5. 上面的就完成了官方文档说的第一步,然后将上面生成的pb文件放到toco工具中,运行就可以生成tflite了。
toco工具的安装也是用bazel,编译并执行是下面的指令:

2019.3.6更新: 这里的文件夹换了,改成:tensorflow/lite/toco,注意替换一下!!

bazel run --config=opt tensorflow/contrib/lite/toco:toco -- \
--input_file=/home/.../tensorflow/.../Mnist_train/speech_my_frozen_graph.pb \
--output_file=/home/.../tensorflow/.../Mnist_train/speech_my_frozen_graph.tflite \
--input_format=TENSORFLOW_GRAPHDEF \
--output_format=TFLITE \
--inference_type=QUANTIZED_UINT8 \
--input_shapes=1,98,40,1 \
--input_arrays=Reshape_1 \
--output_arrays=labels_softmax \
--allow_custom_ops

 
 
 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
仅执行:
bazel-bin/tensorflow/contrib/lite/toco/toco \
--input_file=/home/.../tensorflow/.../Mnist_train/speech_my_frozen_graph.pb \
--output_file=/home/.../tensorflow/.../Mnist_train/speech_my_frozen_graph.tflite \
--input_format=TENSORFLOW_GRAPHDEF \
--output_format=TFLITE \
--inference_type=QUANTIZED_UINT8 \
--input_shapes=1,98,40,1 \
--input_arrays=Reshape_1 \
--output_arrays=labels_softmax \
--allow_custom_ops

 
 
 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
注:

1)tflite文件不只是用在移动端的,PC端也是可以的,感觉是因为移动端主要用的是tflite,然后PC端也可以用Python脚本来使用tflite(用的是tf.contrib.lite.Interpreter),所以谷歌就都换成tflite了,之前的quantize_graph.py被谷歌删除了,这个工具感觉可以不用了。现在唯一的保证精度的方法就是伪量化训练生成pb,然后toco转化为tflite,toco转化也可以用Python来写,这部分,我之后来补上。
(这一部分,建议参考社区问题Zongjun大神的第一个回答:https://www.tensorflowers.cn/t/7136)
2)将上面input_file(这里的pb文件路径是最后freeze生成的pb路径)和output_file(这里的tflite路径是要生成的tflite要保存的路径)的路径换成自己的路径,写绝对路径比较好。
3)allow_custom_ops:这个指令可以避免一些不必要的错误,允许一些传统方法,相关知识可以去下面的参考博客中去看。
4)–input_shapes它后面那几个要注意末尾的“s”,要么都加,要么都不加。这几个的获取可以用:

sudo bazel-bin/tensorflow/tools/graph_transforms/summarize_graph \
--in_graph=/home/.../tensorflow/Mnist_train/speech_my_frozen_graph.pb

 
 
 
 
  • 1
  • 2

来查看简要信息(summarize_graph这个也是要bazel来编译的)
不过这里面针对这个例子打印出来的 input_arrays是wav_data,看图也是这个,但是用这个的话会报错,实际上是 Reshape_1,这个是大神试出来的,打印完整结构的话,能看到这个输入,就是在 summarize_graph指令的后面加一句:–print_structure=true。

sudo bazel-bin/tensorflow/tools/graph_transforms/summarize_graph \
--in_graph=/home/.../tensorflow/Mnist_train/speech_my_frozen_graph.pb \
--print_structure=true

 
 
 
 
  • 1
  • 2
  • 3

5)至于有人提到的–default_ranges_min=0 --default_ranges_max=6这两个指令,我觉得它是Post Training Quantization 需要的,这个方法可以不训练直接转化,但是精度不如训练的,应该是差很多,而且还是用float32计算的(这里可能有点问题,需要再想想)。
6)quantization-aware training文章里提到的–std_value=127.5 --mean_value=127.5这两个指令我没用,还没确定这两个怎么用,知道了再来更新。

运行speech例子:

2018.12.18更新: 今天tensorflow的官方公众号发布了一篇讲解speech这个例子的文章,感觉挺有用的,只是除了里面的量化没有讲…不过作为了解TF框架的例子来讲还是挺好的,位置在这里。)

  1. train .py:没有数据的话,它会自动下载数据包,有2.4G,有语音数据的话,可以自己看看数据格式然后用自己的数据训练。具体可以看train.py上面的注释。要想得到量化的结果记得–quantize这里改成True。
  2. freeze.py里面—output_file这里写上自己想要存储的路径。–start_checkpoint这里写上要使用的ckpt文件,格式是:default=’/tmp/speech_commands_train/conv.ckpt-110’。
  3. 在Netron中查看的图是这样的:.pb   在这里插入图片描述speech.pb
    在这里插入图片描述 speech.tflite

上面是官方的例子,下面写一下我自己的练习,对于这些新技术,我都是在mnist里面实现的,之前测试量化工具和学习tensorflow训练也是用这个,都说它是tensorflow的“hello world”嘛。
(2018.12.21补充:speech.pb大小为3.7M,speech.tflite大小为929.5K)

mnist伪量化训练练习代码:

也分为三部分:

(代码可参考我的github:https://github.com/officeyang/mnist_fakequantization)
1.mnist_build_network.py,代码如下:

代码解析参考博客:https://blog.csdn.net/real_myth/article/details/51782207

import tensorflow as tf
# 创建图片占位符:x,标签占位符:y 和 随机失活系数keep_prob,以供处理图片,训练和预测时使用
#[x并不是一个特定的值,它是一个placeholder,一个我们需要输入数值当我们需要tensorflow进行运算时。我们想要输入任意数量的mnist图片,每一个都展开成一个784维的向量。我们用一个二维的[None, 784]浮点张量代表。 (这里的None表示维度可以是任意的长度.)]
x = tf.placeholder("float", shape=[None, 784], name='input')
y = tf.placeholder("float", shape=[None, 10], name='labels')
keep_prob = tf.placeholder("float", name='keep_prob')
# 定义mnist网络结构
def build_network(is_training):
# 定义网络类型
#[我们的模型中也需要权重和bias。我们可以把它们看成是额外的输入,Tensorflow有更加好的方法来表示它: Variable. Variable是一个Tensorflow图交互操作中一个可以修改的张量。 它可以在计算中修改。对于机器学习的,一般都有一些Variable模型参数。]
    def weight_variable(shape):
        initial = tf.truncated_normal(shape, stddev=0.1)
        return tf.Variable(initial)
    def bias_variable(shape):
        initial = tf.constant(0.1, shape=shape)
        return tf.Variable(initial)
    # convolution and pooling
    def conv2d(x, W):
        return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='VALID')
    def max_pool_2x2(x):
        return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='VALID')
    # convolution layer
    def lenet5_layer(layer, weight, bias):
        W_conv = weight_variable(weight)
        b_conv = bias_variable(bias)
        h_conv = conv2d(layer, W_conv) + b_conv
        return max_pool_2x2(h_conv)
    # connected layer
    def dense_layer(layer, weight, bias):
        W_fc = weight_variable(weight)
        b_fc = bias_variable(bias)
        return tf.matmul(layer, W_fc) + b_fc
# 开始搭建网络结构
    # first layer
    with tf.name_scope('first') as scope:
        x_image = tf.pad(tf.reshape(x, [-1,28,28,1]), [[0,0],[2,2],[2,2],[0,0]])
        firstlayer = lenet5_layer(x_image, [5,5,1,6], [6])
    # second layer
    with tf.name_scope('second') as scope:
        secondlayer = lenet5_layer(firstlayer, [5,5,6,16], [16])
    # third layer
    with tf.name_scope('third') as scope:
        W_conv3 = weight_variable([5,5,16,120])
        b_conv3 = bias_variable([120])
        thirdlayerconv = conv2d(secondlayer, W_conv3) + b_conv3
        thirdlayer = tf.reshape(thirdlayerconv, [-1,120])
    # dense layer1
    with tf.name_scope('dense1') as scope:
        dense_layer1 = dense_layer(thirdlayer, [120,84], [84])
    # dense layer2
    with tf.name_scope('dense2') as scope:
        dense_layer2 = dense_layer(dense_layer1, [84,10], [10])
# 运行得到真实输出:finaloutput
    if is_training:
        finaloutput = tf.nn.softmax(tf.nn.dropout(dense_layer2, keep_prob), name="softmax")
   # 为eval调用准备,eval用的网络要去掉 dropout
    else:
        finaloutput = tf.nn.softmax(dense_layer2, name='softmax')
    print('finaloutput:', finaloutput)
    return finaloutput

 
 
 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60

想了解整体mnist代码比较细致的分析,可以看以下博客:博客1博客2等。

2. mnist_fakequantize_train.py代码:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from Mnist_train.mnist_build_network import build_network, x, y, keep_prob

# 加载mnist下载的数据,有四个,数据我是在一个博客给出的百度网盘里下的,不过我的git账号还没弄好,不能上传,需要的话可以去看mnist实现的相关博客找数据下载。
mnist = input_data.read_data_sets(“MNIST_data/”, one_hot=True)

def create_training_graph():
#创建训练图,加入create_training_graph:
g = tf.get_default_graph() # 给create_training_graph的参数,默认图
#调用网络定义,也就是拿到输出
logits = build_network(is_training=True) #这里的is_training设置为True,因为前面模型定义写了训练时要用到dropout
# 写loss,mnist的loss是用交叉熵来计算的,loss和optimize方法可以根据自己的情况来设置。
with tf.name_scope(‘cross_entropy’):
cross_entropy_mean = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y, logits=logits))
print(‘cost:’, cross_entropy_mean)
# 加入 create_training_graph函数,注意位置要在loss之后, optimize之前
# if FLAGS.quantize:
# 上面这句是如果用parser设置flag参数的话,就用这种方式设置开关,用法可以自己查一下,或者参考speech的例子就知道了。
tf.contrib.quantize.create_training_graph(input_graph=g, quant_delay=0)
# optimize用原来的Adam效果较好,不知道我这里为什么用GradientDescentOptimizer的话,基本不收敛。
optimize = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy_mean)
# optimize = tf.train.GradientDescentOptimizer(1e-4).minimize(cross_entropy_mean)
# 比较输出类别概率的最大值[tf.argmax 是项的极其有益的函数,它给返回在一个标题里最大值的索引。例如,tf.argmax(y,1) 是我们的模型输出的认为是最有可能是的那个值,而 tf.argmax(y_,1) 是正确的标签的标签。]
prediction_labels = tf.argmax(logits, axis=1, name=“output”)
# 将得出的最大值与实际分类标签对比,看二者是否一致[如果我们的预测与匹配真正的值,我们可以使用tf.equal来检查。]
correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(y, 1))
# 给出识别准确率[这会返回我们一个布尔值的列表.为了确定哪些部分是正确的,我们要把它转换成浮点值,然后再示均值。 比如, [True, False, True, True] 会转换成 [1,0,1,1] ,从而它的准确率就是0.75.]
accuracy = tf.reduce_mean(tf.cast(correct_prediction, “float”))
#后面with这部分对于量化应该没啥影响,记得是tensorboard要用的,应该是出准确率啥的曲线图的……(不过这里应该没用到吧,,我只是搬过来了还没看,会用的想用就用吧,不会的就删掉吧)
with tf.get_default_graph().name_scope(‘eval’):
tf.summary.scalar(‘cross_entropy’, cross_entropy_mean)
tf.summary.scalar(‘accuracy’, accuracy)
# 返回所需数据,供训练使用
return dict(
x=x,
y=y,
keep_prob=keep_prob,
optimize=optimize,
cost=cross_entropy_mean,
correct_prediction=correct_prediction,
accuracy=accuracy,
)

#开始训练
def train_network(graph):
# 初始化
init = tf.global_variables_initializer()
# 调用Saver函数保存所需文件
saver = tf.train.Saver()
# 创建上下文,开始训练sess.run(init)
with tf.Session() as sess:
sess.run(init)
# 一共训练两万次,也可以更多,不过两万次感觉准确率就能达到将近1了
for i in range(20000):
# 每次处理50张图片
batch = mnist.train.next_batch(50)
# 每100次保存并打印一次准确率等
if i % 100 == 0:
# feed_dict喂数据
train_accuracy = sess.run([graph[‘accuracy’]], feed_dict={
graph[‘x’]:batch[0], # batch[0]存的图片数据
graph[‘y’]:batch[1], # batch[1]存的标签
graph[‘keep_prob’]: 1.0}) # 随机失活(全部?)
print(“step %d, training accuracy %g”%(i, train_accuracy[0]))
sess.run([graph[‘optimize’]], feed_dict={
graph[‘x’]:batch[0],
graph[‘y’]:batch[1],
graph[‘keep_prob’]:0.5})
test_accuracy = sess.run([graph[‘accuracy’]], feed_dict={
graph[‘x’]: mnist.test.images,
graph[‘y’]: mnist.test.labels,
graph[‘keep_prob’]: 1.0})
print(“Test accuracy %g” % test_accuracy[0])
# 保存ckpt(checkpoint)和pbtxt。记得把路径改成自己的路径,写不好相对路径的就直接写绝对路径。绝对路径就是我写的这种完整的路径。
saver.save(sess, ‘/home/angela/tensorflow/tensorflow/Mnist_train/mnist_fakequantize.ckpt’)
tf.train.write_graph(sess.graph_def, ‘/home/angela/tensorflow/tensorflow/Mnist_train/’, ‘mnist_fakequantize.pbtxt’, True)

def main():
g1 = create_training_graph()
train_network(g1)

main()

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
3. mnist_fakequantize_freeze.py:代码:
import tensorflow as tf
import os.path
from Mnist_train.mnist_build_network import build_network
from tensorflow.python.framework import graph_util

# 创建推理图
def create_inference_graph():
“”“Build the mnist model for evaluation.”""
# 调用网络,Create an output to use for inference.
logits = build_network(is_training=False)
# 得到分类输出
tf.nn.softmax(logits, name=‘output’)

def load_variables_from_checkpoint(sess, start_checkpoint):
“”“Utility function to centralize checkpoint restoration.
Args:
sess: TensorFlow session.
start_checkpoint: Path to saved checkpoint on disk.
“””

saver = tf.train.Saver(tf.global_variables())
saver.restore(sess, start_checkpoint)

def main():
# Create the model and load its weights.
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
# 推理图
create_inference_graph()
# 加入create_eval_graph(),转化为tflite可接受的格式。以下语句中有路径的,记得改路径。
# if FLAGS.quantize:
tf.contrib.quantize.create_eval_graph()
load_variables_from_checkpoint(sess, ‘/home/angela/tensorflow/tensorflow/Mnist_train/mnist_fakequantize.ckpt’)
# Turn all the variables into inline constants inside the graph and save it.
# 固化 frozen:ckpt + pbtxt
frozen_graph_def = graph_util.convert_variables_to_constants(
sess, sess.graph_def, [‘output’])
# 保存最终的pb模型
tf.train.write_graph(
frozen_graph_def,
os.path.dirname(’/home/angela/tensorflow/tensorflow/Mnist_train/mnist_frozen_graph.pb’),
os.path.basename(’/home/angela/tensorflow/tensorflow/Mnist_train/mnist_frozen_graph.pb’),
as_text=False)
tf.logging.info(‘Saved frozen graph to %s’, ‘/home/angela/tensorflow/tensorflow/Mnist_train/mnist_frozen_graph.pb’)

main()

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46

以上三部分,就是对照speech例子完成的mnist伪量化训练,这样得出的pb模型才能通过toco工具转化为可用的tflite,转化后的tflite大小大约变为了pb模型的1/4。
(2018.12.21补充:mnist.pb大小为256.3K,mnist.tflite大小为66.4K)

toco转化为tflite:
bazel-bin/tensorflow/contrib/lite/toco/toco \
--input_file=/home/.../tensorflow/.../Mnist_train/mnist_fakequantize.pb \
--output_file=/home/.../tensorflow/.../Mnist_train/mnist_fakequantize.tflite \
--input_format=TENSORFLOW_GRAPHDEF \
--output_format=TFLITE \
--inference_type=QUANTIZED_UINT8 \
--input_shapes=1,28,28,1 \
--input_arrays=input \
--output_arrays=output \
--allow_custom_ops

 
 
 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
Netron看图:

在这里插入图片描述mnist.pb
在这里插入图片描述mnist.tflite

注:

1)图片里的是之前训练和转化得出来的结果,但是现在里面的input都没有了,我也不知道为什么,现在的指令都是按照官方文档写的,出来后是没有input的,但是之前是有的,然后我测试没有input的结果是正确的,所以不知道这有什么影响,可能之前用的指令在inference_type那里不太一样吧。这个我有时间再研究一下,也可能是我TF版本弄乱了。
2)如果不添加伪量化训练,只用sudo bazel-bin/tensorflow/tools/quantization/quantize_graph工具将普通训练得出的pb文件进行量化的话,得出的pb只是float类型的,尽管也压缩成了1/4,还没报错,但却不是int8的,而这种再用toco转化成的tflite是不能用的,不过具体原因我还不太清楚(此部分具体参见“2.参考博客”)。

其他代码及方法补充:

1.toco转化tflite参考:
Converter Python API guide
2.post_training_quantize 方法:
Post Training Quantization
(可参考:参考博客)

扩展:

公司给的任务是完成基于yolo_loss的手势分类和画框检测,将训练出的pb模型量化压缩,转化成tflite。所以研究了上面那些内容,在用tf底层代码写的mnist上面是测试成功了,但是公司给我的代码都是用keras写的,我现在的问题就在于怎么把伪量化加到keras里面。
不过keras没法量化,这个我也在社区问题里面问了,但还是不知道怎么改,如果有大佬有这方面的例子,方便的话发出来看看,我好学习学习~感谢~

(2018.12.20更新)
看了keras和tensorflow结合使用的方法,还有大神之前说的tf.keras.layers方法,我觉得构建模型的地方可以用keras,也是因为我的代码就是keras写的,所以懒得全部去改了。
然后,loss和optimize的部分整个都要改成tf的,keras的Model方法都不要了,模型后面应该也像mnist那样给个输出,然后再去改train函数,把loss处理这部分还加在这里,然后以前keras带的整个train方法,就是fit方法也都不要了,改成运用图,保存ckpt和pbtxt的形式。
至于输出有类别和画框该怎么写,我先试试再说~

总结:

生成量化的tflite就是这样了,如果有错的地方,欢迎指正~大家有问题也可以在评论区回复,一起交流学习~

参考:

  1. 参考问题:TF中文社区量化问题记录(重要):https://www.tensorflowers.cn/t/7136#pid21630
  2. 参考博客(重要,这个就是社区里提出问题的楼主):
    Tensorflow Lite之编译生成tflite文件:https://blog.csdn.net/qq_16564093/article/details/78996563
  3. 参考文章(前两个重要):
    1)定点量化(TF官方文档中文版):https://tensorflow.juejin.im/performance/quantization.html
    2)TF定点量化官方文档(英文版:Fixed Point Quantization):可在TF源码或git中查找
    3)Tensorflow 模型量化 (Quantizing deep convolutional networks for efficient inference: A whitepaper 译文):https://blog.csdn.net/guvcolie/article/details/81286349
    4)【论文阅读笔记】Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference:https://blog.csdn.net/qq_19784349/article/details/82883271
    5)Post Training Quantization:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/tutorials/post_training_quant.ipynb
      </div>
      <link href="https://csdnimg.cn/release/phoenix/mdeditor/markdown_views-258a4616f7.css" rel="stylesheet">
              </div>
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

tflite热帖: Tensorflow量化步骤及生成量化的tflite(2) 的相关文章

  • Spring之@AutoWired注解

    开发中依赖注入一般用 AutoWired 首先需要bean xml文件中需要配置
  • C语言字符串必备练习题

    1 作业标题 633 字符串的结束标志是 作业内容 A 是 0 B 是EOF C 是 0 D 是空格 答案解析 C语言规定 以 0 作为有效字符串的结尾标记 A 错误 是 0 不是字符0 B EOF一般用来作为检测文本文件的末尾 C 正确
  • 机器学习笔试题汇总

    文章目录 树 特征工程 样本处理 K近邻 聚类 深度学习 分类 距离 相似度 指标性能评价 HMM 数学 为什么将回归问题转化成分类问题 解空间变小 softmax和sigmoid的区别 为什么用多项分布 多项分布能天然刻画值域变化趋势的
  • 演讲文档和视频《元宇宙与区块链IT基础设施》下载

    当今科技迅猛发展 知识爆炸的时代 有些知识 非公司保密信息 及那些不需要申请知识产权 构建护城河的信息 不及时分享 就过期了 重要的是 构建个人知识壁垒的根本是 快速迭代学习和交流碰撞 另外 我希望能遇到更多有共鸣 相互启发的朋友 我的QQ
  • RESETLOGS

    使用resetlogs选项 会把当前的日志序号 log sequence number 重设为1 并抛弃所有日志信息 在以下条件时需要使用resetlogs选项 在不完全恢复 介质恢复 使用备份控制文件 使用resetlogs打开数据库后
  • AttributeError: module ‘networkx‘ has no attribute ‘from_numpy_matrix‘解决方法

    在我学习louvain算法时 运行了这样一段代码 from communities algorithms import louvain method from communities visualization import draw co
  • 初次使用QT5串口类QSerialPort

    因为工作需要 现在正在学习Qt的串口通信 Qt4的话需要使用第三方类 使用起来也非常简单 只需要把对应的文件添加到自己的项目中就可以了 我参考的是Qt论坛上的demo 刘大师和yafei的demo都非常详细 网上都可以下载到 不过 Qt5添
  • Elementui设置样式不起效

    在使用ElementUI时 其渲染的dom元素有时是在模板外的 虽然代码写在了vue文件对用的模板内 但elementui在渲染的时候可能会渲染到和body一级 如这种弹窗 有时候想给这种el dialog加个样式 发现是不起效的 原因是
  • VCS命令行CTRL+C后dump完整的fsdb波形

    UCLI命令行CTRL C后dump完整的fsdb波形 1 ucli fsdbDumpFinish 2 ucli fsdbDumpvars 0 harness mda struct 如果仿真过程中直接CTRL C会调到UCLI接口 此时如果
  • 工作不好找,普通打工人如何破局

    大家好 我是苍何 我的一位阿里朋友被裁后 找工作找了一个月都没结果 很多到最后一面被pass了 不由得做一下感慨 即使是大厂背景又如何 面对经济环境和大环境市场 每个人都不容易 我身边很多都是程序员群体 最近也在在编程导航 收到了很多小伙伴
  • 13.linux进程基础

    一 进程基础 基础概念 关于进程和线程的基本概念在操作系统中早已学过 可以概括为一下几点 根本区别 进程是操作系统资源分配的基本单位 而线程是处理器任务调度和执行的基本单位 资源开销 每个进程都有独立的代码和数据空间 程序上下文 程序之间的
  • Unable to negotiate with 172.16.28.137 port 22: no matching host key type found. Their offer: ssh-rs

    Unable to negotiate with 172 16 28 137 port 22 no matching host key type found Their offer ssh rsa ssh dss ssh连接服务器报错 Un
  • 数字化转型升级是企业的一项重要决策

    无独有偶 世界经济数字化转型是一个大命题 也是一个大趋势 未来一段时期 数字经济将成为拉动经济增长的一个重要引擎 各行业各领域数字化转型步伐将大大加快 不论是行业老大 还是国家政策 数字化转型都纷纷提上了日程 看来 在2020年 进行数字化
  • 贪吃蛇的小程序

    1 创建项目 1 打开微信开发者工具如图所示的界面 点击 2 填写项目以后 点击确定即可 如图所示 2 编程 1 编写index wxml的代码如下
  • ITIL是什么意思?ITIL是什么?

    ITIL是什么 ITIL是Information Technology Infrastructure Library的缩写 即 信息技术基础架构库 ITIL是由英国政府部门CCTA Central Computing and Telecom
  • 解决 hsdb jinfo jmap sa-jdi等mac不可用问题

    mac 使用 hsdb 调试的时候报错 hsdb gt attach 3196 Attaching to process 3196 please wait ERROR attach task for pid 3196 failed os k
  • linux查看文件行数

    这本阿里P8撰写的算法笔记 再次推荐给大家 身边不少朋友学完这本书最后加入大厂 Github 疯传 史上最强悍 阿里大佬 LeetCode刷题手册 开放下载了 使用wc命令 具体通过wc help 可以查看 如 wc l filename
  • 论文笔记:nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation

    nature 2019 配套论文笔记 nnU Net Self adapting Frameworkfor U Net Based Medical Image Segmentation UQI LIUWJ的博客 CSDN博客 1 abstr
  • flutter windows 配置

    按照官网的教程安装好Android Studio flutter3 3 7 添加flutter目录的bin到环境变量Path 特别要注意的是 要添加以下两个环境变量 否则在运行flutter run 编译android程序时 会非常慢 Ru

随机推荐