DSSM pytorch实现

2023-05-16

之前在网上找到了一个文本匹配实现仓库,但是没有提供DSSM的代码,我就根据那个代码实现以下DSSM。数据集采用的是蚂蚁金服的数据集。也参考过别人的代码,但是总感觉怪怪的,DSSM原文中,一个query有对应的正样本和负样本,因此在实现的时候分别计算query与正负样本的余弦相似度,最后拼接再接softmax,但是蚂蚁金服数据集中每一个样本都已一个query和doc,对应一个label,并没有成对的正负样本,因此在实现中遇到了困难,因此最后我索性直接将余弦值作为网络输出,貌似还取得了不错的效果,那么代码会有些许不同。
第一,损失函数采用了二分类损失函数:

class torch.nn.BCELoss(weight=None, size_average=True)

第二,判断类别时:

def correct_predictions(output_probabilities, targets):
    """
    Compute the number of predictions that match some target classes in the
    output of a model.
    Args:
        output_probabilities: A tensor of probabilities for different output
            classes.
        targets: The indices of the actual target classes.
    Returns:
        The number of correct predictions in 'output_probabilities'.
    """
    # _, out_classes = output_probabilities.max(dim=1)
    out_classes = output_probabilities.ge(0.5).byte().float()
    correct = (out_classes == targets).sum()
    return correct.item()

第三,网络结构设计如下:

class DSSM(nn.Module):

    def __init__(self, dropout=0.2,device="gpu"):
        super(DSSM, self).__init__()
        self.device = device
        self.embed = nn.Embedding(7901, 100)
        self.fc1 = nn.Linear(100, 256)
        self.fc2 = nn.Linear(256, 512)
        self.fc3 = nn.Linear(512,256)
        self.dropout = nn.Dropout(dropout)
        self.Sigmoid = nn.Sigmoid() #method1
        self.relu = nn.ReLU()

    def forward(self, a, b):
        a = self.embed(a).sum(1)
        b = self.embed(b).sum(1)

        a = self.relu(self.fc1(a)) #torch.tanh
        # a = self.dropout(a)
        a = self.relu(self.fc2(a))
        # a = self.dropout(a)
        a = self.relu(self.fc3(a))
        # a = self.dropout(a)

        b = self.relu(self.fc1(b))
        # b = self.dropout(b)
        b = self.relu(self.fc2(b))
        # b = self.dropout(b)
        b = self.relu(self.fc3(b))
        # b = self.dropout(b)

        cosine = torch.cosine_similarity(a, b, dim=1, eps=1e-8)  #计算两个句子的余弦相似度
        # cosine = self.Sigmoid(cosine-0.5)
        cosine = self.relu(cosine)
        cosine = torch.clamp(cosine,0,1)
        return cosine

这样在蚂蚁金服测试集的准确率可以达到77以上,如果cosine后面不接relu,我跑到了78以上,但是总感觉出现了过拟合现象。此外,加入dropout效果反而不好,可能这个网络本身就不复杂吧。
其他的训练代码我参考了:https://github.com/zhaogaofeng611/TextMatch

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

DSSM pytorch实现 的相关文章

随机推荐

  • 完全二叉树学习

    定义 xff1a 假设高度为h xff0c 那么前h 1层都是满的 xff0c 最后一层 xff0c 从左向右 xff0c 连续集中在最左边 xff1b k层的完全二叉树总节点个数最小为2 k 1 xff0c 最大节点个数为2 k 1 可以
  • thrift例程编译报错原因和解决方法总结

    thrift里自带的turoral xff0c 使用make编译时经常会报错 xff0c 总结如下 xff1a 1 如果出现如下错误 xff1a error uint8 t does not name a type error uint32
  • C++11带来的move语义

    C 43 43 11带来了move语义 xff0c 可以有效的提高STL的效率 xff0c 这篇文章写的非常好 xff0c 可以参考 xff0c 这里对原文进行翻译 xff0c 加入我自己的理解 原文 xff1a http www cpro
  • C++11带来的lambda表达式

    C 43 43 11带来了lambda表达式 xff0c 可以简化程序的编写 xff0c 使代码更加清晰 现在按照步骤来介绍lambda表达式 xff1a 1 函数对象 又叫仿函数 xff0c 如果一个类或者结构体重载了operator 操
  • caffe中几个基本概念

    caffe中几个基本概念 1 caffe中的blob结构是用来进行数据存储 交换和处理网络中正向反向迭代时的数据和导数信息的数据结构 blob是caffe的标准数组结构 他提供了一个统一的内存接口 其将内部的cpu gpu数据之间的传输与存
  • 摄像头引脚定义

    摄像头引脚定义 1 NC NO CONNECT 2 AGND Power Analog ground 3 SIO D I O SCCB serial interface data I O 4 AVDD Power Analog power
  • Android7.0 JACK编译器不支持多用户同时编译的问题的解决

    xfeff xfeff Android7 0 xff08 也就是Android N xff09 上默认使用JACK编译器而不再使用openjdk了 xff0c 但发现JACK不是很好用 xff0c 比如最大的一个问题就是 xff0c 同一台
  • 【树莓派】死机自动重启、掉线自动重连

    目录 WIFI掉线自动重连 首先查看你的板子硬件型号 拿树莓派去做服务器就要配置下这两项 xff0c 保证随时能够VNC控制 WIFI掉线自动重连 http shumeipai nxez com 2017 01 25 raspberry p
  • open vswitch分析

    Open vSwitch 概述 Open vSwitch xff08 下面简称 OVS xff09 是一个高质量的 多层虚拟交换机 OVS 遵循开源 Apache2 0 许可 xff0c 通过可编程扩展 xff0c OVS 可以实现大规模网
  • C# 接口《通俗解释》

    原文地址 xff1a https www cnblogs com hamburger p 4681681 html 接口的定义 xff1a 接口是指定一组函数成员 xff0c 而不实现他们的引用类型 接口使用interface 关键字进行定
  • linux 如何查看指定动态库

    要查看 Linux 系统指定的动态库 xff0c 可以使用以下命令 xff1a 使用 ldconfig 命令 xff1a ldconfig p 该命令将显示系统已加载的所有动态库及其路径 如果要查找特定动态库 xff0c 可以使用 grep
  • Tortoisegit 恢复文件夹被删除的文件(被误删)

    关于Tortoisegit 恢复git文件夹中被删除的文件 xff1a 1 在git文件夹右键tortorisegit show log 2 选择版本 xff08 当时执行删除操作的版本 xff09 3 选择被delete掉的 xff0c
  • putty screen 快捷键

    使用putty的时候 xff0c 开启screen再detach xff0c 可以防止跑程序过程中断开连接而导致程序中断 总结了下putty与screen 相关的快捷键 目前常用的有如下几个 xff08 命令均在putty终端输入 xff0
  • Magento的不同版本(CE,EE,ECE)介绍

    Magento提供了三个不同的版本平台 xff0c 即Magento Community Edition xff08 CE xff09 社区版 xff0c Magento Enterprise Edition xff08 EE xff09
  • c语言初学,字母大小写转换

    这类题目主要通过ASCII码差值实现 xff0c A对应ASCII码十进制数字是65 xff0c a对应ASCII码十进制数字是97 xff0c 即大小写字母之间ASCII码差值为32 xff0c 想要将大写字母转换为小写字母可以将该字符A
  • matlab——subplot多子图共用一个colorbar,微调子图和colorbar位置

    用subplot命令画出多个图后 xff0c 需要让这些图共用一个colorbar 在这里与大家分享我的操作 xff0c 希望能帮助到有需要的人 备注 xff1a 从 R2019b 开始 xff0c 可以在分块图布局中显示共享颜色栏 xff
  • 远程连接服务器数据库报错:Host ‘XXXXXX’ is blocked because of many connection errors

    一 我遇到的问题描述 使用Navicat for mysql连接公司的服务器数据库 xff0c 报错 xff1a Host XXXXXX is blocked because of many connection errors 二 出现错误
  • android中MediaCodec硬编码中关键帧间隔时间设置问题

    在MediaCodec硬编码中设置 xff29 关键帧时间间隔 xff0c 在 xff21 xff30 xff29 中是这么设置的 mMediaCodec 61 MediaCodec createByCodecName debugger g
  • python3 网络编程问题——虚拟机centos7上运行tcp服务器,在主机win10上使用网络调试助手作为tcp客户端无法建立连接,提示1035错误:the socket is marked...

    前提 xff1a 主机和虚拟机都是在同一网段下 我的网络调试助手的连接结果如下图 xff1a 注意 红框中的提示 xff0c 连接超时的结果可能是由于以下两种可能的情况导致的 xff1a 1 服务器端口未开启监听 2 路由项被防火墙拦截 对
  • DSSM pytorch实现

    之前在网上找到了一个文本匹配实现仓库 xff0c 但是没有提供DSSM的代码 xff0c 我就根据那个代码实现以下DSSM 数据集采用的是蚂蚁金服的数据集 也参考过别人的代码 xff0c 但是总感觉怪怪的 xff0c DSSM原文中 xff