使用 Tensorflow.js 计算损失梯度

2024-02-08

我正在尝试使用 Tensorflow.js 计算与网络可训练权重相关的损失梯度,以便将这些梯度应用于我的网络权重。在 python 中,这可以使用 tf.gradients() 函数轻松完成,该函数需要两个表示 dx 和 dy 的最小输入。 但是,我无法重现 Tensorflow.js 中的行为。我不确定我对权重损失梯度的理解是否错误,或者我的代码是否包含错误。

我花了一些时间分析 tfjs-node 包的核心代码,以了解当我们调用函数 tf.model.fit() 时它是如何完成的,但到目前为止收效甚微。

let model = build_model(); //Two stacked dense layers followed by two parallel dense layers for the output
let loss = compute_loss(...); //This function returns a tf.Tensor of shape [1] containing the mean loss for the batch.
const f = () => loss;
const grad = tf.variableGrads(f);
grad(model.getWeights());

model.getWeights() 函数返回一个 tf.variable() 数组,因此我假设该函数会计算每一层的 dL/dW,稍后我可以将其应用于网络的权重,但是,情况并非如此,因为我得到这个错误:

Error: Cannot compute gradient of y=f(x) with respect to x. Make sure that the f you passed encloses all operations that lead from x to y.

我不太明白这个错误是什么意思。 那么我应该如何使用 Tensorflow.js 计算损失的梯度(类似于 Python 中的 tf.gradients())?

编辑 : 这是计算损失的函数:

function compute_loss(done, new_state, memory, agent, gamma=0.99) {
    let reward_sum = 0.;
    if(done) {
        reward_sum = 0.;
    } else {
        reward_sum = agent.call(tf.oneHot(new_state, 12).reshape([1, 9, 12]))
                    .values.flatten().get(0);
    }

    let discounted_rewards = [];
    let memory_reward_rev = memory.rewards;
    for(let reward of memory_reward_rev.reverse()) {
        reward_sum = reward + gamma * reward_sum;
        discounted_rewards.push(reward_sum);
    }
    discounted_rewards.reverse();

    let onehot_states = [];
    for(let state of memory.states) {
        onehot_states.push(tf.oneHot(state, 12));
    }
    let init_onehot = onehot_states[0];

    for(let i=1; i<onehot_states.length;i++) {
        init_onehot = init_onehot.concat(onehot_states[i]);
    }

    let log_val = agent.call(
        init_onehot.reshape([memory.states.length, 9, 12])
    );

    let disc_reward_tensor = tf.tensor(discounted_rewards);
    let advantage = disc_reward_tensor.reshapeAs(log_val.values).sub(log_val.values);
    let value_loss = advantage.square();
    log_val.values.print();

    let policy = tf.softmax(log_val.logits);
    let logits_cpy = log_val.logits.clone();

    let entropy = policy.mul(logits_cpy.mul(tf.scalar(-1))); 
    entropy = entropy.sum();

    let memory_actions = [];
    for(let i=0; i< memory.actions.length; i++) {
        memory_actions.push(new Array(2000).fill(0));
        memory_actions[i][memory.actions[i]] = 1;
    }
    memory_actions = tf.tensor(memory_actions);
    let policy_loss = tf.losses.softmaxCrossEntropy(memory_actions.reshape([memory.actions.length, 2000]), log_val.logits);

    let value_loss_copy = value_loss.clone();
    let entropy_mul = (entropy.mul(tf.scalar(0.01))).mul(tf.scalar(-1));
    let total_loss_1 = value_loss_copy.mul(tf.scalar(0.5, dtype='float32'));

    let total_loss_2 = total_loss_1.add(policy_loss);
    let total_loss = total_loss_2.add(entropy_mul);
    total_loss.print();
    return total_loss.mean();

}

EDIT 2:

我设法使用compute_loss作为model.compile()上指定的损失函数。但是,它只需要两个输入(预测、标签),所以它不适合我,因为我想输入多个参数。

我真的对这件事迷失了。


错误说明了一切。 您的问题与 tf.variableGrads 有关。loss应该是使用所有可用的计算得出的标量tf张量运算符。loss不应返回问题中所示的张量。

以下是损失应该是什么的示例:

const a = tf.variable(tf.tensor1d([3, 4]));
const b = tf.variable(tf.tensor1d([5, 6]));
const x = tf.tensor1d([1, 2]);

const f = () => a.mul(x.square()).add(b.mul(x)).sum(); // f is a function
// df/da = x ^ 2, df/db = x 
const {value, grads} = tf.variableGrads(f); // gradient of f as respect of each variable

Object.keys(grads).forEach(varName => grads[varName].print());

/!\ 请注意,梯度是根据使用创建的变量来计算的tf.variable

Update:

您没有按应有的方式计算梯度。这是修复方法。

function compute_loss(done, new_state, memory, agent, gamma=0.99) {
    const f = () => { let reward_sum = 0.;
    if(done) {
        reward_sum = 0.;
    } else {
        reward_sum = agent.call(tf.oneHot(new_state, 12).reshape([1, 9, 12]))
                    .values.flatten().get(0);
    }

    let discounted_rewards = [];
    let memory_reward_rev = memory.rewards;
    for(let reward of memory_reward_rev.reverse()) {
        reward_sum = reward + gamma * reward_sum;
        discounted_rewards.push(reward_sum);
    }
    discounted_rewards.reverse();

    let onehot_states = [];
    for(let state of memory.states) {
        onehot_states.push(tf.oneHot(state, 12));
    }
    let init_onehot = onehot_states[0];

    for(let i=1; i<onehot_states.length;i++) {
        init_onehot = init_onehot.concat(onehot_states[i]);
    }

    let log_val = agent.call(
        init_onehot.reshape([memory.states.length, 9, 12])
    );

    let disc_reward_tensor = tf.tensor(discounted_rewards);
    let advantage = disc_reward_tensor.reshapeAs(log_val.values).sub(log_val.values);
    let value_loss = advantage.square();
    log_val.values.print();

    let policy = tf.softmax(log_val.logits);
    let logits_cpy = log_val.logits.clone();

    let entropy = policy.mul(logits_cpy.mul(tf.scalar(-1))); 
    entropy = entropy.sum();

    let memory_actions = [];
    for(let i=0; i< memory.actions.length; i++) {
        memory_actions.push(new Array(2000).fill(0));
        memory_actions[i][memory.actions[i]] = 1;
    }
    memory_actions = tf.tensor(memory_actions);
    let policy_loss = tf.losses.softmaxCrossEntropy(memory_actions.reshape([memory.actions.length, 2000]), log_val.logits);

    let value_loss_copy = value_loss.clone();
    let entropy_mul = (entropy.mul(tf.scalar(0.01))).mul(tf.scalar(-1));
    let total_loss_1 = value_loss_copy.mul(tf.scalar(0.5, dtype='float32'));

    let total_loss_2 = total_loss_1.add(policy_loss);
    let total_loss = total_loss_2.add(entropy_mul);
    total_loss.print();
    return total_loss.mean().asScalar();
}

return tf.variableGrads(f);
}

请注意,您很快就会遇到内存消耗问题。建议将功能区分为tf.tidy处理张量。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

使用 Tensorflow.js 计算损失梯度 的相关文章

  • 如何在光标下的所有元素上调用 mouseover?

    我有一个网络应用程序 每次单击时都会创建一个点 见下文 当我将鼠标悬停在一堆点上时 我希望光标下的每个点都会触发 mouseover 或 mouseenter 事件 然而 只有一个事件被触发 即堆栈 顶部 的点的事件 当鼠标移动到一堆多个点
  • 获取 Node.js npm 命令以在公司代理后面工作

    我正在尝试安装凉亭 npm install g 鲍尔 但我从我们的代理收到身份验证错误 npm http 407http registry npmjs org bower http registry npmjs org bower 错误代码
  • 如何根据按钮单击折叠和展开 Kendo UI 树视图中的所有树节点?

    这是行不通的 您可以使用此代码 1 崩溃 折叠kendoTree查看文档 http docs kendoui com api web treeview methods collapse treeview kendoTreeView var
  • 如何使用 axios / jest 测试失败的请求

    我创建了一个非常小的应用程序 如果您传递硬币和数量 它可以计算为某些加密货币支付的总价格 我想测试错误 但我总是收到 收到的承诺已解决而不是被拒绝 我相信这是因为如果 url 错误 axios 仍然会解决承诺 我遇到的第二个问题是 我尝试测
  • 通知用户消息仍在输入中

    我正在使用 Laravel 5 6 7 Socket IO 和 vue js 我没有使用 Pusher 和 redis 下面是我的代码 用于向与我一对一聊天的用户发送消息 var url http localhost 6001 apps M
  • 如何获取 RxJSSubject 或 Observable 的当前值?

    我有 Angular 2 服务 import Storage from storage import Injectable from angular2 core import Subject from rxjs Subject Inject
  • 使用 NodeJS 让 Discord 机器人发送带有消息的图片

    我有几张照片 全部在 imgur 上 带有直接图像链接 格式 https i imgur com XXXXXX jpg https i imgur com XXXXXX jpg 以及用 NodeJS 制作的 Discord 机器人 我发送这
  • 向对象添加元素

    我需要填充一个 json 文件 现在我有这样的东西 element id 10 quantity 1 我需要添加另一个 元素 我的第一步是使用该 json 将该 json 放入对象类型中cart JSON parse 现在我需要添加新元素
  • 在 Angular2 项目中集成 Treant-js

    我正在尝试在 Angular2 项目中使用 treant js 但我正在努力解决如何正确集成它的问题 我有一个工作正常的 JavaScript HTML 示例 我正在尝试在 Angular2 中工作 我创建了一个组件 从 npm 添加了 t
  • 具有行组的 JQuery 斑马条纹表

    我通常将斑马条纹表行设置为奇数 偶数 如下所示 效果很好 table tbody tr visible even this addClass even table tbody tr visible odd this addClass odd
  • Bing.com 如何创建放大的缩略图?

    当我使用 Bing com 搜索图像时 我发现它们的图像经过精心裁剪和排序 当您将鼠标放在图像上时 会弹出另一个窗口 其中显示放大的图像 我想在我的程序中做同样的事情 我检查了他们页面的源代码 他们正在使用 javascript 但我仍然不
  • 使用 Javascript / Jquery 的本地存储(不使用 HTML5)

    我想在 javascript 或 jquery 中复制本地存储概念 类似于 HTML5 但不幸的是我不知道如何开始 任何人都可以建议如何使用 javascript 或 jquery 实现本地存储 不使用 HTML5 这是一个有点愚蠢的差事
  • 引入 V8 后,Google Apps 脚本无法为其他用户完全执行

    我编写了一个脚本 得到了这里好心人的大力帮助 该脚本使用 Google Sheets 脚本复制 Google Drive 上的文件夹 和内容 它运行了很长一段时间 但后来我启用了 V8 引擎 现在已禁用 问题是 它仍然适用于我 也许还有其他
  • JQuery DataTable 单元格从行单击

    我正在尝试在 jquery 数据表上实现一个函数 该函数返回单击行的第一列和第四列 我正在遵循这个示例 它允许我操作单击的行http datatables net examples api select single row html ht
  • 使用 onBlur 事件上的值更新 React 输入文本字段

    我有以下输入字段 在模糊时 该函数调用服务来更新服务器的输入值 完成后 它会更新输入字段 我怎样才能让它发挥作用 我可以理解为什么它不允许我更改字段 但我能做些什么才能使其工作 我无法使用defaultValue因为我会将这些字段更改为其他
  • 理论上防止 WebSocket 中第一个收到的消息丢失

    服务器端代码发送消息立即地连接打开后 它向客户端发送初始配置 问候语 以下代码是在客户端 var sock new WebSocket url sock addEventListener error processError sock ad
  • Meteor.js 登录事件

    因此 我对 Meteor 框架和 JavaScript 总体来说还很陌生 但我正在使用该框架开发一个小项目 以尝试让自己达到标准 基本上我正在开发一个微博客网站 目前 用户可以通过多种服务登录 fb google 等 我通过插入所需 url
  • 为什么我需要 $(document.body) 来使用 Mootools Element 方法扩展 document.body?

    因此 在尝试让我的应用程序在最新的 IE 上运行后 结果发现 IE 不喜欢以下代码 document body getElement className Firefox 和 Chrome 响应良好 但是document bodyIE 上没有
  • 如何根据所需表单输入的值更改 CSS 样式

    我想知道如何编写 javascript 来改变所需的表单元素的样式 如果它们有价值的话就改变它们 我想要做的是当所需的文本字段为空时 在它们周围有一个彩色边框 并在它们有值时删除边框样式 我想做的是编写一个 javascript 函数来检查
  • Serviceworker Bug event.respondWith

    我的 serviceworker 的逻辑是 当发生获取事件时 它首先获取包含一些布尔值 而不是 event request url 的端点 并根据我正在调用的值检查该值event respondWith 对于当前的获取事件 我正在提供来自缓

随机推荐

  • 如何在知道线程 id 的情况下获取消息线程 URL?

    有如果我有消息 ID 如何构建链接以查看 facebook com 上的消息 http facebook stackoverflow com questions 7747622 how can i construct a link to v
  • jquery mobile 和 ui 不兼容

    尽管有很多人提到类似的兼容性问题 但 50 的问题在 StackOverflow 上得到了解决 我希望我的问题能够成为 51 49 考虑这段代码
  • macOS 公证:找不到 altool

    我想开始构建一个公证自动化脚本 但是 当我尝试在终端中使用 xcrun altool 时 出现以下错误 xcrun error unable to find utility altool not a developer tool or in
  • 如何正确引用本地XML Schema文件?

    我在 XML 文件中引用 XML 架构时遇到此问题 我的 XSD 位于此路径中 C environment workspace maven ws ProjectXmlSchema email xsd 但是 当我在 XML 文件中尝试像这样查
  • 服务器标记格式不正确

    这真是太愚蠢了 但却让我快疯了
  • 堆叠 UITableViews 不会在其视图下方传递触摸事件

    我将 3 个 UIView 堆叠在一起 UI表格视图平面视图根视图 TableView 位于顶部 rootView 位于底部 rootView 不可见 因为 TableView 在它上面 我在 rootView 中实现了以下代码 code
  • 错误 TS2707 通用类型“ɵɵDirectiveDeclaration”需要 6 到 8 个类型参数

    安装角度材料并将角度材料导入 app module ts 添加到项目后 我遇到错误 并且到目前为止所有解决方案都不起作用 我的角度为 14 节点为 16 第一个错误 实际上要长得多 Error node modules angular cd
  • 如何使用 Python 从巨大的 Excel 工作表中提取特定行的数据?

    我需要获取其中包含某些关键字 名称 的特定数据行并将它们写入另一个文件 起始文件是 1 5 GB Excel 文件 我不能只是打开它并将其另存为不同的格式 我应该如何使用 python 处理这个问题 我是 xlrd 的作者和维护者 请编辑您
  • 如何提高Python循环速度?

    我有一个包含 370k 记录的数据集 存储在 Pandas Dataframe 中 需要集成 我尝试了多处理 线程 Cpython 和循环展开 但我没有成功 显示的计算时间是 22 小时 任务如下 matplotlib inline fro
  • 开发游戏服务器用什么语言好?

    我只是想知道什么语言是开发支持大量 数千 用户的游戏服务器的不错选择 我涉足Python 但意识到这太麻烦了 因为它不会跨核心产生线程 意味着8个核心服务器 1个核心服务器 我也不太喜欢这种语言 自我 的东西让我感到恶心 我知道 C 就性能
  • 在xamarin forms pcl项目中打开远程pdf的最佳方法

    在适用于 Ios 和 Android 的 xamarin pcl 应用程序中 在服务器上加载 pdf 的最佳方式是什么 是否有一个好的 nuget 或者我们必须编写自定义渲染器 在应用程序中打开 PDF 您有几个选项 iOS 在其 WebV
  • 使用 Cython 将 Python 链接到共享库

    我正在尝试集成用以下语言编写的第三方库C和我的python应用程序使用Cython 我已经为测试编写了所有 python 代码 我无法找到设置此功能的示例 我有一个pyd pyx我手动创建的文件 第三方给了我一个header file h
  • 使用Delphi RTTI获取接口的字符串名称

    我已经证明我可以使用 Delphi 2010 从其 GUID 获取接口的名称 例如 IMyInterface 转换为字符串 IMyInterface 我想在 Delphi 7 中实现此目的 为了兼容性 这可能吗 或者是存在基本的编译器限制
  • 哪种数据结构最适合 VirtualStringTree?

    我想每个曾经使用过Delphi的VirtualStringTree的人都会同意它是一个很棒的控件 它是一个 虚拟 控件 您的数据必须保存在其他地方 所以我在想什么数据结构最适合这样的任务 IMO认为数据结构必须支持层次结构 它必须快速且易于
  • 扩展器的默认控制模板

    有人 可能使用 Blend 可以为我提供 WPF Expander 的工作默认 ControlTemplate 吗 我想做一些细微的修改 但似乎找不到有效模板的来源 提前致谢 我有混合 可以帮助你 这是 Blend 为我生成的内容
  • 根据日期分割数据框

    我正在尝试根据日期将数据框分成两个 此处的相关问题已解决 根据日期将数据帧分成两部分 https stackoverflow com questions 37532098 split dataframe into two on the ba
  • Chrome 语音合成具有较长的文本

    我在 Chrome 33 中尝试使用语音合成 API 时遇到问题 它可以完美地处理较短的文本 但如果我尝试较长的文本 它就会停在中间 一旦停止后 语音合成将无法在 Chrome 中的任何地方工作 直到浏览器重新启动 示例代码 http js
  • 责任链模式是否可以很好地替代一系列条件?

    当您需要按特定顺序执行一系列操作时 责任链模式是否可以很好地替代一系列条件 用这样的条件替换简单的方法是个好主意吗 public class MyListener implements MyHttpListener if false the
  • 线程安全类是否应该在其构造函数末尾有一个内存屏障?

    当实现一个线程安全的类时 我是否应该在其构造函数末尾包含一个内存屏障 以确保任何内部结构在可以访问之前已完成初始化 或者消费者有责任在使实例可供其他线程使用之前插入内存屏障 简化问题 下面的代码中是否存在竞争危险 由于初始化和线程安全类的访
  • 使用 Tensorflow.js 计算损失梯度

    我正在尝试使用 Tensorflow js 计算与网络可训练权重相关的损失梯度 以便将这些梯度应用于我的网络权重 在 python 中 这可以使用 tf gradients 函数轻松完成 该函数需要两个表示 dx 和 dy 的最小输入 但是