【Kaggle】Stable Diffusion - Image to Prompts竞赛代码初步理解

2023-11-18


在这里插入图片描述

一、前言

此次代码集成了 CLIP Interrogator、OFA 模型和 ViT 模型。

首先安装指定版本的 transformers 库:

transformers-4.18.0.dev0-py3-none-any.whl 是一个 transformers 库的文件,它的命名方式表示这是一个开发版本(dev)的预构建轮子(wheel)文件。

轮子文件是 Python 包的一种打包格式,可以通过 pip 安装。

如果您想要安装这个特定版本的 transformers 库,可以使用以下命令:

pip install transformers-4.18.0.dev0-py3-none-any.whl

请确保您位于包含该文件的目录,并且在运行该命令之前已经安装了适当的依赖项。

请注意,这是一个开发版本的预构建文件,可能包含尚未正式发布的功能或存在 bug。如果您只是想使用稳定版本的 transformers 库,建议使用正式发布的版本,例如:

pip install transformers

这将安装最新的稳定版本,而不是开发版本。

在比赛中,我们可以引入的方法为:

!pip install -q /kaggle/input/stable-diffusion-data/transformers-4.18.0.dev0-py3-none-any.whl

Kaggle 环境中使用 Jupyter NotebookJupyterLab 的一个指令。

  • ! 符号是 Jupyter Notebook 或 JupyterLab 中的一个魔术命令前缀,用于执行系统级命令。
  • pip install 是用于安装 Python 包的 pip 命令。
  • -q–quiet 参数是 pip 命令的选项之一,用于使安装过程静默,即不显示安装的详细信息。
  • /kaggle/input/stable-diffusion-data/transformers-4.18.0.dev0-py3-none-any.whl 是一个文件路径,表示要安装的 transformers 库的轮子文件(.whl 文件)。根据路径中的前缀 /kaggle/input,可以推断这是在 Kaggle 环境中安装位于 /kaggle/input 目录下的本地文件。

综上所述,该命令的含义是在 Kaggle 环境中安装指定路径下的 transformers-4.18.0.dev0-py3-none-any.whl 轮子文件,并且在安装过程中不显示详细信息。这将通过使用 pip 命令将该轮子文件作为本地文件进行安装。

二、导包

import os
import sys
import glob
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from transformers import OFATokenizer, OFAModel
from transformers.models.ofa.generate import sequence_generator
import gc

这是一个 Python 脚本中的导入语句,它用于将一些常用的 Python 和深度学习库导入到脚本中。以下是每个导入库的简要介绍:

  • os:提供了一个与操作系统交互的简单方法,例如读取和写入文件。
  • sys:提供了一个与 Python 解释器交互的方法,例如修改 sys.path 导入路径、获取命令行参数等。
  • glob:提供了一个通用文件和目录的匹配模式方法。
  • pathlib.Path:提供了一个简单的面向对象的路径类,用于操作文件和目录。
  • numpy:提供了一个用于进行数学和科学计算的 Python 库。
  • pandas:提供了一个用于数据分析的 Python 库,可以用于处理和操作大型数据集。
  • matplotlib.pyplot:提供了一个用于绘制数据可视化图形的 Python 库。
  • PIL.Image:提供了一个 Python 图像处理库,可以用于处理和操作图像。
  • torch:提供了一个 PyTorch 深度学习框架库,用于创建和训练神经网络模型。
  • torch.utils.data.Dataset:提供了一个 PyTorch 数据集抽象类,用于创建自定义数据集类。
  • torchvision.transforms:提供了一些 PyTorch 中用于数据增强和预处理的转换类。
  • transformers.OFATokenizer:提供了一个 OneFlow Aware Tokenizer(OFA)类,用于在 PyTorch 中进行模型搜索和架构优化。
  • transformers.OFAModel:提供了一个 OneFlow Aware Model(OFA)类,用于在 PyTorch 中进行模型搜索和架构优化。
  • transformers.models.ofa.generate.sequence_generator:提供了一个序列生成器,用于生成一系列优化的模型架构。
  • gc:是 Python 中的垃圾回收库,可以用于管理内存。
CKPT_DIR = "/kaggle/input/stable-diffusion-data/OFA-large-caption/"
IMAGE_DIR = "/kaggle/input/stable-diffusion-image-to-prompts/images"

BATCH_SIZE = 24

这些是一些变量的赋值语句:

  • CKPT_DIR:设置为 “/kaggle/input/stable-diffusion-data/OFA-large-caption/”,表示一个目录路径,指向存储 OneFlow Aware (OFA) 模型的检查点文件的位置。
  • IMAGE_DIR:设置为 “/kaggle/input/stable-diffusion-image-to-prompts/images”,表示一个目录路径,指向存储图像文件的位置。
  • BATCH_SIZE:设置为 24,表示批量处理数据时的批量大小,即一次传递给模型的样本数量。

这些变量的值可以根据需要进行调整,用于指定数据的路径、模型的保存位置以及批量处理数据时的批量大小。

三、加载预训练的 OFA 模型

mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
resolution = 480
patch_resize_transform = transforms.Compose([
        lambda image: image.convert("RGB"),
        transforms.Resize((resolution, resolution), interpolation=Image.BICUBIC),
        transforms.ToTensor(), 
        transforms.Normalize(mean=mean, std=std)
    ])

tokenizer = OFATokenizer.from_pretrained(CKPT_DIR)
model = OFAModel.from_pretrained(CKPT_DIR, use_cache=False).cuda()
txt = " what does the image describe?"
inputs = tokenizer([txt], return_tensors="pt").input_ids

让我们逐行解读这段代码:

  • mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]:这行代码定义了 meanstd 两个变量,它们是归一化图像时使用的均值和标准差。这里设置的均值和标准差都是 [0.5, 0.5, 0.5],表示将图像的每个通道的像素值缩放到范围 [-1, 1]
  • resolution = 480:这行代码定义了 resolution 变量,它表示将图像调整为的分辨率大小。在这里,图像将被调整为 480 x 480 像素。
  • patch_resize_transform = transforms.Compose([…]):这行代码定义了一个转换序列 patch_resize_transform,它将应用于输入图像。转换序列中的每个操作按顺序应用于图像。
  • image.convert(“RGB”) 将图像转换为 RGB 模式,确保图像有三个通道。
  • transforms.Resize((resolution, resolution), interpolation=Image.BICUBIC) 调整图像大小为给定的 resolution,使用双三次插值方法进行调整。
  • transforms.ToTensor() 将图像转换为张量形式,将像素值缩放到范围 [0, 1]
  • transforms.Normalize(mean=mean, std=std) 对图像进行标准化,将像素值归一化为均值为 mean,标准差为 std 的分布。
  • tokenizer = OFATokenizer.from_pretrained(CKPT_DIR):这行代码创建一个 OFATokenizer 对象,从预训练模型的检查点文件中加载 tokenizer
  • model = OFAModel.from_pretrained(CKPT_DIR, use_cache=False).cuda():这行代码创建一个 OFAModel 对象,并加载预训练模型的权重。use_cache = False 表示不使用缓存。
  • txt = " what does the image describe?":这行代码定义了一个字符串变量 txt,它包含了图像描述的文本。
  • inputs = tokenizer([txt], return_tensors=“pt”).input_ids:这行代码使用之前创建的 tokenizer 对象将文本 txt 编码为模型输入的张量形式。tokenizer([txt]) 将文本转换为 tokens,并返回一个字典对象。.input_ids 提取了输入 tokens 的张量表示。最终,inputs 变量将包含文本编码后的输入张量。

总的来说,这段代码的目的是为了准备图像和文本数据,以便将它们输入到 OFA 模型中进行处理和生成。

四、模型EDA

sample_images = glob.glob("/kaggle/input/stable-diffusion-image-to-prompts/images/*")[:7]
fig, ax = plt.subplots(7,1, figsize=(4,35))

for i,impath in enumerate(sample_images):
    image = Image.open(impath)
    image_t = patch_resize_transform(image).cuda().unsqueeze(0)
    out = model.generate(inputs.cuda(), patch_images=image_t.cuda(), num_beams=5, no_repeat_ngram_size=2)
    out_captions = tokenizer.batch_decode(out, skip_special_tokens=True)
    ax[i].imshow(image)
    ax[i].text(1.1, .5, out_captions[0], horizontalalignment='left', verticalalignment='center', transform=ax[i].transAxes)

让我们逐行解释这段代码:

  • sample_images = glob.glob(“/kaggle/input/stable-diffusion-image-to-prompts/images/*”)[:7]:这行代码使用 glob.glob 函数获取 /kaggle/input/stable-diffusion-image-to-prompts/images/ 目录下的图像文件路径,并选择前 7 个图像文件。这些文件路径被存储在 sample_images 列表中。
  • fig, ax = plt.subplots(7, 1, figsize=(4, 35)):这行代码创建了一个包含 7 行、1 列的子图布局,每个子图的大小为 (4, 35)。返回的 fig 对象表示整个图像,ax 对象是包含 7 个子图的数组。
  • for i, impath in enumerate(sample_images)::这是一个循环,遍历 sample_images 列表中的图像文件路径。enumerate 函数用于同时迭代列表中的元素和它们的索引。在每次迭代中,i 是索引,impath 是当前的图像文件路径。
  • image = Image.open(impath):这行代码使用 PIL 库的 Image.open 函数打开图像文件,并将图像对象存储在 image 变量中。
  • image_t = patch_resize_transform(image).cuda().unsqueeze(0):这行代码对打开的图像 image 应用之前定义的 patch_resize_transform 转换序列,将图像调整大小并进行标准化处理。然后,使用 .cuda() 将图像张量移动到 GPU 上,并使用 unsqueeze(0) 在批次维度上添加一个维度。
  • out = model.generate(inputs.cuda(), patch_images=image_t.cuda(), num_beams=5, no_repeat_ngram_size=2):这行代码使用预训练的 OFA 模型 model 生成文本。它接收一个输入文本的张量 inputs,以及调整大小和标准化后的图像张量 image_tnum_beams=5 表示使用束搜索方法生成多个可能的文本输出,no_repeat_ngram_size=2 表示生成的文本中不会有连续重复的 2-gram
  • out_captions = tokenizer.batch_decode(out, skip_special_tokens=True):这行代码使用 tokenizer 将模型生成的输出 out 解码为文本。skip_special_tokens = True 表示跳过特殊标记,如起始和结束标记。
  • ax[i].imshow(image):这行代码在第 i 个子图中显示图像。
  • ax[i].text(1.1, .5, out_captions[0], horizontalalignment=‘left’, verticalalignment=‘center’, transform=ax[i].transAxes):这行代码在第 i 个子图中添加文本标注。1.1, .5 是文本的位置坐标,horizontalalignment = ‘left’ 表示文本水平对齐方式为左对齐,verticalalignment = ‘center’ 表示文本垂直对齐方式为居中对齐。ax[i].transAxes 表示使用子图坐标系进行转换。

这段代码的目的是展示图像并在每张图像上显示由 OFA 模型生成的文本描述。它首先遍历了前 7 个图像文件的路径,然后对每张图像进行处理:调整大小、标准化,并传递给 OFA 模型生成文本描述。生成的文本描述被解码后,用作文本标注,并与相应的图像一起显示在子图中。

最终的结果是,在一个具有 7 行、1 列的图像布局中,显示了每张图像以及由 OFA 模型生成的相应文本描述。

五、Inference

sys.path.append('../input/sentence-transformers-222/sentence-transformers')
from sentence_transformers import SentenceTransformer, models

comp_path = Path('../input/stable-diffusion-image-to-prompts/')
st_model = SentenceTransformer('/kaggle/input/sentence-transformers-222/all-MiniLM-L6-v2')

让我们逐行解读上面的代码:

  • sys.path.append(‘…/input/sentence-transformers-222/sentence-transformers’):这行代码将 ‘…/input/sentence-transformers-222/sentence-transformers’ 目录添加到 Python 的模块搜索路径中,以便能够导入其中的模块。
  • from sentence_transformers import SentenceTransformer, models:这行代码从 sentence_transformers 模块中导入 SentenceTransformermodels 类。这些类提供了用于处理和生成句子向量的功能。
  • comp_path = Path(‘…/input/stable-diffusion-image-to-prompts/’):这行代码创建了一个 Path
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【Kaggle】Stable Diffusion - Image to Prompts竞赛代码初步理解 的相关文章

  • 如何在刻度标签和轴之间添加空间

    我已成功增加刻度标签的字体 但现在它们距离轴太近了 我想在刻度标签和轴之间添加一点呼吸空间 如果您不想全局更改间距 通过编辑 rcParams 并且想要更简洁的方法 请尝试以下操作 ax tick params axis both whic
  • Python、Tkinter、更改标签颜色

    有没有一种简单的方法来更改按钮中文本的颜色 I use button text input text here 更改按下后按钮文本的内容 是否存在类似的颜色变化 button color red Use the foreground设置按钮
  • InterfaceError:连接已关闭(使用 django + celery + Scrapy)

    当我在 Celery 任务中使用 Scrapy 解析函数 有时可能需要 10 分钟 时 我得到了这个信息 我用 姜戈 1 6 5 django celery 3 1 16 芹菜 3 1 16 psycopg2 2 5 5 我也使用了psyc
  • 如何在android上的python kivy中关闭应用程序后使服务继续工作

    我希望我的服务在关闭应用程序后继续工作 但我做不到 我听说我应该使用startForeground 但如何在Python中做到这一点呢 应用程序代码 from kivy app import App from kivy uix floatl
  • 如何在 Sublime Text 2 的 OSX 终端中显示构建结果

    我刚刚从 TextMate 切换到 Sublime Text 2 我非常喜欢它 让我困扰的一件事是默认的构建结果显示在 ST2 的底部 我的程序产生一些很长的结果 显示它的理想方式 如在 TM2 中 是并排查看它们 如何在 Mac 操作系统
  • 为 pandas 数据透视表中的每个值列定义 aggfunc

    试图生成具有多个 值 列的数据透视表 我知道我可以使用 aggfunc 按照我想要的方式聚合值 但是如果我不想对两列求和或求平均值 而是想要一列的总和 同时求另一列的平均值 该怎么办 那么使用 pandas 可以做到这一点吗 df pd D
  • 从 scikit-learn 导入 make_blobs [重复]

    这个问题在这里已经有答案了 我收到下一个警告 D Programming Python ML venv lib site packages sklearn utils deprecation py 77 DeprecationWarning
  • 运行多个 scrapy 蜘蛛的正确方法

    我只是尝试使用在同一进程中运行多个蜘蛛新的 scrapy 文档 http doc scrapy org en 1 0 topics practices html但我得到 AttributeError CrawlerProcess objec
  • Python 中的二进制缓冲区

    在Python中你可以使用StringIO https docs python org library struct html用于字符数据的类似文件的缓冲区 内存映射文件 https docs python org library mmap
  • python pandas 中的双端队列

    我正在使用Python的deque 实现一个简单的循环缓冲区 from collections import deque import numpy as np test sequence np array range 100 2 resha
  • 如何改变Python中特定打印字母的颜色?

    我正在尝试做一个简短的测验 并且想将错误答案显示为红色 欢迎来到我的测验 您想开始吗 是的 祝你好运 法国的首都是哪里 法国 随机答案不正确的答案 我正在尝试将其显示为红色 我的代码是 print Welcome to my Quiz be
  • 从 pygame 获取 numpy 数组

    我想通过 python 访问我的网络摄像头 不幸的是 由于网络摄像头的原因 openCV 无法工作 Pygame camera 使用以下代码就像魅力一样 from pygame import camera display camera in
  • 如何将 PIL 图像转换为 NumPy 数组?

    如何转换 PILImage来回转换为 NumPy 数组 这样我就可以比 PIL 进行更快的像素级转换PixelAccess允许 我可以通过以下方式将其转换为 NumPy 数组 pic Image open foo jpg pix numpy
  • 为美国东部以外地区的 Cloudwatch 警报发送短信?

    AWS 似乎没有为美国东部以外的 SNS 主题订阅者提供 SMS 作为协议 我想连接我的 CloudWatch 警报并在发生故障时接收短信 但无法将其发送到 SMS YES 经过一番挖掘后 我能够让它发挥作用 它比仅仅选择一个主题或输入闹钟
  • 在Python中重置生成器对象

    我有一个由多个yield 返回的生成器对象 准备调用该生成器是相当耗时的操作 这就是为什么我想多次重复使用生成器 y FunctionWithYield for x in y print x here must be something t
  • 在 Pandas DataFrame Python 中添加新列[重复]

    这个问题在这里已经有答案了 例如 我在 Pandas 中有数据框 Col1 Col2 A 1 B 2 C 3 现在 如果我想再添加一个名为 Col3 的列 并且该值基于 Col2 式中 如果Col2 gt 1 则Col3为0 否则为1 所以
  • 对输入求 Keras 模型的导数返回全零

    所以我有一个 Keras 模型 我想将模型的梯度应用于其输入 这就是我所做的 import tensorflow as tf from keras models import Sequential from keras layers imp
  • 如何使用google colab在jupyter笔记本中显示GIF?

    我正在使用 google colab 想嵌入一个 gif 有谁知道如何做到这一点 我正在使用下面的代码 它并没有在笔记本中为 gif 制作动画 我希望笔记本是交互式的 这样人们就可以看到代码的动画效果 而无需运行它 我发现很多方法在 Goo
  • 在 Python 类中动态定义实例字段

    我是 Python 新手 主要从事 Java 编程 我目前正在思考Python中的类是如何实例化的 我明白那个 init 就像Java中的构造函数 然而 有时 python 类没有 init 方法 在这种情况下我假设有一个默认构造函数 就像
  • Python:元类属性有时会覆盖类属性?

    下面代码的结果让我感到困惑 class MyClass type property def a self return 1 class MyObject object metaclass MyClass a 2 print MyObject

随机推荐

  • LRU LFU 概念、底层原理及其实现 超详细~

    0 前置提要 本篇约为8650字 阅读完需要约40 60分钟 主要介绍页面置换算法 LRU和LFU的原理及其实现 对应leetcode140和460 如果能给个赞就更好了 1 从内存置换算法说起 计算机的运行的程序和数据保存在内存中 内存的
  • C++ MYSQL 多线程并发查询处理数据

    1 ThreadPool hpp pragma once ifndef THREAD POOL H define THREAD POOL H include
  • ERROR: Could not find a version that satisfies the requirement keras-nightly (from versions: none)

    问题描述 这个错误发生于博主在python 3 8的环境里安装tensorflow 2 5 0时 执行pip install tensorflow 2 5 0 出现了以下错误 ERROR Could not find a version t
  • Java正则表达式详解

    1 1 正则表达式的概念以及演示 正则表达式可以用一些规定的字符来制定规则 并用来校验数据格式的合法性 正则表达式就是用来验证各种字符串的规则 它内部描述了一些规则 我们可以验证用户输入的字符串是否匹配这个规则 正则表达式是一种强大的校验机
  • 2023-8-23 堆排序

    题目链接 堆排序 include
  • python写程序计算无穷级数_[python][计算方法]利用无穷级数计算幂运算(开根号)...

    encoding gbk a的n次方函数 def exp a n ret 1 for i in range 0 n ret a return float ret n n 1 n 2 def getN minus n n x ret floa
  • 【教程】加速访问和下载github项目,原来替换一个域名就可以加速了

    目录 前言 gitee方法 更简便方法 使用教程 前言 大家平时下载github项目的时候 非常的慢 有时候浏览某个项目的md介绍时候 图片就是加载不出来 让人很苦恼 想锤电脑 gitee方法 于是有很多人都是用gitee的方法 先导入到g
  • 【存储管理】brk()系统调用

    尽管应用程序编程时很少直接调用brk 系统调用 但是它是最经常使用的系统调用 1 C语言中的malloc以及C 语言中的new都在间接的调用着brk 这个系统调用 内核中含有3GB的用户虚存空间 会部分映射到物理存储空间 用户程序经过编译
  • vue中怎么引入element以及使用的详细教程

    引入element 安装依赖 在使用 Element 之前 需要先安装 Element 的依赖库 可以使用 npm 或者 yarn 安装 npm npm i element ui S yarn yarn add element ui 引入C
  • Qt 如何关闭 Debug信息输出

    在pro文件中加上DEFINES QT NO DEBUG OUTPUT 然后重新构建一下程序 qDebug的信息就不再输出了 不过qWarning qCritical等信息仍然可以输出 类似的宏还有 QT NO INFO OUTPUT QT
  • 剑指Offer第五十八题:对称的二叉树

    题目描述 请实现一个函数 用来判断一颗二叉树是不是对称的 注意 如果一个二叉树同此二叉树的镜像是同样的 定义其为对称的 1 思路 我们通常有三种不同的二叉树遍历算法 即前序遍历 中序遍历和后序遍历 在这三种遍历算法中 都是先遍历左子结点再遍
  • 良许Linux

    Linux 服务器我们天天打交道 特别是 Linux 工程师更是如此 为了保证服务器的安全与性能 我们经常需要监控服务器的一些状态 以保证工作能顺利开展 本文介绍的几个命令 不仅仅适用于服务器监控 也适用于我们日常情况下的开发 1 watc
  • depcheck检测缺失哪些依赖包

    npm install g depcheck 如果不想全局安装 npm i depcheck后可以在package json的scripts中输入 check depcheck 之后使用 npm run check depcheck npm
  • umi-request 网络请求之路

    umi request 网络请求之路 背景 在做中台业务应用开发的过程中 我们发现在请求链路上存在以下问题 请求库各式各样 没有统一 每次新起应用都需要重复实现一套请求层逻辑 切换应用时需要重新学习请求库 API 各应用接口设计不一致 混乱
  • sql注入Less11-20

    Less 11 POST 1 先登录 在表格中输入admin admin 登录成功后为下图 2 在post data中输入以下 uname passwd 1 submit submit 返回的结果显示存在sql语法错误 证明存在注入漏洞 u
  • 修改别人代码的原则

    工作过程中难免会涉及到修改或维护别人写的代码 如 代码原作者请假 离职 或相关的bug落到了你的头上 或用别人写的通用方法不爽时 如果碰到修改别人的代码时 需要注意哪些事项呢 1 和原作者沟通 当用到了他人写的通用方法 又感觉不爽时 如果原
  • 各个版本chrome允许加载使用flash的方法

    根除办法 在html中嵌入标签 用户自动选择是否加载flash 69 0 之前的版本 1 打开 chrome settings content flash 2 禁止网站运行Flash gt 改为 Ask Default 3 允许 gt 添加
  • golang开发的准备 - gvm(go版本管理软件)的安装

    0 系统环境 ubuntu18 04 1 前置条件 sudo apt get install bison 2 安装步骤 1 从github下载安装包文件 git clone https github com moovweb gvm git
  • 【c++】14.编译proto和proto相关用法

    编译proto和proto相关用法 关于proto相关的知识可以参考系列博客 https blog csdn net daaikuaichuan category 9869251 html xx proto文件中如果要注释的话 注释符号也是
  • 【Kaggle】Stable Diffusion - Image to Prompts竞赛代码初步理解

    文章目录 一 前言 二 导包 三 加载预训练的 OFA 模型 四 模型EDA 五 Inference 六 安装并导入所有依赖项 七 设置配置 八 加载示例提交 九 Build index from images 十 CLIP interro