Tensorflow学习(五)——多任务学习验证码识别实战

2023-11-20

一、验证码生成

"""
验证码生成脚本(使用captcha包提供的ImageCaptcha方法)
"""

from captcha.image import ImageCaptcha

import sys
import random
import numpy as np

"""
使用四位数字验证码,当然也可以加入大小写字母。四位验证码有10000种可能(0000~9999)
但是由于生成过程具有随机性,难免出现重复情况,所以最终生成的验证码数量少于10000
"""
number = np.arange(0, 10)
number = [str(x) for x in number]

def random_captcha_text(char_set=number, captcha_size=4):
    # 验证码列表
    captcha_text = []
    for i in range(captcha_size):
        c = random.choice(char_set)     # 随机选中构成名称
        captcha_text.append(c)          # 加入列表
    return captcha_text

def gen_captcha_text_and_image():
    image = ImageCaptcha()
    # 获得随机生成的验证码
    captcha_text = random_captcha_text()
    # 把验证码列表转为字符串
    captcha_text = ''.join(captcha_text)
    # 生成验证码
    captcha = image.generate(captcha_text)
    image.write(captcha_text, 'captcha/images/' + captcha_text + '.jpg')


num = 10000
for i in range(num):
    gen_captcha_text_and_image()
    sys.stdout.write('\r>> Creating image %d/%d' % (i+1, num))
    sys.stdout.flush()
sys.stdout.write('\n')
sys.stdout.flush()
print('生成完毕')

验证码存放在 "./captcha/images/’ 目录下,如图:
在这里插入图片描述
验证码图片如下:在这里插入图片描述
每张图片的label就是验证码数字,此图验证码数字为0695所以文件命名为0695.jpg

二、制作tfrecord文件

1、关于tfrecord文件:

TFRecords可以允许你讲任意的数据转换为TensorFlow所支持的格式, 这种方法可以使TensorFlow的数据集更容易与网络应用架构相匹配。这种建议的方法就是使用TFRecords文件,TFRecords文件包含了[tf.train.Example 协议内存块(protocol buffer)](协议内存块包含了字段[Features],你可以写一段代码获取你的数据, 将数据填入到Example协议内存块(protocol buffer),将协议内存块序列化为一个字符串, 并且通过[tf.python_io.TFRecordWriter class]写入到TFRecords文件。

TFRecords文件格式在图像识别中有很好的使用,其可以将二进制数据和标签数据(训练的类别标签)数据存储在同一个文件中,它可以在模型进行训练之前通过预处理步骤将图像转换为TFRecords格式,此格式最大的优点实践每幅输入图像和与之关联的标签放在同一个文件中.TFRecords文件是一种二进制文件,其不对数据进行压缩,所以可以被快速加载到内存中.格式不支持随机访问,因此它适合于大量的数据流,但不适用于快速分片或其他非连续存取。

TFrecord文件读写方式参考:https://zhuanlan.zhihu.com/p/31992460

2、代码

from PIL import Image
import tensorflow as tf
import numpy as np
import os
import random
import sys

_NUM_TEST = 500
_RANDOM_SEED = 0
DATASET_DIR = 'captcha/images'
TFRECORD_DIR = 'captcha/'


# 判断tfrecord文件是否存在
def _dataset_exists(dataset_dir):
    for split_name in ['train', 'test']:
        output_filename = os.path.join(dataset_dir, split_name + '.tfrecords')
        if not tf.gfile.Exists(output_filename):
            return False
    return True


def _get_filenames_and_classes(dataset_dir):
    photo_filenames = []
    for filename in os.listdir(dataset_dir):
        # 获取文件路径
        path = dataset_dir + '/' + filename
        photo_filenames.append(path)
    return photo_filenames


def bytes_feature(values):  # 格式转换(字符串)
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))


def int64_feature(values):  # 格式转换(64位int)
    if not isinstance(values, (tuple, list)):
        values = [values]
    return tf.train.Feature(int64_list=tf.train.Int64List(value=values))


def image_to_tfexample(image_date, label0, label1, label2, label3):
    # Abstract base class for protocol message
    return tf.train.Example(features=tf.train.Features(feature={
        'image': bytes_feature(image_date),
        'label0': int64_feature(label0),
        'label1': int64_feature(label1),
        'label2': int64_feature(label2),
        'label3': int64_feature(label3)
    }))


# 把数据转换成tfrecord格式
def _convert_dataset(split_name, filenames, dataset_dir):
    assert split_name in ['train', 'test']

    with tf.Session() as sess:
        # 定义tfrecord文件的路径和名称
        output_filename = os.path.join(TFRECORD_DIR, split_name + '.tfrecords')
        with tf.python_io.TFRecordWriter(output_filename, options=tf.python_io.TFRecordOptions(1)) as tfrecord_writer:
            for i, filename in enumerate(filenames):
                try:
                    sys.stdout.write('\r>> Converting image %d/%d' % (i + 1, len(filenames)))
                    sys.stdout.flush()
                    # 读取图片
                    image_data = Image.open(filename)
                    # 根据模型的结构resize
                    image_data = image_data.resize((224, 224))
                    # 灰度转换
                    image_data = np.array(image_data.convert('L'))
                    # 将图片转换为二进制数据
                    image_data = image_data.tobytes()
                    # 获取label
                    labels = filename.split('/')[-1][0:4]
                    num_labels = []
                    for j in range(4):
                        num_labels.append(int(labels[j]))
                    # 生成protocol数据类型
                    example = image_to_tfexample(image_data, num_labels[0], num_labels[1],
                                                 num_labels[2], num_labels[3])
                    tfrecord_writer.write(example.SerializeToString())
                except IOError as e:
                    print('Could not read:', filenames[i])
                    print('Error:', e)
                    print('Skip it\n')
    sys.stdout.write('\n')
    sys.stdout.flush()


# 判断tfrecord文件是否存在
if _dataset_exists(TFRECORD_DIR):
    print('tfrecord文件已经存在')
else:
    # 获得所有图片
    photo_filenames = _get_filenames_and_classes(DATASET_DIR)
    # 把数据集分割为训练集和测试集并打乱
    random.seed(_RANDOM_SEED)
    random.shuffle(photo_filenames)
    training_filenames = photo_filenames[_NUM_TEST:]
    testing_filenames = photo_filenames[:_NUM_TEST]

    # 数据转换
    _convert_dataset('train', training_filenames, DATASET_DIR)
    _convert_dataset('test', training_filenames, DATASET_DIR)
    print('生成tfrecord文件')

说明:DATASET_DIR定义了数据集路径,TFRECORD_DIR定义了tfrecord文件存放路径,_NUM_TEST定义了test数据集数量,该程序将所有图片分为两部分,其中获得_NUM_TEST数量的图像作为测试数据集。在_convert_dataset()中我们对图像数据进行预处理包括灰度转换、图像大小转换已经二进制转换,这些操作方便了我们将数据写入文件以及训练时候对数据的使用。

最终生成的文件如下:
在这里插入图片描述

三、验证码识别模型训练

1、验证码识别思路

将验证码label拆分为4个

例如有一个验证码为0782,则拆分后的label如下(采用one-hot编码,对应位数值置1):

Label0:1000000000
Label1:0000000100
Label2:0000000010
Label3:0010000000

好处:可使用多任务学习

2、什么是多任务学习

在这里插入图片描述
其中X是输入,Shared Layer就是一些卷积与池化操作,Task1-4对应四个标签,产生四个loss,将四个loss求和得总的loss,用优化器优化总的loss,从而降低每个标签产生的loss。

3、获取谷歌提供的alexnet_v2网络

打开github,搜索 tensorflow/models,如下:
在这里插入图片描述
将models文件夹clone下来:
在这里插入图片描述
clone完成后,在路径 “/models/research/silm/” 下找到nets文件夹,将该文件夹拷贝到项目目录,我们在训练过程中会调用nets文件夹下提供的python代码(nets_factory.py)
在这里插入图片描述

4、修改alexnet.py代码

修改后代码如下:

# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains a model definition for AlexNet.

This work was first described in:
  ImageNet Classification with Deep Convolutional Neural Networks
  Alex Krizhevsky, Ilya Sutskever and Geoffrey E. Hinton

and later refined in:
  One weird trick for parallelizing convolutional neural networks
  Alex Krizhevsky, 2014

Here we provide the implementation proposed in "One weird trick" and not
"ImageNet Classification", as per the paper, the LRN layers have been removed.

Usage:
  with slim.arg_scope(alexnet.alexnet_v2_arg_scope()):
    outputs, end_points = alexnet.alexnet_v2(inputs)

@@alexnet_v2
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
from tensorflow.contrib import slim as contrib_slim

slim = contrib_slim

# pylint: disable=g-long-lambda
trunc_normal = lambda stddev: tf.compat.v1.truncated_normal_initializer(
    0.0, stddev)


def alexnet_v2_arg_scope(weight_decay=0.0005):
  with slim.arg_scope([slim.conv2d, slim.fully_connected],
                      activation_fn=tf.nn.relu,
                      biases_initializer=tf.compat.v1.constant_initializer(0.1),
                      weights_regularizer=slim.l2_regularizer(weight_decay)):
    with slim.arg_scope([slim.conv2d], padding='SAME'):
      with slim.arg_scope([slim.max_pool2d], padding='VALID') as arg_sc:
        return arg_sc


def alexnet_v2(inputs,
               num_classes=1000,
               is_training=True,
               dropout_keep_prob=0.5,
               spatial_squeeze=True,
               scope='alexnet_v2',
               global_pool=False):
  """AlexNet version 2.

  Described in: http://arxiv.org/pdf/1404.5997v2.pdf
  Parameters from:
  github.com/akrizhevsky/cuda-convnet2/blob/master/layers/
  layers-imagenet-1gpu.cfg

  Note: All the fully_connected layers have been transformed to conv2d layers.
        To use in classification mode, resize input to 224x224 or set
        global_pool=True. To use in fully convolutional mode, set
        spatial_squeeze to false.
        The LRN layers have been removed and change the initializers from
        random_normal_initializer to xavier_initializer.

  Args:
    inputs: a tensor of size [batch_size, height, width, channels].
    num_classes: the number of predicted classes. If 0 or None, the logits layer
    is omitted and the input features to the logits layer are returned instead.
    is_training: whether or not the model is being trained.
    dropout_keep_prob: the probability that activations are kept in the dropout
      layers during training.
    spatial_squeeze: whether or not should squeeze the spatial dimensions of the
      logits. Useful to remove unnecessary dimensions for classification.
    scope: Optional scope for the variables.
    global_pool: Optional boolean flag. If True, the input to the classification
      layer is avgpooled to size 1x1, for any input size. (This is not part
      of the original AlexNet.)

  Returns:
    net: the output of the logits layer (if num_classes is a non-zero integer),
      or the non-dropped-out input to the logits layer (if num_classes is 0
      or None).
    end_points: a dict of tensors with intermediate activations.
  """
  with tf.compat.v1.variable_scope(scope, 'alexnet_v2', [inputs]) as sc:
    end_points_collection = sc.original_name_scope + '_end_points'
    # Collect outputs for conv2d, fully_connected and max_pool2d.
    with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d],
                        outputs_collections=[end_points_collection]):
      net = slim.conv2d(inputs, 64, [11, 11], 4, padding='VALID',
                        scope='conv1')
      net = slim.max_pool2d(net, [3, 3], 2, scope='pool1')
      net = slim.conv2d(net, 192, [5, 5], scope='conv2')
      net = slim.max_pool2d(net, [3, 3], 2, scope='pool2')
      net = slim.conv2d(net, 384, [3, 3], scope='conv3')
      net = slim.conv2d(net, 384, [3, 3], scope='conv4')
      net = slim.conv2d(net, 256, [3, 3], scope='conv5')
      net = slim.max_pool2d(net, [3, 3], 2, scope='pool5')

      # Use conv2d instead of fully_connected layers.
      with slim.arg_scope(
          [slim.conv2d],
          weights_initializer=trunc_normal(0.005),
          biases_initializer=tf.compat.v1.constant_initializer(0.1)):
        net = slim.conv2d(net, 4096, [5, 5], padding='VALID',
                          scope='fc6')
        net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
                           scope='dropout6')
        net = slim.conv2d(net, 4096, [1, 1], scope='fc7')
        # Convert end_points_collection into a end_point dict.
        end_points = slim.utils.convert_collection_to_dict(
            end_points_collection)
        if global_pool:
          net = tf.reduce_mean(
              input_tensor=net, axis=[1, 2], keepdims=True, name='global_pool')
          end_points['global_pool'] = net
        if num_classes:
          net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
                             scope='dropout7')
          net0 = slim.conv2d(
              net,
              num_classes, [1, 1],
              activation_fn=None,
              normalizer_fn=None,
              biases_initializer=tf.compat.v1.zeros_initializer(),
              scope='fc8_0')
          net1 = slim.conv2d(
              net,
              num_classes, [1, 1],
              activation_fn=None,
              normalizer_fn=None,
              biases_initializer=tf.compat.v1.zeros_initializer(),
              scope='fc8_1')
          net2 = slim.conv2d(
              net,
              num_classes, [1, 1],
              activation_fn=None,
              normalizer_fn=None,
              biases_initializer=tf.compat.v1.zeros_initializer(),
              scope='fc8_2')
          net3 = slim.conv2d(
              net,
              num_classes, [1, 1],
              activation_fn=None,
              normalizer_fn=None,
              biases_initializer=tf.compat.v1.zeros_initializer(),
              scope='fc8_3')

          if spatial_squeeze:
            net0 = tf.squeeze(net0, [1, 2], name='fc8_0/squeezed')
          end_points[sc.name + '/fc8_0'] = net0
          if spatial_squeeze:
            net1 = tf.squeeze(net1, [1, 2], name='fc8_1/squeezed')
          end_points[sc.name + '/fc8_1'] = net1
          if spatial_squeeze:
            net2 = tf.squeeze(net2, [1, 2], name='fc8_2/squeezed')
          end_points[sc.name + '/fc8_2'] = net2
          if spatial_squeeze:
            net3 = tf.squeeze(net3, [1, 2], name='fc8_3/squeezed')
          end_points[sc.name + '/fc8_3'] = net3
      return net0, net1, net2, net3, end_points
alexnet_v2.default_image_size = 224

说明:网络中的卷积层和池化层不发生变化,原网络只有一个net输出,由于我们的验证码识别项目将验证码拆分成四个标签,所以需要四个输出,因此在源代码基础上增加net1 ~ net3输出。

5、train代码

"""验证码识别
学习模式:多任务学习
网络模型:alexnet_v2
完成时间:2020-5-1
"""

import tensorflow as tf
from nets import nets_factory


CHAR_SET_LEN = 10  # 不同字符数量
IMAGE_HEIGHT = 60  # 图片高度
IMAGE_WIDTH = 160  # 图片宽度
BATCH_SIZE = 25
TFRECORD_FILE = 'D:/PycharmProject/StudyDemo/captcha/train.tfrecords'

# placeholder
x = tf.placeholder(tf.float32, [None, 224, 224])
y0 = tf.placeholder(tf.float32, [None])
y1 = tf.placeholder(tf.float32, [None])
y2 = tf.placeholder(tf.float32, [None])
y3 = tf.placeholder(tf.float32, [None])

_learn_rate = tf.Variable(0.003, dtype=tf.float32)


# 从tfrecord文件中读出数据
def read_and_decode(filename):
    # 生成文件队列
    filename_queue = tf.train.string_input_producer([filename])
    reader = tf.TFRecordReader(options=tf.python_io.TFRecordOptions(1))
    # 返回文件名和文件
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(serialized_example, features={
        'image': tf.FixedLenFeature([], tf.string),
        'label0': tf.FixedLenFeature([], tf.int64),
        'label1': tf.FixedLenFeature([], tf.int64),
        'label2': tf.FixedLenFeature([], tf.int64),
        'label3': tf.FixedLenFeature([], tf.int64),
    })
    # 获取图片数据
    image = tf.decode_raw(features['image'], tf.uint8)
    # tf.train.shuffle_batch的使用必须确定shape
    image = tf.reshape(image, [224, 224])
    # 图片预处理
    image = tf.cast(image, tf.float32) / 255.0
    image = tf.subtract(image, 0.5)
    image = tf.multiply(image, 2.0)
    # 获取label
    label0 = tf.cast(features['label0'], tf.int32)
    label1 = tf.cast(features['label1'], tf.int32)
    label2 = tf.cast(features['label2'], tf.int32)
    label3 = tf.cast(features['label3'], tf.int32)

    return image, label0, label1, label2, label3


# 获取图片数据与标签
image, label0, label1, label2, label3 = read_and_decode(TFRECORD_FILE)
# 使用shuffle_batch随机打乱张量顺序创建批次
image_batch, label_batch0, label_batch1, label_batch2, label_batch3 = tf.train.shuffle_batch(
    [image, label0, label1, label2, label3], batch_size=BATCH_SIZE,
    capacity=50000, min_after_dequeue=10000, num_threads=1
)

# 定义网络结构
train_network_fn = nets_factory.get_network_fn('alexnet_v2',
                                               num_classes=CHAR_SET_LEN,
                                               weight_decay=0.0005,
                                               is_training=True)
with tf.Session() as sess:
    # input参数要符合Alexnet_v2网络的要求,所以先做个格式转换
    X = tf.reshape(x, [BATCH_SIZE, 224, 224, 1])
    # 数据输入网络得到输出值
    logits0, logits1, logits2, logits3, _ = train_network_fn(X)

    # 把标签转换成one_hot形式
    one_hot_labels0 = tf.one_hot(indices=tf.cast(y0, tf.int32), depth=CHAR_SET_LEN)
    one_hot_labels1 = tf.one_hot(indices=tf.cast(y1, tf.int32), depth=CHAR_SET_LEN)
    one_hot_labels2 = tf.one_hot(indices=tf.cast(y2, tf.int32), depth=CHAR_SET_LEN)
    one_hot_labels3 = tf.one_hot(indices=tf.cast(y3, tf.int32), depth=CHAR_SET_LEN)

    # 计算损失值
    loss0 = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits0,
                                                                   labels=one_hot_labels0))
    loss1 = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits1,
                                                                   labels=one_hot_labels1))
    loss2 = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits2,
                                                                   labels=one_hot_labels2))
    loss3 = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits3,
                                                                   labels=one_hot_labels3))
    # 总和损失值
    total_loss = (loss0 + loss1 + loss2 + loss3) / 4.0
    # 优化器
    optimizer = tf.train.AdamOptimizer(learning_rate=_learn_rate).minimize(total_loss)
    # 计算准确率
    correct_prediction0 = tf.equal(tf.argmax(one_hot_labels0, 1), tf.argmax(logits0, 1))
    accuracy0 = tf.reduce_mean(tf.cast(correct_prediction0, tf.float32))
    correct_prediction1 = tf.equal(tf.argmax(one_hot_labels1, 1), tf.argmax(logits1, 1))
    accuracy1 = tf.reduce_mean(tf.cast(correct_prediction1, tf.float32))
    correct_prediction2 = tf.equal(tf.argmax(one_hot_labels2, 1), tf.argmax(logits2, 1))
    accuracy2 = tf.reduce_mean(tf.cast(correct_prediction2, tf.float32))
    correct_prediction3 = tf.equal(tf.argmax(one_hot_labels3, 1), tf.argmax(logits3, 1))
    accuracy3 = tf.reduce_mean(tf.cast(correct_prediction3, tf.float32))

    # 保存模型
    saver = tf.train.Saver()
    # 初始化变量
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())
    # 创建一个协调器管理线程
    coord = tf.train.Coordinator()
    # 启动QueueRunner,此时文件名队列已经进队
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    for i in range(6001):
        # 获得一个批次的数据和标签
        b_image, b_label0, b_label1, b_label2, b_label3 = sess.run([image_batch,
                                                                    label_batch0,
                                                                    label_batch1,
                                                                    label_batch2,
                                                                    label_batch3])
        # 优化模型
        sess.run(optimizer, feed_dict={
            x: b_image,
            y0: b_label0,
            y1: b_label1,
            y2: b_label2,
            y3: b_label3
        })
        # 每迭代50次计算并打印一次损失值和准确率
        if i % 50 == 0:
            # 每2000次降低学习率
            if i % 2000 == 0:
                sess.run(tf.assign(_learn_rate, _learn_rate / 3))
            acc0, acc1, acc2, acc3, loss_ = sess.run([accuracy0, accuracy1, accuracy2, accuracy3, total_loss],
                                                     feed_dict={
                                                         x: b_image,
                                                         y0: b_label0,
                                                         y1: b_label1,
                                                         y2: b_label2,
                                                         y3: b_label3
                                                     })
            learing_rate = sess.run(_learn_rate)
            print('Iter: %d  loss: %.3f  accuracy:%.2f,%.2f,%.2f,%.2f  learing_rate:%.4f'
                  % (i, loss_, acc0, acc1, acc2, acc3, learing_rate))
            # 停止训练 / 保存模型
            if i == 6000:   # global_step参数是把训练次数添加到模型名称中
                saver.save(sess, './captcha/models/crack_captcha.model', global_step=i)
                break
    coord.request_stop()    # 通知其他线程关闭
    coord.join(threads)     # 其他线程关闭后该函数才可返回

代码概述:从train.tfrecord读出数据和标签,打乱,将数据送入alexnet网络得到输出值,将输出的标签转化为one_hot形式,计算loss,对loss求和得total_loss并用优化器优化,计算准确率,训练6000次,保存模型。
注意:tfrecords文件读写前后数据格式一定要对应,TFRecordWriter和TFRecordReader的options一定要相同,不然容易出现读写错误,需仔细检查。

保存的模型如下:
在这里插入图片描述
提示:训练过程较慢,笔者使用NVIDIA 940mx显卡跑满2G显存总共花费13个小时完成训练,最终准确率达到99%。

四、模型测试

代码与训练代码相似:

import tensorflow as tf
import matplotlib.pyplot as plt
from PIL import Image
from nets import nets_factory

# 不同字符数量
CHAR_SET_LEN = 10
# 图片高度和宽度
IMAGE_HEIGHT = 60
IMAGE_WIDTH = 160
# 批次
BATCH_SIZE = 1
# tfrecord文件存放路径
TFRECORD_FILE = 'captcha/test.tfrecords'

# placeholder
x = tf.placeholder(tf.float32, [None, 224, 224])


# 从tfrecord读出数据
def read_and_decode(filename):
    # 生成文件队列
    filename_queue = tf.train.string_input_producer([filename])
    reader = tf.TFRecordReader(options=tf.python_io.TFRecordOptions(1))
    # 返回文件名和文件
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(serialized_example, features={
        'image': tf.FixedLenFeature([], tf.string),
        'label0': tf.FixedLenFeature([], tf.int64),
        'label1': tf.FixedLenFeature([], tf.int64),
        'label2': tf.FixedLenFeature([], tf.int64),
        'label3': tf.FixedLenFeature([], tf.int64),
    })
    # 获取图片数据
    image = tf.decode_raw(features['image'], tf.uint8)
    # 没有经过预处理的灰度图
    image_raw = tf.reshape(image, [224, 224])
    # tf.train.shuffle_batch的使用必须确定shape
    image = tf.reshape(image, [224, 224])
    # 图片预处理
    image = tf.cast(image, tf.float32) / 255.0
    image = tf.subtract(image, 0.5)
    image = tf.multiply(image, 2.0)
    # 获取label
    label0 = tf.cast(features['label0'], tf.int32)
    label1 = tf.cast(features['label1'], tf.int32)
    label2 = tf.cast(features['label2'], tf.int32)
    label3 = tf.cast(features['label3'], tf.int32)

    return image, image_raw, label0, label1, label2, label3


# 获取图片数据与标签
image, image_raw, label0, label1, label2, label3 = read_and_decode(TFRECORD_FILE)
# 获得批次
image_batch, image_raw_batch, label_batch0, label_batch1, label_batch2, label_batch3 = tf.train.shuffle_batch(
    [image, image_raw, label0, label1, label2, label3], batch_size=BATCH_SIZE,
    capacity=50000, min_after_dequeue=10000, num_threads=1
)

# 定义网络结构
train_network_fn = nets_factory.get_network_fn('alexnet_v2',
                                               num_classes=CHAR_SET_LEN,
                                               weight_decay=0.0005,
                                               is_training=False)
with tf.Session() as sess:
    # inputs格式[batch_size, height, width, channels]
    X = tf.reshape(x, [BATCH_SIZE, 224, 224, 1])
    # 数据输入网络得到输出值
    logits0, logits1, logits2, logits3, _ = train_network_fn(X)
    # 预测值
    predict0 = tf.reshape(logits0, [-1, CHAR_SET_LEN])
    predict0 = tf.argmax(predict0, 1)

    predict1 = tf.reshape(logits1, [-1, CHAR_SET_LEN])
    predict1 = tf.argmax(predict1, 1)

    predict2 = tf.reshape(logits2, [-1, CHAR_SET_LEN])
    predict2 = tf.argmax(predict2, 1)

    predict3 = tf.reshape(logits3, [-1, CHAR_SET_LEN])
    predict3 = tf.argmax(predict3, 1)

    # 初始化
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())
    # 载入模型
    saver = tf.train.Saver()
    saver.restore(sess, './captcha/models/crack_captcha.model-6000')
    # 创建一个协调器管理线程
    coord = tf.train.Coordinator()
    # 启动QueueRunner,此时文件名队列已经进队
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    for i in range(10):
        # 获得一个批次的数据和标签
        b_image, b_image_raw, b_label0, b_label1, b_label2, b_label3 = sess.run([image_batch,
                                                                                 image_raw_batch,
                                                                                 label_batch0,
                                                                                 label_batch1,
                                                                                 label_batch2,
                                                                                 label_batch3])
        # 显示图片
        img = Image.fromarray(b_image_raw[0], 'L')
        plt.imshow(img)
        plt.axis('off')
        plt.show()
        # 打印标签
        print('label:', b_label0, b_label1, b_label2, b_label3)
        # 预测
        label0, label1, label2, label3 = sess.run([predict0, predict1, predict2, predict3],
                                                  feed_dict={x: b_image})
        # 打印预测值
        print('predict:', label0, label1, label2, label3)

    # 通知其他线程关闭
    coord.request_stop()
    coord.join(threads)

运行结果:
在这里插入图片描述
在这里插入图片描述

END

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Tensorflow学习(五)——多任务学习验证码识别实战 的相关文章

  • 如何在刻度标签和轴之间添加空间

    我已成功增加刻度标签的字体 但现在它们距离轴太近了 我想在刻度标签和轴之间添加一点呼吸空间 如果您不想全局更改间距 通过编辑 rcParams 并且想要更简洁的方法 请尝试以下操作 ax tick params axis both whic
  • 将字符串转换为带有毫秒和时区的日期时间 - Python

    我有以下 python 片段 from datetime import datetime timestamp 05 Jan 2015 17 47 59 000 0800 datetime object datetime strptime t
  • 使用 openCV 对图像中的子图像进行通用检测

    免责声明 我是计算机视觉菜鸟 我看过很多关于如何在较大图像中查找特定子图像的堆栈溢出帖子 我的用例有点不同 因为我不希望它是具体的 而且我不确定如何做到这一点 如果可能的话 但我感觉应该如此 我有大量图像数据集 有时 其中一些图像是数据集的
  • 如何使用固定的 pandas 数据框进行动态 matplotlib 绘图?

    我有一个名为的数据框benchmark returns and strategy returns 两者具有相同的时间跨度 我想找到一种方法以漂亮的动画风格绘制数据点 以便它显示逐渐加载的所有点 我知道有一个matplotlib animat
  • Flask 和 uWSGI - 无法加载应用程序 0 (mountpoint='')(找不到可调用或导入错误)

    当我尝试使用 uWSGI 启动 Flask 时 出现以下错误 我是这样开始的 gt cd gt root localhost uwsgi socket 127 0 0 1 6000 file path to folder run py ca
  • Python 多处理示例不起作用

    我正在尝试学习如何使用multiprocessing但我无法让它发挥作用 这是代码文档 http docs python org 2 library multiprocessing html from multiprocessing imp
  • 如何在Windows上模拟socket.socketpair

    标准Python函数套接字 套接字对 https docs python org 3 library socket html socket socketpair不幸的是 它在 Windows 上不可用 从 Python 3 4 1 开始 我
  • 如何使用包含代码的“asyncio.sleep()”进行单元测试?

    我在编写 asyncio sleep 包含的单元测试时遇到问题 我要等待实际的睡眠时间吗 I used freezegun到嘲笑时间 当我尝试使用普通可调用对象运行测试时 这个库非常有用 但我找不到运行包含 asyncio sleep 的测
  • SQL Alchemy 中的 NULL 安全不等式比较?

    目前 我知道如何表达 NULL 安全的唯一方法 SQL Alchemy 中的比较 其中与 NULL 条目的比较计算结果为 True 而不是 NULL 是 or field None field value 有没有办法在 SQL Alchem
  • 安装后 Anaconda 提示损坏

    我刚刚安装张量流GPU创建单独的后环境按照以下指示here https github com antoniosehk keras tensorflow windows installation 但是 安装后当我关闭提示窗口并打开新航站楼弹出
  • 交换keras中的张量轴

    我想将图像批次的张量轴从 batch size row col ch 交换为 批次大小 通道 行 列 在 numpy 中 这可以通过以下方式完成 X batch np moveaxis X batch 3 1 我该如何在 Keras 中做到
  • 在pyyaml中表示具有相同基类的不同类的实例

    我有一些单元测试集 希望将每个测试运行的结果存储为 YAML 文件以供进一步分析 YAML 格式的转储数据在几个方面满足我的需求 但测试属于不同的套装 结果有不同的父类 这是我所拥有的示例 gt gt gt rz shorthand for
  • python 集合可以包含的值的数量是否有限制?

    我正在尝试使用 python 设置作为 mysql 表中 ids 的过滤器 python集存储了所有要过滤的id 现在大约有30000个 这个数字会随着时间的推移慢慢增长 我担心python集的最大容量 它可以包含的元素数量有限制吗 您最大
  • 表达式中的 Python 'in' 关键字与 for 循环中的比较 [重复]

    这个问题在这里已经有答案了 我明白什么是in运算符在此代码中执行的操作 some list 1 2 3 4 5 print 2 in some list 我也明白i将采用此代码中列表的每个值 for i in 1 2 3 4 5 print
  • ExpectedFailure 被计为错误而不是通过

    我在用着expectedFailure因为有一个我想记录的错误 我现在无法修复 但想将来再回来解决 我的理解expectedFailure是它会将测试计为通过 但在摘要中表示预期失败的数量为 x 类似于它如何处理跳过的 tets 但是 当我
  • 循环中断打破tqdm

    下面的简单代码使用tqdm https github com tqdm tqdm在循环迭代时显示进度条 import tqdm for f in tqdm tqdm range 100000000 if f gt 100000000 4 b
  • 为美国东部以外地区的 Cloudwatch 警报发送短信?

    AWS 似乎没有为美国东部以外的 SNS 主题订阅者提供 SMS 作为协议 我想连接我的 CloudWatch 警报并在发生故障时接收短信 但无法将其发送到 SMS YES 经过一番挖掘后 我能够让它发挥作用 它比仅仅选择一个主题或输入闹钟
  • 检查所有值是否作为字典中的键存在

    我有一个值列表和一本字典 我想确保列表中的每个值都作为字典中的键存在 目前我正在使用两组来确定字典中是否存在任何值 unmapped set foo set bar keys 有没有更Pythonic的方法来测试这个 感觉有点像黑客 您的方
  • VSCode:调试配置中的 Python 路径无效

    对 Python 和 VSCode 以及 stackoverflow 非常陌生 直到最近 我已经使用了大约 3 个月 一切都很好 当尝试在调试器中运行任何基本的 Python 程序时 弹出窗口The Python path in your
  • 使用基于正则表达式的部分匹配来选择 Pandas 数据帧的子数据帧

    我有一个 Pandas 数据框 它有两列 一列 进程参数 列 包含字符串 另一列 值 列 包含相应的浮点值 我需要过滤出部分匹配列 过程参数 中的一组键的子数据帧 并提取与这些键匹配的数据帧的两列 df pd DataFrame Proce

随机推荐

  • python——selenium

    一 Selenium Python环境搭建及配置 1 1 selenium 介绍 selenium 是一个 web 的自动化测试工具 不少学习功能自动化的同学开始首选 selenium 因为它相比 QTP 有诸多有点 免费 也不用再为破解
  • cpolar内网穿透+ EasyImage组合,自建一个图床网站

    文章目录 1 前言 2 EasyImage网站搭建 2 1 EasyImage下载和安装 2 2 EasyImage网页测试 2 3 cpolar的安装和注册 3 本地网页发布 3 1 Cpolar云端设置 3 2 Cpolar内网穿透本地
  • 【马士兵】Python基础--15

    Python基础 15 文章目录 Python基础 15 编程思想 类与对象 类的创建 对象的创建 类属性 类方法 静态方法 动态绑定属性和方法 知识点总结 编程思想 类与对象 python中一切皆对象 类的创建 类的名称由一个或多个单词组
  • 【SpringCloud】SpringAMQP总结

    文章目录 1 AMQP 2 基本消息模型队列 3 WorkQueue模型 4 发布订阅模型 5 发布订阅 Fanout Exchange 6 发布订阅 DirectExchange 7 发布订阅 TopicExchange 8 消息转换器
  • 迁移学习 & 凯明初始化

    前言 这一章其实就是之前没做完的事 来补一下 两者其实没啥关系 迁移学习 以下内容学习自迁移学习 斯坦福21秋季 实用机器学习中文版 迁移学习包括什么 feature extraction train a model on a relate
  • 由于缺少调试目标 E:a\b\c\串口配置工具\bin\Debug\串口配置工具.exe“,visual Studio无法开始调试。请生成项目并重试,或者相应OutputPath和AssemblyNa

    最近做一个窗体程序时候出现这个错误 我的项目名称是串口配置工具 建议为英文来命名 项目名称下面有这两个 发现 没有这个串口配置工具 exe 然后再这个 这里面发现这个串口配置工具 exe 最后直接 exe文件把这个复制到 项目名称 bin
  • C++基础——const成员函数

    目录 一 Const成员函数 1 定义 2 格式 3 代码示例 h文件 definition cpp文件 特性 例 那么const对象既可以调用非const型成员函数吗 问题3 const成员函数内可以调用其它的非const成员函数吗 问题
  • 手机运行python 神器,pydroid3 包含库的版本

    初次安装pydroid 或者qpython的同学运行爬虫时是不是蛋疼的一比 lxml根本装不了 虽然可以下载whl折腾 可是也很麻烦 后来我不死心 终于找到了包含库的版本 只有pydroid 64位 https lanzous com id
  • msa2000映射到服务器,HPmsa2000i官方详细的设置操作流程步骤.doc

    HPmsa2000i官方详细的设置操作流程步骤 从本地管理主机登录进入 SMU 如要从本地管理主机登录进入 SMU 在网络浏览器的地址栏中 键入某个控制器机柜的以太网管理端口的 IP 地址 然后按Enter 此时显示 SMU Login 页
  • IDEA java.lang.NullPointerException (no error message)

    今天在不停启动debug 停止debug后无法再启动debug 提示java lang NullPointerException no error message 经百度 删除 project下 gradle无效 恢复代码后无效 且未更改配
  • 【C语言】合并两个数组,降序排列并删除重复元素(通俗易懂)

    问题描述 试着写一个程序 具体内容如下 建立两个整型数组 int n scanf d n int a n 将其合并 对他们进行降序排序 去掉相同项 输出处理过后的数组 输入形式 首先第一行输入第一个数组中的长度n 然后输入n个整型数 然后在
  • MYSQL进阶-msql日志-慢查询日志

    2 慢查询日志 慢查询日志主要用来记录执行时间超过设置的某个时长的SQL语句 能够帮助数据库维护人员找出执行时间比较长 执行效率比较低的SQL语句 并对这些SQL语句进行针对性优化 2 1 开启慢查询日志 可以在my cnf文件或者my i
  • ant design pro 代码学习(七) ----- 组件封装(登录模块)

    以登录模块为例 对ant design pro的组件封装进行相关分析 登录模块包含基础组件的封装 组件按模块划分 同类组件通过配置文件生成 跨层级组件直接数据通信等 相对来说还是具有一定的代表性 1 登录模块流程图 首先 全局了解一下登录模
  • 在idea中安装并且使用easy code插件 ,以及在idea中配置mysql数据库

    在idea中安装并且使用easy code插件 以及在idea中配置mysql数据库 1 从导航栏进入设置页面 2 点击plugins选项 在输入框中输入easy code查找 并点击installed安装 下载安装好了以后需要重启软件 点
  • GNURadio报错Unable to create context(windows10环境)

    GNURadio报错Unable to create context windows10环境 这里本人使用的是GNU Radio3 7 11 iiosupport win64 版本 外设是ADI的ADALM PLUTO 这里本人使用的是GN
  • 多维时序

    多维时序 MATLAB实现ELM极限学习机多维时序预测 股票价格预测 目录 多维时序 MATLAB实现ELM极限学习机多维时序预测 股票价格预测 效果一览 基本介绍 程序设计 结果输出 参考资料 效果一览 基本介绍
  • 2018-12-13 LeetCode Q5 最长回文子串

    5 最长回文子串 给定一个字符串 s 找到 s 中最长的回文子串 你可以假设 s 的最大长度为 1000 示例 1 输入 babad 输出 bab 注意 aba 也是一个有效答案 示例 2 输入 cbbd 输出 bb 暴力解法 6004ms
  • 关于linux进程间的close-on-exec机制

    转载请注明出处 帘卷西风的专栏 http blog csdn net ljxfblog 前几天写了一篇博客 讲述了端口占用情况的查看和解决 关于linux系统端口查看和占用的解决方案 大部分这种问题都能够解决 在文章的最后 提到了一种特殊情
  • 判断字符串的两半是否相似(1704.leetcode)-------------------c++实现

    判断字符串的两半是否相似 1704 leetcode unordered map c 实现 题目表述 给你一个偶数长度的字符串 s 将其拆分成长度相同的两半 前一半为 a 后一半为 b 两个字符串 相似 的前提是它们都含有相同数目的元音 a
  • Tensorflow学习(五)——多任务学习验证码识别实战

    一 验证码生成 验证码生成脚本 使用captcha包提供的ImageCaptcha方法 from captcha image import ImageCaptcha import sys import random import numpy