有没有一种简单的方法来扩展现有的激活函数?我的自定义 softmax 函数返回: 操作具有“无”梯度

2024-02-27

我想通过仅使用向量中的前 k 个值来实现使 softmax 更快的尝试。

为此,我尝试为张量流实现一个自定义函数以在模型中使用:

def softmax_top_k(logits, k=10):
    values, indices = tf.nn.top_k(logits, k, sorted=False)
    softmax = tf.nn.softmax(values)
    logits_shape = tf.shape(logits)
    return_value = tf.sparse_to_dense(indices, logits_shape, softmax)
    return_value = tf.convert_to_tensor(return_value, dtype=logits.dtype, name=logits.name)
    return return_value

我正在使用 Fashion mnist 来测试这种尝试是否有效:

fashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

# normalize the data
train_images = train_images / 255.0
test_images = test_images / 255.0

# split the training data into train and validate arrays (will be used later)
train_images, train_images_validate, train_labels, train_labels_validate = train_test_split(
    train_images, train_labels, test_size=0.2, random_state=133742,
)

model = keras.models.Sequential([
    keras.layers.Flatten(input_shape=(28, 28)),
    keras.layers.Dense(128, activation=tf.nn.relu),
    keras.layers.Dense(10, activation=softmax_top_k)
])


model.compile(
    loss='sparse_categorical_crossentropy',
    optimizer='adam',
    metrics=['accuracy']
)

model.fit(
    train_images, train_labels,
    epochs=10,
    validation_data=(train_images_validate, train_labels_validate),
)

model_without_cnn.compile(
    loss='sparse_categorical_crossentropy',
    optimizer='adam',
    metrics=['accuracy']
)

model_without_cnn.fit(
    train_images, train_labels,
    epochs=10,
    validation_data=(train_images_validate, train_labels_validate),
)

但执行过程中出现错误:

ValueError: An operation hasNonefor gradient. Please make sure that all of your ops have a gradient defined (i.e. are differentiable).

我发现了this:(如何制作自定义激活函数) https://stackoverflow.com/questions/39921607/how-to-make-a-custom-activation-function-with-only-python-in-tensorflow#39921608,它解释了如何为tensorflow实现完全自定义的激活函数。但由于它使用并扩展了softmax,我认为梯度应该仍然是相同的。

这是我使用 Python 和 Tensorflow 进行编码的第一周,因此我还没有对所有内部实现有一个很好的概述。

有没有更简单的方法将 softmax 扩展到新函数,而不是从头开始实现?

提前致谢!


不要使用稀疏张量来使张量“除 softmaxed 前 K 值外全部为零”,而是使用tf.scatter_nd https://www.tensorflow.org/api_docs/python/tf/scatter_nd:

import tensorflow as tf

def softmax_top_k(logits, k=10):
    values, indices = tf.nn.top_k(logits, k, sorted=False)
    softmax = tf.nn.softmax(values)
    logits_shape = tf.shape(logits)
    # Assuming that logits is 2D
    rows = tf.tile(tf.expand_dims(tf.range(logits_shape[0]), 1), [1, k])
    scatter_idx = tf.stack([rows, indices], axis=-1)
    return tf.scatter_nd(scatter_idx, softmax, logits_shape)

编辑:这是具有任意维数的张量的稍微复杂的版本。不过,代码仍然要求在图构建时已知维数。

import tensorflow as tf

def softmax_top_k(logits, k=10):
    values, indices = tf.nn.top_k(logits, k, sorted=False)
    softmax = tf.nn.softmax(values)
    # Make nd indices
    logits_shape = tf.shape(logits)
    dims = [tf.range(logits_shape[i]) for i in range(logits_shape.shape.num_elements() - 1)]
    grid = tf.meshgrid(*dims, tf.range(k), indexing='ij')
    scatter_idx = tf.stack(grid[:-1] + [indices], axis=-1)
    return tf.scatter_nd(scatter_idx, softmax, logits_shape)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

有没有一种简单的方法来扩展现有的激活函数?我的自定义 softmax 函数返回: 操作具有“无”梯度 的相关文章

  • Django 管理员在模型编辑时间歇性返回 404

    我们使用 Django Admin 来维护导出到我们的一些站点的一些数据 有时 当单击标准更改列表视图来获取模型编辑表单而不是路由到正确的页面时 我们会得到 Django 404 页面 模板 它是偶尔发生的 我们可以通过重新加载三次来重现它
  • 将 Matplotlib 误差线放置在不位于条形中心的位置

    我正在 Matplotlib 中生成带有错误栏的堆积条形图 不幸的是 某些层相对较小且数据多样 因此多个层的错误条可能重叠 从而使它们难以或无法读取 Example 有没有办法设置每个误差条的位置 即沿 x 轴移动它 以便重叠的线显示在彼此
  • OpenCV Python cv2.mixChannels()

    我试图将其从 C 转换为 Python 但它给出了不同的色调结果 In C Transform it to HSV cvtColor src hsv CV BGR2HSV Use only the Hue value hue create
  • 使用 matplotlib 绘制时间序列数据并仅在年初显示年份

    rcParams date autoformatter month b n Y 我正在使用 matpltolib 来绘制时间序列 如果我按上述方式设置 rcParams 则生成的图会在每个刻度处标记月份名称和年份 我怎样才能将其设置为仅在每
  • 如何使用 Ansible playbook 中的 service_facts 模块检查服务是否存在且未安装在服务器中?

    我用过service facts检查服务是否正在运行并启用 在某些服务器中 未安装特定的软件包 现在 我如何知道这个特定的软件包没有安装在该特定的服务器上service facts module 在 Ansible 剧本中 它显示以下错误
  • 如何替换 pandas 数据框列中的重音符号

    我有一个数据框dataSwiss其中包含瑞士城市的信息 我想用普通字母替换带有重音符号的字母 这就是我正在做的 dataSwiss Municipality dataSwiss Municipality str encode utf 8 d
  • 根据列值突出显示数据框中的行?

    假设我有这样的数据框 col1 col2 col3 col4 0 A A 1 pass 2 1 A A 2 pass 4 2 A A 1 fail 4 3 A A 1 fail 5 4 A A 1 pass 3 5 A A 2 fail 2
  • Python 函数可以从作用域之外赋予新属性吗?

    我不知道你可以这样做 def tom print tom s locals locals def dick z print z name z name z guest Harry print z guest z guest print di
  • BeautifulSoup 中的嵌套标签 - Python

    我在网站和 stackoverflow 上查看了许多示例 但找不到解决我的问题的通用解决方案 我正在处理一个非常混乱的网站 我想抓取一些数据 标记看起来像这样 table tbody tr tr tr td td td table tr t
  • Python 的“zip”内置函数的 Ruby 等价物是什么?

    Ruby 是否有与 Python 内置函数等效的东西zip功能 如果不是 做同样事情的简洁方法是什么 一些背景信息 当我试图找到一种干净的方法来进行涉及两个数组的检查时 出现了这个问题 如果我有zip 我可以写这样的东西 zip a b a
  • 在f字符串中转义字符[重复]

    这个问题在这里已经有答案了 我遇到了以下问题f string gt gt gt a hello how to print hello gt gt gt f a a gt gt gt f a File
  • python获取上传/下载速度

    我想在我的计算机上监控上传和下载速度 一个名为 conky 的程序已经在 conky conf 中执行了以下操作 Connection quality alignr wireless link qual perc wlan0 downspe
  • 无法在 Python 3 中导入 cProfile

    我试图将 cProfile 模块导入 Python 3 3 0 但出现以下错误 Traceback most recent call last File
  • 使用 \r 并打印一些文本后如何清除控制台中的一行?

    对于我当前的项目 有一些代码很慢并且我无法使其更快 为了获得一些关于已完成 必须完成多少的反馈 我创建了一个进度片段 您可以在下面看到 当你看到最后一行时 sys stdout write r100 80 n I use 80覆盖最终剩余的
  • Pandas:merge_asof() 对多行求和/不重复

    我正在处理两个数据集 每个数据集具有不同的关联日期 我想合并它们 但因为日期不完全匹配 我相信merge asof 是最好的方法 然而 有两件事发生merge asof 不理想的 数字重复 数字丢失 以下代码是一个示例 df a pd Da
  • 如何在Python中对类别进行加权随机抽样

    给定一个元组列表 其中每个元组都包含一个概率和一个项目 我想根据其概率对项目进行采样 例如 给出列表 3 a 4 b 3 c 我想在 40 的时间内对 b 进行采样 在 python 中执行此操作的规范方法是什么 我查看了 random 模
  • 为字典中的一个键附加多个值[重复]

    这个问题在这里已经有答案了 我是 python 新手 我有每年的年份和值列表 我想要做的是检查字典中是否已存在该年份 如果存在 则将该值附加到特定键的值列表中 例如 我有一个年份列表 并且每年都有一个值 2010 2 2009 4 1989
  • 如何计算 pandas 数据帧上的连续有序值

    我试图从给定的数据帧中获取连续 0 值的最大计数 其中包含来自 pandas 数据帧的 id date value 列 如下所示 id date value 354 2019 03 01 0 354 2019 03 02 0 354 201
  • 在 Qt 中自动调整标签文本大小 - 奇怪的行为

    在 Qt 中 我有一个复合小部件 它由排列在 QBoxLayouts 内的多个 QLabels 组成 当小部件调整大小时 我希望标签文本缩放以填充标签区域 并且我已经在 resizeEvent 中实现了文本大小的调整 这可行 但似乎发生了某
  • Python 类继承 - 诡异的动作

    我观察到类继承有一个奇怪的效果 对于我正在处理的项目 我正在创建一个类来充当另一个模块的类的包装器 我正在使用第 3 方 aeidon 模块 用于操作字幕文件 但问题可能不太具体 以下是您通常如何使用该模块 project aeidon P

随机推荐

  • 使用单向多对多映射进行删除级联

    我正在使用 Fluent 和 NHibernate 我有两个对象 A 和 B 它们之间具有多对多关系 当 A HasMany B 时 我使用单向多对多映射 B中没有关于A 单向 的参考 这会在数据库中创建第三个表 名为 ABMapping
  • 将日期时间插入 SQLite 数据库

    我试图将时间插入数据库 但是当我打印插入的时间时 它不正确 我在将时间变量插入数据库之前打印了它 它是 12 01 09 149059 我用的时候效果很好strftime但我换了 因为时间已经到了 from datetime import
  • Bootstrap 轮播显示下一张和上一张图像

    引导程序轮播是否可扩展以在滑块中显示下一个和上一个图像 div class carousel slide ol class carousel indicators li class active li li li ol div
  • React Native Android:方法不会覆盖或实现超类型的方法

    我已经添加react native fbsdk到我的 React Native 项目并让它在 iOS 上正常构建 但在android方面 我无法让gradle来构建项目 当尝试编译react native fbsdk时 我遇到了 方法不会覆
  • eclipse 4 RCP 应用程序中启动屏幕上的进度条

    我想在 Eclipse 4 RCP 初始屏幕上添加一个进度条 我已经尝试了以下代码和设置 但仍然无法获取进度条 org eclipse ui SHOW PROGRESS ON STARTUP true 在plugin customizati
  • 将大对象插入 Postgresql 返回 53200 Out of Memory 错误

    PostgreSQL 9 1 NPGSQL 2 0 12 我想将二进制数据存储在 postgresql 数据库中 大多数文件加载良好 但是 大型二进制 664 Mb 文件会导致问题 当尝试通过 Npgsql 使用大对象支持将文件加载到 po
  • DropDownList 的 MVC2 编辑器模板

    过去一周的大部分时间我都在深入研究 MVC2 中的新模板功能 我很难让 DropDownList 模板正常工作 我一直在努力解决的最大问题是如何将下拉列表的源数据获取到模板 我看到很多示例 您可以将源数据放入 ViewData 字典 Vie
  • LoadError: 无法加载此类文件 -- capybara 独立代码

    我正在使用 Ruby 和以下教程构建一个简单的后挖矿程序 http ngauthier com 2014 06 scraping the web with ruby html http ngauthier com 2014 06 scrap
  • 自定义 Spring Bean 参数

    我正在使用 activator 上发布的 Spring Akka 示例来创建 Spring 托管 bean actor 这是我当前使用的代码 包括演示类 Component class Test extends UntypedActor A
  • 检查 Asp.Net(Core) 应用程序是否托管在 IIS 中

    如何检查应用程序是否托管在 IIS 中 检查环境变量 APP POOL ID 是否设置 public static bool InsideIIS gt System Environment GetEnvironmentVariable AP
  • Android 应用程序基 64 公钥

    如何获取 或查看 Android 应用程序 Base 64 公钥 我有许可证文件 并且我之前已经发布过我的应用程序 我需要许可密钥 要查找您的应用程序的公共许可密钥 请执行以下步骤 1 登录您发布应用的 Google Play 开发者控制台
  • 理解主定理

    通用形式 T n aT n b f n 所以我必须将 n logb a 与 f n 进行比较 if n logba gt f n is case 1 and T n n logb a if n logba lt f n is case 2
  • Symbian 的不同版本

    我必须在 Symbian 中构建一个项目 我有一些困惑 并有一些与 Symbian 版本相关的问题 有什么区别Symbian 3 S60 3rd edition and S60 5th edition 从编码角度来看 与 Symbian 3
  • 为同一轴上的抽动设置不同的颜色

    是否可以在同一轴上使用不同颜色或样式的抽动 tics 0 1 1 5 2我想要0和2有色red or bold 非常适合multiplots其中有关于相同测量值的图 并且您希望在不同的图中标记 y 或 x 范围 但又不会使其过载太多 现在对
  • 只比较时间,不比较日期?

    我需要编写一个方法来检查是否Time now位于商店的营业时间和打烊时间之间 营业时间和营业时间被保存为 Time 对象 但我无法直接比较它 因为商店将营业时间保存在2012 2 2所以开放时间大概是这样的 2012 02 02 02 30
  • Passport.js 中的本地和 Google 策略:序列化用户时出现问题

    我一直试图理解为什么即使身份验证本身正在工作 我也无法让用户在经过身份验证后保持登录状态 我什至在这里发布了一个问题 Passport js 本地策略未进行身份验证 https stackoverflow com questions 515
  • Google 服务的 Android 版本冲突

    我已经为此搜索了很多解决方案 但没有一个适合我的具体情况 我在 Gradle Sync 上收到此错误 错误 任务 app processDebugGoogleServices 执行失败 请通过更新 google services 插件的版本
  • 如何直观地显示 java ResultSet?

    我正在寻找一种在屏幕上显示 java sql ResultSet 的方法 最好内置于java或swing中 如果这两个都没有一个简单的好方法 我会考虑 spring How 循环 ResultSet 的结果并将其放入 TableModel
  • 如何在 Google 电子表格中添加标题

    我在用gdata spreadsheet 3 0jar 用于在 Google 电子表格中输入数据 我在用 new ListEntry getCustomElements setValueLocal Header Name Value 但我不
  • 有没有一种简单的方法来扩展现有的激活函数?我的自定义 softmax 函数返回: 操作具有“无”梯度

    我想通过仅使用向量中的前 k 个值来实现使 softmax 更快的尝试 为此 我尝试为张量流实现一个自定义函数以在模型中使用 def softmax top k logits k 10 values indices tf nn top k