论文:https://arxiv.org/abs/2010.11929
pytorch代码:https://github.com/lucidrains/vit-pytorch
不了解Transformer的,建议先看这篇:https://blog.csdn.net/czt_666/article/details/118113634
架构
如上图所示,ViT 的基本步骤:
-
图片切分为图块
- 所有图块作为输入,输入经过一个线性映射并将图块展平为一维向量(patch embedding)
- 嵌入可学习的类别
- 展平的图块嵌入位置(图片无空间信息)
- Transformer Encoder
- MLP Head
- 分类
Transformer Encoder
ViT 的Transformer Encoder(左图)和 Attention is all you need的Transformer编码器(右图)类似,使用了 Multi-Head Attention 和 MLP两种残差块,不同的是归一化层Norm前置了,还有就是MLP和前馈网络,不过二者差不多。
ViT 的Transformer Encoder的表达式如下:
其中,公式(1)为步骤1~4,稍后再做解释。
公式(2) 为残差块 Multi-Head Attention ,图中仅显示了一个编码器,我们称之为子编码器,
z
l
−
1
z_{l-1}
zl−1为上一个子编码器的输出,
LN
\text {LN}
LN为LayerNorm归一化层,MSA为Multi-Head self-attention,MSA为残差,加上
z
l
−
1
z_{l-1}
zl−1就是完整的残差结构。
公式(3) 为残差块 MLP, Multi-Head Attention的输出作为残差块 MLP的输入,MLP函数为感知机。
步骤1~4
1. 图片切分为图块
标准的transformer的输入是1维的token embedding。为了处理二维图像,
- 图像尺寸为
H
×
W
H\times W
H×W
- 图块的尺寸为
P
×
P
P\times P
P×P
- 图块的数量为
N
=
H
W
P
2
N=\frac{HW}{P^2}
N=P2HW
2. 线性映射和展平(patch embedding)
将图块展平,并使用可训练的线性投影映射到隐矢量D的大小,将此投影的输出称为patch embedding。
3. 嵌入可学习的类别
类似BERT的[class] token,作者为patch embedding序列
(
z
0
0
=
x
c
l
a
s
s
)
(z_0^0=x_{class})
(z00=xclass)预先准备了一个可学习的embedding。
4. 嵌入位置
位置embedding会添加到patch embedding中,以保留位置信息。作者使用标准的可学习1D位置embedding,因为作者没有观察到使用更高级的2D感知位置embedding可显着提高性能。embedding向量的结果序列用作编码器的输入。
具体表达式为公式(1)
后言
在计算机视觉中,卷积结构仍然占主导地位。 受NLP中Transformer扩展成功的启发,作者尝试将标准Transformer直接应用于图像,并进行最少的修改。为此,作者将图像拆分为小块,并提供这些小块的线性嵌入序列作为Transformer的输入。图像图块与NLP应用程序中的token(words)的处理方式相同,以监督方式对模型进行图像分类训练。
当在公共ImageNet-21k数据集或内部JFT-300M数据集上进行预训练时,ViT在多个图像识别基准上达到或超越了最新水平。特别是,最佳模型
- 在ImageNet上达到88.55%的精度
- 在ImageNet-ReaL上达到90.72%的精度
- 在CIFAR-100上达到94.55%的精度
- 在19个任务的VTAB上达到77.63%的精度。