构建 keras 模型

2024-02-10

我不明白这段代码中发生了什么:

def construct_model(use_imagenet=True):
    # line 1: how do we keep all layers of this model ?
    model = keras.applications.InceptionV3(include_top=False, input_shape=(IMG_SIZE, IMG_SIZE, 3),
                                          weights='imagenet' if use_imagenet else None) # line 1: how do we keep all layers of this model ?

    new_output = keras.layers.GlobalAveragePooling2D()(model.output)

    new_output = keras.layers.Dense(N_CLASSES, activation='softmax')(new_output)
    model = keras.engine.training.Model(model.inputs, new_output)
    return model

具体来说,我的困惑是,当我们调用最后一个构造函数时

model = keras.engine.training.Model(model.inputs, new_output)

我们指定输入层和输出层,但它如何知道我们希望所有其他层保留?

换句话说,我们将 new_output 层附加到我们在第 1 行加载的预训练模型,即 new_output 层,然后在最终的构造函数(最后一行)中,我们只需创建并返回一个具有指定输入和的模型输出层,但它如何知道我们想要在中间放置哪些其他层?

附带问题1):keras.engine.training.Model 和 keras.models.Model 有什么区别?

附带问题 2):当我们执行 new_layer = keras.layers.Dense(...)(prev_layer) 时到底会发生什么? () 操作是否返回新层,它到底做了什么?


该模型是使用功能性API模型 https://keras.io/getting-started/functional-api-guide/

基本上它的工作原理是这样的(也许如果你在阅读本文之前转到下面的“附带问题2”,它可能会变得更清楚):

  • 你有一个输入张量(您也可以将其视为“输入数据”)
  • 您创建(或重用)图层
  • 您将输入张量传递给一个层(您用输入“调用”一个层)
  • 你得到一个输出张量

您继续使用这些张量,直到创建了整个张量graph.

但这还没有创建一个“模型”。 (你可以训练和使用其他东西)。
你所拥有的只是一张图表,告诉你哪些张量去哪里。

要创建模型,您需要定义其起点和终点。


在例子中。

  • 他们采用现有模型:model = keras.applications.InceptionV3(...)
  • 他们想要扩展这个模型,所以他们得到了它输出张量: model.output
  • 他们将此张量作为输入GlobalAveragePooling2D
  • 他们得到该层的输出张量为new_output
  • 他们将其作为输入传递给另一层:Dense(N_CLASSES, ....)
  • 并得到它的输出new_output(这个变量被替换,因为他们对保留其旧值不感兴趣......)

但是,由于它与函数式 API 一起使用,我们还没有模型,只有图表。为了创建模型,我们使用Model定义输入张量和输出张量:

new_model = Model(old_model.inputs, new_output)    

现在你有了你的模型。
如果你像我一样在另一个变量中使用它(new_model),旧模型仍然存在model。这些模型共享相同的层,每当你训练其中一个模型时,另一个模型也会更新。


问题:它如何知道我们想要在中间添加哪些其他层?

当你这样做时:

outputTensor = SomeLayer(...)(inputTensor)    

输入和输出之间有连接。 (Keras 将使用内部张量流机制并将这些张量和节点添加到图中)。如果没有输入,输出张量就不可能存在。整个InceptionV3模型从头到尾都是连接的。它的输入张量经过所有层以产生输出张量。数据遵循的方式只有一种可能,而图表就是方式。

当您获得该模型的输出并使用它来获得进一步的输出时,所有新输出都将连接到此模型,从而连接到模型的第一个输入。

大概是属性_keras_history添加到张量中的值与其跟踪图的方式密切相关。

所以,做Model(old_model.inputs, new_output)自然会遵循唯一可能的方式:图表。

如果您尝试使用未连接的张量执行此操作,您将收到错误。


附带问题1

更喜欢从“keras.models”导入。基本上,该模块将从其他模块导入:

  • https://github.com/keras-team/keras/blob/master/keras/models.py https://github.com/keras-team/keras/blob/master/keras/models.py

请注意该文件keras/models.py进口Model from keras.engine.training。所以,这是同样的事情。

附带问题2

它不是new_layer = keras.layers.Dense(...)(prev_layer).

It is output_tensor = keras.layers.Dense(...)(input_tensor).

你在同一行做两件事:

  • 创建一个图层 - 使用keras.layers.Dense(...)
  • 使用输入张量调用层以获得输出张量

如果您想使用具有不同输入的同一层:

denseLayer = keras.layers.Dense(...) #creating a layer

output1 = denseLayer(input1)  #calling a layer with an input and getting an output
output2 = denseLayer(input2)  #calling the same layer on another input
output3 = denseLayer(input3)  #again   

奖励 - 创建一个与顺序模型相同的功能模型

如果您创建此顺序模型:

model = Sequential()
model.add(Layer1(...., input_shape=some_shape))   
model.add(Layer2(...))
model.add(Layer3(...))

你所做的与以下完全相同:

inputTensor = Input(some_shape)
outputTensor = Layer1(...)(inputTensor)
outputTensor = Layer2(...)(outputTensor)    
outputTensor = Layer3(...)(outputTensor)

model = Model(inputTensor,outputTensor)

有什么不同?

嗯,函数式 API 模型是完全免费的,可以按照您想要的方式构建。您可以创建分支:

out1 = Layer1(..)(inputTensor)    
out2 = Layer2(..)(inputTensor)

您可以加入张量:

joinedOut = Concatenate()([out1,out2])   

有了这个,您可以创建anything你想要各种奇特的东西,分支,门,串联,添加等等,这是顺序模型无法做到的。

事实上,一个Sequential模型也是一个Model,但创建它是为了在没有分支的模型中快速使用。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

构建 keras 模型 的相关文章

  • 使用多个具有不同日志级别的处理程序时出现意外的 python 记录器输出

    我正在尝试将数据记录到 stderr 并记录到文件中 该文件应包含all日志消息 并且 stderr 应该只转到命令行上配置的日志级别 这在日志记录指南中多次描述 但它似乎对我不起作用 我创建了一个小测试脚本来说明我的问题 usr bin
  • 如何在 python 中使用 libSVM 计算精度、召回率和 F 分数

    我想计算precision recall and f score using libsvm在Python中 但我不知道如何 我已经发现这个网站 http www csie ntu edu tw cjlin libsvmtools eval
  • 使用 for 循环 Python 为数组赋值

    我正在尝试将字符串的值分配给不同的数组索引 但我收到一个名为 列表分配超出范围 的错误 uuidVal distVal uuidArray distArray for i in range len returnedList for beac
  • 插入多行并返回主键时 Sqlalchemy 的奇怪行为

    插入多行并返回主键时 我注意到一些奇怪的事情 如果我在 isert 查询中添加使用参数值 我会得到预期的行为 但是当将值传递给游标时 不会返回任何内容 这可能是一个错误还是我误解了什么 我的sqlachemy版本是0 9 4 下面如何重现错
  • 创建一个行为类似于任何变量但具有更改/读取回调的类

    我想创建一个类 其行为类似于 python 变量 但在更改 读取 变量 时调用一些回调函数 换句话说 我希望能够按如下方式使用该类 x myClass change callback read callback 将 x 定义为 myclas
  • 计算 for 循环期间的运行总计 - Python

    编辑 下面是我根据收到的反馈 答案编写的工作代码 这个问题源于我之前使用 MIT 的开放课件学习 Python CS 时提出的问题 在这里查看我之前的问题 https stackoverflow com questions 4990159
  • 熊猫 style.background_gradient 忽略 NaN

    我有以下代码来转储数据帧results到 HTML 表格中 这样的列TIME FRAMES根据seaborn 的颜色图进行着色 import seaborn as sns TIME FRAMES 24h 7d 30d 1y Set CSS
  • 在 PyCharm 中运行命令行命令

    你好 我正在使用Python 但之前从未真正使用过它 我收到一些命令 需要在终端中运行 基本上 python Test py GET feeds 我正在使用 PyCharm 我想知道是否有办法从该 IDE 中运行这些相同的命令 按 Alt
  • 如何对嵌套函数进行单元测试? [复制]

    这个问题在这里已经有答案了 您将如何对嵌套函数进行单元测试f1 在下面的例子中 def f def f1 return 1 return 2 或者需要测试的函数不应该嵌套吗 有一个类似的问题这个链接 https stackoverflow
  • 替换 pandas 数据框中的点

    我有一个如图所示的数据框 数字实际上是对象 正在做df treasury rate pd to numeric df treasury rate 可预见的炸弹 然而 做df replace np nan 似乎没有摆脱这个点 所以我很困惑 有
  • 使用 python 写入 aws lambda 中的 /tmp 目录

    Goal 我正在尝试将 zip 文件写入 python aws lambda 中的 tmp 文件夹 因此我可以在压缩之前提取操作 并将其放入 s3 存储桶中 Problem 操作系统 Errno30 只读文件系统 这段代码在我的计算机上进行
  • 将 Python 控制台集成到 GUI C++ 应用程序中

    I m going to add a python console widget into a C GUI below some other controls 许多类将暴露给 python 代码 包括一些对 GUI 的访问 也许我会考虑 P
  • wxPython:更新wx.ListBox列表

    我在 python 程序中有一个 wx ListBox 我不想在 wx Timer 更新时更改其中的列表 我的计时器正在工作 我只是不知道如何更改它显示的列表 这是一个例子 http www daniweb com code snippet
  • 如何在 Python 中包含 PHP 脚本?

    我有一个 PHP 脚本 news generator php 当我包含它时 它会抓取一堆新闻项并打印它们 现在 我在我的网站 CGI 中使用 Python 当我使用 PHP 时 我在 新闻 页面上使用了这样的内容 为了简单起见 我删掉了这个
  • Python:如何使用生成器来避免 sql 内存问题

    我有以下方法来访问 mysql 数据库 并且查询在服务器中执行 我无权更改有关增加内存的任何内容 我对生成器很陌生 并开始阅读更多有关它的内容 并认为我可以将其转换为使用生成器 def getUNames self globalUserQu
  • 检测计算机何时解锁 Windows

    我用过这个优秀的方法 https stackoverflow com questions 20733441 lock windows workstation using python 20733443锁定 Windows 计算机 那部分工作
  • Tensorboard——High-level节点的计算时间与其子节点计算时间的总和不同

    继tutorial https www tensorflow org programmers guide graph viz在 TensorFlow 上 我试图使用张量板来理解运行时统计数据 我发现代表名称范围的高级节点的计算时间不等于其子
  • 有效积累稀疏 scipy 矩阵的集合

    我有一个 O N NxN 的集合scipy sparse csr matrix 每个稀疏矩阵都有 N 个元素集 我想将所有这些矩阵加在一起以获得一个常规的 NxN numpy 数组 N 约为 1000 矩阵内非零元素的排列使得所得总和肯定不
  • 从 xgb.train() 获取概率

    我是 Python 和机器学习的新手 我在网上搜索了我的问题 并尝试了人们建议的解决方案 但仍然没有得到它 如果有人能帮助我 我将非常感激 我正在开发我的第一个 XGboost 模型 我已经使用 xgb XGBClassifier 调整了参
  • 无法在 Python 2.4 中解码 unicode 字符串

    这是Python 2 4 中的 这是我的情况 我从数据库中提取一个字符串 它包含一个变音的 o xf6 此时 如果我运行 type value 它会返回 str 然后我尝试运行 decode utf 8 但收到错误 utf8 编解码器无法解

随机推荐

  • 如何动态添加和扩展私有数据集合?

    设想 I have 3个组织 O1 O3 O1 是申请人的组织 O2 O3 管理与他们共享的公共和私人数据 O1 O3 彼此共享私有数据 O1 O2 共享私有数据 网络正在运行 集合已经分发 一切正常 当我现在想要添加更多组织 以千计 O4
  • 为什么从另一个文件导入类会调用 __init__ 函数?

    该项目的结构是 project 主 py 会话 py 蜘蛛 py session py中有一个类 import requests class Session def init self self session requests Sessi
  • 如何通过反射找出方法的可见性?

    Context 我正在尝试学习 练习 TDD 并决定我需要创建一个不可变的类 为了测试 不变性不变量 你能这么说吗 我想我只需通过反射调用类中的所有公共方法 然后检查类之后是否没有更改 这样我以后就不太可能不小心破坏这个不变量了 这本身可能
  • 为什么Python中的元组可以使用reversed但没有__reversed__?

    在讨论中这个答案 https stackoverflow com questions 9449674 how to implement a persistent python list 9449852 9449852我们意识到元组没有 re
  • 更正应用程序的类路径,使其包含类 Log4J2LoggingSystem 和 PropertiesUtil 的兼容版本

    我正在将一个项目从 Spring Boot 2 6 1 迁移到 Spring Boot 3 0 2 但我遇到了 log4j 依赖项版本的问题 我已经修改了所有给我带来问题的依赖项 但我仍然无法解决问题 错误如下 Java HotSpot T
  • Flowplayer 播放一切

    我有一个flowplayer我正在使用它 下面有几张图片 当您点击这些图片时dialog是用这些图片的放大版本创建的 问题是flowplayer永远会在最上面dialog 我尝试过设置z index of the dialog高和flowp
  • 如何在 SwiftUI 中处理拖动到停靠栏图标上的操作?

    我已经设置了一个 SwiftUI 应用程序 它似乎接受拖放到停靠图标上的图像 但我无法弄清楚在应用程序代码中处理拖放图像的位置 如何处理将图像 或任何特定文件 拖放到 SwiftUI 应用程序的停靠图标上 背景 对于使用 NSApplica
  • 将枚举数据绑定到 WPF + MVVM 中的组合框

    我读了这个非常相关的问题在这里 https stackoverflow com questions 58743 databinding an enum property to a combobox in wpf 由于答案中的链接 这非常有帮
  • Golang:将文件附加到现有的 tar 存档中

    如何将文件附加到 Go 中现有的 tar 存档中 我没有看到任何明显的东西docs http golang org pkg archive tar 关于如何去做 我有一个已经创建的 tar 文件 我想在它关闭后向其中添加更多内容 EDIT
  • 为什么我不必在第二个 TableViewController 中释放 ManagedObjectContext

    我有两个显示 CoreData 对象的表视图控制器 一种是详细视图 带句子 一种是概述 带故事 选择一个故事 gt 查看句子 看来我过度释放了管理对象上下文 我最初在 dealloc 的两个 TableViewController 中发布了
  • 优化Python代码

    关于优化此 python 代码的任何提示寻找下一个回文 输入号码可以为1000000位 添加评论 usr bin python def inc lst lng this function first extract the left hal
  • 修复 Swift 3 中的警告“C-style for Statement is deprecated”

    我有更新Xcode到 7 3 现在我对用于创建随机字符串的函数发出警告 我尝试过改变for声明与for i in 0 lt len 然而 警告变成了错误 我怎样才能删除警告 static func randomStringWithLengt
  • Swift stdlib 工具错误

    我在使用 Xcode 8 1 和 Swift 3 编译时遇到此错误 Swift stdlib 工具错误 编译日志的末尾如下所示 Users Library Developer Xcode DerivedData Build Products
  • 让用户将记录器注入 Nodejs 模块的最佳实践

    我为 nodejs 编写了这个模块 可用于通过 sockjs 从任何地方向客户端分派事件 现在我想包括一些可配置的日志记录机制 目前 我将 winston 添加为依赖项 要求它作为每个类中的记录器并使用 logger error logge
  • 如何使用 MATLAB 和 JDBC 加速表检索?

    我正在使用 MATLAB 调用的 JDBC 访问 PostGreSQL 8 4 数据库 我感兴趣的表基本上由不同数据类型的各个列组成 他们是通过时间戳来选择的 由于我想检索大量数据 因此我正在寻找一种使请求比现在更快的方法 我现在正在做的事
  • 如何在 XAML 中使用 C# 中定义的画笔资源

    到目前为止我有这个
  • 新的 Conda 环境以及适用于 Jupyter Notebook 的最新 Python 版本

    由于 Python 版本变化很少 我总是忘记如何使用最新的 Python for Jupyter Notebook 创建新的 Conda 环境 所以我想下次将其列出来 从 StackOverflow 来看 有一些答案不再有效 下面是我在 S
  • 从 Apache Cordova 开始

    我刚刚下载了 Apache Cordova 似乎有特定于平台的版本 在将其移植到另一个平台之前 我是否必须为特定平台编写代码 是否可以创建一个多平台项目 我是否正确理解了我应该开始工作的方式 Apache Cordova 主页也是这么说的
  • 网络应用程序的照片存储[重复]

    这个问题在这里已经有答案了 可能的重复 用户镜像 数据库与文件系统存储 https stackoverflow com questions 585224 user images database vs filesystem storage
  • 构建 keras 模型

    我不明白这段代码中发生了什么 def construct model use imagenet True line 1 how do we keep all layers of this model model keras applicat