Score SDE 三种随机微分方程代码解读

2023-11-17

定义SDE类

定义了7个子函数

T: End time of the SDE.
sde:
marginal_prob: Parameters to determine the marginal distribution of the SDE, p t ( x ) p_t(x) pt(x).
prior_sampling: Generate one sample from the prior distribution, p T ( x ) p_T(x) pT(x).
prior_logp: Compute log-density of the prior distribution.
discretize: Discretize the SDE in the form: x_{i+1} = x_i + f_i(x_i) + G_i z_i.
Useful for reverse diffusion sampling and probabiliy flow sampling.
Defaults to Euler-Maruyama discretization.
reverse: Create the reverse-time SDE/ODE.

reverse

class SDE(abc.ABC):
  """SDE abstract class. Functions are designed for a mini-batch of inputs."""

  def __init__(self, N):
    """Construct an SDE.
    Args:
      N: number of discretization time steps.
    """
    super().__init__()
    self.N = N

  @property
  @abc.abstractmethod
  def T(self):
    """End time of the SDE."""
    pass

  @abc.abstractmethod
  def sde(self, x, t):
    pass

  @abc.abstractmethod
  def marginal_prob(self, x, t):
    """Parameters to determine the marginal distribution of the SDE, $p_t(x)$."""
    pass

  @abc.abstractmethod
  def prior_sampling(self, shape):
    """Generate one sample from the prior distribution, $p_T(x)$."""
    pass

  @abc.abstractmethod
  def prior_logp(self, z):
    """Compute log-density of the prior distribution.
    Useful for computing the log-likelihood via probability flow ODE.
    Args:
      z: latent code
    Returns:
      log probability density
    """
    pass

  def discretize(self, x, t):
    """Discretize the SDE in the form: x_{i+1} = x_i + f_i(x_i) + G_i z_i.
    Useful for reverse diffusion sampling and probabiliy flow sampling.
    Defaults to Euler-Maruyama discretization.
    Args:
      x: a torch tensor
      t: a torch float representing the time step (from 0 to `self.T`)
    Returns:
      f, G
    """
    dt = 1 / self.N
    drift, diffusion = self.sde(x, t)
    f = drift * dt
    G = diffusion * torch.sqrt(torch.tensor(dt, device=t.device))
    return f, G

  def reverse(self, score_fn, probability_flow=False):
    """Create the reverse-time SDE/ODE.
    Args:
      score_fn: A time-dependent score-based model that takes x and t and returns the score.
      probability_flow: If `True`, create the reverse-time ODE used for probability flow sampling.
    """
    N = self.N
    T = self.T
    sde_fn = self.sde
    discretize_fn = self.discretize

定义reverse-time SDE类

    # Build the class for reverse-time SDE.
    class RSDE(self.__class__):
      def __init__(self):
        self.N = N
        self.probability_flow = probability_flow

      @property
      def T(self):
        return T

      def sde(self, x, t):
        """Create the drift and diffusion functions for the reverse SDE/ODE."""
        drift, diffusion = sde_fn(x, t)
        score = score_fn(x, t)
        drift = drift - diffusion[:, None, None, None] ** 2 * score * (0.5 if self.probability_flow else 1.)
        # Set the diffusion function to zero for ODEs.
        diffusion = 0. if self.probability_flow else diffusion
        return drift, diffusion

      def discretize(self, x, t):
        """Create discretized iteration rules for the reverse diffusion sampler."""
        f, G = discretize_fn(x, t)
        rev_f = f - G[:, None, None, None] ** 2 * score_fn(x, t) * (0.5 if self.probability_flow else 1.)
        rev_G = torch.zeros_like(G) if self.probability_flow else G
        return rev_f, rev_G

    return RSDE()

定义VPSDE类

class VPSDE(SDE):
  def __init__(self, beta_min=0.1, beta_max=20, N=1000):
    """Construct a Variance Preserving SDE.
    Args:
      beta_min: value of beta(0)
      beta_max: value of beta(1)
      N: number of discretization steps
    """
    super().__init__(N)
    self.beta_0 = beta_min
    self.beta_1 = beta_max
    self.N = N
    self.discrete_betas = torch.linspace(beta_min / N, beta_max / N, N)
    self.alphas = 1. - self.discrete_betas
    self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
    self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
    self.sqrt_1m_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)

  @property
  def T(self):
    return 1

  def sde(self, x, t):
    beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0)
    drift = -0.5 * beta_t[:, None, None, None] * x
    diffusion = torch.sqrt(beta_t)
    return drift, diffusion

  def marginal_prob(self, x, t):
    log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
    mean = torch.exp(log_mean_coeff[:, None, None, None]) * x
    std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff))
    return mean, std

  def prior_sampling(self, shape):
    return torch.randn(*shape)

  def prior_logp(self, z):
    shape = z.shape
    N = np.prod(shape[1:])
    logps = -N / 2. * np.log(2 * np.pi) - torch.sum(z ** 2, dim=(1, 2, 3)) / 2.
    return logps

  def discretize(self, x, t):
    """DDPM discretization."""
    timestep = (t * (self.N - 1) / self.T).long()
    beta = self.discrete_betas.to(x.device)[timestep]
    alpha = self.alphas.to(x.device)[timestep]
    sqrt_beta = torch.sqrt(beta)
    f = torch.sqrt(alpha)[:, None, None, None] * x - x
    G = sqrt_beta
    return f, G

定义subVPSDE类

class subVPSDE(SDE):
  def __init__(self, beta_min=0.1, beta_max=20, N=1000):
    """Construct the sub-VP SDE that excels at likelihoods.
    Args:
      beta_min: value of beta(0)
      beta_max: value of beta(1)
      N: number of discretization steps
    """
    super().__init__(N)
    self.beta_0 = beta_min
    self.beta_1 = beta_max
    self.N = N

  @property
  def T(self):
    return 1

  def sde(self, x, t):
    beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0)
    drift = -0.5 * beta_t[:, None, None, None] * x
    discount = 1. - torch.exp(-2 * self.beta_0 * t - (self.beta_1 - self.beta_0) * t ** 2)
    diffusion = torch.sqrt(beta_t * discount)
    return drift, diffusion


#边际概率函数,返回值是边际概率的mean和std
  def marginal_prob(self, x, t):
    log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
    mean = torch.exp(log_mean_coeff)[:, None, None, None] * x
    std = 1 - torch.exp(2. * log_mean_coeff)
    return mean, std

  def prior_sampling(self, shape):
    return torch.randn(*shape)

  def prior_logp(self, z):
    shape = z.shape
    N = np.prod(shape[1:])
    return -N / 2. * np.log(2 * np.pi) - torch.sum(z ** 2, dim=(1, 2, 3)) / 2.

定义VESDE类

class VESDE(SDE):
  def __init__(self, sigma_min=0.01, sigma_max=50, N=1000):
    """Construct a Variance Exploding SDE.
    Args:
      sigma_min: smallest sigma.
      sigma_max: largest sigma.
      N: number of discretization steps
    """
    super().__init__(N)
    self.sigma_min = sigma_min
    self.sigma_max = sigma_max
    self.discrete_sigmas = torch.exp(torch.linspace(np.log(self.sigma_min), np.log(self.sigma_max), N))
    self.N = N

  @property
  def T(self):
    return 1

  def sde(self, x, t):
    sigma = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
    drift = torch.zeros_like(x)
    diffusion = sigma * torch.sqrt(torch.tensor(2 * (np.log(self.sigma_max) - np.log(self.sigma_min)),
                                                device=t.device))
    return drift, diffusion

  def marginal_prob(self, x, t):
    std = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
    mean = x
    return mean, std

  def prior_sampling(self, shape):
    return torch.randn(*shape) * self.sigma_max

  def prior_logp(self, z):
    shape = z.shape
    N = np.prod(shape[1:])
    return -N / 2. * np.log(2 * np.pi * self.sigma_max ** 2) - torch.sum(z ** 2, dim=(1, 2, 3)) / (2 * self.sigma_max ** 2)

  def discretize(self, x, t):
    """SMLD(NCSN) discretization."""
    timestep = (t * (self.N - 1) / self.T).long()
    sigma = self.discrete_sigmas.to(t.device)[timestep]
    adjacent_sigma = torch.where(timestep == 0, torch.zeros_like(t),
                                 self.discrete_sigmas[timestep - 1].to(t.device))
    f = torch.zeros_like(x)
    G = torch.sqrt(sigma ** 2 - adjacent_sigma ** 2)
    return f, G

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Score SDE 三种随机微分方程代码解读 的相关文章

随机推荐

  • Windows Terminal 配置(Neovim配置)

    Neovim的意义 vim是在linux上非常好用的编辑器 IDE 毕竟高度可配置 如果你是一名linux的开发人员 当切换到windows上时可以通过通过这个快速适应 配置完成后你可以直接在powershell中通过vim的命令进行编辑文
  • ceph集群换盘

    一引言 某地项目运行两年后磁盘批量报错 利用smartctl检测发现出现大量扇区错误 但并未达到彻底无法读写程度 统计下来发现数量接近40块 考虑批次换盘 坏盘期间为了保证不影响业务 需拟定一个较好的方案 二 方案 在查阅一堆资料后 发现无
  • 地图切片工具集合

    最近比较感兴趣地图瓦片生成 搜到不少有用的工具 发现以下三种已经比较成熟 文档齐全 当然还有很多专门的工具或者更好而且没有发现 欢迎补充 MapTiler gdal2tiles GUI版 目前版本为alpha 能够切成符合tms标准格式的瓦
  • 数据库中的关键字——字段(列)、记录(元组)、表、主键、外键

    一 字段 列 某一个事物的一个特征 或者说是属性 其中姓名就是员工的一个属性 可称之为字段 二 记录 元组 事物所有特征的组合 可以描述一个具体的事物 三 表 记录的组合 表示 同一类 事物的组合 四 主键 能唯一标识信息的事物 在说主键之
  • JAVA基础原理篇_1.1—— 关于JVM 、JDK以及 JRE

    目录 一 关于JVM JDK以及 JRE 1 JVM 2 JDK 3 JRE 二 为什么说 Java 语言 编译与解释并存 2 2 将高级编程语言按照程序的执行方式分为两种 2 2 Java的执行过程 2 3 所以为什么Java语言 编译与
  • Win7缺少d3dcompiler_43.dll文件如何处理?

    其实很多用户玩单机游戏或者安装软件的时候就出现过这种问题 如果是新手第一时间会认为是软件或游戏出错了 其实并不是这样 其主要原因就是你电脑系统的该dll文件丢失了或者损坏了 这时你只需下载这个d3dcompiler 43 dll文件进行安装
  • 使用tar --checkpoint提权操作 详解--checkpoint-action的参数及作用

    如果管理员给予了某个普通用户tar命令的超级管理员操作 那么我们可以使用tar命令进行提权 命令如下 sudo u root tar cf dev null exploit checkpoint 1 checkpoint action ex
  • error:03000086:digital envelope routines::initialization error

    error 03000086 digital envelope routines initialization error 问题原因分析 1 node版本问题 2 具体错误原因 ERR OSSL EVP UNSUPPORTED 错误SSL
  • 51单片机——LED灯

    如下图所示是51单片机的开发板原理图 我们想要让二极管D1亮 只需要把p20口置低电平即可 只需要把P2寄存器第0位置0 LED原理解释 CPU配置寄存器的值来控制硬件电路达到我们预期效果 例程1 点亮第一个LED include
  • MongoDB 非正常关机/意外关机(拉电闸)后无法启动的解决方案

    一 环境 Host CentOS 7 9 Version MongoDB 5 Install 二进制 二 说明 公司某天电闸突然跳闸 导致服务器重启后 伴随的自启动服务 MongoDB 启动失败 具体报错如下所示 三 排查 1 查看启动状态
  • IDEA:Warning: No artifacts configured FIX

    问题 办法 Warning No artifacts configured 警告 未配置项目 给idea项目添加tomcat的时候出现 解决办法 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 现在还不够 16 17
  • QT连接SQLserver详细教程

    Qt 连接 SQL Sever数据库 环境 一 配置 ODBC数据源 一 在SQL Sever Manger中添加 新的用户 1 打开如下自带的 MSS Management Studio 2 登录时选择 Windows 身份验证 3 去往
  • 在python中使用python-docx实现word文档自动化

    五一马上就要结束了 趁着今天休息的一天 给大家说说在python对办公文档处理 文章中说要详细的介绍python中几个对文档处理的库 今天就介绍一下word文档处理的python docx库 好了废话不多说开始吧 哈哈哈哈哈哈哈啊哈哈 py
  • 学习MongoDB 三: MongoDB无法启动的解决方法

    一简介 我们之前介绍了MongoDB入门 安装与配置 我们今天在打开MongDB时 我们先运行cmd exe进入dos命令界面 然后进入cd D mongodb bin目录下 启动服务或者mongo命令都报了错误 二 解决 1 net st
  • vue实现三级联动

    div div
  • 图片在盒子内等比展示不变形

    通过这个属性 object fit cover 使用场景如下 fatherBox 父盒子要有宽高 width 240px height 240px sonBox 子盒子 width 100 height 100 object fit cov
  • 7种Git错误以及解决方法

    使用Git的时候如果出现报错 要会解决Git错误 以下整理了七种Git错误以及解决的方法 1 当出现fatal not a git repository or any of the parent directories git时 说明不是一
  • vector find() 用法

    int main vector
  • STM32的PA0输出高电平的具体库函数代码操作

    在STM32中 可以使用库函数控制PA0输出高电平 具体的代码如下 初始化GPIOA的引脚模式 设置PA0为输出模式 GPIO InitTypeDef GPIO InitStruct HAL RCC GPIOA CLK ENABLE GPI
  • Score SDE 三种随机微分方程代码解读

    定义SDE类 定义了7个子函数 T End time of the SDE sde marginal prob Parameters to determine the marginal distribution of the SDE p t