将Hugging Face模型转换成LibTorch模型

2023-11-02

Hugging Face的模型

waifu-diffusion模型为例,给出的实现一般是基于diffuser库,示例代码如下:

import torch
from torch import autocast
from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained(
    'hakurei/waifu-diffusion',
    torch_dtype=torch.float32
).to('cuda')

prompt = "1girl, aqua eyes, baseball cap, blonde hair, closed mouth, earrings, green background, hat, hoop earrings, jewelry, looking at viewer, shirt, short hair, simple background, solo, upper body, yellow shirt"
with autocast("cuda"):
    image = pipe(prompt, guidance_scale=6)["sample"][0]  
    
image.save("test.png")

通过网络下载预训练模型,预训练模型直接加载,但其实这个模型是下载到了本地的,只不过看起来不是很轻松:

因为模型太大,分成了一些小的文件进行了下载,而且后面可以看出来模型实际上是由一些子模型组成的,所以这里面有几个比较大的文件应该是对应了unet、vae这种,看大小也差不多。

下载好了可以直接print(pipe),发现:

StableDiffusionPipeline {
  "_class_name": "StableDiffusionPipeline",
  "_diffusers_version": "0.11.0",
  "feature_extractor": [
    "transformers",
    "CLIPImageProcessor"
  ],
  "requires_safety_checker": true,
  "safety_checker": [
    "stable_diffusion",
    "StableDiffusionSafetyChecker"
  ],
  "scheduler": [
    "diffusers",
    "PNDMScheduler"
  ],
  "text_encoder": [
    "transformers",
    "CLIPTextModel"
  ],
  "tokenizer": [
    "transformers",
    "CLIPTokenizer"
  ],
  "unet": [
    "diffusers",
    "UNet2DConditionModel"
  ],
  "vae": [
    "diffusers",
    "AutoencoderKL"
  ]
}

果然是一系列的小模型以及一些不重要的参数,这个模型可以直接保存为.pth文件,同样也可以使用torch.load(pipe.pth)读入,但是在实例化模型的时候,会出现

Traceback (most recent call last):
  File "/home/gaoyi/example-app/test.py", line 59, in <module>
    traced_script_module = torch.jit.trace(model, example)
  File "/home/gaoyi/anaconda3/lib/python3.9/site-packages/torch/jit/_trace.py", line 803, in trace
    name = _qualified_name(func)
  File "/home/gaoyi/anaconda3/lib/python3.9/site-packages/torch/_jit_internal.py", line 1125, in _qualified_name
    raise RuntimeError("Could not get name of python class object")
RuntimeError: Could not get name of python class object

这是因为这个大家伙不能作为一个模型类加载,故也不能直接通过torch.jit.trace进行转化,我们换个方式,将子模型进行转化

模型转化

通过打印print(pipe.unet),可以看出这个unet是一个普通的网络,拥有一堆熟悉的网络层:

UNet2DConditionModel(
  (conv_in): Conv2d(4, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (time_proj): Timesteps()
  (time_embedding): TimestepEmbedding(
    (linear_1): Linear(in_features=320, out_features=1280, bias=True)
    (act): SiLU()
    (linear_2): Linear(in_features=1280, out_features=1280, bias=True)
  )
  (down_blocks): ModuleList(
    (0): CrossAttnDownBlock2D(
      (attentions): ModuleList(
        (0): Transformer2DModel(
          (norm): GroupNorm(32, 320, eps=1e-06, affine=True)
          (proj_in): Linear(in_features=320, out_features=320, bias=True)
          (transformer_blocks): ModuleList(
            (0): BasicTransformerBlock(
              (attn1): CrossAttention(
                (to_q): Linear(in_features=320, out_features=320, bias=False)
                (to_k): Linear(in_features=320, out_features=320, bias=False)
                (to_v): Linear(in_features=320, out_features=320, bias=False)
                (to_out): ModuleList(
                  (0): Linear(in_features=320, out_features=320, bias=True)
                  (1): Dropout(p=0.0, inplace=False)
                )
              )
              (ff): FeedForward(
                (net): ModuleList(
                  (0): GEGLU(
                    (proj): Linear(in_features=320, out_features=2560, bias=True)
                  )
                  (1): Dropout(p=0.0, inplace=False)
                  (2): Linear(in_features=1280, out_features=320, bias=True)
                )
              )
              (attn2): CrossAttention(
                (to_q): Linear(in_features=320, out_features=320, bias=False)
                (to_k): Linear(in_features=1024, out_features=320, bias=False)
                (to_v): Linear(in_features=1024, out_features=320, bias=False)
                (to_out): ModuleList(
                  (0): Linear(in_features=320, out_features=320, bias=True)
                  (1): Dropout(p=0.0, inplace=False)
                )
              )
              (norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
              (norm2): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
              (norm3): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
            )
          )
          (proj_out): Linear(in_features=320, out_features=320, bias=True)
        )
        (1): Transformer2DModel(
       
        ...
        ...略
        ...
        
  (conv_norm_out): GroupNorm(32, 320, eps=1e-05, affine=True)
  (conv_act): SiLU()
  (conv_out): Conv2d(320, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)

好,那我们就可以将这个子模型进行转化,变成需要的LibTorch模型,但是我们不知道这个模型需要的输入,通过打印的信息我们知道了这个模型的名字是UNet2DConditionModel,所以我们可以从Hugging Face的官方文档进行查询:UNet2DConditionModel

查询发现模型的输入为:

但是具体的数值依旧不知道,这时候可以通过print(model.config)进行查看:

FrozenDict([('sample_size', 64), ('in_channels', 4), ('out_channels', 4), ('center_input_sample', False), 
('flip_sin_to_cos', True), ('freq_shift', 0), ('down_block_types', ['CrossAttnDownBlock2D', 
'CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D']), ('mid_block_type', 
'UNetMidBlock2DCrossAttn'), ('up_block_types', ['UpBlock2D', 'CrossAttnUpBlock2D', 
'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D']), ('only_cross_attention', False), 
('block_out_channels', [320, 640, 1280, 1280]), ('layers_per_block', 2), ('downsample_padding', 1), 
('mid_block_scale_factor', 1), ('act_fn', 'silu'), ('norm_num_groups', 32), ('norm_eps', 1e-05), 
('cross_attention_dim', 1024), ('attention_head_dim', [5, 10, 20, 20]), ('dual_cross_attention', False), 
('use_linear_projection', True), ('class_embed_type', None), ('num_class_embeds', None), 
('upcast_attention', False), ('resnet_time_scale_shift', 'default'), ('_class_name', 'UNet2DConditionModel'), 
('_diffusers_version', '0.10.2'), ('_name_or_path', 
'/home/gaoyi/.cache/huggingface/diffusers/models--hakurei--waifu-diffusion/snapshots/55fd50bfae0dd8bcc4bd3a6f25cb167580b972a0/unet')])

一个大字典,找到我们所需要的('sample_size', 64), ('in_channels', 4), ('out_channels', 4),作为用于实例化的输入,此时我们的.py文件如下:

model = torch.load("pipe-unet.pth")

# print(model.config)
# print(model)

example = torch.rand(1, 4, 64, 64)
timestep = torch.rand(1)
encoder_hidden_states = torch.rand(1, 4, 64, 64)

traced_script_module = torch.jit.trace(model, (example, timestep, encoder_hidden_states))
traced_script_module.save("pipe-unet.pt")

但是报错mat1 can not be multiplied with mat2, shape 256x64 and 1024x320,大概是这么个问题,具体的信息就不粘贴了,既然是矩阵形状不对,那就改形状,之前理解的encoder_hidden_states形状与example应该是一样的,但看起来不对,可是改了1024x1024之后又遇到了新的问题,计算注意力的时候数据太多,接受的参数只有三个,所以干脆将encoder_hidden_states = torch.rand(1, 4, 1024),实测通过

之后的新问题,好像是实例化的时候输入元组的问题,具体如下:

RuntimeError: Encountering a dict at the output of the tracer might cause the trace to be incorrect, 
this is only valid if the container structure does not change based on the module's inputs. 
Consider using a constant container instead (e.g. for `list`, use a `tuple` instead. for `dict`, 
use a `NamedTuple` instead). If you absolutely need this and know the side effects, 
pass strict=False to trace() to allow this behavior.

应该是需要在转换的时候传一个参数strict=False,调整完之后的代码如下:

model = torch.load("pipe-unet.pth")

# print(model.config)
# print(model)

example = torch.rand(1, 4, 64, 64)
timestep = torch.rand(1)
encoder_hidden_states = torch.rand(1, 4, 1024)

traced_script_module = torch.jit.trace(model, (example, timestep, encoder_hidden_states), strict=False)
traced_script_module.save("pipe-unet.pt")

成功保存!

模型测试

根据PyTorch官网的测试教程,编写相应的C++文件,然后使用CMake进行编译,最终生成example-app的可执行文件,运行:

./example-app ../pipe-unet.pt

输出ok,成功转化!

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

将Hugging Face模型转换成LibTorch模型 的相关文章

随机推荐

  • AI 协助办公 |记一次用 GPT-4 写一个消息同步 App

    GPT 4 最近风头正劲 作为 NebulaGraph 的研发人员的我自然是跟进新技术步伐 恰好 现在有一个将 Slack channel 消息同步到其他 IM 的需求 看看 GPT 4 能不能帮我完成这次的信息同步工具的代码编写工作 本文
  • 二叉树的翻转

    目录 一 题目 二 解题思路 1 二叉树翻转 2 具体步骤 迭代法 三 代码实现 一 题目 1 leetcode链接 力扣 2 题目内容 给你一棵二叉树的根节点 root 翻转这棵二叉树 并返回其根节点 示例 1 输入 root 4 2 7
  • LeetCode No3. 无重复字符的最长子串 题解

    文章目录 一 题目 二 算法思想 三 示例 四 代码 五 复杂度分析 六 算法评价 一 题目 给定一个字符串 s 请你找出其中不含有重复字符的 最长子串 的长度 示例 1 输入 s abcabcbb 输出 3 解释 因为无重复字符的最长子串
  • 从高中到大学 寻找真实的自己

    写在前面 这是这个寒假刚开始在CSDN上写博客的时候发的第一个blink 当时想说的话有点多 但blink的文字限制是1024字 所以那时控制了字数 现在放开重新写 写在正文 因为疫情原因在家上了差不多3个月的网课 大一回来过个寒假 再次回
  • 2020年研究生数学建模竞赛优秀论文汇总

    A题 ASIC 芯片上的载波恢复 DSP算法设计与实现论文1 论文2 论文3 论文4 论文5 B题 降低汽油精制过程中的辛烷值损失模型论文1 论文2 论文3 论文4 论文5 论文6 论文7 论文8 论文9 论文10 C题 面向康复工程的脑电
  • HTTP协议2)----对于传输层的详细讲解

    大家好 我是 兔7 一位努力学习C 的博主 如果文章知识点有错误的地方 请指正 和大家一起学习 一起进步 如有不懂 可以随时向我提问 我会全力讲解 如果感觉博主的文章还不错的话 希望大家关注 点赞 收藏三连支持一下博主哦 你们的支持是我创作
  • pythonfilter_Python如何用filter函数筛选数据

    一 filter函数简介 filter函数主要用来筛选数据 过滤掉不符合条件的元素 并返回一个迭代器对象 如果要转换为列表list或者元祖tuple 可以使用内置函数list 或者内置函数tuple 来转换 filter函数接收两个参数 第
  • Altium Designer可以实现选中整条同网络线路的快捷键

    选中一段线路 按Tab键 可以选中同网络的整条线路
  • Masked Autoencoders Are Scalable Vision Learners

    Masked Autoencoders Are Scalable Vision Learners Author Unit Facebook AI Research FAIR Authors Kaiming He
  • Finclip小程序目录结构与微信小程序目录结构

    Finclip小程序目录结构 小程序包含一个描述整体程序的 app 和多个描述各自页面的 page 一个小程序主体部分由三个文件组成 必须放在项目的根目录 如下 文件 必需 作用 app js 是 小程序逻辑 app json 是 小程序公
  • 两个无序的数组 如何进行合并 为一个有序的数组

    这里我们首先来看 自己也才毕业半年 这些题比较适合新手练练思想 技术之路且行且珍惜 算法绝对是核心竞争力 两个无序的数组 那么首先第一步合并 第二步 使用正则表达式去掉 第三步 split进行划分 第四步 最核心的排序 此处用了Arrays
  • MYSQL索引那些事

    一 关系型和非关系型的区别 以及使用场景 关系型数据库 采用关系模型来组织数据的数据库 关系模型就是二维表格模型 一张二维表的表名就是关系 二维表中的一行就是一条记录 二维表中的一列就是一个字段 优点 容易理解 使用方便 通用的 sql 语
  • Ceph OSD Down

    CEPH集群跑了一段时间后有几个OSD变成down的状态了 但是我用这个命令去activate也不行 ceph deploy osd activate osd1 dev sdb2 dev sdb1 只能把osd从集群中移除 然后再重建了 这
  • 【我的Android进阶之旅】如何快速寻找Android第三方开源库在Jcenter上的最新版本...

    问题描述 解决方法 先了解compile comsquareupokhttpokhttp240的意义 了解Jcenter和Maven jcenter Maven Central 理解jcenter和Maven Central 快速搜索方法1
  • 改造我们的学习

    我们知道 程序员必须得不断的学习 才能跟上日新月异的技术 但是很多朋友陷入了误区 比如学习C 总觉得我要把 C Primier 看完 再开始编程 学习图像处理也是 非要把数字图像处理与Opencv的书籍看完 才开始上机调试 最后云里雾里 感
  • 零基础Qt笔记<传智教育>Qt版本:2022 5.15

    目录 1 创建第一个Qt程序 2 命名规范以及快捷键 3 QPushBottom的创建 4 对象树 5 Qt中的坐标系 6 信号和槽 6 1 实现点击按钮关闭窗口 6 2 自定义的信号和槽 6 3 自定义的信号和槽发生重载的解决 6 4 信
  • 电话号码升位(拷贝构造函数)

    题目描述 定义一个电话号码类CTelNumber 包含1个字符指针数据成员 以及构造 析构 打印及拷贝构造函数 字符指针是用于动态创建一个字符数组 然后保存外来输入的电话号码 构造函数的功能是为对象设置键盘输入的7位电话号码 拷贝构造函数的
  • python编程实战(三):暴力破解WIFI密码!亲测运行有效!

    本文非原创 参考 Python破解WIFI密码详细介绍 对于代码有细微修改 增加注意事项介绍 声明 本文只是从技术的角度来阐述学习Pywifi库 并不建议大家做任何破坏性的操作和任何不当的行为 并不建议大家做任何破坏性的操作和任何不当的行为
  • js分治法入门级教程,二分搜索的解法

    一 分治法定义 在计算机科学中 分治法是一种很重要的算法 字面上的解释是 分而治之 分治法就是把一个复杂的问题分成两个或更多的相同或相似的子问题 再把子问题分成更小的子问题 直到最后子问题可以简单的直接求解 原问题的解即子问题的解的合并 分
  • 将Hugging Face模型转换成LibTorch模型

    Hugging Face的模型 以waifu diffusion模型为例 给出的实现一般是基于diffuser库 示例代码如下 import torch from torch import autocast from diffusers i