tf.keras - 尽管使用 GPU 内存,但第一个时期的训练并未取得进展

2023-12-31

我一直在尝试训练使用 Keras 的 Tensorflow 实现编写的 CNN。看起来训练在到达第一个 epoch 时就陷入了困境——尽管根据 nvidia-smi 的说法,我的 GPU 似乎仍在使用内存。也没有错误消息或回溯打印到终端,这使得调试对我来说有点棘手。我还使用 TF 估计器和数据集编写了此代码,当我将其放置过夜时,网络并未进行训练。因此,我不认为这只是让代码运行更长时间的情况 - 这可能是我所做的事情,但也可能是由于(据称已修复)错误(根据下面的第二个链接)。

目前,我还尝试使用 model.fit() 中的“verbose”参数来跟踪训练过程,以查看是否发生了任何情况。但我没有看到终端中出现任何内容。其他遇到此问题的人似乎仍然会出现进度条。

我还尝试使用 TensorBoard 进行日志记录并保存模型检查点。没有保存检查点,并且关于 Tensorboard,看起来也没有保存图表。

关于可能导致这种情况的原因有什么想法吗?

无法通过第一个纪元——只是挂起 [Keras 迁移学习初始阶段] https://stackoverflow.com/questions/47382952/cant-get-past-first-epoch-just-hangs-keras-transfer-learning-inception

Keras 拟合在第一个 epoch 结束时冻结 https://stackoverflow.com/questions/48748413/keras-fit-freezes-at-the-end-of-the-first-epoch

import os
import tensorflow as tf
from tensorflow import keras
import cv2
import numpy as np
from tensorflow.python.framework.graph_util import convert_variables_to_constants
from tensorflow.python.keras import backend as K

cwd = os.getcwd()
log_dir = cwd + "/Keras_Model/"
callbacks = [keras.callbacks.ModelCheckpoint(filepath="./Checkpoints/weights.{epoch:02d}-{val_loss:.2f}.hdf5"),
         keras.callbacks.TensorBoard(log_dir="./logs")]

def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
"""
TAKEN FROM HERE: https://stackoverflow.com/questions/45466020/how-to-export-keras-h5-to-tensorflow-pb
Freezes the state of a session into a pruned computation graph. Used later to save model as TF pb file.

Creates a new computation graph where variable nodes are replaced by
constants taking their current value in the session. The new graph will be
pruned so subgraphs that are not necessary to compute the requested
outputs are removed.

@param session The TensorFlow session to be frozen.
@param keep_var_names A list of variable names that should not be frozen,
                      or None to freeze all the variables in the graph.
@param output_names Names of the relevant graph outputs.
@param clear_devices Remove the device directives from the graph for better portability.
@return The frozen graph definition.
"""
graph = session.graph
with graph.as_default():
    freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
    output_names = output_names or []
    output_names += [v.op.name for v in tf.global_variables()]
    input_graph_def = graph.as_graph_def()
    if clear_devices:
        for node in input_graph_def.node:
            node.device = ""
    frozen_graph = convert_variables_to_constants(session, input_graph_def,
                                                  output_names, freeze_var_names)
    return frozen_graph

### IMPORT TRAINING IMAGES AS NUMPY ARRAY ###

t_dir = cwd + "/data-1/training/" 
e_dir = cwd + "/data-1/evaluation"

xtrain = []
ytrain = []

print(" - Collating training data and labels... - ")

for subdir, dirs, files in os.walk(t_dir):
    for f in files:
        img = os.path.join(subdir, f)
        x = cv2.imread(img) # --> Produces 8-bit tensor from image file.
        y = int(img.split("/")[-2]) - 1 # --> Get label from file path.
        xtrain.append(x)
        ytrain.append(y)

data = np.asarray(xtrain)
print(" - Training data collated. - ")
labels = np.asarray(ytrain)
print(" - Training labels collated. - ")


### IMPORT EVALUATION IMAGES AS TF ITERATOR ###

xeval = []
yeval = []

print(" - Collating validation data and labels... - ")

for subdir, dirs, files in os.walk(e_dir):
    for f in files:
        img = os.path.join(subdir, f)
        x = cv2.imread(img) # --> Produces 8-bit tensor from image file.
        y = int(img.split("/")[-2]) - 1 # --> Get label from file path.
        xeval.append(x)
        yeval.append(y)

 val_data = np.asarray(xeval)
 print(" - Validation data collated. - ")
 val_labels = np.asarray(yeval)
 print(" - Validation labels collated. - ")

 ### CREATE MODEL ###

 model = keras.Sequential()

 model.add(keras.layers.Conv2D(filters=32, kernel_size=5, strides=1, padding="same", data_format = "channels_last", activation="relu", input_shape=    (480,640,3)))

 model.add(keras.layers.GlobalMaxPool2D(data_format = "channels_last"))

 model.add(keras.layers.Dense(64, activation="relu"))

 model.add(keras.layers.Dropout(0.4)) # --> Change dropout rate here.

 model.add(keras.layers.Dense(8, activation="softmax"))

 model.compile(optimizer=tf.train.AdamOptimizer(0.001), # --> Choose learning rate here.
          loss=keras.losses.sparse_categorical_crossentropy,
          metrics=[keras.metrics.categorical_accuracy])

print(" - Model created... - ")
print(" - Model Summary - ")
model.summary() # --> Print model summary.

### TRAIN AND EVALUATE MODEL ###

print(" - Training model... - ")
model.fit(data, labels, epochs = 5, batch_size=32, callbacks=callbacks, validation_data=(val_data, val_labels), verbose = 2)
print(" - Model trained! - ")

### SAVE MODEL AS H5 AND PB FILES ###

model.save("./Keras_Model/model.h5", save_format="h5")
print(" - Saved model as h5. - ")

frozen_graph = freeze_session(K.get_session(), output_names=[out.op.name for out in model.outputs])
tf.train.write_graph(frozen_graph, "./Tensorflow_Model/", "model.pb", as_text=False)
print(" - Saved model as pb. - ")

print(" - Clearing session. - ")
keras.clear_session()

如果可以的话,我还可以提供使用 TF 数据集和评估器的版本,或者其他任何内容。如果我遗漏了任何明显的内容,我深表歉意,我刚刚开始使用 SO。

更新:我昨晚回家并在我的计算机上运行了这个脚本 - 它似乎工作得很清楚,这不是使用问题,但可能是 TF 本身的问题或它在我们服务器上的配置方式的问题。这有点奇怪,因为 TF 之前在某个时刻正在工作,但你能做什么呢?大家干杯。


None

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

tf.keras - 尽管使用 GPU 内存,但第一个时期的训练并未取得进展 的相关文章

随机推荐

  • 使用 R,将可变行数的文本合并到单个文本元素中

    什么样的 R 代码可以将下面的模拟数据框中每个人的叙述条目合并到一个变量中 数据来自 Excel 电子表格 其中记录的叙述条目可以有 1 到 8 行 每个计时员的记录都以空行结束 假设这个数据框 dput 如下 gt df timekeep
  • 提前终止工作人员 puma 日志意味着什么?为什么会发生这种情况?

    对于我的 Elastic Beanstalk 实例 我得到了504每当我访问它时状态代码响应 当我跟踪日志时 我在 puma 应用服务器上看到以下日志 gt var log puma puma log lt 27240 Early term
  • 在 PL/SQL 中将 varchar2 转换为日期 ('MM/DD/YYYY')

    我需要将字符串从 varchar 转换为 MM DD YYYY 格式的日期 我的输入字符串是 4 9 2013 我的预期输出是 04 09 2013 即 2 位月份 2 位日期和 4 位年份 以 分隔 我有以下数据 DOJ varchar2
  • 为什么提供静态文件不安全

    这可能是一个愚蠢的问题 并且有一个明显的答案 但我正在测试 404 和 500 错误处理程序 这意味着我必须将 debug 切换为 False 我进入 Django 管理页面 注意到没有提供静态文件 我知道它们应该通过 Apache 路由
  • iOS 7.1 滑动解锁文字动画

    我不确定以前是否有人问过这个问题 但我很难找到它 也许我没有使用正确的搜索词 所以如果答案已经存在 如果有人能指出我正确的方向 我将不胜感激 我刚刚注意到 随着 iOS 7 1 的更新 锁屏 滑动解锁 文本上的闪烁动画发生了变化 聚光灯现在
  • 什么时候编写“ad hoc sql”与存储过程更好[重复]

    这个问题在这里已经有答案了 我的应用程序中有 100 的即席 SQL 我的一个朋友建议我转换为存储过程以获得额外的性能和安全性 这在我脑海中提出了一个问题 除了速度和安全性之外 还有其他理由坚持使用即席 SQL 查询吗 SQL Server
  • 将 Camunda 嵌入现有 Java 应用程序

    我已经提取了 Camunda 最新映像并在它自己的 docker 容器中运行 Camunda 我有一个 dmn 上传到 Camunda Cockpit 并且我能够进行 Rest 调用以从我上传到 Camunda Cockpit 的决策表中获
  • 错误:访问属性“处理程序”的权限被拒绝

    我有一个 Firefox 的 Greasemonkey 脚本 昨天运行得很好 我今天尝试使用它 没有修改代码 我注意到它停止工作 经过进一步检查 脚本现在抛出以下错误 Error Permission denied to access pr
  • 我可以将 cperl 模式与 perl 模式着色一起使用吗?

    Emacs cperl 模式似乎比 perl 模式更容易混淆 但彩虹糖效应使该东西对我来说无法使用 有谁知道或知道 emacs 块的示例 该示例使 cperl mode 使用 perl mode 的着色 理想情况下以一种足够可读的形式 以便
  • 寻找适合企业网站的轻文本富编辑器,比tinymce更轻,带有用于评论表单的基本按钮

    我正在寻找适合企业网站的轻文本富编辑器 比tinymce更轻 带有用于评论表单的基本按钮 重要的是编辑器也可以在 IE6 中运行 到目前为止 我尝试使用 cleditor 15KB 但当按 enter 键时 IE 出现问题 客户有问题 Jq
  • EmberJS 使用 HasMany 取消(回滚)对象

    假设我有一个 ParentObjecthasMany项目 我想在我的应用程序中实现取消功能Add将回滚所有内容的路线 简而言之 我有 父对象IsNew and IsDirty 并且有可能 项目 也将是IsNew and IsDirty 所以
  • 使用 like 关键字在单个查询中匹配多个标题

    使用 like 关键字在单个查询中匹配多个标题 我正在尝试获取与给定标题匹配的所有记录 下面是数据库的结构请参阅 数据库截图 https prnt sc JduJ6NSIr1E 当我传递单个类似查询时 它返回数据 Query SELECT
  • 聚焦离子输入时有没有办法隐藏键盘?

    我想要一个可以聚焦的离子输入 并且键盘不应该出现 有什么办法或者有可能吗 谢谢你 是的 安装这个插件 gt https ionicframework com docs native keyboard https ionicframework
  • 在Scheme中注释代码

    我正在查看一些代码Scheme from Festival并且似乎无法弄清楚评论 目前 我可以看到 and 用于指示注释行 网络上的其他来源表明上面的一些可能是指示多行注释的方法 我的问题是 有什么区别 and 用于发表评论 什么时候应该使
  • mvc 和 webapi 之间的身份验证(单独的域/应用程序)

    我正在为以下场景寻找好的想法 资源 实现 MVC 网站位于http mywebsite com http mywebsite com Webapi REST 服务位于http myapi com http myapi com 重要信息 请注
  • 如何为 Arduino IDE 安装 openCV 库?

    我正在开发一个使用面部跟踪 对象跟踪 面部识别等的 Arduino 项目 为了实现这一目标 我决定使用 OpenCV 库 然而问题是 我不知道如何安装 Arduino 和处理的 OpenCV 库 谁能告诉我该怎么做 谢谢 如果您使用的是处理
  • 绘制多条路线谷歌地图

    我想根据Google中的路线服务绘制多条路线 代码如下 p s Data是我从json调用中获得的列表 for i 0 i lt 20 i route data i start new google maps LatLng route fr
  • Visual C# Studio 项目中的哪些文件不需要版本控制?

    我是 Visual C Studio 的新手 实际上使用的是 Express 版本 但另一个开发人员正在使用完整版本 并且我们正在使用版本控制 svn 将项目文件添加到存储库对我来说是可以接受的 因为此存储库仅适用于我们两个使用 Visua
  • 为什么我的 vscode 光标在 div 周围显示一个块

    一旦我进入 DIV 标签或任何函数 我的 vscode 就会在这些标签周围显示一个空白框 在此输入图像描述 https i stack imgur com GhhQ2 png我也添加了一张图片 有人可以帮我禁用这个吗 所以我只能看到光标 看
  • tf.keras - 尽管使用 GPU 内存,但第一个时期的训练并未取得进展

    我一直在尝试训练使用 Keras 的 Tensorflow 实现编写的 CNN 看起来训练在到达第一个 epoch 时就陷入了困境 尽管根据 nvidia smi 的说法 我的 GPU 似乎仍在使用内存 也没有错误消息或回溯打印到终端 这使