Beam Search源码理解

2023-05-16

本文的beam search源码来自:CodeBERT/model.py at master · microsoft/CodeBERT (github.com)https://github.com/microsoft/CodeBERT/blob/master/CodeBERT/code2nl/model.py

理解过程中加入了注释:

class Beam(object):
    def __init__(self, size,sos,eos):
        self.size = size
        self.tt = torch.cuda

        self.scores = self.tt.FloatTensor(size).zero_()
        # 大小为[beam_size],记录当前每个beam的分数总和
        
        self.prevKs = []
        # 记录每一步选取的是第几个beam,便于最后回溯生成结果
        
        self.nextYs = [self.tt.LongTensor(size)
                       .fill_(0)]
        # nextYs: [seq_len=1, beam_size],随着预测过程seq_len逐渐增加,表示每一步的输出结果
        # seq_len即为time_step
        
        self.nextYs[0][0] = sos
        
        # Has EOS topped the beam yet.
        self._eos = eos
        self.eosTop = False
        # Time and k pair for finished.
        self.finished = []

    def getCurrentState(self):
        batch = self.tt.LongTensor(self.nextYs[-1]).view(-1, 1)
        # batch: [beam_size, seq_len],用于加入到下一次模型的输入中。
        return batch

    def getCurrentOrigin(self):
        "Get the backpointers for the current timestep."
        return self.prevKs[-1]

    def advance(self, wordLk):
        '''
        更新beam中的信息
        wordLk: [beam_size, vocab_size],上一个时间节点每个beam的模型预测结果,需要用LogSoftMax进行归一化
        '''

        numWords = wordLk.size(1)
        # numWords: vocab_size
        

        if len(self.prevKs) > 0:
            beamLk = wordLk + self.scores.unsqueeze(1).expand_as(wordLk)
            # scores: [beam_size]
            # wordLk是当前的分数,scores是之前的分数,加起来得到beamLk: [beam_size, vocab_size]
            

            for i in range(self.nextYs[-1].size(0)):
                if self.nextYs[-1][i] == self._eos:
                    beamLk[i] = -1e20
                    # 把第i个beam的概率全部设置为负无穷
        else:
            beamLk = wordLk[0]
            # beamLk: [vocab_size] 刚开始只有第一个beam
        
        flatBeamLk = beamLk.view(-1) # beamlLk展开
        bestScores, bestScoresId = flatBeamLk.topk(self.size, 0, True, True) # topk个最好分数
        
        self.scores = bestScores
        # scores: [beam_size]

        prevK = bestScoresId // numWords
        # prevK: [beam_size]
        self.prevKs.append(prevK)
        # prevKs: [time_step, beam_size] 记录了每个时间节点的结果来自于第几个beam
        self.nextYs.append((bestScoresId - prevK * numWords))
        # nextYs: [seq_len, beam_size] 记录了每个事件节点选取的id, seq_len即time_step
        
        # 对nextYs的最后一个时间节点进行遍历,检查是否出现了结束符
        for i in range(self.nextYs[-1].size(0)):
            if self.nextYs[-1][i] == self._eos:
                s = self.scores[i]
                self.finished.append((s, len(self.nextYs) - 1, i))
                # i 表示第几个beam
                # 若出现结束符,将(总分数,句子长度,beam的id)三元组加入到finished列表中
                # finished列表中存的是已经结束的beam的信息

        # End condition is when top-of-beam is EOS and no global score.
        if self.nextYs[-1][0] == self._eos:
            # 当nextYs中最后一个时间点的第一个id为结束符时,将eosTop设置为True
            self.eosTop = True

    def done(self):
        # 当eosTop为True且已经结束的beam数大于等于beam_size的时候就结束。
        return self.eosTop and len(self.finished) >=self.size

    def getFinal(self):
        
        if len(self.finished) == 0:
            # 这里的情况就是所有beam的句子长度都达到了max_length但没有任何一个产生了结束符
            self.finished.append((self.scores[0], len(self.nextYs) - 1, 0))
            # 这种情况下就手动将第0个beam设置为已经结束
        self.finished.sort(key=lambda a: -a[0])
        # 将finished按beam的分数由大到小排序
        if len(self.finished) != self.size:
            # 将没有结束的句子也按(分数,长度,beam_id)三元组的形势加入到finished中
            unfinished=[]
            for i in range(self.nextYs[-1].size(0)):
                if self.nextYs[-1][i] != self._eos:
                    s = self.scores[i]
                    unfinished.append((s, len(self.nextYs) - 1, i)) 
            unfinished.sort(key=lambda a: -a[0])
            self.finished+=unfinished[:self.size-len(self.finished)]
        # 已经结束的beam排在未结束的句子前面
        return self.finished[:self.size]

    def getHyp(self, beam_res):
        """
        回溯,生成结果
        """
        
        # beam_res 传入的就是finished列表,由get_final得到
        hyps=[]
        for _,timestep, k in beam_res:
            # k是指该结果来自于第几个beam
            hyp = []
            for j in range(len(self.prevKs[:timestep]) - 1, -1, -1):
                # prevKs: [time_step, beam_size] 记录了每个时间节点的结果来自于第几个beam
                hyp.append(self.nextYs[j+1][k])
                # nextYs: [time_step, beam_size] 记录了每个beam的每一步选择,将该id加入到hyp中
                k = self.prevKs[j][k]
                # k为结果来自于第几个beam
            hyps.append(hyp[::-1])# hyp反过来加入到hyps中
        # 最后得到的hyps:[beam_size, ~]列表,~即长度不一,是每一个beam的预测结果,按分数大小排列
        return hyps
    
    def buildTargetTokens(self, preds):
        # preds即为getHyp产生的hyps,记录了每个beam产生的结果,按分数大小排列
        # 这个函数的目的是截断eos之后的结果
        sentence=[]
        for pred in preds:
            tokens = []
            for tok in pred:
                if tok==self._eos:
                    break
                tokens.append(tok)
            sentence.append(tokens)
        return sentence

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Beam Search源码理解 的相关文章

  • Powershell 错误:方法调用...不包含名为“replace”的方法

    我想使用 PowerShell 搜索并替换 xml 文件中的字符串 我试过这个 gc d test xml replace 1234 xxxx sc d test xml 这对于我的 test xml 文件效果很好 我的 test xml
  • 实现快速 Javascript 搜索?

    基本上 我有一个带有文本框的页面和 ul 列在其下面 这 ul 由用户的朋友列表填充 用户开始在文本框中输入朋友的名字 例如按 r 我想立即更新 ul 每次按键仅显示名字以 R 开头的朋友 例如 Richard Redmond Raheem
  • 在python中删除链表中的节点

    删除链表中的节点 这个实现有什么问题 def delete self val tmp self head prev None while tmp if val tmp data self size 1 if prev None self h
  • 以编程方式在 App Store 上运行搜索?

    是否可以从我的应用程序中打开 App Store 应用程序并运行搜索 我想看看是否有一个 appstore 类型的 URL 可以使用 就像 mailto 和 sms 分别打开邮件和短信一样 有谁知道这是否可能 编辑 更多信息 我一直在尝试使
  • JAVA:如何搜索地图?

    我有一个 Map 其键为字符串 其值为集合 包含整数 假设我的钥匙看起来像 苹果 香蕉 橙色 等 用户输入文本 我将其保存为字符串变量 如何在我的地图中搜索相同的密钥 因此 如果用户输入 apple 我如何将该字符串提供给方法并让该方法在我
  • 在一个后台为MYSQL的网站上集成搜索

    我有一个位置搜索website http www jammulinks com对于一个城市 我们首先收集该城市所有可能类别的数据 如学校 学院 百货商店等 并将其信息存储在单独的表中 因为每个条目除了名称 地址和电话号码外都有不同的详细信息
  • 在 Meteor 应用程序中实现 MongoDB 2.4 的全文搜索

    我正在考虑向 Meteor 应用程序添加全文搜索 我知道 MongoDB 现在支持此功能 但我对实现有一些疑问 启用文本搜索功能的最佳方法是什么 textSearchEnabled true 在 Meteor 应用程序中 有没有办法添加索引
  • 在 .csv 文件中搜索 C 中的名称匹配项

    我目前有一个 csv 文件 其中包含三个字段 用户 密码 类型 例如 我的文件如下所示 michael sun123 user joseph sierra7 user isaac apple2 sysop 我想从这样的文件中读取并检查用户
  • PHP、in_array 和数组中的快速搜索(到最后)

    我对在数组中进行快速搜索的更好方法有疑问 我正在谈论一个特定的情况 假设我有一个数组 L A B C 当我开始时 当程序运行时 L 可能会增长 但到最后 当我进行搜索时 一个可能的原因是 L A B C D E 事实是 当我搜索时 我想要找
  • 使用 PHP MySql 进行关键字搜索?

    我的 mysql 表中有标题 varchar 描述 text 关键字 varchar 字段 我保留了关键字字段 因为我认为我只会在这个字段中搜索 但我现在需要在所有三个字段中进行搜索 所以对于关键字 word1 word2 word3 我的
  • 使用 Fortran 进行数组问题的二分查找

    我正在使用 Schaum 的 Fortran 77 编程概要 一书 其中有一个关于使用括号值组方法进行二分搜索的示例 首先这是代码 INTEGER X 100 INTEGER RANGE INTEGER START FINISH PRINT
  • 检查 Bash 数组中是否存在元素[重复]

    这个问题在这里已经有答案了 我想知道是否有一种有效的方法来检查 Bash 数组中是否存在元素 我正在寻找类似于我可以在Python中做的事情 例如 arr a b c d if d in arr do your thing else do
  • grep 查找 Unix 中的特殊字符

    我有一个日志文件 application log 其中可能包含以下多行普通和特殊字符字符串 Q 我想搜索包含这个特殊字符串的行号 grep Q application log 上述命令不返回任何结果 获取行号的正确语法是什么 Tell gr
  • PHP 搜索部分字符串

    如何在键入时搜索部分字符串 不使用 MySQL 例如 MySQL 中的 LIKE 函数 但在搜索字符串时使用 PHP 例如 但这显然行不通 但是有没有一个函数可以搜索部分字符串 那太好了 EDIT 如果它在数组中怎么办 如果我使用 strp
  • 将默认搜索文本添加到搜索框 html

    我正在努力将 搜索 文本添加到搜索框 我正在努力实现 onfocus 消失文本 And onblur 重新出现文本 到目前为止 我已经实现了这一点 但我必须将其硬编码为 html eg
  • 使 IPTC 数据可搜索

    我对 IPTC 元数据有疑问 是否可以通过 IPTC 元数据 关键字 搜索不在数据库中的图像并显示它们 我将如何执行此操作 我只需要一个基本的想法 我知道 PHP 有 iptcparse 函数 我已经编写了一个函数来获取画廊文件夹和所有子目
  • 常用姓名别名/昵称数据库

    我参与了一个 SQL NET 项目 该项目将搜索名称列表 我正在寻找一种方法来返回类似名字的人的一些结果 如果搜索 Tom 结果将包括 Thom Thomas 等 这是文件还是 Web 服务并不重要 设计示例 Table Names has
  • 使用 dismax 处理程序进行通配符搜索?

    我已成功索引文件 并且希望能够使用通配符进行搜索 我目前正在使用 dismaxRequestHandler QueryType dismax 进行搜索 以便我可以搜索查询的所有字段 像 computer 这样的常规搜索会返回结果 但 com
  • Bing 搜索 API 和 Azure

    我正在尝试以编程方式在 Microsoft Bing 搜索引擎上执行搜索 这是我的理解 有一个 Bing Search API 2 0 很快就会被替换 2012 年 8 月 1 日 新的 API 称为 Windows Azure Marke
  • 为什么 C# Array.BinarySearch 这么快?

    我已经实施了一个很简单用于在整数数组中查找整数的 C 中的 binarySearch 实现 二分查找 static int binarySearch int arr int i int low 0 high arr Length 1 mid

随机推荐

  • NVIDIA显卡及架构介绍

    版权申明 未经博主同意 xff0c 谢绝转载 xff01 xff08 请尊重原创 xff0c 博主保留追究权 xff09 xff1b 本博客的内容来自于 xff1a NVIDIA显卡及架构介绍 xff1b 学习 合作与交流联系q384660
  • 脉冲神经网络资料汇总

    往期文章推荐 xff1a 损失函数与代价函数 神经网络从入门到精通 脉冲神经网络综述笔记 版权申明 未经博主同意 xff0c 谢绝转载 xff01 xff08 请尊重原创 xff0c 博主保留追究权 xff09 xff1b 本博客的内容来自
  • 什么是NAS

    一 NAS是什么 简单的说就是连接在网络上 xff0c 让大家可以透过网络 xff08 内网 xff0c 外网 xff09 来进行储存和读取资料的设备 通俗点说 xff0c 就是有一台很小很小的台式主机 xff0c 里面只装了很多颗的磁盘
  • numba安装与使用

    一 numba是什么 Numba是一个针对Python的开源JIT编译器 xff0c 由Anaconda公司主导开发 xff0c 可以对Python原生代码进行CPU和GPU加速 Numba对NumPy数组和函数非常友好 解释器可以参考第四
  • 目标检测中算法评价指标FPS

    一 FPS 每秒传输帧数 Frames Per Second 是什么 FPS就是目标网络每秒可以处理 xff08 检测 xff09 多少帧 多少张图片 FPS简单来理解就是图像的刷新频率 xff0c 也就是每秒多少帧 假设目标检测网络处理1
  • pytorch版本对计算能力的要求

    一 pytorch对计算能力要求 首先查看pytorch是否可用cuda完整流程应该是先查看是否在当前环境下的python In span class token punctuation span span class token numb
  • 在VS2013中配置boost_1_58_0过程和遇到的的问题

    Boost是为C 43 43 语言标准库提供扩展的一些C 43 43 程序库的总称 Boost库是一个可移植 提供源代码的C 43 43 库 xff0c 作为标准库的后备 xff0c 是C 43 43 标准化进程的开发引擎之一 xff0c
  • C语言学习专栏(1):易忘点

    C语言学习专栏系列 xff1a 版权申明 未经博主同意 xff0c 谢绝转载 xff01 xff08 请尊重原创 xff0c 博主保留追究权 xff09 xff1b 本博客的内容来自于 xff1a C语言学习专栏 xff08 1 xff09
  • git如何配置模板文件

    git如何创建模板文件 创建xxx template文件 xff0c 其内容为团队制定的Git提交注释规范 xff0c 如 xff1a Desgraption Date Author 通过git config命令配置commit templ
  • iOS很坑的error:

    iOS错误如下 error using bridging headers with module interfaces is unsupported 仔细看好错误类型 xff0c 是关于swift混合编译问题 解决办法 完美解决 xff0c
  • 使用Hexo搭建个人博客,绑定GitHub以及个人域名

    文章目录 前言安装Git安装Nodejs安装Hexo创建一个根目录安装Hexo验证安装是否成功初始化网址安装网址依赖开启本地服务 托管到Git配置git的SSH在github上配置秘钥 托管到GitHub配置仓库地址hexo安装部署的命令验
  • ubuntu定时任务的设置

    ubuntu 定时执行任务需要进行如下操作 xff1a span class token comment 使用 crontab 添加定时任务 span span class token comment 1 打开定时任务 span span
  • linux静态库、linux动态库制作、使用,动态库报错:error while loading shared libraries: libxxx.so: cannot open shared o

    接上一篇 xff1a linux C C 43 43 程序编译 gcc编译器基础使用 编译阶段 编译优化 命令大全 g 43 43 适用 本次来分享linux下C C 43 43 程序的静态库和动态库的制作和使用 xff0c 不废话 xff
  • SpringBoot热部署四步完成(idea2021.1)

    1 在pom xml文件中设置 xff08 两小步 xff09 span class token number 1 1 span xff1a 在 span class token generics span class token punc
  • spring boot中.yml配置日志文件格式正确运行出错(logging level)

    yml文件配置logging出错 格式如下 logging span class token operator span level span class token operator span com span class token p
  • 基本类型的字面值及其类型转换

    基本类型的字面值及其类型转换 一 基本类型的字面值二 类型转换 一 基本类型的字面值 1 整数字面值是int类型 2 byte xff0c short xff0c char三种比int小的整数可以用范围内的值直接赋值 3 浮点数的字面值是d
  • 使用idea创建servlet程序(idea:2021.2)

    使用idea创建servlet程序 1 Feil gt New gt Project 2 创建一个java项目 创建好之后项目结构如下图 右键项目点击Add Frameworks Support 勾选Web Application如下图 x
  • Java笔记(1)——绪论

    1 Java程序的总结 编写 xff1a 将编写的java程序保存在以 java 结尾的源文件中 编译 xff1a 使用javac exe命令编译java源文件 运行 xff1a 使用java exe命令解释运行字节码文件 2 一个Java
  • idea导入第三方jar包并打包在项目中

    IDEA项目引入第三方jar包 1 在resource创建lib文件并导入第三方jar包2 在pom xml文件中进行配置3 刷新maven 1 在resource创建lib文件并导入第三方jar包 2 在pom xml文件中进行配置 3
  • Beam Search源码理解

    本文的beam search源码来自 xff1a CodeBERT model py at master microsoft CodeBERT github com https github com microsoft CodeBERT b