如何获取基于Keras的LSTM模型中每个epoch的一层权重矩阵?

2024-04-08

我有一个基于 Keras 的简单 LSTM 模型。

X_train, X_test, Y_train, Y_test = train_test_split(input, labels, test_size=0.2, random_state=i*10)

X_train = X_train.reshape(80,112,12)
X_test = X_test.reshape(20,112,12)

y_train = np.zeros((80,112),dtype='int')
y_test = np.zeros((20,112),dtype='int')

y_train = np.repeat(Y_train,112, axis=1)
y_test = np.repeat(Y_test,112, axis=1)
np.random.seed(1)

# create the model
model = Sequential()
batch_size = 20

model.add(BatchNormalization(input_shape=(112,12), mode = 0, axis = 2))#4
model.add(LSTM(100, return_sequences=False, input_shape=(112,12))) #7 

model.add(Dense(112, activation='hard_sigmoid'))#9
model.compile(loss='binary_crossentropy', optimizer='RMSprop', metrics=['binary_accuracy'])#9

model.fit(X_train, y_train, nb_epoch=30)#9

# Final evaluation of the model
scores = model.evaluate(X_test, y_test, batch_size = batch_size, verbose=0)

我知道如何获取体重清单model.get_weights(),但这是模型完全训练后的值。我想获得每个时期的权重矩阵(例如,LSTM 中的最后一层),而不仅仅是它的最终值。换句话说,我有 30 个 epoch,需要获得 30 个权重矩阵值。

真的谢谢你,我在keras的wiki上没有找到解决方案。


您可以为其编写自定义回调:

from keras.callbacks import Callback

class CollectWeightCallback(Callback):
    def __init__(self, layer_index):
        super(CollectWeightCallback, self).__init__()
        self.layer_index = layer_index
        self.weights = []

    def on_epoch_end(self, epoch, logs=None):
        layer = self.model.layers[self.layer_index]
        self.weights.append(layer.get_weights())

属性self.model回调的引用是对正在训练的模型的引用。它是通过设置Callback.set_model()当训练开始时。

要获取每个时期最后一层的权重,请将其与以下命令一起使用:

cbk = CollectWeightCallback(layer_index=-1)
model.fit(X_train, y_train, nb_epoch=30, callbacks=[cbk])

然后权重矩阵将被收集到cbk.weights.

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何获取基于Keras的LSTM模型中每个epoch的一层权重矩阵? 的相关文章

随机推荐

  • 如何更新 Haskell Map 中的项目?

    我是 Haskell 的新手 正在尝试找出一个合理的方法 写入地图的方式 为解决特定问题做准备 欧拉工程问题 我希望写一个函数来填充 带有记录的地图 但我无法让它发挥作用 let似乎创建局部变量而不是治疗smap作为一个全球性的 一定有某种
  • 使用javamail连接到hotmail?

    我想知道是否可以使用JavaMail 连接到Hotmail 我已经尝试过 但它不起作用 连接被拒绝 String host pop3 live com String username email protected cdn cgi l em
  • 自定义 QML 模块部署到 Android:缺少 QML 依赖项

    我正在开发一个包含一些特殊类型的自定义 QML 模块 我们称之为 MyModule 它用作其他应用程序项目的预编译库 即源代码对它们不可用 它通过 import MyModule 1 0 设置必要的导入路径等来使用 该模块包含基于 C 的
  • 过渡导航栏标题

    在一个名为 Luvocracy 的应用程序中 当用户在屏幕上向上滑动时 导航栏的标题会发生变化 旧标题被推高 而新标题则过渡进来 我现在没有视频 但这里有一些屏幕截图 https www dropbox com s sns0bsxkdv7p
  • 安全存储客户端敏感数据

    背景故事我在一家中小型公司工作 我们正在重新设计面向客户的会计门户 我的经理希望使用存储在最终用户计算机上的 cookie 中的信用卡信息来进行单击付款选项 我根本不喜欢这个想法 事实上我仍在努力改变他的想法 话虽这么说 我正在努力使其尽可
  • 如何在 alpine.js 应用程序中制作具有时间间隔的计时器

    使用 alpine js 2 我尝试在应用程序的页脚 为所有布局设置 中定义计时器 div div div span style background color yellow span div div div
  • 如何查看浏览器请求?

    我正在与另一位程序员合作 他最近向我发送了一个新的基于 JSON 的 API 来工作 他说我可以通过访问特定网站并查看浏览器请求来查看所有 API 调用的示例 我的问题是 如何查看我的浏览器请求 我之前曾使用 Wireshark 来分析我的
  • JAVA ANDROID - 获取文件夹的文件列表

    我想在一个文件夹中显示不同文件夹的图像GridView 但我不知道需要做什么才能获取包含可绘制文件夹内的文件名称的列表 此方法将为您提供一个包含 dir 文件夹列表以及子文件夹列表的列表 public void listf String d
  • 从调用者类停止异步 Spring 方法

    我有一个类调用 Rest Web 服务来从服务器接收文件 在传输字节时 我创建了一个异步任务 它检查与服务器的连接是否正常 以便在出现错误时允许停止连接 这个异步任务有一个我必须停止的循环 Component public class Co
  • POJO 反序列化期间忽略 @JsonTypeInfo 属性

    我使用 JsonTypeInfo 指示 Jackson 2 1 0 在 鉴别器 属性中查找具体类型信息 这很有效 但在反序列化期间 鉴别器属性没有设置到 POJO 中 根据 Jackon 的 Javadoc com fasterxml ja
  • CMake 未从 conan 中找到 boost 库

    所以我试图让我的 cmake 与 conan boost 一起工作 为此 我有一个简单的柯南文件 from conans import ConanFile class Boost Conan Cmake MinimalConfig Cona
  • 迭代Python字典并特殊追加到新列表?

    我想迭代字典 并将按其值 频率 重复的每个键 字母 附加到新列表中 例如 输入 A 1 B 2 预期输出 A B B 我正在做的事情不起作用 我应该在函数中写什么来做到这一点 def get freq dict freq dict J 1
  • 批处理文件中的文件时间超过 4 分钟

    我正在使用这个脚本来计算文件的时间 set filename myfile txt rem extract current date and time for f tokens 1 5 delims a in date time do se
  • 反应路由器路径内的问号

    我正在尝试在 URL 中传递参数 但读取时遇到问题 我正在使用反应路由器 v4 URL http localhost 3000 reset token 123 http localhost 3000 reset token 123我试着这样
  • 带有react-router-dom NavLinks的react-bootstrap导航栏中的collapseOnSelect

    我正在制作一个网站 在其中使用 React router dom NavLink 组件来防止单页面应用程序体验的重新渲染 当我试图使网站响应时 我一直在尝试使响应式导航栏在选择 NavLink 后从 React Bootstrap 折叠 但
  • 插入 Base64 图像作为 WordPress 帖子附件

    我正在将画布转换为 base64 png 图像 现在我想将此图像添加为帖子附件 这是我在服务器端的图像 image base64 decode preg replace data image w base64 i data pdf thum
  • 使用php脚本在csv文件中插入图像[重复]

    这个问题在这里已经有答案了 可能的重复 PHP代码可以将图像插入Excel文件并在MS Excel中正确打开它吗 https stackoverflow com questions 11337142 php code can insert
  • 函数 sum 无法正常工作 javascript [重复]

    这个问题在这里已经有答案了 当我添加两个字段时 代码不会对其进行求和 而是添加我输入的内容 这是我的代码 function calculate loan var amountBorrowed document form amountBorr
  • ElasticSearch 脚本:检查数组是否包含值

    假设我创建了一个这样的文档 PUT idx type 1 the field 1 2 3 我可以使用 GET idx type 1 检索我的文档 index idx type type id 1 version 1 found true s
  • 如何获取基于Keras的LSTM模型中每个epoch的一层权重矩阵?

    我有一个基于 Keras 的简单 LSTM 模型 X train X test Y train Y test train test split input labels test size 0 2 random state i 10 X t