【深度学习】去掉softmax后Transformer会更好吗?复旦&华为诺亚提出SOFT:轻松搞定线性近似...

2023-10-30

作者丨happy  编辑丨极市平台

导读

 

本文介绍了复旦大学&华为诺亚提出的一种新颖的softmax-free的Transformer—SOFT。所提SOFT显著改善了现有ViT方案的计算效率,更为关键的是:SOFT的线性复杂度可以允许更长的token序列,进而取得更佳的精度-复杂度均衡。

1121dcb7a061554840bb34e746bf6461.png

论文链接:https://arxiv.org/pdf/2110.11945.pdf

代码链接:https://github.com/fudan-zvg/SOFT

项目链接:https://fudan-zvg.github.io/SOFT/

本文是复旦大学&华为诺亚关于Transformer中自注意力机制复杂度的深度思考,首次提出了一种新颖的softmax-free 的Transformer 。本文从softmax self-attention局限性出发,分析了其存在的挑战;然后由此提出了线性复杂度的SOFT;再针对线性SOFT存在的训练问题,提出了一种具有理论保证的近似方案。所提SOFT在ImageNet分类任务上取得了比已有CNN、Transformer更佳的精度-复杂度均衡。

Abstract

ViT通过图像块序列化+自注意力机制将不同CV任务性能往前推了一把。然而,自注意力机制会带来更高的计算复杂度与内存占用。在NLP领域已有不同的方案尝试采用线性复杂度对自注意力进行近似。然而,本文的深入分析表明:NLP中的近似方案在CV中缺乏理论支撑或者无效。

我们进一步分析了其局限性根因:softmax self-attention 。具体来说,传统自注意力通过计算token之间的点乘并归一化得到自注意力。softmax操作会对后续的线性近似带来极大挑战。基于该发现,本文首次提出了SOFT(softmax-free transformer )。

为移除自注意力中的softmax,我们采用高斯核函数替代点乘相似性且无需进一步的归一化。这就使得自注意力矩阵可以通过低秩矩阵分析近似 。近似的鲁棒性可以通过计算其MP逆(Moore-Penrose Inverse)得到。

ImageNet数据集上的实验结果表明:所提SOFT显著改善了现有ViT方案的计算效率 。更为关键的是:SOFT的线性复杂度可以允许更长的token序列,进而取得更佳的精度-复杂度均衡。

Contributation

本文的贡献主要包含以下几点:

  • 提出一种新颖的线性空间、时间复杂度softmax-free Transformer

  • 所提注意力矩阵近似可以通过具有理论保证的矩阵分解算法 计算得到;

  • 所提SOFT在ImageNet图像分类任务上取得了比其他ViT方案更佳的精度-复杂度均衡 (见下图a)。

9256067e70ca0052f00a7002bcdb0cc5.png

Method

Softmax-free self-attention formulation

c301262996904c72d21e4880073a9ace.png

上图给出了本文所提SOFT架构示意图。我们首先来看一下该注意力模块的设计。给定包含n个token的输入序列,自注意力旨在挖掘所有token对之间的相关性

具体来说,X首先线性投影为三个维的query、key以及values:

自注意力可以表示为如下广义形式:

自注意力的关键函数包含一个非线性函数与一个相关函数。自注意力的常规配置定义如下:

虽然该softmax自注意力已成为首选且很少受到质疑,但是它并不适合进行线性化。为构建线性自注意力设计,我们引入了一种sfotmax-free自注意力函数:通过高斯核替换点乘操作。定义如下:

为保持注意力矩阵的对称性,我们设置投影矩阵相同,即。所提自注意力矩阵定义如下:

为描述的简单性,我们定义为矩阵形式:。所提自注意力矩阵S具有三个重要属性:

  • 对称性

  • 所有元素均在[0,1]范围内;

  • 所有对角元素具有最大值1。

我们发现:当采用无线性化的核自注意力矩阵时,transformer的训练难以收敛 。这也就解释了:为何softmax自注意力在transformer中如此流行。

Low-rank regularization via matrix decomposition with linear complxity

为解决收敛于二次复杂度问题,我们利用了矩阵分解作为带低秩正则的统一解,这就使得模型复杂度大幅下降,且无需计算全部的自注意力矩阵。

作出上述选择因为在于:S为半正定矩阵,且无需后接归一化。我们将S表示为块矩阵形式:

其中,。通过上述分解,注意力矩阵可以近似表示为:

其中,表示A的MP逆。更多关于MP逆的信息建议查看原文,这里略过。

在上述公式,A和B是S通过随机采样m个token得到的子矩阵,可表示为(我们将其称之为bottleneck token )。然而,我们发现:随机采样对于m非常敏感。因此,我们通过利用结构先验探索两种额外的方案:

  • 采用一个核尺寸为k、stride为k的卷积学习;

  • 采用一个核尺寸为k、stride为k的均值池化生成。

通过实验对比发现:卷积层学习 具有更好的精度 。由于K与Q相等,因此。给定m个token,我们计算A和P:

最终,我们得到了SOFT的正则化后的自注意力矩阵:

03694eaa8cc38476f53526538fa2c72e.png

上图Algorithm1给出所提SOFT流程,它涉及到了MP逆计算。一种精确且常用的计算MP逆的方法是SVD,然而SVD对于GPU训练不友好。为解决该问题,我们采用了Newton-Raphson方法,见上图Algorithm2:一种迭代算法。与此同时,作者还给出了上述迭代可以最终收敛到MP逆的证明。对该证明感兴趣的同时强烈建议查看原文公式,哈哈。

Instantiations

上面主要聚焦于softmax-free self-attention 模块的介绍,接下来我们将介绍如何利用SOFT模块构建Transformer模型。我们以图像分类任务为切入点,以PVT作为基础架构并引入所提SOFT模块构建最终的SOFT模型,同时还在stem部分进行了微小改动。下表给出了本文所提方案在不同容量大小下的配置信息。

0e92faa9c68879b257d71fe18745e703.png

Experiments

19b429dbb4c67428a0fe21e0b8954649.png

上表对比了所提方案与现有线性Transformer模型的性能,从中可以看到:

  • 相比基线Transformer,线性Transformer能够大幅降低内存占用与FLOPs,同时保持相当参数量;

  • 所提SOFT在所有线性方案中取得了最佳分类精度;

  • 所提SOFT与其他线性方案的推理速度相当,训练速度稍慢。

776d08231aab1990d945ab96a9f1e060.png

上图给出了不同方案的序列长度与内存占用之间的关系,从中可以看到:所提SOFT确实具有线性复杂度的内存占用

f4b1f5d51dfd7e269adbae5f3dce7424.png

上表给出了所提方案与其他CNN、Transformer的性能对比,从中可以看到:

  • 总体来说,ViT及其变种具有比CNN更高的分类精度;

  • 相比ViT、DeiT等Transformer方法以及RegNet等CNN方法,所提SOFT取得了最佳性能;

  • 相比PVT,所提方案具有更高的分类精度,直接验证了所提SOFT模块的有效性;

  • 相比Twins与Swin,所提SOFT具有相当的精度,甚至更优性能。

4fc05596a926ebb54c35b0dae7eec771.png

此外,作者还在NLP任务上进行了对比,见上表,很明显:SOFT又一次胜出

312ef073b0a52cf21bc40678fd1b145c.png

 
 
 
 
 
 
 
 
往期精彩回顾




适合初学者入门人工智能的路线及资料下载机器学习及深度学习笔记等资料打印机器学习在线手册深度学习笔记专辑《统计学习方法》的代码复现专辑
AI基础下载黄海广老师《机器学习课程》视频课黄海广老师《机器学习课程》711页完整版课件

本站qq群554839127,加入微信群请扫码:

5b879a9833362af22748bbe8312b48a2.png

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【深度学习】去掉softmax后Transformer会更好吗?复旦&华为诺亚提出SOFT:轻松搞定线性近似... 的相关文章

  • (discord.py) 尝试更改成员角色时,“用户”对象没有属性“角色”

    因此 我正在尝试编写一个机器人 让某人在命令中指定的主持人指定的一段时间内暂停角色 我知道该变量称为 小时 即使它目前以秒为单位 我稍后会解决这个问题 基本上 它是由主持人在消息 暂停 personmention numberofhours
  • 如何在刻度标签和轴之间添加空间

    我已成功增加刻度标签的字体 但现在它们距离轴太近了 我想在刻度标签和轴之间添加一点呼吸空间 如果您不想全局更改间距 通过编辑 rcParams 并且想要更简洁的方法 请尝试以下操作 ax tick params axis both whic
  • Python、Tkinter、更改标签颜色

    有没有一种简单的方法来更改按钮中文本的颜色 I use button text input text here 更改按下后按钮文本的内容 是否存在类似的颜色变化 button color red Use the foreground设置按钮
  • 如何使用包含代码的“asyncio.sleep()”进行单元测试?

    我在编写 asyncio sleep 包含的单元测试时遇到问题 我要等待实际的睡眠时间吗 I used freezegun到嘲笑时间 当我尝试使用普通可调用对象运行测试时 这个库非常有用 但我找不到运行包含 asyncio sleep 的测
  • 如何等到 Excel 计算公式后再继续 win32com

    我有一个 win32com Python 脚本 它将多个 Excel 文件合并到电子表格中并将其另存为 PDF 现在的工作原理是输出几乎都是 NAME 因为文件是在计算 Excel 文件内容之前输出的 这可能需要一分钟 如何强制工作簿计算值
  • Spark的distinct()函数是否仅对每个分区中的不同元组进行洗牌

    据我了解 distinct 哈希分区 RDD 来识别唯一键 但它是否针对仅移动每个分区的不同元组进行了优化 想象一个具有以下分区的 RDD 1 2 2 1 4 2 2 1 3 3 5 4 5 5 5 在此 RDD 上的不同键上 所有重复键
  • 从 scikit-learn 导入 make_blobs [重复]

    这个问题在这里已经有答案了 我收到下一个警告 D Programming Python ML venv lib site packages sklearn utils deprecation py 77 DeprecationWarning
  • 在 NumPy 中获取 ndarray 的索引和值

    我有一个 ndarrayA任意维数N 我想创建一个数组B元组 数组或列表 其中第一个N每个元组中的元素是索引 最后一个元素是该索引的值A 例如 A array 1 2 3 4 5 6 Then B 0 0 1 0 1 2 0 2 3 1 0
  • IRichBolt 在storm-1.0.0 和 pyleus-0.3.0 上运行拓扑时出错

    我正在运行风暴拓扑 pyleus verbose local xyz topology jar using storm 1 0 0 pyleus 0 3 0 centos 6 6并得到错误 线程 main java lang NoClass
  • python pandas 中的双端队列

    我正在使用Python的deque 实现一个简单的循环缓冲区 from collections import deque import numpy as np test sequence np array range 100 2 resha
  • Python:字符串不会转换为浮点数[重复]

    这个问题在这里已经有答案了 我几个小时前写了这个程序 while True print What would you like me to double line raw input gt if line done break else f
  • Pandas Dataframe 中 bool 值的条件前向填充

    问题 如何转发 fill boolTruepandas 数据框中的值 如果是当天的第一个条目 True 到一天结束时 请参阅以下示例和所需的输出 Data import pandas as pd import numpy as np df
  • 如何将 numpy.matrix 提高到非整数幂?

    The 运算符为numpy matrix不支持非整数幂 gt gt gt m matrix 1 0 0 5 0 5 gt gt gt m 2 5 TypeError exponent must be an integer 我想要的是 oct
  • ExpectedFailure 被计为错误而不是通过

    我在用着expectedFailure因为有一个我想记录的错误 我现在无法修复 但想将来再回来解决 我的理解expectedFailure是它会将测试计为通过 但在摘要中表示预期失败的数量为 x 类似于它如何处理跳过的 tets 但是 当我
  • Python - 按月对日期进行分组

    这是一个简单的问题 起初我认为很简单而忽略了它 一个小时过去了 我不太确定 所以 我有一个Python列表datetime对象 我想用图表来表示它们 x 值是年份和月份 y 值是此列表中本月发生的日期对象的数量 也许一个例子可以更好地证明这
  • 通过数据框与函数进行交互

    如果我有这样的日期框架 氮 EG 00 04 NEG 04 08 NEG 08 12 NEG 12 16 NEG 16 20 NEG 20 24 datum von 2017 10 12 21 69 15 36 0 87 1 42 0 76
  • 设置 torch.gather(...) 调用的结果

    我有一个形状为 n x m 的 2D pytorch 张量 我想使用索引列表来索引第二个维度 可以使用 torch gather 完成 然后然后还设置新值到索引的结果 Example data torch tensor 0 1 2 3 4
  • glpk.LPX 向后兼容性?

    较新版本的glpk没有LPXapi 旧包需要它 我如何使用旧包 例如COBRA http opencobra sourceforge net openCOBRA Welcome html 与较新版本的glpk 注意COBRA适用于 MATL
  • 循环标记时出现“ValueError:无法识别的标记样式 -d”

    我正在尝试编码pyplot允许不同标记样式的绘图 这些图是循环生成的 标记是从列表中选取的 为了演示目的 我还提供了一个颜色列表 版本是Python 2 7 9 IPython 3 0 0 matplotlib 1 4 3 这是一个简单的代
  • PyAudio ErrNo 输入溢出 -9981

    我遇到了与用户相同的错误 Python 使用 Pyaudio 以 16000Hz 录制音频时出错 https stackoverflow com questions 12994981 python error audio recording

随机推荐

  • Ubuntu16.04下搭建LAMP环境

    Ubuntu16 04下搭建LAMP环境 Ubuntu16 04下搭建LAMP环境 1 安装 Apache2 2 重启 apache2 3 测试apache2是否安装成功 4 安装php7 5 测试php是否安装成功 6 安装mysql数据
  • 序列化与反序列化之Flatbuffers(一):初步使用

    序列化与反序列化之Flatbuffers 一 初步使用 一 前言 在MNN中 一个训练好的静态模型是经过Flatbuffers序列化之后保存在硬盘中的 这带来两个问题 1 为什么模型信息要序列化不能直接保存 2 其他框架如caffe和onn
  • 深度学习在目标视觉检测中的应用进展与展望

    前言 文章综述了深度学习在目标视觉检测中的应用进展与展望 首先对目标视觉检测的基本流程进行总结 并介绍了目标视觉检测研究常用的公共数据集 然后重点介绍了目前发展迅猛的深度学习方法在目标视觉检测中的最新应用进展 最后讨论了深度学习方法应用于目
  • ORAN专题系列-0: O-RAN快速索引

    专题一 O RAN的快速概述 ORAN专题系列 1 什么是开放无线接入网O RAN ORAN专题系列 1 什么是开放无线接入网O RAN 文火冰糖的硅基工坊的博客 CSDN博客 什么是oran ORAN专题系列 2 O RAN的系统架构 O
  • C和C++安全编码笔记:动态内存管理

    4 1 C内存管理 C标准内存管理函数 1 malloc size t size 分配size个字节 并返回一个指向分配的内存的指针 分配的内存未被初始化为一个已知值 2 aligned alloc size t alignment siz
  • Spring Aop自定义注解用在Controller层

    前提项目用的框架是SpringMVC 切面类 Aspect Component 把这个注掉是为了不让Spring中扫描 应该让SpringMVC扫描 public class SysLogAop Pointcut annotation co
  • 图像识别毕业设计 opencv实现植物识别算法系统 - python 深度学习

    文章目录 0 前言 2 相关技术 2 1 VGG Net模型 2 2 VGG Net在植物识别的优势 1 卷积核 池化核大小固定 2 特征提取更全面 3 网络训练误差收敛速度较快 3 VGG Net的搭建 3 1 Tornado简介 1 优
  • Maven项目的jdk版本修改

    Maven项目的jdk版本修改 修改的办法有以下三种 一 选择项目 gt 右键 gt build path Configure build path 选择旧的jre 1 5 gt remove删除 gt add Library 添加新的jr
  • Activity 工作流引擎

    Activiti工作流引擎使用详解 http blog csdn net m0 37327416 article details 71743368 Activity用户手册 http www mossle com docs activiti
  • SpringBoot笔记:SpringBoot 集成 Dataway(一)

    文章目录 1 什么是 Dataway 2 主打场景 3 技术架构 4 整合SpringBoot 4 1 maven 依赖 4 2 初始化脚本 4 3 整合 SpringBoot 5 Dataway 接口管理 6 Mybatis 语法支持 7
  • Kafka3.0.0版本——文件清理策略

    目录 一 文件清理策略 1 1 文件清理策略的概述 1 2 文件清理策略的官方文档 1 3 日志超过了设置的时间如何处理 1 3 1 delete日志删除 将过期数据删除 1 3 2 compact日志压缩 一 文件清理策略 1 1 文件清
  • 【Pytorch】利用Pytorch+GRU实现情感分类(附源码)

    在这个实验中 数据的预处理过程以及网络的初始化及模型的训练等过程同前文 利用Pytorch LSTM实现中文新闻分类 具体这里就不再重复解释了 如果有读者在对数据集的预处理过程中有疑问 请参考我的其他博客 里面对这些方法均有我的一些个人体会
  • 稀缺原理

    不管是什么东西 只要你晓得会失去它 自然就会爱上它了 稀缺原理 机会越少见 价值似乎就越高 对失去某种东西的恐惧似乎比对获得同一物品的渴望 更能激发人们的行动力 稀缺原理的力量来源 1 基本可以根据获得一样东西的难易程度 迅速 准确的判断它
  • plsql developer 终极注册码

    product code 4v6hkjs66vc944tp74p3e7t4gs6duq4m4szbf3t38wq2 serial number 1412970386 password xs374ca 手机扫一扫 欢迎关注公众号 关注程序员成
  • python:从键盘输入一个字符,判别它是否大写字母,如果是,将它转换成小写字母;如果不是,则不转换。然后输出最后得到的字符。

    letter str input 请输入一个字母 if letter lt Z 凡是小于大写Z的都要转换成小写 print 转换小写字母为 letter lower lower 方法可以把大写转换成小写 else print 转换大写字母为
  • 网络协程编程

    一 背景 为什么需要网络协程 1 协程 纤程并不是一个新概念2 大并发 高性能对于服务端的高要求3 移动设备的快速增长加大了服务端大并发压力4 Go 语言的兴起将协程带到了一个新的高度支持协程的编程语言 1 Go 语言 非常容易支持大并发
  • Eigen入门之密集矩阵 1 -- 类Matrix介绍

    简介 本篇介绍Eigen中的Matrix类 在Eigen中 矩阵和向量的类型都用Matrix来表示 向量是一种特殊的矩阵 其只有一行或者一列 Matrix构造 在Matrix h中 定义了Matrix类 其中的构造器包括如下的5个 可以看到
  • python爬虫可以做什么呢?

    1 收集数据 Python爬虫程序可用于收集数据 这是最直接和最常用的方法 由于爬虫程序是一个程序 程序运行得非常快 不会因为重复的事情而感到疲倦 因此使用爬虫程序获取大量数据变得非常简单 快速 2 数据储存 Python爬虫可以将从各个网
  • 【防攻世界】misc解题思路-学习笔记

    前言 靶场地址 防攻世界 一 Cat falg 丢进 kali 或者其他Linux系统直接 cat flag 二 MeowMeow可爱的小猫 这道题就很离谱 flag需要用010工具打开 拉到最后就可以看到文字样式 组起来就是 CatCTF
  • 【深度学习】去掉softmax后Transformer会更好吗?复旦&华为诺亚提出SOFT:轻松搞定线性近似...

    作者丨happy 编辑丨极市平台 导读 本文介绍了复旦大学 华为诺亚提出的一种新颖的softmax free的Transformer SOFT 所提SOFT显著改善了现有ViT方案的计算效率 更为关键的是 SOFT的线性复杂度可以允许更长的