Tensorflow高级API

2023-11-15

一、Estimator

1、介绍

  • 编程堆栈

编程堆栈

  • Estimator:代表一个完整的模型。Estimator API 提供一些方法来训练模型、判断模型的准确率并生成预测。
  • 数据集:构建数据输入管道。Dataset API 提供一些方法来加载和操作数据,并将数据馈送到您的模型中。Dataset APIEstimator API 合作无间

2、鸢尾花进行分类

  • 数据集介绍:4个属性,分为3类:
花萼长度 花萼宽度 花瓣长度 花瓣宽度 品种(标签)
5.1 3.3 1.7 0.5 0(山鸢尾)
5.0 2.3 3.3 1.0 1(变色鸢尾)
6.4 2.8 5.6 2.2 2(维吉尼亚鸢尾)
  • 网络模型

网络模型

3、实现

  • EstimatorTensorFlow 对完整模型的高级表示。它会处理初始化、日志记录、保存和恢复等细节部分,并具有很多其他功能,以便您可以专注于模型。
3.1 预创建模型
import tensorflow as tf
import argparse
import iris_data


# 超参数
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', default=100, type=int, help="batch size")
parser.add_argument('--train_steps', default=1000, type=int, help="number of training steps")
  • 构建模型
    • 特征列:feature_column:特征列是一个对象,用于说明模型应该如何使用特征字典中的原始输入数据。在构建 Estimator 模型时,您会向其传递一个特征列的列表,其中包含您希望模型使用的每个特征。tf.feature_column 模块提供很多用于向模型表示数据的选项。
      • 对于鸢尾花问题,4 个原始特征是数值,因此我们会构建一个特征列的列表,以告知 Estimator 模型将这 4 个特征都表示为 32 位浮点值。
    • 实例化 Estimator: 使用的是预创建模型 cls = tf.estimator.DNNClassifier()模型
    • 训练模型 cls.train(input_fn, hooks=None, steps=None, max_steps=None, saving_listeners=None)
      • input_fn指定输入的函数,包含 (features, labels)tf.data.Dataset 类型的数据
      • steps 参数告知方法在训练多少步后停止训练。
    • 评估经过训练的模型:eval_res = cls.evaluate(input_fn, steps=None, hooks=None, checkpoint_path=None, name=None)
      • 输入和训练数据一致
      • 返回的有{'accuracy': 1.0, 'loss': 3.936471, 'average_loss': 0.1312157, 'global_step': 100}
    • 预测: predictions = cls.predict(input_fn, predict_keys=None, hooks=None, checkpoint_path=None, yield_single_examples=True)
      • 输入数据为 batch_size 的测试数据,不包含 label,返回生成器结果
def main(argv):
    args = parser.parse_args(argv[1:])
    # 加载数据, pandas类型
    (train_x, train_y), (test_x, test_y) = iris_data.load_data()
    # feature columns描述如何使用输入数据
    my_feature_columns = []
    for key in train_x.keys():
        my_feature_columns.append(tf.feature_column.numeric_column(key = key))
    # 建立模型
    cls = tf.estimator.DNNClassifier(hidden_units=[10,10], feature_columns=my_feature_columns, 
                                    n_classes=3)
    # 训练模型
    cls.train(input_fn=lambda:iris_data.train_input_fn(train_x, train_y, args.batch_size),
              steps=args.train_steps)
    # 评价模型
    eval_res = cls.evaluate(input_fn=lambda:iris_data.eval_input_fn(test_x, test_y, args.batch_size))
    print("\n Test Set accuracy: {:0.3f}\n".format(eval_res['accuracy']))
    
    # 预测
    expected = ['Setosa', 'Versicolor', 'Virginica']
    predict_x = {
        'SepalLength': [5.1, 5.9, 6.9],
        'SepalWidth':  [3.3, 3.0, 3.1],
        'PetalLength': [1.7, 4.2, 5.4],
        'PetalWidth':  [0.5, 1.5, 2.1],        
    }
    
    predictions = cls.predict(input_fn=lambda:iris_data.eval_input_fn(predict_x, 
                                                                      labels=None,
                                                                      batch_size=args.batch_size))
    template = ('\n Prediction is "{}" ({:.1f}%), expected "{}"' )
    for pred_dict, expec in zip(predictions, expected):
        class_id = pred_dict['class_ids'][0]
        prob = pred_dict['probabilities'][class_id]
        print(template.format(iris_data.SPECIES[class_id], 100*prob, expec))
  • 运行函数
    • tf.app.run(main=main)会先解析命令行参数,然后执行main函数
if __name__ == "__main__":
    tf.logging.set_verbosity(tf.logging.INFO)
    tf.app.run(main=main)
  • 保存和加载模型
    • 指定模型地址即可:model_dir,在第一次训练时会保存模型
      first train call
      • 如果未在 Estimator 的构造函数中指定 model_dir,则 Estimator 会将检查点文件写入由 Pythontempfile.mkdtemp 函数选择的临时目录中,可以print(classifier.model_dir)查看
    • 检查点频率:
      • 默认
        • 10 分钟(600 秒)写入一个检查点。
        • train 方法开始(第一次迭代)和完成(最后一次迭代)时写入一个检查点。
        • 只在目录中保留 5 个最近写入的检查点。
      • 自己配置:
    my_checkpoint_config = tf.estimator.RunConfig(save_checkpoints_secs = 20*60,   # 每20分钟保存一次
                                                  keep_checkpoint_max = 10)        # 保存10个最近的检查点
    cls = tf.estimator.DNNClassifier(hidden_units=[10,10], feature_columns=my_feature_columns, 
                                    n_classes=3,
                                    model_dir='model/',
                                    config=my_checkpoint_config)
    • 加载模型
      • 不需要改动,一旦存在检查点,TensorFlow 就会在您每次调用 train()evaluate()predict() 时重建模型。
        subsequent_calls
3.2 自定义模型
  • 完整代码:点击查看
  • 预创建的 Estimatortf.estimator.Estimator 基类的子类,而自定义 Estimatortf.estimator.Estimator 的实例
    estimator types
  • 创建模型
    • 模型函数(即 model_fn)会实现机器学习算法
    • params 参数会传递给自己实现的模型
    cls = tf.estimator.Estimator(model_fn=my_model, 
                                 params={
                                    'feature_columns': my_feature_columns,
                                    'hidden_units': [10, 10],
                                    'num_classes': 3
                                    })
  • 自定义my_model函数:
    • 输入层指定输入的数据和对应的feature columns
    • 隐藏层通过tf.layers.dense()创建
    • 通过mode来判断是训练、评价还是预测操作,返回必须是tf.estimator.EstimatorSpec 对象
      input layer
def my_model(features, labels, mode, params):
    '''自定义模型
       ---------------------------------------------
       features: 输入数据
       labels  : 标签数据
       mode    : 指示是训练、评价还是预测
       params  : 构建模型的参数
    
    '''
    net = tf.feature_column.input_layer(features=features, 
                                        feature_columns=params['feature_columns'])   # 输入层
    for units in params['hidden_units']:                                             # 隐藏层,遍历参数配置
        net = tf.layers.dense(inputs=net, units=units, activation=tf.nn.relu)
    
    logits = tf.layers.dense(net, params['num_classes'], activation=None)
    pred = tf.argmax(logits, 1)    # 预测结果
    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {
            'class_ids': pred[:, tf.newaxis],
            'probabilities': tf.nn.softmax(logits),
            'logits': logits,
        }
        return tf.estimator.EstimatorSpec(mode, predictions=predictions)

    # 计算loss
    loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
    # 计算评价信息
    accuracy = tf.metrics.accuracy(labels=labels, predictions=pred, 
                                  name='acc_op')
    metrics = {'accuracy': accuracy}
    tf.summary.scalar(name='accuracy', tensor=accuracy[1])
    if mode == tf.estimator.ModeKeys.EVAL:
        return tf.estimator.EstimatorSpec(mode, loss=loss, eval_metric_ops=metrics)
    
    # 训练操作
    assert mode == tf.estimator.ModeKeys.TRAIN
    
    optimizer = tf.train.AdagradOptimizer(learning_rate=0.1)
    train_op = optimizer.minimize(loss=loss, global_step=tf.train.get_global_step())
    return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
  • TensorBoard 中查看自定义 Estimator 的训练结果。(预定义的模型结果展示更丰富一些)
    • tensorboard --logdir=PATH
    • global_step/sec:这是一个性能指标,显示我们在进行模型训练时每秒处理的批次数(梯度更新)。
      global step
    • loss:所报告的损失。
      loss
    • accuracy:准确率由下列两行记录:
      • eval_metric_ops={‘my_accuracy’: accuracy})(评估期间)。
      • tf.summary.scalar(‘accuracy’, accuracy1)(训练期间)。
        accuracy

二、Dataset

  • tf.data 模块包含一系列类,可让轻松地加载数据、操作数据并通过管道将数据传送到模型中。

1、基本输入

  • 从数组中提取接片,上面用到的代码
    • feature:特征数据,为feature-name: array的字典或者DataFrame

    • labels: 标签数组

    • from_tensor_slices 会按第一个维度进行切片,比如输入为[6000, 28, 28]维度的数据,切片后返回600028, 28Dataset 对象

    • shuffle 方法使用一个固定大小的缓冲区,在条目经过时随机化处理条目。在这种情况下,buffer_size 大于 Dataset 中样本的数量,确保数据完全被随机化处理。

    • repeat 方法会在结束时重启 Dataset。要限制周期数量,请设置 count 参数。

    • batch 方法会收集大量样本并将它们堆叠起来以创建批次。这为批次的形状增加了一个维度。新的维度将添加为第一个维度。

def train_input_fn(features, labels, batch_size):
    """训练集输入函数"""
    dataset = tf.data.Dataset.from_tensor_slices((dict(features,), labels))   # 转化为Dataset
    
    dataset = dataset.shuffle(buffer_size=1000).repeat().batch(batch_size)    # Shuffle, batch
    
    return dataset

2、读取CSV文件

  • 代码
  • 处理一行数据,line: tf.string类型
CSV_TYPES = [[0.0], [0.0], [0.0], [0.0], [0]]
def _parse_line(line):
    '''解析一行数据'''
    field = tf.decode_csv(line, record_defaults=CSV_TYPES)
    features = dict(zip(CSV_COLUMN_NAMES, field))
    labels = features.pop("Species")
    return features, labels
  • 处理text 文件,得到dataset
    • 读取文本类型为:<SkipDataset shapes: (), types: tf.string>
    • 然后使用map 函数,每个对象处理
      map函数示意
def csv_input_fn(csv_path, batch_size):
    '''csv文件输入函数'''
    dataset = tf.data.TextLineDataset(csv_path).skip(1)   # 跳过第一行
    dataset = dataset.map(_parse_line)        # 应用map函数处理dataset中的每一个元素
    dataset = dataset.shuffle(1000).repeat().batch(batch_size)
    return dataset

Reference

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Tensorflow高级API 的相关文章

随机推荐

  • python pymysql emoji表情插入mysql数据库异常记录报错 pymysql.err.InternalError

    在数据库存储微信小程序用户昵称时候 发现用户昵称使用emoji表情时候就存不了数据库中间报错 pymysql err InternalError 1366 Incorrect string value xF0 x9F x98 x81 xF0
  • STM32F429串口1配置

    static void ConfigUART u32 bound GPIO InitTypeDef GPIO InitStructure GPIO InitStructure用于存放GPIO的参数 USART InitTypeDef USA
  • gethostbyname() -- 用域名或主机名获取IP地址

    http hi baidu com zengzhaonong item 87d9d296d0824cbb82d29570 include
  • leetcode236—二叉树的最近公共祖先(递归/深搜/理解)

    给定一个二叉树 找到该树中两个指定节点的最近公共祖先 百度百科中 最近公共祖先的定义为 对于有根树 T 的两个节点 p q 最近公共祖先表示为一个节点 x 满足 x 是 p q 的祖先且 x 的深度尽可能大 一个节点也可以是它自己的祖先 深
  • 解决CSS引用字体跨域问题

    最近一个需求客户要求换字体需要引入字体 开始没有注意 后来发现会出现跨越现象 特别是在手机上很明显 通过解决尝试终于解决 希望可以帮到大家 1 解决方案就是将文字设置为 base64 编码 字体转base64编码网址 点击进去 下载文件解压
  • oracle wait class user i/o,[ORACLE]管理方面的脚本收集

    1 查询AWR相关的视图名称 SELECT table name FROM dba tables t WHERE table name LIKE WRH AND NOT EXISTS SELECT x FROM dba tab column
  • 英文键盘盲打最快速练习口诀和方法

    下面我提供几种在新建的文档 如用于处理文字的word 里进行英文盲打的练习参考资料 以一天练习四个字母计算 最多7天你的盲打就基本练习成功了 下面的口诀 能帮助你快速记住键盘字母的排列顺序 爱上一个不爱回家的人 七 q 碗 w n 鹅 肉
  • Map遍历取值的五种方式

    方法1 Set set map keySet for Object o set System out println o map get o 方法2 Set set map keySet Iterator iterator set iter
  • 准备WebUI自动化测试面试?这30个问题你必须掌握(二)

    本文共有11000字 包含了后十五个问题 如需要前十五个问题 可查看文末链接 16 在WebUI自动化测试中 你如何处理验证码或图像识别的问题 1 人工识别 一种简单但费时费力的方法是使用人工手动识别验证码 测试人员可以手动输入验证码 将其
  • IntelliJ IDEA中谷歌打开页面,出现windows 找不到文件chrome

    1 右击桌面上的chrome浏览器图标 找到属性 gt 快捷方式 gt 目标 复制路径 即chrome浏览器 exe文件的路径 2 打开IntelliJ IDEA软件 找到file gt settings gt 找到Web Browsers
  • Java实现国密算法SM2,SM3,SM4,并且实现ECB和CBC模式

    代码中实现了电码本ECB模式和密文分组连接CBC模式 SM3 java和SM4 java为算法实现类 utils的都是根据实现类写的工具 可以根据需要调用杂凑算法SM3的杂凑功能获得杂凑值 SM4 java中 sm4 crypt ecb S
  • iOS objc_msgSend iOS too many arguments in function call 报错解决方案

    Build Settings gt 搜索 objc gt 设置 Enable Strict Checking of objc msgSend Calls 为 NO
  • [django项目] 利用elasticsearch实现搜索功能

    新闻搜索 I 搜索功能分析 本节我们来完成新闻搜索功能 首先让我们来思考一下 要做一个通过关键词搜索文章的功能 需要搜索哪些字段 以及使用什么技术方案呢 既然我们是准备做新闻博客网站 那我们就可以拿同类型网站的做一下对比 例如CSDN 简书
  • docker系列-搭建本地私有仓库-registry容器的各种坑

    总结的坑 a 关注daemon json的书写格式 一句话可以错好几个点 b tag要清楚的表示registry服务器的信息 才能push上传成功 不是可有可无的信息 c tag中有版本号要清楚的写上 系统自动补全的是用latest 搭建过
  • RPC

    RPC 远程过程调用 是什么 简单的说 RPC就是从一台机器 客户端 上通过参数传递的方式调用另一台机器 服务器 上的一个函数或方法 可以统称为服务 并得到返回的结果 RPC 会隐藏底层的通讯细节 不需要直接处理Socket通讯或Http通
  • 安卓pwn - De1taCTF(BroadcastTest)

    BroadcastTest 背景 逆向APK可知程序中仅有MainActivity Message和三个Receiver类 前者实现了一个Parcelable类 后三个则是广播 其中Receiver1是exported的 接收并向Recei
  • jsts 学习

    性能问题一直困扰了我很长的时间 今天听同事介绍了一个网站 感觉视角开阔了许多 一直做GIS开发 原来不只是java有jts包 原来javascript也有这样的一个包 叫做jsts 这个包的功能跟java里面的jts包差不多 前段提供空间关
  • Linux--信号

    文章目录 信号入门 生活角度的信号 技术应用角度的信号 注意 信号概念 使用kill l命令可以查看系统定义的信号列表 信号处理常见的方式 产生信号 1 通过终端按键产生信号 Core Dump 使用core dump进行事后调试 2 系统
  • How far away ? 【HDU - 2586】【在线LCA算法讲解】

    题目链接 做些LCA的算法 还是很提高代码能力的 这道题就是典型的LCA模板 所以用它来练一下我的LCA算法还是很好的 我们要求的是在一棵树上的任意两点的相互距离 既然在一棵树上 就可以直接调用LCA来解了 我们先任取一根节点 我取的是1
  • Tensorflow高级API

    本文个人博客地址 点击查看 一 Estimator 1 介绍 编程堆栈 Estimator 代表一个完整的模型 Estimator API 提供一些方法来训练模型 判断模型的准确率并生成预测 数据集 构建数据输入管道 Dataset API