计算测试集每个类别的熵以测量 pytorch 上的不确定性

2024-01-11

我正在尝试使用 MC Dropout 方法和此链接中提出的解决方案来计算图像分类任务的数据集的每一类的熵,以测量 pytorch 上的不确定性
在 pytorch 上使用 MC Dropout 测量不确定性 https://stackoverflow.com/questions/63285197/measuring-uncertainty-using-mc-dropout-on-pytorch

首先,我计算了不同前向传递中每批每个类的平均值 (class_mean_batch),然后计算了所有测试加载程序 (classes_mean),然后进行了一些转换以获取 (total_mean) 以使用它来计算熵,如下面的代码所示

def mcdropout_test(batch_size,n_classes,model,T):

    #set non-dropout layers to eval mode
    model.eval()

    #set dropout layers to train mode
    enable_dropout(model)
    
    softmax = nn.Softmax(dim=1)
    classes_mean = []
       
    for images,labels in testloader:
        images = images.to(device)
        labels = labels.to(device)
        classes_mean_batch = []
            
        with torch.no_grad():
          output_list = []
          
          #getting outputs for T forward passes
          for i in range(T):
            output = model(images)
            output = softmax(output)
            output_list.append(torch.unsqueeze(output, 0))
            
        
        concat_output = torch.cat(output_list,0)
        
        # getting mean of each class per batch across multiple MCD forward passes
        for i in range (n_classes):
          mean = torch.mean(concat_output[:, : , i])
          classes_mean_batch.append(mean)
        
        # getting mean of each class for the testloader
        classes_mean.append(torch.stack(classes_mean_batch))
        

    total_mean = []
    concat_classes_mean = torch.stack(classes_mean)

    for i in range (n_classes):
      concat_classes = concat_classes_mean[: , i]
      total_mean.append(concat_classes)


    total_mean = torch.stack(total_mean)
    total_mean = np.asarray(total_mean.cpu())
 
    epsilon = sys.float_info.min
    # Calculating entropy across multiple MCD forward passes 
    entropy = (- np.sum(total_mean*np.log(total_mean + epsilon), axis=-1)).tolist()
    for i in range(n_classes):
      print(f'The uncertainty of class {i+1} is {entropy[i]:.4f}')
    
    

任何人都可以纠正或确认我用来计算每个类的熵的实现。


None

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

计算测试集每个类别的熵以测量 pytorch 上的不确定性 的相关文章

  • Java机器学习库可以商用吗? [关闭]

    Closed 这个问题正在寻求书籍 工具 软件库等的推荐 不满足堆栈溢出指南 help closed questions 目前不接受答案 有谁知道我可以将其用于商业产品的优秀 Java 机器学习库吗 不幸的是 Weka 和 Rapidmin
  • 如何从已安装的云端硬盘文件夹中永久删除?

    我编写了一个脚本 在每次迭代后将我的模型和训练示例上传到 Google Drive 以防发生崩溃或任何阻止笔记本运行的情况 如下所示 drive path drive My Drive Colab Notebooks models if p
  • R 中多类分类的 ROC 曲线

    我有一个包含 6 个类别的数据集 我想绘制多类别分类的 ROC 曲线 Achim Zeileis 给出的第一个答案非常好 R中使用rpart包的ROC曲线 https stackoverflow com questions 30818188
  • 如何使用 pytorch 同时迭代两个数据加载器?

    我正在尝试实现一个接收两张图像的暹罗网络 我加载这些图像并创建两个单独的数据加载器 在我的循环中 我想同时遍历两个数据加载器 以便我可以在两个图像上训练网络 for i data in enumerate zip dataloaders1
  • PyTorch 给出 cuda 运行时错误

    我对我的代码做了一些小小的修改 以便它不使用 DataParallel and DistributedDataParallel 代码如下 import argparse import os import shutil import time
  • Weka J48 分类器:无法处理数字类?

    我现在尝试使用 Weka 在我的训练数据上构建 J48 C4 5 分类器模型 首先我这样做 这似乎很顺利 java Xmx10G cp weka weka jar weka core converters TextDirectoryLoad
  • Pytorch TypeError:eq() 收到无效的参数组合

    num samples 10 def predict x sampled models guide None None for in range num samples yhats model x data for model in sam
  • 如何屏蔽 PyTorch 权重参数中的权重?

    我正在尝试在 PyTorch 中屏蔽 强制为零 特定权重值 我试图掩盖的权重是这样定义的def init class LSTM MASK nn Module def init self options inp dim super LSTM
  • ID3和C4.5:“增益比”如何标准化“增益”?

    ID3算法使用 信息增益 度量 C4 5 使用 增益比 度量 即信息增益除以SplitInfo 然而SplitInfo对于记录在不同结果之间平均分配的分割 该值较高 否则较低 我的问题是 这如何帮助解决信息增益偏向于具有多种结果的分裂的问题
  • 如何以干净高效的方式在 pytorch 中获得小批量?

    我试图做一件简单的事情 即使用火炬通过随机梯度下降 SGD 训练线性模型 import numpy as np import torch from torch autograd import Variable import pdb def
  • 增量决策树 C++ 实现

    有谁知道决策树分类器的增量实现吗 这样 当您将新实例添加到训练集中时 它可以根据现有决策树分类器以低计算量并尽可能快地生成最佳决策树分类器 换句话说 我有一个最优决策树分类器集A 其中命名为T 1 现在我想添加实例X to set A并找到
  • 如何使用 lstm 执行多类多输出分类

    I have multiclass multioutput classification see https scikit learn org stable modules multiclass html https scikit lear
  • 使用 scikit-learn 在朴素贝叶斯分类器中混合类别数据和连续数据

    我正在使用 Python 中的 scikit learn 开发分类算法来预测某些客户的性别 除此之外 我想使用朴素贝叶斯分类器 但我的问题是我混合了分类数据 例如 在线注册 接受电子邮件通知 等 和连续数据 例如 年龄 长度 会员资格 等
  • 如何从 PyTorch 模型的特定层获取输出?

    如何从预训练的 PyTorch 模型 例如 ResNet 或 VGG 中提取特定层的特征 而无需再次进行前向传递 新答案 Edit torchvision v0 11 0 中有一个新功能 允许提取特征 https github com py
  • 如何在pytorch中动态索引张量?

    例如 我有一个张量 tensor torch rand 12 512 768 我得到了一个索引列表 说它是 0 2 3 400 5 32 7 8 321 107 100 511 我希望从给定索引列表的维度 2 上的 512 个元素中选择 1
  • PyInstaller 可执行文件无法获取 TorchScript 源代码

    我正在尝试使包含 PyTorch 的脚本在 Windows 中可执行 我的脚本的导入是 import numpy core multiarray which is a workaround for ImportError numpy cor
  • 将 Pytorch 模型 .pth 转换为 onnx 模型

    我有一个预训练的模型 其格式为 pth 扩展名 我想将其转换为 Tensorflow protobuf 但我没有找到任何方法来做到这一点 我见过 onnx 可以将模型从 pytorch 转换为 onnx 然后从 onnx 转换为 Tenso
  • PyTorch 中的后向函数

    我对 pytorch 的后向功能有一些疑问 我认为我没有得到正确的输出 import numpy as np import torch from torch autograd import Variable a Variable torch
  • Pytorch LSTM:计算交叉熵损失的目标维度

    我一直在尝试在 Pytorch 中使用 LSTM LSTM 后跟自定义模型中的线性层 但在计算损失时出现以下错误 Assertion cur target gt 0 cur target lt n classes failed 我用以下函数
  • 如何在 Google Colab 上安装 PyTorch v1.0.0+?

    PyTorch v1 0 0 稳定版是发布于 2018 年 12 月 8 日 https github com pytorch pytorch releases tag v1 0 0成为之后7个月前宣布 https code fb com

随机推荐

  • ASP.NET MVC 中部分视图的正确位置是什么?

    有人会确认 ASP NET MVC 中部分视图的最佳位置吗 我的想法是 如果这是一个将在许多地方使用的全球视图 那么就可以共享 如果它是视图的一部分 并被包装到部分视图中以使代码阅读更容易 那么它应该进入 Views Controller
  • 理解从先序遍历构造树的伪代码

    我需要做一些类似于这个问题中描述的任务 根据给定的前序遍历构造树 https stackoverflow com questions 4908545 construct tree with pre order traversal given
  • 如何使用 WebGL 和 GLSL 在 J/s 文件中运行 Shadertoy 中的着色器?

    我是着色器编程新手 我想使用 WebGL 和 GLSL 创建一个着色器 为了了解它的实际工作原理 我想测试 Shadertoy 的着色器 但是如何从 Shadertoy 获取代码并实际在 J S 文件中运行它呢 您是否只需将 Shadert
  • 以编程方式从“p”和“q”生成“d”(RSA)

    我有两个号码 p and q 我知道我能得到phi p 1 q 1 然后ed 1 mod phi 但我不确定我明白这意味着什么 我写了一些Python p NUM q NUM e NUM phi p 1 q 1 d 1 phi float
  • 回显所有 json_encoded 行

    我正在尝试循环访问数据库并输出与连接表匹配的所有行 我有以下两个表 任务项目存储与项目相关的所有数据 加入任务项存储玩家 ID 和玩家拥有的物品之间的关联 JS 传入查询表所需的所有信息 getJSON phpscripts php pla
  • 尝试使用 Protocol Buffers - Google 的数据交换格式时,goog 未定义错误

    我正在尝试使用 Protocol Buffers Google 的数据交换格式https github com google protobuf tree master js https github com google protobuf
  • plpgsql For循环中的Select语句创建多个CSV文件

    我想重复以下查询 8760 次 将一年中每个小时的 2 替换为 1 到 8760 我们的想法是每小时创建一个单独的 CSV 文件以进行进一步处理 COPY SELECT FROM public completedsolarirad2012
  • ZF2 toRoute 与 https

    我们正在使用 Zend Framework 2 并使用toRoute在我们的控制器中重定向到不同的位置 例如 this gt redirect gt toRoute home 无论如何 是否可以使用此方法或替代方法将其重定向到 https
  • 如何嵌入文件以供以后解析执行使用

    我本质上是想浏览一个 html 文件的文件夹 我想将它们嵌入到二进制文件中 并能够根据请求解析它们以用于模板执行目的 如果我措辞不当 请原谅 任何想法 提示 技巧或更好的方法来实现这一点都非常感谢 Template Files type T
  • Base64 java 中的文件编码失败

    我有这个类来编码和解码文件 当我使用 txt 文件运行该类时 结果成功 但是 当我使用 jpg 或 doc 运行代码时 我无法打开该文件 或者它不等于原始文件 我不知道为什么会发生这种情况 我修改了这个类http myjeeva com c
  • 在 Node 中通过“_id”搜索 MongoDB 条目的正确方法

    我在用着MongoDb 作为 的一部分MongoJS in Node 这是 MongoJS 的文档 https github com gett mongojs 我正在尝试根据条目在 Node 内进行调用 id场地 使用香草时MongoDB从
  • 如何改变gvim中的左边距

    我在 XP 上有 gvim 7 3 我的问题是 当我编辑文件并关闭行号时 文本距离左窗口边距太近 我不想添加前导空白 我想增加边距 当我有行号时 我不喜欢 左窗口边框和行号之间有足够的空间 行号和文本之间有足够的空间 但是当行号关闭时就没有
  • 如何获取隐藏数据库的数据库模式?

    我的客户是一家牙科诊所 购买了一款诊所管理软件 该软件安装在他们的本地服务器上 包括患者数据库 时间表和各种医疗记录 现在他们希望我为他们编写一些他们的软件包中未提供的实用程序 为此我需要能够查询该数据库 我尝试致电软件制造商的技术支持 帕
  • Azure AD - 仅应用程序令牌中缺少角色声明

    当我尝试从 Nodejs 后端服务器获取仅应用程序令牌时 如下所述here https learn microsoft com en us graph auth v2 service 4 get an access token 有时role
  • 如何在 Vim 中创建文件夹(优先使用 NERDTree)?

    我知道如何创建重命名 删除和移动文件NERDTree 只需按m then either a d or m 但我不知道如何创建文件夹 有谁知道如何做到这一点NERDTree 或者只是以 vim 的原生 方式 You use m a并放置一个尾
  • ##+#. 是什么意思?是什么意思?

    谷歌几乎是不可能的 因此我的理解仅限于阅读 slime 源代码的上下文线索 也许它是 common lisp 中对象系统的一部分 类似 自己 的东西 片段 cond swank backend sbcl with new stepper p
  • 基于列子集修剪 NA - 更优雅的解决方案?

    stackoverflow 社区的新年难题 通过阅读过去的帖子和答案很有帮助 这是我的第一个问题 我找到了解决方法 但我想知道是否可以建议其他方法 解决方案 我正在尝试从大型文件中删除尾随的 NAdata frame 但这些 NA 只出现在
  • jQuery UI DatePicker - 禁用除每月最后一天之外的所有日期

    我正在尝试使用 jquery UI 日期选择器来显示仅可选择该月最后一天的日历 我已成功使用 beforeShowDay 事件禁用一周中的几天 但不确定如何使用它来禁用除该月最后一天之外的所有内容 beforeShowDay 会为日历上显示
  • Android - 仅垂直布局

    如何确保我的应用程序仅适用于垂直布局 我努力了android screenOrientation portrait 但这似乎并不能解决问题 您需要添加到所有活动中 而不仅仅是一项活动 我认为您了解设置是每个应用程序范围内的 但事实并非如此
  • 计算测试集每个类别的熵以测量 pytorch 上的不确定性

    我正在尝试使用 MC Dropout 方法和此链接中提出的解决方案来计算图像分类任务的数据集的每一类的熵 以测量 pytorch 上的不确定性 在 pytorch 上使用 MC Dropout 测量不确定性 https stackoverf