DiffusionDet源码阅读(1)

2023-05-16

本文仅仅适用于已经通读过全文的小伙伴

本文代码节选自 mmdet 中的 DiffusionDet 代码,目前该代码还处于 Development 阶段,所以我博客里写的代码和之后的稳定版本可能稍有不同,不过不用担心,我们只看最关键的部分

DDPM中扩散部分有个参数 β \beta β:

q ( z t ∣ z t − 1 ) : = N ( z t ; 1 − β t z t − 1 , β t I ) q(z_t | z_{t-1}) := \mathcal{N} (z_{t}; \sqrt{1 - \beta_t} z_{t-1}, \beta_t \bf{I} ) q(ztzt1):=N(zt;1βt zt1,βtI)

这就是每次的加噪过程,也可以视为 z t − 1 z_{t-1} zt1先经过一个缩放,再加一个随机噪声之后,就成了 z t z_{t} zt
每次加噪声通过一个参数 β t \beta_t βt来控制,这个参数是人为给定的,而不是可学习的,由于:

q ( z t ∣ z 0 ) : = N ( z t ; α ˉ t z 0 , ( 1 − α ˉ t ) I ) q(z_t | z_{0}) := \mathcal{N} (z_{t}; \sqrt{ \bar{\alpha}_t } z_{0}, (1-\bar{\alpha}_t) \bf{I} ) q(ztz0):=N(zt;αˉt z0,(1αˉt)I)
即:

z t = α ˉ t z 0 + ϵ 1 − α ˉ t ,    w h e r e    ϵ ∈ N ( 0 , I ) z_t = \sqrt{ \bar{\alpha}_t } z_{0} + \epsilon \sqrt{1 - \bar{\alpha}_t}, \ \ where \ \ \epsilon \in \mathcal{N}(0, \bf{I}) zt=αˉt z0+ϵ1αˉt ,  where  ϵN(0,I)

在给定 z 0 z_{0} z0 的基础上, q ( z t ∣ z 0 ) q(z_t | z_{0}) q(ztz0) 也是一个高斯分布,其中:

α t = 1 − β t α ˉ t = Π s = 0 t α s \alpha_t = 1 - \beta_t \\ \bar{\alpha}_t = \Pi_{s=0}^t \alpha_s αt=1βtαˉt=Πs=0tαs

α ˉ t \bar{\alpha}_t αˉt 取值趋近于0时, z t z_t zt 可以视为一个标准的高斯分布,在DiffusionDet中, β 1 : T \beta_{1:T} β1:T取了一系列零到一,且逐渐变大的值,以下是生成 β \beta β 的代码,这里我们取 T = 1000 T=1000 T=1000,即共采样 1000 1000 1000

def cosine_beta_schedule(timesteps, s=0.008):
    """Cosine schedule as proposed in
    https://openreview.net/forum?id=-NEXDKk8gZ."""
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps, dtype=torch.float64)
    alphas_cumprod = torch.cos(
        ((x / timesteps) + s) / (1 + s) * math.pi * 0.5)**2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0, 0.999)

c o s ( x ) cos(x) cos(x) c o s 2 ( x ) cos^2(x) cos2(x) 两个函数的曲线,红线是前者,蓝线是后者,二者有同一个零点 ( π 2 , 0 ) (\frac{\pi}{2}, 0) (2π,0)

请添加图片描述

这是 β \beta β的曲线

请添加图片描述

接下来就是上边计算 α \alpha α α ˉ \bar{\alpha} αˉ之类的代码:

    def _build_diffusion(self):
        betas = cosine_beta_schedule(self.timesteps)
        alphas = 1. - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)
        alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.)

        self.register_buffer('betas', betas)
        self.register_buffer('alphas_cumprod', alphas_cumprod)
        self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
        self.register_buffer('sqrt_one_minus_alphas_cumprod',
                             torch.sqrt(1. - alphas_cumprod))
        self.register_buffer('log_one_minus_alphas_cumprod',
                             torch.log(1. - alphas_cumprod))
        self.register_buffer('sqrt_recip_alphas_cumprod',
                             torch.sqrt(1. / alphas_cumprod))
        self.register_buffer('sqrt_recipm1_alphas_cumprod',
                             torch.sqrt(1. / alphas_cumprod - 1))

        # calculations for posterior q(x_{t-1} | x_t, x_0)
        # equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
        posterior_variance = betas * (1. - alphas_cumprod_prev) / (
            1. - alphas_cumprod)
        self.register_buffer('posterior_variance', posterior_variance)

        # log calculation clipped because the posterior variance is 0 at
        # the beginning of the diffusion chain
        self.register_buffer('posterior_log_variance_clipped',
                             torch.log(posterior_variance.clamp(min=1e-20)))
        self.register_buffer(
            'posterior_mean_coef1',
            betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
        self.register_buffer('posterior_mean_coef2',
                             (1. - alphas_cumprod_prev) * torch.sqrt(alphas) /
                             (1. - alphas_cumprod))

这三行计算了 β t \beta_t βt, α ˉ t \bar{\alpha}_t αˉt α ˉ t − 1 \bar{\alpha}_{t-1} αˉt1,其长度都是 T T T

        alphas = 1. - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)
        alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.)

        self.register_buffer('betas', betas)
        self.register_buffer('alphas_cumprod', alphas_cumprod)
        self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)

q ( z t ∣ z t − 1 ) : = N ( z t ; 1 − β t z t − 1 , β t I ) q(z_t | z_{t-1}) := \mathcal{N} (z_{t}; \sqrt{1 - \beta_t} z_{t-1}, \beta_t \bf{I} ) q(ztzt1):=N(zt;1βt zt1,βtI)

接下来计算 α ˉ t \sqrt{\bar{\alpha}_{t}} αˉt 1 − α ˉ t \sqrt{1 - \bar{\alpha}_{t}} 1αˉt log ⁡ ( 1 − α ˉ t ) \log{(1-\bar{\alpha}_{t})} log(1αˉt) 1 α ˉ t \frac{1}{\sqrt{\bar{\alpha}_{t}}} αˉt 1 1 α ˉ t − 1 \sqrt{\frac{1}{\bar{\alpha}_t} - 1} αˉt11

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
        self.register_buffer('sqrt_one_minus_alphas_cumprod',
                             torch.sqrt(1. - alphas_cumprod))
        self.register_buffer('log_one_minus_alphas_cumprod',
                             torch.log(1. - alphas_cumprod))
        self.register_buffer('sqrt_recip_alphas_cumprod',
                             torch.sqrt(1. / alphas_cumprod))
        self.register_buffer('sqrt_recipm1_alphas_cumprod',
                             torch.sqrt(1. / alphas_cumprod - 1))

DDPM文中假设,后验分布 q ( z t − 1 ∣ z t , z 0 ) q(z_{t-1} | z_t, z_0) q(zt1zt,z0)也是高斯分布,有:

q ( z t − 1 ∣ z t , z 0 ) = N ( z t − 1 ; μ ~ ( z t , z 0 ) , β t ~ I ) q(z_{t-1} | z_t, z_0) = \mathcal{N} (z_{t-1} ; \tilde{\mu}(z_t, z_0), \tilde{\beta_t} \bm{I}) q(zt1zt,z0)=N(zt1;μ~(zt,z0),βt~I)

算式整理后有:

μ ~ t ( z t , z 0 ) = α ˉ t − 1 β t 1 − α ˉ t z 0 + α t ( 1 − α ˉ t − 1 ) 1 − α ˉ t z t \tilde{\mu}_t(z_t, z_0) = \frac{ \sqrt{\bar{\alpha}_{t-1}} \beta_t }{ 1 - \bar{\alpha}_t } z_{0} + \frac { \sqrt{\alpha_t} (1 - \bar{\alpha}_{t-1}) } { 1 - \bar{\alpha}_t } z_{t} μ~t(zt,z0)=1αˉtαˉt1 βtz0+1αˉtαt (1αˉt1)zt

β ~ t = 1 − α ˉ t − 1 1 − α ˉ t β t \tilde{\beta}_{t} = \frac { 1 - \bar{\alpha}_{t-1} } { 1 - \bar{\alpha}_t } \beta_{t} β~t=1αˉt1αˉt1βt

接下来的几行代码用来计算这几个系数:

        # calculations for posterior q(x_{t-1} | x_t, x_0)
        # equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
        posterior_variance = betas * (1. - alphas_cumprod_prev) / (
            1. - alphas_cumprod)
        self.register_buffer('posterior_variance', posterior_variance)

        # log calculation clipped because the posterior variance is 0 at
        # the beginning of the diffusion chain
        self.register_buffer('posterior_log_variance_clipped',
                             torch.log(posterior_variance.clamp(min=1e-20)))
        self.register_buffer(
            'posterior_mean_coef1',
            betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
        self.register_buffer('posterior_mean_coef2',
                             (1. - alphas_cumprod_prev) * torch.sqrt(alphas) /
                             (1. - alphas_cumprod))

以上就是函数 _build_diffusion 的全部内容,集中几个log项可能是之后计算loss用的

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

DiffusionDet源码阅读(1) 的相关文章

  • DispatcherServlet 源码阅读(1)

    有时间还是应该多看看源码 DispatcherServlet 是一个实实在在的 Servlet xff0c 所以 Spring MVC 引入后不会改变 Servlet 容器的行为 xff0c 仍然是解析 web xml 部署文件 xff0c
  • glibc源码阅读

    FBI warning 本文仅仅是试图以二进制选手的方式来理解mallo c中所使用的堆机制 xff0c 不会对具体操作以及堆块结构作过多叙述 xff0c 敬请谅解 水平欠佳 xff0c 有问题也欢迎留言指出 先解释一些常用的宏与常量 变量
  • 使用DiffusionDet在mot数据集上训练

    数据集处理 在https github com facebookresearch detectron2 detectron2 data datasets builtin py 中 xff0c 可以看到 xff0c detectron2中可以
  • TeaPearce/Conditional_Diffusion_MNIST 源码阅读

    文章目录 tqdm超参数预运算nn Module register buffer绘制动画ddpmforward U net噪声预测模型信息向量掩码向量conext mask上采样层的信息融合恢复阶段 总结后记 tqdm dataset sp
  • DiffusionDet源码阅读(1)

    本文仅仅适用于已经通读过全文的小伙伴 本文代码节选自 mmdet 中的 DiffusionDet 代码 xff0c 目前该代码还处于 Development 阶段 xff0c 所以我博客里写的代码和之后的稳定版本可能稍有不同 xff0c 不
  • 每日lodash源码阅读(一)——createMathOperation

    每日lodash源码阅读 xff08 一 xff09 createMathOperation 一 写在前面二 使用举例三 源码分析add jscreateMathOperation js 一 写在前面 createMathOperation
  • kube-proxy源码阅读(iptables实现)

    Reference 文章目录 1 入口2 ProxyServer创建及调用3 ProxyServer 核心调用流程3 1 func o Options Run err3 2 func o Options runLoop error3 3 f
  • MSCKF-vio源码阅读

    作为一个菜狗来说 xff0c 一开始弄明白kf ekf等滤波方法实属不易 xff0c 但是一旦理解原理之后再发散到基于滤波的状态估计方法 xff0c 学习起来就会事半功倍 xff0c 就像导航包中的robot pose ekf xff0c
  • 【cartographer_slam源码阅读】4-6激光雷达数据的转换

    HandleLaserScanMessage 函数 作用 xff1a 利用 ToPointCloudWithIntensities函数 将ros中的数据转换为carto中定义的数据类型 xff1b 传入 HandleLaserScan 函数
  • DiffusionDet:Diffusion Model for Object Detection

    Diffusion Model for Object Detection 一种用于目标检测的扩散模型 Motivation 1 如何使用一种更简单的方法代替可查询的object queries 2 Bounding box的生成方式过去是三
  • ReentrantLock源码阅读(1)(JDK1.8)

    ReentrantLock 前言ReentrantLock JDK 1 8 实现了Lock接口Sync类NonfairSync类FairSync类重要属性和方法 总结 前言 最近在使用Java 并发包时遇到一些问题 xff0c 感觉对于其还
  • REDIS 源码阅读

    https redissrc readthedocs io en latest datastruct dict html 一个注释的开源项目 xff1a 书是redis的设计与实现 https github com huangz1990 r
  • 【FreeRTOS源码阅读】<2> task.c (1) 任务创建以及TCB、List的结构

    上篇讲述了list c关于链表操作的源码阅读 xff0c 此片文章将开始阅读task c task h相关结构体 由eTaskGetState返回的任务状态 typedef enum eRunning 61 0 一个任务查询自己的状态 xf
  • 【Python源码阅读】PYC 文件剖析

    pyc 文件相信大家见怪不怪 xff0c 大家经常在 pycache 里面见到这些文件 这些文件存储了 python 编译出来的字节码文件 xff0c 还有一些元信息 xff08 例如版本号 xff0c 对应文件的修改时间 xff09 接下
  • 源码阅读——validate-npm-package-name

    文章目录 前言一 源码阅读工具二 阅读源码1 目录结构2 package json3 index js 三 使用该包1 vue cli中使用2 create react app 中使用 总结 前言 validate npm package
  • A-LOAM源码阅读

    LOAM 论文地址 xff1a https www ri cmu edu pub files 2014 7 Ji LidarMapping RSS2014 v8 pdf A LOAM地址 xff1a https github com HKU
  • Deformable Detr代码阅读

    前言 本文主要是自己在阅读mmdet中Deformable Detr的源码时的一个记录 如有错误或者问题 欢迎指正 deformable attention的流程 首先zq即为object query 通过一个线性层 先预测出offset
  • leveldb官方手册摘录

    本文内容摘自leveldb官方手册 版权归其所有 CHAPTER 1 基本概念 leveldb是一个写性能十分优秀的存储引擎 是典型的LSM树 Log Structured Merge Tree 实现 LSM树的核心思想就是放弃部分读的性能
  • “npm create vite“ 是如何实现初始化 Vite 项目?

    欢迎关注我的公号 前端我废了 查看更多文章 前言 我们从 vite 的官方文档中看到 可以使用 npm yarn pnpm create 命令来快速初始化一个基于 Vite 的项目 其实很多框架或库都会开发相应的脚手架工具 用于快速初始化项
  • Quartz框架多个trigger任务执行出现漏执行的问题分析

    一 问题描述 使用Quartz配置定时任务 配置了超过10个定时任务 这些定时任务配置的触发时间都是5分钟执行一次 实际运行时 发现总有几个定时任务不能执行到 二 示例程序 1 简单介绍 采用spring quartz整合方案实现定时任务

随机推荐