如果手动使用 BCE 损失,则所有梯度值计算为“无”

2023-12-27

我正在研究一个多输出模型,在计算总体损失之前,我需要权衡所有输出损失。我有一个定制的model. fit() 训练循环 https://www.tensorflow.org/guide/keras/customizing_what_happens_in_fit为了达成这个。

由于我需要计算所有四个输出的样本损失并在应用权重后融合这些样本损失,因此我定制了标准代码。现在,损失是按样本计算的,但在计算梯度时,所有梯度值都计算为“无”。我试着把tape.watch(loss)也,但它不起作用。请帮我解决这个问题。

class CustomModel(keras.Model):
    def train_step(self, data):
        print(tf.executing_eagerly())
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        x, y = data
        alpha = 0.1
        loss = 0
        y_pred_all = []

        with tf.GradientTape() as tape:
            bce = tf.keras.losses.BinaryCrossentropy(reduction=tf.keras.losses.Reduction.NONE)
            for spl in range(1 if np.shape(x)[0] == None else np.shape(x)[0]):
                tape.watch(loss)
                tape.watch(loss_mean)
                tape.watch(loss_element)
                x_spl = np.reshape(x[spl], (1, np.shape(x)[1], np.shape(x)[2], np.shape(x)[3]))
                y_pred = self(x_spl, training=True)  # Forward pass
                y_pred_all.append(y_pred)
                loss_element = bce(y[spl], y_pred)
                loss_mean = [np.mean(loss_element[0]), np.mean(loss_element[1]), np.mean(loss_element[2]), np.mean(loss_element[3])]
                id = np.argmin(loss_mean)
                for i, ele in enumerate(loss_mean):
                    if i == id:
                        loss_mean[i] *= 1
                    else:
                        loss_mean[i] *= alpha

                loss = loss + np.sum(loss_mean)

        # Compute gradients
        trainable_vars = self.trainable_variables

        gradients = tape.gradient(loss, trainable_vars)

        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        # Update metrics (includes the metric that tracks the loss)
        self.compiled_metrics.update_state(y, y_pred_all)
        # Return a dict mapping metric names to current value
        return {m.name: m.result() for m in self.metrics}

UPDATE我按照建议做了一些更改@rvinas现在它正在计算梯度,没有任何错误,但我不确定我所做的更改是否正确:

class CustomModel(keras.Model):
    def train_step(self, data):
        # print(tf.executing_eagerly())
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        x, y = data
        alpha = 0.1
        loss = tf.Variable(0, dtype='float32')
        y_pred_all = []

        with tf.GradientTape() as tape:
            bce = tf.keras.losses.BinaryCrossentropy(reduction=tf.keras.losses.Reduction.NONE)
            for spl in tf.range(1 if tf.shape(x)[0] == None else tf.shape(x)[0]):
                loss_mean=tf.convert_to_tensor([])
                x_spl =  tf.reshape(x[spl], (1, tf.shape(x)[1], tf.shape(x)[2], tf.shape(x)[3]))
                y_pred = self(x_spl, training=True)  # Forward pass
                y_pred_all.append(y_pred)
                loss_element = bce(y[spl], y_pred)
                loss_mean = [tf.reduce_mean(loss_element[0]), tf.reduce_mean(loss_element[1]), tf.reduce_mean(loss_element[2]), tf.reduce_mean(loss_element[3])]

                id = tf.argmin(loss_mean)
                for i, ele in enumerate(loss_mean):
                    if i == id:
                        loss_mean[i] = tf.multiply(loss_mean[i], 1)
                    else:
                        loss_mean[i] = tf.multiply(loss_mean[i], alpha)

                loss = tf.add(loss, tf.add(tf.add(tf.add(loss_mean[0],loss_mean[1]), loss_mean[2]), loss_mean[3]))

        # Compute gradients
        trainable_vars = self.trainable_variables

        gradients = tape.gradient(loss, trainable_vars)

        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        # Update metrics (includes the metric that tracks the loss)
        self.compiled_metrics.update_state(y, y_pred_all)
        # Return a dict mapping metric names to current value
        return {m.name: m.result() for m in self.metrics}

出现 NaN 梯度是因为您正在使用 NumPy 运算(例如np.sum, np.reshape, ...),这会导致图表断开连接。相反,我们只需要使用张量流运算来实现逻辑。


例如,可以按如下方式实现评论部分中描述的权重:

bce = tf.keras.losses.BinaryCrossentropy(reduction=tf.keras.losses.Reduction.NONE)
with tf.GradientTape() as tape:
    # Compute element-wise losses
    y_pred = self(x, training=True)
    losses = bce(y, y_pred)  # Shape=(bs, 4)

    # Find maximum loss for each sample
    idx_max = tf.argmax(losses, axis=-1)  # Shape=(bs,)
    idx_max_onehot = tf.one_hot(idx_max, depth=y.shape[-1])  # Shape=(bs, 4)

    # Create weights tensor
    weight_max = 1
    weight_others = 0.1
    weights = idx_max_onehot * weight_max + (1 - idx_max_onehot) * weight_others

    # Aggregate losses
    losses = tf.reduce_sum(weights * losses, axis=-1)
    loss = tf.reduce_mean(losses)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如果手动使用 BCE 损失,则所有梯度值计算为“无” 的相关文章

  • 如何生成给定范围内的回文数列表?

    假设范围是 1 X 120 这是我尝试过的 gt gt gt def isPalindrome s check if a number is a Palindrome s str s return s s 1 gt gt gt def ge
  • Pycharm Python 控制台不打印输出

    我有一个从 Pycharm python 控制台调用的函数 但没有显示输出 In 2 def problem1 6 for i in range 1 101 2 print i end In 3 problem1 6 In 4 另一方面 像
  • 导入错误:没有名为 _ssl 的模块

    带 Python 2 7 的 Ubuntu Maverick 我不知道如何解决以下导入错误 gt gt gt import ssl Traceback most recent call last File
  • 如何打印没有类型的defaultdict变量?

    在下面的代码中 from collections import defaultdict confusion proba dict defaultdict float for i in xrange 10 confusion proba di
  • SQL Alchemy 中的 NULL 安全不等式比较?

    目前 我知道如何表达 NULL 安全的唯一方法 SQL Alchemy 中的比较 其中与 NULL 条目的比较计算结果为 True 而不是 NULL 是 or field None field value 有没有办法在 SQL Alchem
  • 如何使用 Scrapy 从网站获取所有纯文本?

    我希望在 HTML 呈现后 可以从网站上看到所有文本 我正在使用 Scrapy 框架使用 Python 工作 和xpath body text 我能够获取它 但是带有 HTML 标签 而且我只想要文本 有什么解决办法吗 最简单的选择是ext
  • Spark的distinct()函数是否仅对每个分区中的不同元组进行洗牌

    据我了解 distinct 哈希分区 RDD 来识别唯一键 但它是否针对仅移动每个分区的不同元组进行了优化 想象一个具有以下分区的 RDD 1 2 2 1 4 2 2 1 3 3 5 4 5 5 5 在此 RDD 上的不同键上 所有重复键
  • __del__ 真的是析构函数吗?

    我主要用 C 做事情 其中 析构函数方法实际上是为了销毁所获取的资源 最近我开始使用python 这真的很有趣而且很棒 我开始了解到它有像java一样的GC 因此 没有过分强调对象所有权 构造和销毁 据我所知 init 方法对我来说在 py
  • Python 中的二进制缓冲区

    在Python中你可以使用StringIO https docs python org library struct html用于字符数据的类似文件的缓冲区 内存映射文件 https docs python org library mmap
  • NameError:名称“urllib”未定义”

    CODE import networkx as net from urllib request import urlopen def read lj friends g name fetch the friend list from Liv
  • 在pyyaml中表示具有相同基类的不同类的实例

    我有一些单元测试集 希望将每个测试运行的结果存储为 YAML 文件以供进一步分析 YAML 格式的转储数据在几个方面满足我的需求 但测试属于不同的套装 结果有不同的父类 这是我所拥有的示例 gt gt gt rz shorthand for
  • 使用 OpenPyXL 迭代工作表和单元格,并使用包含的字符串更新单元格[重复]

    这个问题在这里已经有答案了 我想使用 OpenPyXL 来搜索工作簿 但我遇到了一些问题 希望有人可以帮助解决 以下是一些障碍 待办事项 我的工作表和单元格数量未知 我想搜索工作簿并将工作表名称放入数组中 我想循环遍历每个数组项并搜索包含特
  • Numpy 优化

    我有一个根据条件分配值的函数 我的数据集大小通常在 30 50k 范围内 我不确定这是否是使用 numpy 的正确方法 但是当数字超过 5k 时 它会变得非常慢 有没有更好的方法让它更快 import numpy as np N 5000
  • Python 3 中“map”类型的对象没有 len()

    我在使用 Python 3 时遇到问题 我得到了 Python 2 7 代码 目前我正在尝试更新它 我收到错误 类型错误 map 类型的对象没有 len 在这部分 str len seed candidates 在我像这样初始化它之前 se
  • 在Python中重置生成器对象

    我有一个由多个yield 返回的生成器对象 准备调用该生成器是相当耗时的操作 这就是为什么我想多次重复使用生成器 y FunctionWithYield for x in y print x here must be something t
  • 检查所有值是否作为字典中的键存在

    我有一个值列表和一本字典 我想确保列表中的每个值都作为字典中的键存在 目前我正在使用两组来确定字典中是否存在任何值 unmapped set foo set bar keys 有没有更Pythonic的方法来测试这个 感觉有点像黑客 您的方
  • 如何使用google colab在jupyter笔记本中显示GIF?

    我正在使用 google colab 想嵌入一个 gif 有谁知道如何做到这一点 我正在使用下面的代码 它并没有在笔记本中为 gif 制作动画 我希望笔记本是交互式的 这样人们就可以看到代码的动画效果 而无需运行它 我发现很多方法在 Goo
  • 循环标记时出现“ValueError:无法识别的标记样式 -d”

    我正在尝试编码pyplot允许不同标记样式的绘图 这些图是循环生成的 标记是从列表中选取的 为了演示目的 我还提供了一个颜色列表 版本是Python 2 7 9 IPython 3 0 0 matplotlib 1 4 3 这是一个简单的代
  • 使用基于正则表达式的部分匹配来选择 Pandas 数据帧的子数据帧

    我有一个 Pandas 数据框 它有两列 一列 进程参数 列 包含字符串 另一列 值 列 包含相应的浮点值 我需要过滤出部分匹配列 过程参数 中的一组键的子数据帧 并提取与这些键匹配的数据帧的两列 df pd DataFrame Proce
  • Pandas 与 Numpy 数据帧

    看这几行代码 df2 df copy df2 1 df 1 df 1 values 1 df2 ix 0 0 我们的教练说我们需要使用 values属性来访问底层的 numpy 数组 否则我们的代码将无法工作 我知道 pandas Data

随机推荐