综述
论文题目:《AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE》
会议时间:International Conference on Learning Representations, 2021 (ICLR, 2021)
论文地址:https://openreview.net/pdf?id=YicbFdNTTy
论文源码:https://github.com/lucidrains/vit-pytorch(非官方)
介绍
网络结构
图片引自:https://github.com/lucidrains/vit-pytorch
模型规格
常见规格:
Model |
Layers |
Hidden size D |
MLP size |
Heads |
Params |
ViT-Base |
12 |
768 |
3072 |
12 |
86M |
ViT-Large |
24 |
1024 |
4096 |
16 |
307M |
ViT-Huge |
32 |
1280 |
5120 |
16 |
632M |
另外,还会添加patch大小,例如:ViT-L/16表示使用
16
×
16
16\times16
16×16的patch大小切分图片。
源码实现
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool='cls', channels=3,
dim_head=64, dropout=0., emb_dropout=0.):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
# 位置编码
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
# 类别token(类似类别查询向量)
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img):
x = self.to_patch_embedding(img)
b, n, _ = x.shape
# 类别token先与特征合并(沿序列方向合并)
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b=b)
x = torch.cat((cls_tokens, x), dim=1)
# 加上位置编码,用于表示图片patch的位置
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
x = self.transformer(x)
x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x)
注:以上仅是笔者个人见解,若有问题,欢迎指正。