DDIM模型代码解析(一)

2023-10-27

目录

预备知识

main.py

解析命令行参数

解析配置文件


预备知识

由于代码中除了一些必要的对模型、数据进行操作的PyTorch函数外,还有一些辅助显示训练等过程有关信息的,或辅助对文件目录进行操作的库。因此,建议读者先对这些库进行了解,试着写一写示例代码,理解库中函数的使用方法后再阅读下面的讲解,这样可以更顺畅。

import argparse
import traceback
import shutil
import logging
import yaml
import sys
import os

main.py

首先对输出的选项进行设定,让输出的内容不按科学计数法模式。

torch.set_printoptions(sci_mode=False)  # 设置为不按照科学计数法表示输出

然后程序进入main()函数中,在main函数中完成了以下任务:

  • 解析命令行参数
  • 解析配置文件
  • 打印相关信息
  • 扩散过程实例化
  • 完成采样 / 测试 / 训练过程

后面我们逐一进行代码分析。

def main():
    args, config = parse_args_and_config()  # 解析命令行参数和配置文件
    logging.info("Writing log file to {}".format(args.log_path))  # 显示日志存储路径信息
    logging.info("Exp instance id = {}".format(os.getpid()))  # 显示进程id信息
    logging.info("Exp comment = {}".format(args.comment))  # 显示实验注释信息

    try:
        runner = Diffusion(args, config)  # 构建扩散运行实例对象
        if args.sample:  # 如果是采样操作,就执行采样函数
            runner.sample()
        elif args.test:  # 如果是测试模型,就执行测试函数
            runner.test()
        else:  # 否则就执行训练函数
            runner.train()
    except Exception:  # 如果报错就输出错误信息日志
        logging.error(traceback.format_exc())

    return 0

解析命令行参数

对命令行参数的解析在parse_args_and_config函数中完成,每一个参数的含义以注释的形式标明,如果有异议欢迎在评论中指出。

def parse_args_and_config():
    parser = argparse.ArgumentParser(description=globals()["__doc__"])

    parser.add_argument(  # config文件路径
        "--config", type=str, required=True, help="Path to the config file"
    )
    parser.add_argument("--seed", type=int, default=1234, help="Random seed")  # 随机种子
    parser.add_argument(  # 用于保存运行相关数据的路径
        "--exp", type=str, default="exp", help="Path for saving running related data."
    )
    parser.add_argument(  # log日志文件夹名称
        "--doc",
        type=str,
        required=True,
        help="A string for documentation purpose. "
        "Will be the name of the log folder.",
    )
    parser.add_argument(  # 实验注释
        "--comment", type=str, default="", help="A string for experiment comment"
    )
    parser.add_argument(  # logging日志的级别: info, debug, warning, critical
        "--verbose",
        type=str,
        default="info",
        help="Verbose level: info | debug | warning | critical",
    )
    parser.add_argument("--test", action="store_true", help="Whether to test the model")  # 是否测试模型
    parser.add_argument(  # 是否从模型产生采样
        "--sample",
        action="store_true",
        help="Whether to produce samples from the model",
    )
    parser.add_argument("--fid", action="store_true")  # FID指标
    parser.add_argument("--interpolation", action="store_true")  # 插值
    parser.add_argument(  # 是否为继续训练
        "--resume_training", action="store_true", help="Whether to resume training"
    )
    parser.add_argument(  # 采样的文件夹名称
        "-i",
        "--image_folder",
        type=str,
        default="images",
        help="The folder name of samples",
    )
    parser.add_argument(  # 无交互
        "--ni",
        action="store_true",
        help="No interaction. Suitable for Slurm Job launcher",
    )
    parser.add_argument("--use_pretrained", action="store_true")  # 使用预训练
    parser.add_argument(  # 采样类型
        "--sample_type",
        type=str,
        default="generalized",
        help="sampling approach (generalized or ddpm_noisy)",
    )
    parser.add_argument(  # 跳跃类型
        "--skip_type",
        type=str,
        default="uniform",
        help="skip according to (uniform or quadratic)",
    )
    parser.add_argument(  # 步数
        "--timesteps", type=int, default=1000, help="number of steps involved"
    )
    parser.add_argument(  # \eta超参数用于控制方差
        "--eta",
        type=float,
        default=0.0,
        help="eta used to control the variances of sigma",
    )
    parser.add_argument("--sequence", action="store_true")  # 是否为序列

    args = parser.parse_args()  # 解析参数
    args.log_path = os.path.join(args.exp, "logs", args.doc)  # log日志路径: exp/logs/$doc$
    
    ...

解析配置文件

解析配置文件的过程也是在parse_args_and_config函数中,args.config应该是bedroom,celeba,church,cifar10中的一个。这样我们可以直接打开文件夹configs中对应数据集的yaml配置文件,此时config为字典类型。经过dict2namespace函数,将字典类型转换为argparse中命名空间的形式。

def parse_args_and_config():
    ...

    # parse config file
    with open(os.path.join("configs", args.config), "r") as f:
        config = yaml.safe_load(f)
    new_config = dict2namespace(config)

    ...

转换函数如下:

def dict2namespace(config):
    namespace = argparse.Namespace()
    for key, value in config.items():
        if isinstance(value, dict):
            new_value = dict2namespace(value)
        else:
            new_value = value
        setattr(namespace, key, new_value)
    return namespace

之后还有一步设定tensorboard日志的路径,可以在训练时用tensorboard查看训练进度信息:

def parse_args_and_config():
    ...

    tb_path = os.path.join(args.exp, "tensorboard", args.doc)  # tensorboard日志路径: exp/tensorboard/$doc$

    ...

之后会执行训练 / 采样 / 测试不同的代码部分:

首先看一下对于训练会执行的代码:

  • 创建log日志文件夹
  • 创建tensorboard日志文件夹
  • 设置logging的logger
def parse_args_and_config():
    ...

    if not args.test and not args.sample:
        if not args.resume_training:
            if os.path.exists(args.log_path):  # 如果log输出路径存在的话
                overwrite = False  # 选择不覆盖
                if args.ni:  # 如果ni为True
                    overwrite = True  # 选择覆盖
                else:
                    response = input("Folder already exists. Overwrite? (Y/N)")  # 询问是否覆盖
                    if response.upper() == "Y":  # 如果Y, 则选择覆盖原有log
                        overwrite = True

                if overwrite:  # 如果选择覆盖
                    shutil.rmtree(args.log_path)  # 删除原有log文件路径
                    shutil.rmtree(tb_path)  # 删除原有tensorboard文件路径
                    os.makedirs(args.log_path)  # 创建新的log文件路径
                    if os.path.exists(tb_path):  # 如果tensorboard文件路径存在, 就删除它
                        shutil.rmtree(tb_path)
                else:  # 如果选择不覆盖, 则提示文件夹存在, 程序停止
                    print("Folder exists. Program halted.")
                    sys.exit(0)
            else:  # 如果log输出路径不存在就创建路径
                os.makedirs(args.log_path)

            with open(os.path.join(args.log_path, "config.yml"), "w") as f:
                yaml.dump(new_config, f, default_flow_style=False)

        new_config.tb_logger = tb.SummaryWriter(log_dir=tb_path)
        # setup logger
        level = getattr(logging, args.verbose.upper(), None)  # 20 (logging.INFO) 或者其它的级别
        if not isinstance(level, int):  # 如果为None的话就会报错
            raise ValueError("level {} not supported".format(args.verbose))

        handler1 = logging.StreamHandler()  # 将log在CLI输出的handler
        handler2 = logging.FileHandler(os.path.join(args.log_path, "stdout.txt"))  # 将log在文件输出的handler
        formatter = logging.Formatter(  # 控制log输出格式的formatter
            "%(levelname)s - %(filename)s - %(asctime)s - %(message)s"  # INFO - __main__ - ... - ....
        )
        handler1.setFormatter(formatter)  # 设置CLI输出handler的格式
        handler2.setFormatter(formatter)  # 设置文件输出handler的格式
        logger = logging.getLogger()  # root logger
        logger.addHandler(handler1)  # 添加CLI输出handler
        logger.addHandler(handler2)  # 添加文件输出handler
        logger.setLevel(level)  # 设定root logger的级别

    ...

然后是采样 / 测试会执行的代码:

  • 设置logging的logger
  • 对于采样,会创建图像文件夹
def parse_args_and_config():
    ...


    else:
        level = getattr(logging, args.verbose.upper(), None)
        if not isinstance(level, int):
            raise ValueError("level {} not supported".format(args.verbose))

        handler1 = logging.StreamHandler()
        formatter = logging.Formatter(
            "%(levelname)s - %(filename)s - %(asctime)s - %(message)s"
        )
        handler1.setFormatter(formatter)
        logger = logging.getLogger()
        logger.addHandler(handler1)
        logger.setLevel(level)

        if args.sample:  # 如果是采样
            os.makedirs(os.path.join(args.exp, "image_samples"), exist_ok=True)  # 创建目录: exp/image_samples
            args.image_folder = os.path.join(  # 添加图像文件夹参数: exp/image_samples/$image_folder$
                args.exp, "image_samples", args.image_folder
            )
            if not os.path.exists(args.image_folder):  # 如果图像文件夹不存在就创建一个
                os.makedirs(args.image_folder)
            else:  # 如果图像文件夹存在
                if not (args.fid or args.interpolation):
                    overwrite = False
                    if args.ni:
                        overwrite = True
                    else:
                        response = input(
                            f"Image folder {args.image_folder} already exists. Overwrite? (Y/N)"
                        )
                        if response.upper() == "Y":
                            overwrite = True

                    if overwrite:  # 如果覆盖, 删除并新建文件夹
                        shutil.rmtree(args.image_folder)
                        os.makedirs(args.image_folder)
                    else:
                        print("Output image folder exists. Program halted.")
                        sys.exit(0)

    ...

最后是对PyTorch进行设置:

  • device
  • 随机种子
  • causes cuDNN to benchmark multiple convolution algorithms and select the fastest.
def parse_args_and_config():
    ...

    # add device
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    logging.info("Using device: {}".format(device))
    new_config.device = device

    # set random seed
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)

    torch.backends.cudnn.benchmark = True

    return args, new_config

至此,就基本结束main.py的学习了,后面讲进入Diffusion类中查看具体初始化、训练、采样、测试这些函数是如何实现的了。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

DDIM模型代码解析(一) 的相关文章

  • 计算另一个字符串中多个字符串的出现次数

    在 Python 2 7 中 给定以下字符串 Spot是一只棕色的狗 斑点有棕色的头发 斑点的头发是棕色的 查找字符串中 Spot brown 和 hair 总数的最佳方法是什么 在示例中 它将返回 8 我正在寻找类似的东西string c
  • 如何在 __init__ 中使用await设置类属性

    我如何定义一个类await在构造函数或类体中 例如我想要的 import asyncio some code class Foo object async def init self settings self settings setti
  • VSCode Settings.json 丢失

    我正在遵循教程 并尝试将 vscode 指向我为 Scrapy 设置的虚拟工作区 但是当我在 VSCode 中打开设置时 工作区设置 选项卡不在 用户设置 选项卡旁边 我还尝试通过以下方式手动转到文件 APPDATA Code User s
  • 从Django中具有外键关系的两个表中检索数据? [复制]

    这个问题在这里已经有答案了 This is my models py file from django db import models class Author models Model first name models CharFie
  • 在 Django Admin 中调整字段大小

    在管理上添加或编辑条目时 Django 倾向于填充水平空间 但在某些情况下 当编辑 8 个字符宽的日期字段或 6 或 8 个字符的 CharField 时 这确实是一种空间浪费 字符宽 然后编辑框最多可容纳 15 或 20 个字符 我如何告
  • Pycharm 在 os.path 连接上出现“未解析的引用”

    将pycharm升级到2018 1 并将python升级到3 6 5后 pycharm报告 未解析的引用 join 最新版本的 pycharm 不会显示以下行的任何警告 from os path import join expanduser
  • GUI(输入和输出矩阵)?

    我需要创建一个 GUI 将数据输入到矩阵或表格中并读取此表单数据 完美的解决方案是限制输入表单仅允许float 例如 A 1 02 0 25 0 30 0 515 0 41 1 13 0 15 1 555 0 25 0 14 1 21 2
  • 为什么一旦我离开内置的运行服务器,Django 就无法找到我的管理媒体文件?

    当我使用内置的简单服务器时 一切正常 管理界面很漂亮 python manage py runserver 但是 当我尝试使用 wsgi 服务器为我的应用程序提供服务时django core handlers wsgi WSGIHandle
  • Java 和 Python 可以在同一个应用程序中共存吗?

    我需要一个 Java 实例直接从 Python 实例数据存储中获取数据 我不知道这是否可能 数据存储是否透明 唯一 或者每个实例 如果它们确实可以共存 都有其单独的数据存储 总结一下 Java 应用程序如何从 Python 应用程序的数据存
  • 使用 Python Oauthlib 通过服务帐户验证 Google API

    我不想使用适用于 Python 的 Google API 客户端库 但仍想使用 Python 访问 Google APIOauthlib https github com idan oauthlib 创建服务帐户后谷歌开发者控制台 http
  • 无法导入 langchain.agents.load_tools

    我正在尝试使用 LangChain Agents 但无法导入 load tools 版本 langchain 0 0 27 我尝试过这些 from langchain agents import initialize agent from
  • python的shutil.move()在linux上是原子的吗?

    我想知道python的shutil move在linux上是否是原子的 如果源文件和目标文件位于两个不同的分区上 行为是否不同 或者与它们存在于同一分区上时的行为相同吗 我更关心的是如果源文件和目标文件位于同一分区上 shutil move
  • 当字段是数字时怎么说...在 mongodb 中匹配?

    所以我的结果中有一个名为 城市 的字段 结果已损坏 有时它是一个实际名称 有时它是一个数字 以下代码显示所有记录 db zips aggregate project city substr city 0 1 sort city 1 我需要修
  • 如何将 ascii 值列表转换为 python 中的字符串?

    我在 Python 程序中有一个列表 其中包含一系列数字 这些数字本身就是 ASCII 值 如何将其转换为可以在屏幕上回显的 常规 字符串 您可能正在寻找 chr gt gt gt L 104 101 108 108 111 44 32 1
  • Python GTK+ 画布

    我目前正在通过 PyGobject 学习 GTK 需要画布之类的东西 我已经搜索了文档 发现两个小部件似乎可以完成这项工作 GtkDrawingArea 和 GtkLayout 我需要一些基本函数 如 fillrect 或 drawline
  • 在 Google App Engine 中,如何避免创建具有相同属性的重复实体?

    我正在尝试添加一个事务 以避免创建具有相同属性的两个实体 在我的应用程序中 每次看到新的 Google 用户登录时 我都会创建一个新的播放器 当新的 Google 用户在几毫秒内进行多个 json 调用时 我当前的实现偶尔会创建重复的播放器
  • 如何使用 AWS Lambda Python 读取 AWS S3 存储的 Word 文档(.doc 和 .docx)文件内容?

    我的场景是 我尝试使用 python 实现从 Aws Lambda 读取 AWS 存储的 S3 word 文档 doc 和 docx 文件内容 下面的代码是我使用的 我的问题是我可以获取文件名 但无法读取内容 def lambda hand
  • 每当使用 import cv2 时 OpenCV 都会出错

    我在终端上使用 pip3 install opencv contrib python 安装了 cv2 并且它工作了 但是每当我尝试导入 cv2 或运行导入了 cv2 的 vscode 文件时 在 python IDLE 上它都会说 Trac
  • 如何在 Flask 中的视图函数/会话之间传递复杂对象

    我正在编写一个 Web 应用程序 当 且仅当 用户登录时 该应用程序从第三方服务器接收大量数据 这些数据被解析为自定义对象并存储在list 现在 用户在应用程序中使用这些数据 调用不同的视图 例如发送不同的请求 我不确定什么是最好的模式在视
  • NLTK:查找单词大小为 2k 的上下文

    我有一个语料库 我有一个词 对于语料库中该单词的每次出现 我想获取一个包含该单词之前的 k 个单词和该单词之后的 k 个单词的列表 我在算法上做得很好 见下文 但我想知道 NLTK 是否提供了一些我错过的功能来满足我的需求 def size

随机推荐

  • JetBrains:推出“新一代 IDE ”!VS Code 对手来了

    近期 JetBrains 在官方博客宣布 推出一款有点不一样的轻量级编辑器 Fleet 并称其为 下一代 IDE 官方地址 https www jetbrains com zh cn fleet 官网介绍中说明 以 20 年的 IDE 开发
  • flink-python的安装

    一 下载flink flink flink python at master apache flink GitHub 二 安装pyflink yum install maven 安装maven 3 1 1以上版本 https ci apac
  • 小程序授权登陆流程

    小程序授权登陆流程 1 当用户进入微信小程序时 首先我们先判断用户是否授权过此小程序 wx getSetting wx getSetting方法获取用户的当前设置 查看是否授权 sucsess res gt 调用成功的回调函数 if res
  • 影视剪辑,PR剪辑软件两个转场教程

    一 古风渐变擦除转场 拖入视频1和视频2 将视频2放到视频1上面的轨道 2者重叠部分就是转场部分 效果 渐变擦除 拖到视频2 在开头K关键帧 效果控件 渐变擦除 过渡完成 K帧调到100 在2段视频交接处 K帧 过渡完成调到0 为了使效果更
  • 数据中台-让数据用起来-8

    文章目录 第八章 数据资产管理 8 1 数据资产的定义和3个特征 8 2 数字资产管理现状和调整 8 3 数据资产管理的4个目标 8 4 数据资产管理在数据中台架构中的位置 8 5 数据治理 8 5 1 数据治理的6个目标 8 5 2 数据
  • 【无人机路径规划】基于IRM和RRTstar进行无人机路径规划(Matlab代码实现)

    欢迎来到本博客 博主优势 博客内容尽量做到思维缜密 逻辑清晰 为了方便读者 座右铭 行百里者 半于九十 本文目录如下 目录 1 概述 2 运行结果 3 参考文献 4 Matlab代码 文章详细讲解 1 概述 本文将无人机路径规划这一非线性规
  • 2021年最新,解决xgboost安装问题:xgboost.core.XGBoostError: XGBoost Library (xgboost.dll) could not be loaded.

    1 环境 python 3 7版本 64位的 原来 python3 8版本的安装不了 平台不支持 2 直接pip3 install xgboost 3 然后有出错提示xgboost core XGBoostError XGBoost Lib
  • 2020-09-26

    package main 本文通过golang 实现msgpack字节流 参见 https github com hashicorp memberlist git util go decode encode import bytes fmt
  • 打印准考证服务器异常显示,注意了!打印准考证时,你可能遇到这些问题!

    原标题 注意了 打印准考证时 你可能遇到这些问题 2019年研究生准考证下载打印开放时间为 12月14日 12月24日 考生们一定要留心 不要错过打印时间 准考证打印流程 第一步 登录中国研究生招生信息网 并填写用户名和密码 第二步 登录完
  • Docker启动提示:Cannot connect to the Docker daemon...

    执行docker image导入时 提示 Cannot connect to the Docker daemon at unix var run docker sock Is the docker daemon running 执行dock
  • 实战:如何修改vscode作为git默认的编辑器-20211108

    目录 文章目录 目录 实验环境 实验软件 无 1 问题 如何修改vscode作为git默认的编辑器 2 配置方法 1 查看当前环境 2 开始配置 3 验证 关于我 最后 实验环境 win10 git version 2 17 0 windo
  • 硬盘的读写原理

    硬盘的种类主要是SCSI IDE 以及现在流行的SATA等 任何一种硬盘的生产都要一定的标准 随着相应的标准的升级 硬盘生产技术也在升级 比如 SCSI标准已经经历了SCSI 1 SCSI 2 SCSI 3 其中目前咱们经常在服务器网站看到
  • el-date-picker 限制固定开始时间与结束日期,用户只能在此范围内选择

    今天拿到的需求是 开始时间与结束时间是固定的 用户只能在这个范围内选择 为了用户体验好点 我选择了把不能选的日期直接置灰这种实现效果 效果如下 能清楚的看到 2023 01 04 之前的日期都不能选择 当前时间限制 开始范围是2023 01
  • handler机制的原理面试,技术水平真的很重要!真香

    面试如作战 我们看战争影视剧的时候 经常看到这些剧作往往主要聚焦于作战过程 战场战略 对战前准备给的篇幅往往很少 实际上 战前准备也是关键的一环 没有充足的粮草 车马 兵器的准备 别说赢得战争 投入战斗都不可能 这个道理在面试中也是一样 如
  • Linux环境项目以jar包形式启动,指定环境配置文件

    nohup java jar xxx jar spring profiles active DEV gt xxx logs txt
  • 选择排序和冒泡排序算法

    冒泡排序算法 Test public void sort2 int array 1 34 4 56 67 7 89 for int i 0 i lt array length 1 i for int j 0 j lt array lengt
  • 7-16 求符合给定条件的整数集 (15分)

    7 16 求符合给定条件的整数集 15分 给定不超过6的正整数A 考虑从A开始的连续4个数字 请输出所有由它们组成的无重复数字的3位数 输入格式 输入在一行中给出A 输出格式 输出满足条件的的3位数 要求从小到大 每行6个整数 整数间以空格
  • 基于CRNN的中文车牌识别

    1 概述 目前HyperLRP是一个开源的 基于深度学习高性能中文车牌识别库 本文主要在其基础上进行改动 自己训练一个crnn车牌识别模型 2 可识别的车牌类型 单行蓝牌 单行黄牌 新能源车牌 白色警用车牌 使馆 港澳车牌 教练车牌 3 可
  • 在windows上配置VScode支持ARM GCC开发环境

    简单有效的在windows上 配置VS Code 以支持GCC开发环境 没有什么花里胡哨的 需要用到的工具 Visual Studio Code 编辑工具 ARM GCC 交叉编译工具链 Msys2 命令行开发环境 mingw window
  • DDIM模型代码解析(一)

    目录 预备知识 main py 解析命令行参数 解析配置文件 预备知识 由于代码中除了一些必要的对模型 数据进行操作的PyTorch函数外 还有一些辅助显示训练等过程有关信息的 或辅助对文件目录进行操作的库 因此 建议读者先对这些库进行了解