【深度学习】Attention提速9倍!FlashAttention燃爆显存,Transformer上下文长度史诗级提升...

2023-11-11

转载自 | 新智元

  继超快且省内存的注意力算法FlashAttention爆火后,升级版的2代来了。

FlashAttention-2是一种从头编写的算法,可以加快注意力并减少其内存占用,且没有任何近似值。

比起第一代,FlashAttention-2速度提升了2倍。

甚至,相较于PyTorch的标准注意力,其运行速度最高可达9倍。

c71ca931d8f6a7d641510b793b0e5374.png

一年前,StanfordAILab博士Tri Dao发布了FlashAttention,让注意力快了2到4倍,如今,FlashAttention已经被许多企业和研究室采用,广泛应用于大多数LLM库。

如今,随着长文档查询、编写故事等新用例的需要,大语言模型的上下文以前比过去变长了许多——GPT-4的上下文长度是32k,MosaicML的MPT上下文长度是65k,Anthropic的Claude上下文长度是100k。

但是,扩大Transformer的上下文长度是一项极大的挑战,因为作为其核心的注意力层的运行时间和内存要求,是输入序列长度的二次方。

Tri Dao一直在研究FlashAttention-2,它比v1快2倍,比标准的注意力快5到9倍,在A100上已经达到了225 TFLOP/s的训练速度!

291a980fa22a9a0d3fafd53f2339bf7f.png

论文地址:https://tridao.me/publications/flash2/flash2.pdf

项目地址:https://github.com/Dao-AILab/flash-attention

FlashAttention-2:更好的算法、并行性和工作分区

端到端训练GPT模型,速度高达225 TFLOP/s

虽说FlashAttention在发布时就已经比优化的基线快了2-4倍,但还是有相当大的进步空间。

比方说,FlashAttention仍然不如优化矩阵乘法(GEMM)运算快,仅能达到理论最大FLOPs/s的25-40%(例如,在A100 GPU上的速度可达124 TFLOPs/s)。

e2ec03458c0549774809b49b13ef1d6b.png

GEMM如何用于卷积

在过去的几个月里,研究人员一直在开发FlashAttention-2,它的性能指标比第一代更强。

研究人员表示,2代相当于完全从头重写,使用英伟达的CUTLASS 3.x及其核心库CuTe。从速度上看,FlashAttention-2比之前的版本快了2倍,在A100 GPU上的速度可达230 TFLOPs/s。

当使用端到端来训练GPT之类的语言模型时,研究人员的训练速度高达225 TFLOPs/s(模型的FLOP利用率为72%)。

对注意力计算重新排序

我们知道,FlashAttention是一种对注意力计算进行重新排序的算法,利用平铺、重新计算来显著加快计算速度,并将序列长度的内存使用量从二次减少到线性。

54b5db4aef8f1bb0f27cac6079513963.png

研究人员将输入块从HBM(GPU内存)加载到SRAM(快速缓存),并对该模块执行注意,更新HBM中的输出。

由于没有将大型中间注意力矩阵写入HBM,内存的读/写量也跟着减少,进而带来了2-4倍的执行时间加速。

下图是FlashAttention的前向传递图:通过平铺和softmax重新缩放,研究人员人员按模块进行操作,避免从HBM读取或是写入,同时获得正确输出,无需近似。

a9d257e813da8dfa6213cd7a2e21bed8.png

然而,FlashAttention仍然存在一些低效率的问题,这是由于不同线程块之间的工作划分并不理想,以及GPU上的warp——导致低占用率或不必要的共享内存读写。

更少的non-matmul FLOP(非矩阵乘法浮点计算数)

研究人员通过调整FlashAttention的算法来减少non-matmul FLOP的次数。这非常重要,因为现代GPU有专门的计算单元(比如英伟达GPU上的张量核心),这就使得matmul的速度更快。

例如,A100 GPU FP16/BF16 matmul的最大理论吞吐量为312 TFLOPs/s,但non-matmul FP32的理论吞吐量仅为 19.5 TFLOPs/s。

另外,每个非matmul FLOP比matmul FLOP要贵16倍。

所以为了保持高吞吐量,研究人员希望在matmul FLOP上花尽可能多的时间。

研究人员还重新编写了FlashAttention中使用的在线softmax技巧,以减少重新缩放操作的数量,以及边界检查和因果掩码操作,而无需更改输出。

更好的并行性

FlashAttention v1在批大小和部数量上进行并行化处理。研究人员使用1个线程块来处理一个注意力头,共有 (batch_size * head number) 个线程块。

b30fe42e379a078c3da4e8479f256847.png

在前向处理(左图)中,研究者将Worker(线程块)并行化,每个Worker负责处理注意力矩阵的一个行块。在后向处理过程中(右图),每个Worker处理注意力矩阵的一个列块

每个线程块都在流式多处理器 (SM)运行,例如,A100 GPU上有108个这样的处理器。当这个数字很大(比如 ≥80)时,这种调度是有效的,因为在这种情况下,可以有效地使用GPU上几乎所有的计算资源。

在长序列的情况下(通常意味着更小批或更少的头),为了更好地利用GPU上的多处理器,研究人员在序列长度的维度上另外进行了并行化,使得该机制获得了显著加速。

更好的工作分区

即使在每个线程块内,研究人员也必须决定如何在不同的warp(线程束)之间划分工作(一组32个线程一起工作)。研究人员通常在每个线程块使用4或8个warp,分区方案如下图所示。

研究人员在FlashAttention-2中改进了这种分区,减少了不同warp之间的同步和通信量,从而减少共享内存读/写。

2482f207c16cbf1c6939498f6d9dbb1d.png

对于每个块,FlashAttention将K和V分割到4个warp上,同时保持Q可被所有warp访问。这称为「sliced-K」方案。

然而,这样做的效率并不高,因为所有warp都需要将其中间结果写入共享内存,进行同步,然后再将中间结果相加。

而这些共享内存读/写会减慢FlashAttention中的前向传播速度。

在FlashAttention-2中,研究人员将Q拆分为4个warp,同时保持所有warp都可以访问K和V。

在每个warp执行矩阵乘法得到Q K^T的一个切片后,它们只需与共享的V切片相乘,即可得到相应的输出切片。

这样一来,warp之间就不再需要通信。共享内存读写的减少就可以提高速度。

新功能:头的维度高达256,多查询注意力

FlashAttention仅支持最大128的头的维度,虽说适用于大多数模型,但还是有一些模型被排除在外。

FlashAttention-2现在支持256的头的维度,这意味着GPT-J、CodeGen、CodeGen2以及Stable Diffusion 1.x等模型都可以使用FlashAttention-2来获得加速和节省内存。

v2还支持多查询注意力(MQA)以及分组查询注意力(GQA)。

9c8dec98b13dcab53bfa07e248e5efb5.png

GQA为每组查询头共享单个key和value的头,在多头和多查询注意之间进行插值

这些都是注意力的变体,其中多个查询头会指向key和value的同一个头,以减少推理过程中KV缓存的大小,并可以显著提高推理的吞吐量。

注意力基准

研究人员人员在A100 80GB SXM4 GPU 上测量不同设置(有无因果掩码、头的维度是64或128)下不同注意力方法的运行时间。

b6b07cce0cd75d79ffa6270ba016fc37.png

研究人员发现FlashAttention-2比第一代快大约2倍(包括在xformers库和Triton中的其他实现)。

与PyTorch中的标准注意力实现相比,FlashAttention-2的速度最高可达其9倍。

24cb777eb5bb60e59bcfa04a1a81ea31.png

A100 GPU上的前向+后向速度

只需在H100 GPU上运行相同的实现(不需要使用特殊指令来利用TMA和第四代Tensor Core等新硬件功能),研究人员就可以获得高达335 TFLOPs/s的速度。

97a9341d4a4cce71948feb54e11a1d22.png

H100 GPU上的前向+后向速度

当用于端到端训练GPT类模型时,FlashAttention-2能在A100 GPU上实现高达225TFLOPs/s的速度(模型FLOPs利用率为72%)。

与已经非常优化的FlashAttention模型相比,端到端的加速进一步提高了1.3倍。

55e580b56a0068c0f895f01360b39da7.png

未来的工作

速度上快2倍,意味着研究人员可以用与之前训练8k上下文模型相同的成本,来训练16k上下文长度的模型。这些模型可以理解长篇书籍和报告、高分辨率图像、音频和视频。

同时,FlashAttention-2还将加速现有模型的训练、微调和推理。

在不久的将来,研究人员还计划扩大合作,使FlashAttention广泛适用于不同类型的设备(例如H100 GPU、AMD GPU)以及新的数据类型(例如fp8)。

下一步,研究人员计划针对H100 GPU进一步优化FlashAttention-2,以使用新的硬件功能(TMA、第四代Tensor Core、fp8等等)。

将FlashAttention-2中的低级优化与高级算法更改(例如局部、扩张、块稀疏注意力)相结合,可以让研究人员用更长的上下文来训练AI模型。

研究人员也很高兴与编译器研究人员合作,使这些优化技术更好地应用于编程。

作者介绍

Tri Dao曾在斯坦福大学获得了计算机博士学位,导师是Christopher Ré和Stefano Ermon。

根据主页介绍,他将从2024年9月开始,任职普林斯顿大学计算机科学助理教授。

e173b2188ce829985b89b526022579e8.png

Tri Dao的研究兴趣在于机器学习和系统,重点关注高效训练和长期环境:

- 高效Transformer训练和推理 - 远程记忆的序列模型 - 紧凑型深度学习模型的结构化稀疏性。

值得一提的是,Tri Dao今天正式成为生成式AI初创公司Together AI的首席科学家。

a46bf87702da89c7884ca820cc3ff10b.png

参考资料:

https://princeton-nlp.github.io/flash-atttention-2/

 
 

1d47dce4fe64c8661e65947191e4779b.jpeg

 
 
 
 
 
 
 
 
往期精彩回顾




适合初学者入门人工智能的路线及资料下载(图文+视频)机器学习入门系列下载机器学习及深度学习笔记等资料打印《统计学习方法》的代码复现专辑机器学习交流qq群955171419,加入微信群请扫码
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【深度学习】Attention提速9倍!FlashAttention燃爆显存,Transformer上下文长度史诗级提升... 的相关文章

  • 深度好文:最全的大模型 RAG 技术概览

    本文是对检索增强生成 Retrieval Augmented Generation RAG 技术和算法的全面研究 对各种方法进行了系统性的梳理 涉及了 RAG 流程中的数据拆分 向量化 查询重写 查询路由等等 在做 RAG 的小伙伴一定知道
  • 比尔盖茨与萨姆.奥尔特曼的对话及感想

    谈话内容 比尔 盖茨 嘿 萨姆 萨姆 奥尔特曼 嘿 比尔 比尔 盖茨 你好吗 萨姆 奥尔特曼 哦 天哪 这真的太疯狂了 我还好 这是一个非常激动人心的时期 比尔 盖茨 团队情况怎么样 萨姆 奥尔特曼 我想 你知道很多人都注意到了这样一个事实
  • 用通俗易懂的方式讲解:图解 Transformer 架构

    文章目录 用通俗易懂方式讲解系列 1 导语 2 正文开始 现在我们开始 编码 从宏观视角看自注意力机制 从微观视角看自注意力机制 通过矩阵运算实现自注意力机制
  • 用通俗易懂的方式讲解:如何用大语言模型构建一个知识问答系统

    传统搜索系统基于关键字匹配 在面向 游戏攻略 技术图谱 知识库等业务场景时 缺少对用户问题理解和答案二次处理能力 本文探索使用大语言模型 Large Language Model LLM 通过其对自然语言理解和生成的能力 揣摩用户意图 并对
  • 【路径规划】基于A*算法路径规划研究(Matlab代码实现)

    欢迎来到本博客 博主优势 博客内容尽量做到思维缜密 逻辑清晰 为了方便读者 座右铭 行百里者 半于九十 本文目录如下 目录 1 概述 2 运行结果 3 参考文献 4 Matlab代码实现
  • 蒙特卡洛在发电系统中的应用(Matlab代码实现)

    欢迎来到本博客 博主优势 博客内容尽量做到思维缜密 逻辑清晰 为了方便读者 座右铭 行百里者 半于九十 本文目录如下 目录 1 概述 2 运行结果 3 参考文献 4 Matlab代码实现
  • 基于java的ssh医院在线挂号系统设计与实现

    基于java的ssh医院在线挂号系统设计与实现 I 引言 A 研究背景和动机 基于Java的SSH医院在线挂号系统设计与实现的研究背景和动机 随着信息技术的迅速发展和应用 医院在线挂号系统已成为医院管理的重要组成部分 传统的挂号方式存在许多
  • 让CHAT介绍下V2ray

    CHAT回复 V2Ray是一个网络工具 主要用于科学上网和保护用户的网络安全 它的名字源自Vmess Ray 光线 通过使用新的网络协议 为用户提供稳定且灵活的代理服务 下面是一些V2Ray的主要特性 1 多协议支持 V2Ray 提供了大量
  • 面对AI革新时,Soul App等社交应用的“出圈”解法是什么?

    2023年初 ChatGPT掀开海内外互联网 AI革新 的序幕 公众在惊讶于ChatGPT对于海量信息富有逻辑的整合归纳 帮助大家提升工作及学习效率之余 更为期待的莫过于有一天人工智能的 意识觉醒 十余年前由斯派克 琼斯 Spike Jon
  • 扬帆证券:三只松鼠去年扣非净利预增超1.4倍

    在 高端性价比 战略驱动下 三只松鼠 300783 重拾增势 1月15日晚间 三只松鼠发布成绩预告 预计2023年度净赢利为2亿元至2 2亿元 同比增加54 97 至70 47 扣非后净赢利为1亿元至1 1亿元 同比增速达146 9 至17
  • 毕业设计:基于卷积神经网络的验证码识别系统 机器视觉 人工智能

    目录 前言 设计思路 一 课题背景与意义 二 算法理论原理 2 1 字符分割算法 2 2 深度学习 三 检测的实现 3 1 数据集 3 2 实验环境搭建 3 3 实验及结果分析 最后 前言 大四是整个大学期间最忙碌的时光 一边要忙着备考或实
  • 性能大减80%,英伟达芯片在华“遇冷”,我方霸气回应:不强求

    中国这么大一块市场 谁看了不眼馋 在科技实力大于一切的今天 高端芯片的重要性不言而喻 作为半导体产业发展过程中不可或缺的一环 芯片技术也一直是我国技术发展的一大 心病 在美西方等国的联手压制下 我国芯片技术发展处处受阻 至今也未能在高端芯片
  • 强烈推荐收藏!LlamaIndex 官方发布高清大图,纵览高级 RAG技术

    近日 Llamaindex 官方博客重磅发布了一篇博文 A Cheat Sheet and Some Recipes For Building Advanced RAG 通过一张图给开发者总结了当下主流的高级RAG技术 帮助应对复杂的生产场
  • 如何快速申请GPT账号?

    详情点击链接 如何快速申请GPT账号 一OpenAI 1 最新大模型GPT 4 Turbo 2 最新发布的高级数据分析 AI画图 图像识别 文档API 3 GPT Store 4 从0到1创建自己的GPT应用 5 模型Gemini以及大模型
  • 如何用GPT进行论文润色与改写?

    详情点击链接 如何用GPT GPT4进行论文润色与改写 一OpenAI 1 最新大模型GPT 4 Turbo 2 最新发布的高级数据分析 AI画图 图像识别 文档API 3 GPT Store 4 从0到1创建自己的GPT应用 5 模型Ge
  • 2023最新pytorch安装(超详细版)

    前言 一 判断是否有Nvidia 英伟达显卡 二 CPU版 2 1 安装Anaconda 2 2 创建虚拟环境 2 3安装pytorch 2 4 验证pytorch是否安装成功 三 GPU版 3 1 安装Anaconda 3 2 创建虚拟环
  • AI在保护环境、应对气候变化中的作用

    对于AI生命周期数据领域的全球领导者而言 暂时搁置我们惯常的AI见解和AI生命周期数据内容产出 来认识诸如世界地球日这样的自然环境类活动日 似乎是个奇怪的事情 我们想要知道 数据是否真的会影响我们的地球环境 简而言之 是 确实如此 但作为一
  • 基于节点电价的电网对电动汽车接纳能力评估模型研究(Matlab代码实现)

    欢迎来到本博客 博主优势 博客内容尽量做到思维缜密 逻辑清晰 为了方便读者 座右铭 行百里者 半于九十 本文目录如下 目录 1 概述 2 运行结果 3 参考文献 4 Matlab代码 数据
  • 基于节点电价的电网对电动汽车接纳能力评估模型研究(Matlab代码实现)

    欢迎来到本博客 博主优势 博客内容尽量做到思维缜密 逻辑清晰 为了方便读者 座右铭 行百里者 半于九十 本文目录如下 目录 1 概述 2 运行结果 3 参考文献 4 Matlab代码 数据
  • 深度学习(5)--Keras实战

    一 Keras基础概念 Keras是深度学习中的一个神经网络框架 是一个高级神经网络API 用Python编写 可以在TensorFlow CNTK或Theano之上运行 Keras优点 1 允许简单快速的原型设计 用户友好性 模块化和可扩

随机推荐

  • presto集群安装

    presto集群安装 整合hive 张映 发表于 2019 11 07 分类目录 hadoop spark scala 标签 hive presto Presto是一个运行在多台服务器上的分布式系统 完整安装包括一个coordinator
  • 微信小程序中使用ECharts--折线图、柱状图、饼图等

    微信小程序开发者的反馈 表示他们强烈需要像 ECharts 这样的可视化工具 但是微信小程序是不支持 DOM 操作的 Canvas 接口也和浏览器不尽相同 因此 ECharts 团队和微信小程序官方团队合作 提供了 ECharts 的微信小
  • EF codeFirst Database.SetInitializer的四种选项

    public MyDbContext base name Default 不会创建数据库 生产环境建议用这个设置 表通过sql来创建或修改 Database SetInitializer
  • Redis基础知识(一):学习笔记

    文章目录 一 nosql和redis 二 在Centos 7 中用Docker安装并运行Redis 一 安装并运行 二 其他操作 三 redis数据库常见操作 三 基本数据类型 一 String字符串 二 list列表 三 set集合 四
  • 【ACM出版 + EI检索稳定】第六届高性能编译、计算和通信国际会议(HP3C 2022)

    第六届高性能编译 计算和通信国际会议 HP3C 2022年6月23 25日 中国吉林 2022年第六届高性能编译 计算和通信国际会议 HP3C 2022 将于2022年6月23 25日在中国吉林举行 本次会议由东北电力大学主办 东北电力大学
  • 再谈 Java中Runnable和Thread的区别

    在面试中老有面试官问这个问题 个人觉得这个问题问的没有技术 一个死记硬背就能回答的很好 但是更深的回答是什么了 那就是直接回答源码吧 thread类实现了runnable 接口 Runnable就是一个借口 只能我们去实现了才能用 对吧 不
  • js 日期格式 转换 yyyy-MM-dd

    之前js获取到数据库的Date 总是显示成 后来知道是js 的Date 格式不能直接转换常用的yyyy MM dd 的格式 Date prototype yyyymmdd function var yyyy this getFullYear
  • H5的本地存储(localStorage)和会话(sessionStorage)还有cookie的使用与注意事项

    目录 本地存储使用的时候注意 js代码如下 cookie使用的时候注意 open in browser与open with live server的区别 最后是总代码如下 本地存储使用的时候注意 js代码如下 本地存储的使用 localSt
  • Linux系统下运行jar文件,提示:No main manifest attribute, in XXX.jar

    在Linux系统下执行java jar XXX jar com HelloWorld往往会提示 No main manifest attribute in XXX jar 原因如下 正常情况下 java打包成jar包需要在MANIFEST
  • 数据化看板(vue+Element+Echarts)

    效果图 由于一些原因无法上传整体的代码 只能放一些部分页面的代码 Element ui Echarts Echarts实例 分割线 首页代码
  • node_modules安装及node-sass使用

    node modules安装及node sass使用 终端进入项目文件夹运行以下命令 一 安装node modules 1 输入npm init 一直回车 会生成package json文件 2 npm i D vue 然后就成功node
  • 第十二届蓝桥杯嵌入式总结及分享

    蓝桥杯嵌入式总结以及分享 这是我第一次写CSDN的文章 如果有什么写的不好的 还请各位见谅 不喜勿喷 谢谢 经过这两个星期时间准备省赛 两个星期准备国赛 最终取得了国二的成绩 个人认为这个比赛就是耗时长 我对源码不够熟悉 所以导致时间基本都
  • 走得最慢的人,只要他不丧失目标,也比漫无目的地徘徊的人走得快。

    走得最慢的人 只要他不丧失目标 也比漫无目的地徘徊的人走得快 莱辛 有着坚定明确的目标 且知道如何做能达成目标 没有追求 未来迷茫 或许大家都想当第一种人 但可能在不知不觉中就成了第二种人 自己也不知道 或是因为目标太大 难以后继 最终失却
  • 1049 数列的片段和

    给定一个正数数列 我们可以从中截取任意的连续的几个数 称为片段 例如 给定数列 0 1 0 2 0 3 0 4 我们有 0 1 0 1 0 2 0 1 0 2 0 3 0 1 0 2 0 3 0 4 0 2 0 2 0 3 0 2 0 3
  • Xshell缺少.dll文件解决方案

    背景 安装Xshell时报错 由于找不到MSVCR110 dll文件 无法继续执行代码 重新安装程序可能会解决此问题 这种情况是缺少运行库 解决方案 微软常用运行库 https pan baidu com s 1kgJRky3cicOMxR
  • 计算机打不开excel表格,excel表格打不开怎么办?excel表格打不开的原因及解决方法...

    今天有网友反映 他昨天做的Excel表格打不开了 但其他Excel表格是可以打开的 非常郁闷 那么Excel表格打不开是什么原因呢 Excel表格打不开怎么办呢 下面脚本之家小编就来说说有关造成Excel表格打不开的几种原因及解决办法 一
  • 基于Loung Attention+LSTM的机器翻译模型

    目录 需要掌握的基础知识 1 Encoder Decoder架构 2 LSTM模型原理 3 Attention机制 基于Loung Attention LSTM的机器翻译模型 模型 数据 训练 基于Bahdanau Attention LS
  • 大数据安全治理平台建设方案

    近年来 随着大数据应用的普及 在新基建 智慧城市 云端应用等大背景趋势下 给我们日常生活便来了很多方便 同时也派生出更多网络安全风险 如企业数据泄露 欺诈 数据违规使用 个人隐私泄露以及企业内部各种威胁和潜在风险 数据是宝贵的资源和财富 当
  • LCD操作原理

    一 LCD原理介绍 LCD内部内部结构 1 lcd由Framebuffer lcd屏幕 信号线 电子枪 lcd控制器组成 2 Framebuffer提供显示数据 lcd屏幕显示 信号线传输Frambuffer中的数据和lcd控制器发出的信号
  • 【深度学习】Attention提速9倍!FlashAttention燃爆显存,Transformer上下文长度史诗级提升...

    转载自 新智元 继超快且省内存的注意力算法FlashAttention爆火后 升级版的2代来了 FlashAttention 2是一种从头编写的算法 可以加快注意力并减少其内存占用 且没有任何近似值 比起第一代 FlashAttention