TensorFlow 图像分类

2024-02-25

我对 TensorFlow 很陌生。我正在做图像分类使用我自己的训练数据库。

然而,在我训练了自己的数据集之后,我不知道如何对输入图像进行分类。

这是我的代码准备我自己的数据集

filenames = ['01.jpg', '02.jpg', '03.jpg', '04.jpg']
label = [0,1,1,1]
filename_queue = tf.train.string_input_producer(filenames)

reader = tf.WholeFileReader()
filename, content = reader.read(filename_queue)
image = tf.image.decode_jpeg(content, channels=3)
image = tf.cast(image, tf.float32)
resized_image = tf.image.resize_images(image, 224, 224)

image_batch , label_batch= tf.train.batch([resized_image,label], batch_size=8, num_threads = 3, capacity=5000)

这是训练数据集的正确代码吗?

之后,我尝试使用它通过以下代码对输入图像进行分类。

test = ['test.jpg', 'test2.jpg']
test_queue=tf.train.string_input_producer(test)
reader = tf.WholeFileReader()
testname, test_content = reader.read(test_queue)
test = tf.image.decode_jpeg(test_content, channels=3)
test = tf.cast(test, tf.float32)
resized_image = tf.image.resize_images(test, 224,224)

with tf.Session() as sess:
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)
    res = sess.run(resized_image)
    coord.request_stop()
    coord.join(threads)

但是,它不会返回输入图像的预测标签。 我正在寻找有人教我如何使用自己的数据集对图像进行分类。

谢谢。


也许你可以在安装 PIL python lib 后尝试这个:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import time
import math
import numpy
import numpy as np
import random
from PIL import Image
from six.moves import xrange  # pylint: disable=redefined-builtin
import tensorflow as tf

# Basic model parameters as external flags.
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.')
flags.DEFINE_integer('max_steps', 2000, 'Number of steps to run trainer.')
flags.DEFINE_integer('hidden1', 128, 'Number of units in hidden layer 1.')
flags.DEFINE_integer('hidden2', 32, 'Number of units in hidden layer 2.')
flags.DEFINE_integer('batch_size', 4, 'Batch size.  '
                     'Must divide evenly into the dataset sizes.')
flags.DEFINE_string('train_dir', 'data', 'Directory to put the training data.')
flags.DEFINE_boolean('fake_data', False, 'If true, uses fake data '
                     'for unit testing.')
NUM_CLASSES = 2 
IMAGE_SIZE = 28 
CHANNELS = 3
IMAGE_PIXELS = IMAGE_SIZE * IMAGE_SIZE * CHANNELS





def inference(images, hidden1_units, hidden2_units):
  # Hidden 1
  with tf.name_scope('hidden1'):
    weights = tf.Variable(
        tf.truncated_normal([IMAGE_PIXELS, hidden1_units],
                            stddev=1.0 / math.sqrt(float(IMAGE_PIXELS))),
        name='weights')
    biases = tf.Variable(tf.zeros([hidden1_units]),
                         name='biases')
    hidden1 = tf.nn.relu(tf.matmul(images, weights) + biases)
  # Hidden 2
  with tf.name_scope('hidden2'):
    weights = tf.Variable(
        tf.truncated_normal([hidden1_units, hidden2_units],
                            stddev=1.0 / math.sqrt(float(hidden1_units))),
        name='weights')
    biases = tf.Variable(tf.zeros([hidden2_units]),
                         name='biases')
    hidden2 = tf.nn.relu(tf.matmul(hidden1, weights) + biases)
  # Linear
  with tf.name_scope('softmax_linear'):
    weights = tf.Variable(
        tf.truncated_normal([hidden2_units, NUM_CLASSES],
                            stddev=1.0 / math.sqrt(float(hidden2_units))),
        name='weights')
    biases = tf.Variable(tf.zeros([NUM_CLASSES]),
                         name='biases')
    logits = tf.matmul(hidden2, weights) + biases
  return logits


def cal_loss(logits, labels):
  labels = tf.to_int64(labels)
  cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
      logits, labels, name='xentropy')
  loss = tf.reduce_mean(cross_entropy, name='xentropy_mean')
  return loss


def training(loss, learning_rate):
  optimizer = tf.train.GradientDescentOptimizer(learning_rate)
  global_step = tf.Variable(0, name='global_step', trainable=False)
  train_op = optimizer.minimize(loss, global_step=global_step)
  return train_op


def evaluation(logits, labels):
  correct = tf.nn.in_top_k(logits, labels, 1)
  return tf.reduce_sum(tf.cast(correct, tf.int32))


def placeholder_inputs(batch_size):
  images_placeholder = tf.placeholder(tf.float32, shape=(batch_size,IMAGE_PIXELS))
  labels_placeholder = tf.placeholder(tf.int32, shape=(batch_size))
  return images_placeholder, labels_placeholder

def fill_feed_dict(images_feed,labels_feed, images_pl, labels_pl):
  feed_dict = {
      images_pl: images_feed,
      labels_pl: labels_feed,
  }
  return feed_dict

def do_eval(sess,
            eval_correct,
            images_placeholder,
            labels_placeholder,
            data_set):
  # And run one epoch of eval.
  true_count = 0  # Counts the number of correct predictions.
  steps_per_epoch = 4 // FLAGS.batch_size
  num_examples = steps_per_epoch * FLAGS.batch_size
  for step in xrange(steps_per_epoch):
    feed_dict = fill_feed_dict(train_images,train_labels,
                               images_placeholder,
                               labels_placeholder)
    true_count += sess.run(eval_correct, feed_dict=feed_dict)
  precision = true_count / num_examples
  print('  Num examples: %d  Num correct: %d  Precision @ 1: %0.04f' %
        (num_examples, true_count, precision))

# Get the sets of images and labels for training, validation, and
train_images = []
for filename in ['01.jpg', '02.jpg', '03.jpg', '04.jpg']:
  image = Image.open(filename)
  image = image.resize((IMAGE_SIZE,IMAGE_SIZE))
  train_images.append(np.array(image))

train_images = np.array(train_images)
train_images = train_images.reshape(4,IMAGE_PIXELS)

label = [0,1,1,1]
train_labels = np.array(label)

def run_training():
  # Tell TensorFlow that the model will be built into the default Graph.
  with tf.Graph().as_default():
    # Generate placeholders for the images and labels.
    images_placeholder, labels_placeholder = placeholder_inputs(4)

    # Build a Graph that computes predictions from the inference model.
    logits = inference(images_placeholder,
                             FLAGS.hidden1,
                             FLAGS.hidden2)

    # Add to the Graph the Ops for loss calculation.
    loss = cal_loss(logits, labels_placeholder)

    # Add to the Graph the Ops that calculate and apply gradients.
    train_op = training(loss, FLAGS.learning_rate)

    # Add the Op to compare the logits to the labels during evaluation.
    eval_correct = evaluation(logits, labels_placeholder)

    # Create a saver for writing training checkpoints.
    saver = tf.train.Saver()

    # Create a session for running Ops on the Graph.
    sess = tf.Session()

    # Run the Op to initialize the variables.
    init = tf.initialize_all_variables()
    sess.run(init)

    # And then after everything is built, start the training loop.
    for step in xrange(FLAGS.max_steps):
      start_time = time.time()
      feed_dict = fill_feed_dict(train_images,train_labels,
                                 images_placeholder,
                                 labels_placeholder)
      _, loss_value = sess.run([train_op, loss],
                               feed_dict=feed_dict)
      duration = time.time() - start_time
      if step % 100 == 0:
        # Print status to stdout.
        print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration))
      if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps:
        saver.save(sess, FLAGS.train_dir, global_step=step)
        print('Training Data Eval:')
        do_eval(sess,
                eval_correct,
                images_placeholder,
                labels_placeholder,
                train_images)

def main(_):
  run_training()
if __name__ == '__main__':
  tf.app.run()
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

TensorFlow 图像分类 的相关文章

随机推荐

  • 如何从 SQL 二进制字段获取单个字节

    我在 SQL Server 中有一个二进制字段 我想在 SQL 函数中一次读取一个字节 在代码中我将使用字节数组 SQL 中有等效的吗 我用谷歌找不到任何东西 SUBSTRING 函数应该足够了 一个简单的例子 假设表 MyTable 的列
  • 使用 jquery/javascript 增加点击时的 CSS 亮度颜色?

    因此 如果我有一个文本 Click Me to Brighten 它具有某种深绿色十六进制颜色 如 00801a 的CSS颜色属性 我想将其设置为当我单击它时 它会使其变为浅绿色 同样 如果它是某种蓝色 单击它会使它变浅蓝色 基本上我想知道
  • 可可:框架和边界有什么区别?

    UIView及其子类都具有以下属性frame and bounds 有什么不同 The bounds of an UIView http developer apple com iPhone library documentation UI
  • 适用于 Java、Python、Ruby、Node.JS 和 PHP 的开放支付网关库

    我正在寻找支持许多不同支付处理器 API 的通用开源支付库 换句话说 我想开发一个使用单一支付处理 API 的应用程序 但能够轻松地在支付网关之间切换 例如 Authorize Net Payflow Pro Braintree PayPa
  • 信号器程序集加载问题 OWIN

    我在尝试加载类时收到此错误Microsoft AspNet SignalR Owin集会 执行离开后抛出异常Configuration中的方法startup cs 我已经注册了一个全局异常处理程序来尝试捕获异常 但它没有被捕获 public
  • 警报显示使用已弃用的 HREF 而没有绝对 URL

    Facebook 开发者页面中的消息提醒我的网站当前正在使用以下已弃用的功能 社交插件 Like Button Like Box 中没有绝对 URLhref范围 此问题必须在 2013 年 7 月之前解决 我猜它正在谈论 喜欢 的 data
  • TDirect2DCanvas 速度慢还是我做错了什么?

    在寻找替代 GDI 的替代品时 我试图测试 Delphi 的 2010TDirect2D画布Windows 7 中的性能 我通过使用 Direct2D 绘制一条巨大的折线来测试它 结果速度慢得离谱 即使数据量比我使用 GDI 运行相同测试的
  • Instagram API 不返回关注者

    我已通过 Instagram 进行身份验证 并且获得了具有范围的访问令牌follower list 然后我尝试获取我的关注者列表 https api instagram com v1 users self followed by acces
  • 未知的指令类型“toctree”。 Pycharm 出错,但 index.html 有效

    在 PyCharm 中工作时 我在 Sphinx 中创建的文档的预览模式显示 System Message ERROR 3
  • “Where like”子句使用 2 列的串联值与雄辩

    我有一个查询 在多个列中搜索一个术语 其中之一必须是全名 我已将姓名和姓氏分开 因此在搜索时必须连接这两个值 我现在只有搜索名字 我如何将连接添加到姓氏 我正在调查突变体 但我不知道这是否是正确的方法 public function sea
  • 初学者的 C 套接字编程

    我刚刚开始学习套接字编程 发现它非常有趣 目前我正在制作服务器和客户端在同一台计算机上因此我可以拥有IP地址作为环回地址 127 0 0 1一切似乎都运行良好 但现在我正在考虑拥有两台计算机并做这件事 我有以下问题 假设一台计算机是服务器
  • Android SQLiteConstraintException:错误代码19:约束失败

    我已经看到了有关此异常的其他问题 但所有这些问题似乎都通过解决方案解决了 即已存在指定主键的行 对我来说似乎并非如此 我尝试用双引号替换字符串中的所有单引号 但出现了同样的问题 我正在尝试通过执行以下操作将一行插入到我创建的 SQLite
  • 使用 T & F 代替 TRUE & FALSE 有什么问题吗?

    我注意到使用T and F代替TRUE and FALSER 中的函数给了我相同的结果 当然 T and F更简洁 但是 我明白了TRUE and FALSE被更频繁地使用 我想知道两者之间有什么区别吗 使用有什么问题吗T and F T
  • oracle sqlplus中获取sql脚本的执行时间

    我有一个脚本 用于将数据加载到 Oracle 中的表中 通过插入语句列表 如何获取整个加载过程的执行时间 我尝试过set timing on 但这给了我每个插入语句的持续时间 而不是整个过程的持续时间 脚本如下所示 spo load log
  • 是否可以将 supertest 与 hapi 一起使用?

    我用的是hapi 不是express 超级测试还应该有效吗 如果是这样 有没有一种快速方法可以更改我的代码以使其运行 我的测试看起来像这样 基于文档 https github com visionmedia supertest import
  • 如何在 Mockito 中模拟 CompletableFuture 的完成

    我想模拟当某个代码被调用时CompletableFuture已成功完成 我有这门课 public class MyClassImplementRunner implements Runnable private final String p
  • 通过参数对函数调用进行反跳

    David Walsh 拥有出色的去抖动实现here https davidwalsh name javascript debounce function Returns a function that as long as it cont
  • Firebase:观察 childAdded 返回现有/旧记录?

    我有一个查询 用 swift 编写 FIRDatabase database reference withPath ORDERS PATH lId child orders observe childAdded with firebaseS
  • Bash sqlite3 行 |如何转换为JSON格式

    我想将数据库中的 sqlite 数据转换为 JSON 格式 我想使用这个语法 sqlite3 linemembers db 从成员LIMIT 3中选择 gt members txt OUTPUT id 1 fname Leif gname
  • TensorFlow 图像分类

    我对 TensorFlow 很陌生 我正在做图像分类使用我自己的训练数据库 然而 在我训练了自己的数据集之后 我不知道如何对输入图像进行分类 这是我的代码准备我自己的数据集 filenames 01 jpg 02 jpg 03 jpg 04