MultiHeadAttention Attention_mask [Keras、Tensorflow] 示例

2024-05-22

我正在努力掩盖 MultiHeadAttention 层的输入。我正在使用 Keras 文档中的 Transformer Block 进行自我关注。到目前为止,我在网上找不到任何示例代码,如果有人能给我一个代码片段,我将不胜感激。

变压器块来自this https://keras.io/examples/nlp/text_classification_with_transformer/ page:

class TransformerBlock(layers.Layer):
    def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
        super(TransformerBlock, self).__init__()
        self.att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
        self.ffn = keras.Sequential(
            [layers.Dense(ff_dim, activation="relu"), layers.Dense(embed_dim),]
        )
        self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = layers.Dropout(rate)
        self.dropout2 = layers.Dropout(rate)

    def call(self, inputs, training):
        attn_output = self.att(inputs, inputs)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(inputs + attn_output)
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        return self.layernorm2(out1 + ffn_output)

屏蔽的文档可以在下面找到this https://www.tensorflow.org/api_docs/python/tf/keras/layers/MultiHeadAttention link:

Attention_mask:形状为 [B, T, S] 的布尔掩码,可防止 注意某些位置。布尔掩码指定哪个查询 元素可以关注哪些关键元素,1表示关注,0表示关注 表示不注意。可能会针对丢失的批次进行广播 尺寸和头部尺寸。

我唯一可以运行的是在图层类外部创建的掩码作为 numpy 数组:

mask = np.ones((observations, sequence_length, sequence_length))
mask[X[:observations,:,0]==0]=0

然后在调用层时输入,变压器块中唯一的变化是:

def call(self, inputs, mask, training):
    attn_output = self.att(inputs, inputs, attention_mask=mask)

然而,当在拟合时给定batch_size时,这当然不起作用,并且仅适用于我记忆中的5个观察,因此它没有任何意义。 除此之外,我认为这没有正确屏蔽输入 - 一般来说,考虑到注意力掩码的形状(观察、序列长度、序列长度),我对如何屏蔽感到非常困惑。我的输入的形状是(观察、序列长度、特征)。该输入被零填充,但是,当涉及到转换器块时,它已经通过了嵌入层和 CNN。 我尝试了各种方法来编写函数,该函数在使用不同的 Tensor 或 Keras 对象进行训练时创建掩码。然而我每次都会遇到错误。

我希望更熟悉 Tensorflow/Keras 的人能够提供一个例子。 或者有人告诉我,考虑到我的架构,屏蔽是没有用的。该模型表现良好。然而,我希望屏蔽可以帮助加快计算速度。 但令我烦恼的是我无法理解它。


也许有点晚了,但对于任何最终在这篇文章中寻找解决方案的人来说,这可能会有所帮助。

使用 Transformer 的典型场景是在 NLP 问题中,其中有成批的句子(为了简单起见,我们假设它们已经被标记化)。考虑以下示例:

sentences = [['Lorem', 'ipsum', 'dolor', 'sit', 'amet'], ['Integer', 'tincidunt', 'in', 'arcu', 'nec', 'fringilla', 'suscipit']]

如您所见,我们有两个长度不同的句子。为了在张量流模型中学习它们,我们可以用一个特殊的标记填充最短的一个,比方说'[PAD]',然后按照您的建议将它们输入到 Transformer 模型中。因此:

sentences = tf.constant([['Lorem', 'ipsum', 'dolor', 'sit', 'amet', '[PAD]', '[PAD]'], ['Integer', 'tincidunt', 'in', 'arcu', 'nec', 'fringilla', 'suscipit']])

还假设我们已经有了从某些语料库中提取的标记词汇表,例如词汇表1000令牌,我们可以定义一个StringLookup将我们的一批句子转换为给定词汇的数字投影的层。我们可以指定使用哪个令牌masking.

lookup = tf.keras.layers.StringLookup(vocabulary=vocabulary, mask_token='[PAD]')
x = lookup(sentences)
# x is a tf.Tensor([[2, 150, 19, 997, 9, 0, 0], [72, 14, 1, 1, 960, 58, 87]], shape=(2, 7), dtype=int64)

我们可以看到[PAD]令牌映射到0词汇中的价值。

典型的下一步是将这个张量输入到Embedding层,像这样:

embedding = tf.keras.layers.Embedding(input_dim=lookup.vocabulary_size(), output_dim=64, mask_zero=True)

这里的关键是论证mask_zero。根据文档 https://www.tensorflow.org/api_docs/python/tf/keras/layers/Embedding,这个论证的意思是:

布尔值,输入值 0 是否是一个特殊的“填充”值,应该被屏蔽掉......

这允许embedding层为后续层生成一个掩码,以指示哪些位置应该出现,哪些位置不应该出现。该掩码可以通过以下方式访问:

mask = embedding.compute_mask(sentences)
# mask is a tf.Tensor([[True, True, True, True, True, False, False], [True, True, True, True, True, True, True]], shape=(2, 7), dtype=bool)

嵌入的张量的形式为:

y = embedding(sentences)
# y is a tf.Tensor of shape=(2, 7, 64), dtype=float32)

为了使用mask进入MultiHeadAttention层,必须重新调整掩模的形状才能满足形状要求,根据文档是[B, T, S] where B意味着批量大小(示例中为 2),T意味着查询大小(在我们的示例中为 7),以及S意味着key size(如果我们使用自注意力,则再次为 7)。同样在多头注意力层中,我们必须注意头的数量H。使用此输入创建兼容掩码的最简单方法是通过广播:

mask = mask[:, tf.newaxis, tf.newaxis, :]
# mask is a tf.Tensor of shape=(2, 1, 1, 7), dtype=bool) -> [B, H, T, S]

然后我们终于可以喂食了MultiHeadAttention层如下:

mha = tf.keras.layers.MultiHeadAttention(num_heads=4, key_dim=64)
z = mha(y, y, attention_mask=mask)

所以为了使用,你的TransformerBlock带有遮罩的图层,您应该添加到call方法一mask论证,如下:

def call(self, inputs, training, mask=None):
    attn_output = self.att(inputs, inputs, attention_mask=mask)
    ...

在您调用的层/模型中MultiHeadAttention层,您必须传递/传播使用生成的掩码Embedding layer.

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

MultiHeadAttention Attention_mask [Keras、Tensorflow] 示例 的相关文章

  • 为什么 scikit learn 的平均精度分数返回 nan?

    我的 Keras 模型旨在接收两个输入时间序列 将它们连接起来 通过 LSTM 提供它们 并在下一个时间步骤中进行多标签预测 有 50 个训练样本 每个样本有 24 个时间步 每个样本有 5625 个标签 有 12 个验证样本 每个样本有
  • AttributeError:模块“keras.engine”没有属性“Layer”

    当我试图运行时Parking Slot mask rcnn py文件我收到如下错误mrcnn model py文件我该如何解决 gt 2021 06 17 08 25 18 585897 W tensorflow stream execut
  • 从 Keras 检查点加载

    我正在 Keras 中训练一个模型 我使用以下代码保存了所有内容 filepath project model hdh5 checkpoint ModelCheckpoint project model hdf5 monitor loss
  • 池化与随时间池化

    我从概念上理解最大 总和池中发生的情况作为 CNN 层操作 但我看到这个术语 随时间变化的最大池 或 随时间变化的总和池 例如 用于句子分类的卷积神经网络 https arxiv org pdf 1408 5882 pdfYoon Kim
  • 如何使用DecisionTreeClassifier平衡分类?

    我有一个数据集 其中类别不平衡 课程是0 1 or 2 如何计算每个类别的预测误差然后重新平衡weights相应地在 scikit learn 中 如果您想完全平衡 将每个类别视为同等重要 您可以简单地通过class weight bala
  • conv1D 中形状的尺寸

    我尝试过构建一个只有一层的 CNN 但遇到了一些问题 事实上 编译器告诉我 ValueError 检查模型输入时出错 预期的 conv1d 1 input 具有 3 个维度 但得到形状为 569 30 的数组 这是代码 import num
  • Tensorflow新Op CUDA内核内存管理

    我已经使用 GPU CUDA 内核在 Tensorflow 中实现了一个相当复杂的新 Op 该操作需要大量动态内存分配 这些变量不是张量 并且在操作完成后被释放 更具体地说 它涉及使用哈希表 现在我正在使用cudaMalloc and cu
  • 提高SVM分类器准确率的技术

    我正在尝试使用 UCI 数据集构建一个分类器来预测乳腺癌 我正在使用支持向量机 尽管我尽最大努力提高分类器的准确性 但仍无法超过 97 062 我尝试过以下方法 1 Finding the most optimal C and gamma
  • 理解高斯混合模型的概念

    我试图通过阅读在线资源来理解 GMM 我已经使用 K 均值实现了聚类 并且正在了解 GMM 与 K 均值的比较 以下是我的理解 如有错误请指出 GMM 类似于 KNN 在这两种情况下都实现了聚类 但在 GMM 中 每个簇都有自己独立的均值和
  • 无需安装 Tensorflow 即可服务 Tensorflow 模型

    我有一个经过训练的模型 想在 python 应用程序中使用 但我看不到任何在不安装 TensorFlow 或创建 gRPC 服务的情况下部署到生产环境的示例 有可能吗 在这种情况下 正确的做法是什么 如果不使用 TensorFlow 本身或
  • 如何在 python 中使用 libSVM 计算精度、召回率和 F 分数

    我想计算precision recall and f score using libsvm在Python中 但我不知道如何 我已经发现这个网站 http www csie ntu edu tw cjlin libsvmtools eval
  • 预处理 csv 文件以与 tflearn 一起使用

    我的问题是关于在将 csv 文件输入神经网络之前对其进行预处理 我想使用 python 3 中的 tflearn 为著名的 iris 数据集构建一个深度神经网络 数据集 http archive ics uci edu ml machine
  • 朴素贝叶斯分类器仅基于先验概率做出决策

    我试图根据推文的情绪将推文分为三类 买入 持有 卖出 我正在使用 R 和包 e1071 我有两个数据框 一个训练集和一组需要预测情绪的新推文 训练集数据框 text sentiment this stock is a good buy Bu
  • keras加载模型错误尝试将包含17层的权重文件加载到0层的模型中

    我目前正在使用 keras 开发 vgg16 模型 我用我的一些图层微调 vgg 模型 拟合我的模型 训练 后 我保存我的模型model save name h5 可以毫无问题地保存 但是 当我尝试使用以下命令重新加载模型时load mod
  • 交换keras中的张量轴

    我想将图像批次的张量轴从 batch size row col ch 交换为 批次大小 通道 行 列 在 numpy 中 这可以通过以下方式完成 X batch np moveaxis X batch 3 1 我该如何在 Keras 中做到
  • 如何在 py_function 之后重塑(图像,标签)数据集

    我正在尝试读取自定义映射数据集进行训练 但是在使用 py function 映射数据集后 我得到了未知的形状 例如 def process path file path label get label file path img tf io
  • 如何解释tf.map_fn的结果?

    看代码 import tensorflow as tf import numpy as np elems tf ones 1 2 3 dtype tf int64 alternates tf map fn lambda x x x x el
  • NotImplementedError:无法将符号张量 (lstm_2/strided_slice:0) 转换为 numpy 数组。时间

    张量流版本 2 3 1 numpy 版本 1 20 在代码下面 define model model Sequential model add LSTM 50 activation relu input shape n steps n fe
  • 无法使用 tf.data.Dataset 对组件 0 中具有不同形状的张量进行批处理

    我的输入管道中出现以下错误 tensorflow python framework errors impl InvalidArgumentError 不能 分量 0 中具有不同形状的批量张量 第一个元素有 形状为 2 48 48 3 元素
  • Keras CNN 回归模型损失低,准确度为 0

    我在 keras 中遇到这个 NN 回归模型的问题 我正在研究一个汽车数据集 以根据 13 个维度预测价格 简而言之 我已将其读取为 pandas 数据帧 将数值转换为浮点数 缩放值 然后对分类值使用 one hot 编码 这创建了很多新列

随机推荐

  • 节省页面加载时间的提示[重复]

    这个问题在这里已经有答案了 我的问题 削减那些不必要的 kb 并使页面加载速度更快的最佳方法是什么 全部是什么优化实践 编码实践 在js php中 如果执行可以使您的页面更轻 为什么我问这个 我读了这篇关于 jquery js 与 jque
  • 将 2D 数组映射到 1D 数组

    我想用一维数组来表示一个二维数组 函数将传递两个索引 x y 和要存储的值 这两个索引代表一维数组的单个元素 并相应地设置它 我知道一维数组需要具有 arrayWidth arrayHeight 的大小 但我不知道如何设置每个元素 例如 如
  • 测试 hdf5/c++ 中的组是否存在

    我正在打开一个现有的 HDF5 文件来附加数据 我想向那个叫做的小组保证 A存在以供后续访问 我正在寻找一种简单的方法来创建 A有条件地 如果不存在则创建并返回新组 或者返回现有组 一种方法是测试 A存在 我怎样才能高效地做到这一点 根据
  • 无法将中间件与 Firebase 和 NuxtJS 3 一起使用

    我正在尝试在示例项目中使用 Firebase 身份验证 身份验证按预期工作 但是一旦我想使用中间件来阻止用户访问管理页面或在已经登录的情况下访问登录页面 这是不可能的 我已经尝试了几个小时 但没有任何效果 这是我的package json
  • UITableView行高不变

    我创建了一个自定义单元格 我有一系列字典 对于我需要创建的字典值UILables 每个单元可能包含不同数量的UILabels 所以按照我的习惯UITableViewCell类我就是这样做的 void generateCell BOOL is
  • 如何在 C# 事件中区分更改是由代码还是由用户进行?

    我有一个简单的TextBox一开始是空的 我有一个简单的事件 TextChanged 可以知道用户何时更改了其中的任何内容TextBox 但是 如果我自己在代码中对其执行任何操作 该事件就会触发 喜欢设置textbox Text Test
  • 处理 LINQ sum 表达式中的 null

    我正在使用 LINQ 查询来查找列的总和 并且在少数情况下该值有可能为空 我现在使用的查询是 int score dbContext domainmaps Where p gt p SchoolId schoolid Sum v gt v
  • Groupby Sum 忽略几列

    在此数据框中 我想按 位置 进行分组并获得 分数 的总和 但我不希望 纬度 经度 和 年份 在此过程中受到影响 sample pd DataFrame Location A B C A B C Year 2001 2002 2003 200
  • 是否可以访问可执行 JAR 之外的 SQLite 数据库文件?

    我有一个作为可执行 JAR 文件部署的应用程序 最初 这个 JAR 文件将与 MySQL 数据库通信 但最近我决定改用 SQLite 然而 在测试时我发现从 JAR 文件运行应用程序时无法访问 SQLite 数据库文件 我使用来自以下网站的
  • 错误代码:1062。重复条目“PRIMARY”

    因此 我的教授给了我表格将其插入数据库 但是当我执行他的代码时 MySQL 不断给出错误代码 1062 这是冲突表和插入 TABLES CREATE TABLE FABRICANTES COD FABRICANTE integer NOT
  • 为什么要为字符变化类型指定长度

    参考 Postgres 文档字符类型 http www postgresql org docs current static datatype character html 我不清楚指定字符变化 varchar 类型的长度 假设 字符串的长
  • 如何使用 MPMusicPlayerController 播放音乐?

    任何人都可以建议我如何在我的应用程序中使用 MPMusicPlayerController 播放音乐 任何人的帮助将不胜感激 谢谢你 莫尼什 创建一个MPMediaPickerController这样你就可以从 iPod 中选择一些音乐 然
  • 在 Spring Boot 异常处理期间保留自定义 MDC 属性

    简短版本 有足够的细节 如何保留添加在MDC中的属性doFilter 的方法javax servlet Filter执行 public void doFilter ServletRequest request ServletResponse
  • .net 运行时 - Silverlight 运行时 =?

    我用 google 搜索了一下 但没能找到 net CLR 中的哪些类未包含在 CoreCLR 又名 Silverlight 中的详细列表 Windows net Framework 中缺少什么 Silverlight 另外 是否存在 Si
  • 通过递归扩展 Prolog 目标?

    我 最终 实现了一些目标 这些目标将根据开始由 开始之后 and duration 然而 计划目标仅接受规定数量的任务 我想扩展计划目标的功能以接受单个列表并在计划时迭代该列表 不幸的是 我认为这将需要与can run and 冲突目标如下
  • iOS设备和iPhone模拟器内存​​组织的差异

    我正在尝试使用 Xcode 4 3 3 和 iPhone 5 1 模拟器开发一个应用程序 当我在模拟器上运行这个应用程序时 我没有收到任何警告 并且它运行得很好 但是 当我尝试在 iOS 设备上执行此操作时 我收到一条警告消息 收到内存警告
  • ES6 Promises/在满足多个 Promise 后调用函数(不能使用 Promises.all)[重复]

    这个问题在这里已经有答案了 我正在编写 Javascript 它需要这些事件按以下顺序发生 同时触发多个 API 调用 所有调用完成且响应返回后 执行一行代码 听起来很简单 但棘手的部分是我不能使用 Promises all 因为我仍然希望
  • 未捕获的类型错误:未定义不是函数

    我收到消息Uncaught TypeError Undefined is not a function当我尝试调用家庭控制器中的方法时 也许关于我为什么收到此消息的建议 findIdpActivities function pernr ca
  • 对 data.table 进行子集化的最快方法是什么?

    在我看来 这是执行行 列子集的最快方法data table是使用 join 和nomatch option 它是否正确 DT data table rep 1 100 100000 rep 1 10 1000000 setkey DT V1
  • MultiHeadAttention Attention_mask [Keras、Tensorflow] 示例

    我正在努力掩盖 MultiHeadAttention 层的输入 我正在使用 Keras 文档中的 Transformer Block 进行自我关注 到目前为止 我在网上找不到任何示例代码 如果有人能给我一个代码片段 我将不胜感激 变压器块来