ViT常见的模型规格以及源码记录

2023-11-08

ViT常见的模型规格以及源码记录

综述

论文题目:《AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE》

会议时间:International Conference on Learning Representations, 2021 (ICLR, 2021)

论文地址:https://openreview.net/pdf?id=YicbFdNTTy

论文源码:https://github.com/lucidrains/vit-pytorch(非官方)

介绍

网络结构

在这里插入图片描述

图片引自:https://github.com/lucidrains/vit-pytorch

模型规格

常见规格:

Model Layers Hidden size D MLP size Heads Params
ViT-Base 12 768 3072 12 86M
ViT-Large 24 1024 4096 16 307M
ViT-Huge 32 1280 5120 16 632M

另外,还会添加patch大小,例如:ViT-L/16表示使用 16 × 16 16\times16 16×16的patch大小切分图片。

源码实现

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool='cls', channels=3,
                 dim_head=64, dropout=0., emb_dropout=0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )
        # 位置编码
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        # 类别token(类似类别查询向量)
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape
        # 类别token先与特征合并(沿序列方向合并)
        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b=b)
        x = torch.cat((cls_tokens, x), dim=1)
        # 加上位置编码,用于表示图片patch的位置
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        x = self.transformer(x)

        x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        return self.mlp_head(x)

注:以上仅是笔者个人见解,若有问题,欢迎指正。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

ViT常见的模型规格以及源码记录 的相关文章

随机推荐

  • BERTopic

    论文标题 BERTopic Neural topic modeling with a class based TF IDF procedure 论文作者 Maarten Grootendorst 论文链接 https arxiv org p
  • vue 引入weixin-js-sdk报错: import wx from ‘weixin-js-sdk‘ wx=‘undefined‘

    vue 中通过 npm 引入 weixin js sdk 使用 wx config 时报错了 c0e6 189 Uncaught in promise TypeError Cannot read property config of und
  • 分支限界法解作业分配问题的实现(C++)

    include
  • Mac版本的After Effects 2023中英文切换方法

    打开ae模板会发现有许多系统的表达式错误 这些错误时由于系统语言不通导致的 只要更改下ae界面语言即可 那么如何将中文版的After Effects 2023 Mac版切换成英文版呢 新版本已经不能通过首选项更改语言设置了 要从applic
  • 国内直接下载google play谷歌商店apk安装包的网站【https://apkpure.com/】

    https apkpure com 这里可以直接下载google play 谷歌商店中的app
  • RedisTemplate使用最详解(一)--- opsForValue()

    1 set K var1 V var2 新增一个字符串类型的值 var1是key var2是值 key存在就覆盖 不存在新增 redisTemplate opsForValue set BBB 你好 2 set K key V value
  • $Luogu[P3673]$小清新计数题

    这他妈什么玩意儿 这里是可爱的链接菌 转化模型 对于第 i 句话 第 p 句话为真话 将 i p 连一条白边 第 p 句话为假话 将 i p 连一条黑边 显然我们的图会是一片基环树森林 并且边为无向边 白边连的两点真假相同 黑边相反 那么要
  • python存csv中文乱码问题

    这两天做了一个小测试是抓的天气信息本来想存数据库 后来觉得还是存csv比较好 使用方便 但是在使用的过程中 发现存中文的时候会出现乱码的情况 查了一下资料 跟大家分享一下python3中存csv乱码的问题 亲测在python2中是不能设置这
  • Linux脚本- 将当前文件夹以及所有子文件夹下的所有.cpp文件,拷贝到指定文件路径下

    需求 将当前文件夹以及所有子文件夹下的所有 cpp文件 拷贝到指定文件路径 home majn llvm project llvm cpp test suite下 以下是一个用于实现该功能的 Bash 脚本 它会递归地查找当前文件夹和所有子
  • mpvue 未找到入口 app.json 文件

    从网上下了个mpvue的程序下来 说是直接用微信打开目录就ok了 但是打开之后发现编译直接出错了 说 未找到入口 app json 文件 懵逼啊 原来要先运行 npm intall 安装依赖包 然后再运行 npm run dev 执行一下m
  • SQL Server数据导入导出工具BCP详解

    bcp是SQL Server中负责导入导出数据的一个命令行工具 它是基于DB Library的 并且能以并行的方式高效地导入导出大批量的数据 bcp可以将数据库的表或视图直接导出 也能通过SELECT FROM语句对表或视图进行过滤后导出
  • 磁盘分区基础和LINUX上硬盘分区设备号解释

    现在就开始讲讲分区 先明确一下概念 主分区 一块物理硬盘上可以被独立使用的一部分 一个硬盘最多可以有4个主分区 扩展分区 为了突破一个物理硬盘只能有4个分区的限制 引入了扩展分区 扩展分区和主分区的地位相当 但是扩展分区本身不能被直接使用
  • linux之getopt 函数

    命令行参数解析函数 getopt getopt 函数声明如下 include
  • mysql日期相减取小时

    mysql日期相减取小时 TIMESTAMPDIFF HOUR a StartTime a EndTime 转载于 https www cnblogs com penghq p 8657064 html
  • 各国语言对应翻译表

    为了工作方便 自己做了一个地区语言的英文翻译 让自己可以更快的找到自己需要的地方 同时 分享给大家 谢谢 中文 各国语言 翻译 序号 中文 翻译 1 阿尔巴尼亚语 2 阿拉伯语 3 阿姆哈拉语 4 阿塞拜疆语 Az rbaycan 5 爱尔
  • 本地springboot项目上传到gitee

    1 在gitee上新建一个仓库 创建后可以拿到仓库地址 https gitee com ouyangshuiming linux test git 2 选中 创建git仓库 3 4 最后一步 一定记得这里要写上一段话 才能成功提交 比如gi
  • Elasticsearch的一些基本概念

    文章目录 基本概念 文档和索引 JSON文档 元数据 索引 REST API 节点和集群 节点 Master eligible节点和Master节点 Data Node 和 Coordinating Node 其它节点 分片 Primary
  • 如何找到电脑自带的浏览器

    1 找到电脑自带的浏览器 首先就是进入你的C盘 然后在C盘里找到自己的如下路径 C Program Files internet explorer 找到成功 完成
  • Conan

    环境 ubuntu bionic的docker image shell docker run it ubuntu bionic 预装工具 shell apt get install cmake 安装conan shell pip3 inst
  • ViT常见的模型规格以及源码记录

    ViT常见的模型规格以及源码记录 综述 介绍 网络结构 模型规格 源码实现 综述 论文题目 AN IMAGE IS WORTH 16X16 WORDS TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE