pytorch FX模型静态量化

2023-11-18


前言

以前面文章写到的mobilenet图像分类为例,本文主要记录一下pytorchh训练后静态量化的过程。


一、pytorch静态量化(手动版)

静态量化是最常用的量化形式,float32的模型量化成int8,模型大小大概变为原来的1/4,推理速度我在intel 8700k CPU上测试速度正好快4倍,但是在AMD的5800h CPU 上测试速度反而慢了两倍,应该是AMD不支持某些指令集加速。

踩坑:

之前手动添加量化节点的方式搞了好几天,最后模型是出来了,但是推理时候报错,大多数时候是RuntimeError: Could not run ‘--------’ with arguments from the ‘CPU’ backend,网上是说推理的时候没有安插QuantStub()和DeQuantStub(),可能是用的这个MobilenetV3网络结构复杂,某些地方没有手动添加到,这种方式肯定是可以成功的,只是比较麻烦容易出错。

# 加载模型
model = MobileNetV3_Large(2).to(device)  # 加载一个网络,我这边是二分类传了一个2
checkpoint = torch.load(weights, map_location=device)
model.load_state_dict(checkpoint)
model.to('cpu').eval()

合并层对于一些重复使用的Block和nn.Sequential要打印出来看,然后append到mix_list 里面
比如

# 打印model
for name, module in model.named_children():
	print(name, module)

比如这里Sequential里面存在conv+bn+relu,append进去的应该是[‘bneck.0.conv1’, ‘bneck.0.bn1’,‘nolinear1’],但是nolinear1是个变量,也就是说某些时候是relu某些时候又不是,这种时候就要一个个分析判断好然后写代码,稍微复杂点就容易出错或者遗漏。

backend = "fbgemm"  # x86平台
model.qconfig = torch.quantization.get_default_qconfig(backend)
mix_list = [['conv1','bn1'], ['conv2','bn2']] # 合并层只支持conv+bn conv+relu conv+bn+relu等操作,具体可以查一下,网络中存在的这些操作都append到mix_list里面
model = torch.quantization.fuse_modules(model,listmix) # 合并某些层
model_fp32_prepared = torch.quantization.prepare(model)
model_int8 = torch.quantization.convert(model_fp32_prepared)

有时候存在不支持的操作relu6这些要替换成relu,加法操作也要替换,最后还要输入一批图像校准模型等

self.skip_add = nn.quantized.FloatFunctional()
# forward的时候比如return a+b 改为return self.skip_add.add(a, b)

一系列注意事项操作完毕,最后推理各种报错,放弃了

二、使用FX量化

1.版本

fx量化版本也有坑,之前在torch 1.7版本操作总是报错搞不定,换成1.12.0版本就正常了,这一点非常重要。

2.代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import torchvision
from torchvision import transforms
from torch.quantization.quantize_fx import prepare_fx, convert_fx
from torch.quantization import get_default_qconfig
from torch import optim
import os
import time
from utils import load_data
from models.mobilenetv3copy import MobileNetV3_Large


def evaluate_model(model, test_loader, device, criterion=None):
    model.eval()
    model.to(device)
    running_loss = 0
    running_corrects = 0
    for inputs, labels in test_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        if criterion is not None:
            loss = criterion(outputs, labels).item()
        else:
            loss = 0
        # statistics
        running_loss += loss * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)
    eval_loss = running_loss / len(test_loader.dataset)
    eval_accuracy = running_corrects / len(test_loader.dataset)
    return eval_loss, eval_accuracy


def quant_fx(model, data_loader):
    model_to_quantize = copy.deepcopy(model)
    model_to_quantize.eval()
    qconfig = get_default_qconfig("fbgemm")
    qconfig_dict = {"": qconfig}
    prepared_model = prepare_fx(model_to_quantize, qconfig_dict)
    print("开始校准")
    calibrate(prepared_model, data_loader)  # 这是输入一批有代表性的数据来校准
    print("校准完毕")
    quantized_model = convert_fx(prepared_model)  # 转换
    return quantized_model


def calibrate(model, data_loader):
    model.eval()
    with torch.no_grad():
        for image, target in train_loader:
            model(image)


if __name__ == "__main__":
    cuda_device = torch.device("cuda:0")
    cpu_device = torch.device("cpu:0")
    model = MobileNetV3_Large(2)  # 加载自己的网络
    train_loader, test_loader = load_data(64, 8)  # 自己写一个pytorch加载数据的方法
    
    # quantization
    state_dict = torch.load('./mymodel.pth')  # 加载一个正常训练好的模型
    model.load_state_dict(state_dict)
    model.to('cpu')
    model.eval()
    quant_model = quant_fx(model, train_loader)  # 执行量化代码
    quant_model.eval()
    print("开始验证")
    eval_loss, eval_accuracy = evaluate_model(model=quant_model,
                                              test_loader=test_loader,
                                              device=cpu_device,
                                              criterion=nn.CrossEntropyLoss())
    print("Epoch: {:02d} Eval Loss: {:.3f} Eval Acc: {:.3f}".format(
        -1, eval_loss, eval_accuracy))
    torch.jit.save(torch.jit.script(quant_model), 'outQuant.pth')  # 保存量化后的模型
    
    # 加载量化模型推理
    loaded_quantized_model = torch.jit.load('outQuant.pth')
    eval_loss, eval_accuracy = evaluate_model(model=quant_model,
                                              test_loader=test_loader,
                                              device=cpu_device,
                                              criterion=nn.CrossEntropyLoss())
    print("Epoch: {:02d} Eval Loss: {:.3f} Eval Acc: {:.3f}".format(
        -1, eval_loss, eval_accuracy))

fx量化也不用管里面什么算子不支持之类的,开箱即用,以上代码参考pytorch官网https://pytorch.org/docs/stable/fx.html
最后验证模型精度下降0.02%可以忽略不计,pytorch量化的模型是不支持gpu推理的,只能在arm或者x86平台实现压缩提速。要用cuda的话要上tensorrt+onnx,以后完成了再讲。完整的训练模型量化模型的代码后面会放到github上面。
完整代码:https://github.com/Ysnower/pytorch-static-quant


总结

刚开始搞量化坑比较多,一个是某些操作不支持,合并层麻烦,另外有版本问题导致的报错可能搞很久,觉得有用的各位吴彦祖麻烦送个免费三连

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

pytorch FX模型静态量化 的相关文章

随机推荐

  • spring gateway 的搭建与配置

    步骤 建项目 给主启动类添加Eureka的注解 EnableEurekaClient 添加并配置application yml 第一步 新建gateway的项目 gateway 8205 需要用到的组件
  • el-descriptions的使用

    el descriptions的使用 解释 我们页面有很多无序的列表展示 为了高效得去开发我们得页面 可以借助于这个组件进行适应 图片 代码 template部分
  • MIPI D-PHY介绍(二) FPGA

    MIPI D PHY介绍 二 FPGA 随着移动设备的广泛普及 MIPI D PHY作为其最主要的物理层标准之一 被越来越多地使用在各种嵌入式系统中 本文将详细介绍MIPI D PHY的工作原理和在FPGA设计中的实现方法 MIPI D P
  • 在k8s集群内搭建Prometheus监控平台

    基本架构 Prometheus由SoundCloud发布 是一套由go语言开发的开源的监控 报警 时间序列数据库的组合 Prometheus的基本原理是通过HTTP协议周期性抓取被监控组件的状态 任意组件只要提供对应的HTTP接口就可以接入
  • .NetCore技术研究-ConfigurationManager在单元测试下的坑

    最近在将原有代码迁移 NET Core 代码的迁移基本很快 当然也遇到了不少坑 重构了不少 后续逐步总结分享给大家 今天总结分享一下ConfigurationManager遇到的一个问题 先说一下场景 迁移 NET Core后 已有的配置文
  • 如何使用Visual Studio Code运行C/C++程序

    与Visual Studio 2008 2010 集成开发工具不同 Visual Studio Code只是一个代码编辑器 在Windows环境下 需下载安装 C C 编译器 配置环境等 VS Code才可以编译代码和运行程序 1 下载安装
  • javaScript基础面试题 --- 原型链

    1 原型可以解决什么问题 对象共享属性和共享方法 2 谁有原型 函数有prototype 对象有 proto 3 查找顺序 当查询一个对象的属性时 JavaScript 会首先检查对象自身是否有这个属性 如果对象本身没有该属性 那么 JS
  • 使用python和snapshot备份ElasticSearch索引数据

    该python备份snapshot的索引数据脚本 通过Elasticsearch连接es 然后通过es indices get alias函数获取所有索引名称 通过列表的startswith函数剔除 开头的自带索引名称 然后把所有索引名称放
  • 多边形的面积

    1 三角形面积 xy平面内 有三角形123 如下图所示 图1 借助矢量叉积和点积 这个三角形的面积公式非常简单 这个面积是有符号的 1 2 3逆时针排列 则面积为正 1 2 3顺时针排列 则面积为负 这是对右手系的总结 如果从背面看这个坐标
  • 11月11日 自定义Events,将自定义Events分配给UI,给UI添加动画 UE4斯坦福 学习笔记

    自定义Events 在AttributeComponent的 h头文件上加上代码 自定义Event DECLARE DYNAMIC MULTICAST DELEGATE FourParams FOnHealthChanged AActor
  • 思科模拟器简单校园网设计,期末作业难度

    文章简介 本文用思科模拟器设计和规划了一个校园网络 相当于计算机网络相关专业期末作业难度 作者简介 网络工程师 希望能认识更多的小伙伴一起交流 可私信或QQ号 1686231613 一 网络需求分析 1 学校建有办公室 实验室 教学楼 学生
  • 【STM32】RS485通信使用DMA串口发送数据出现数据丢失、断包问题排查方法

    最近在搞这个Modbus协议 由于485协议是半双工的 区别于RS 232的全双工 考虑不周导致调试modbus协议时候出了不少问题 第一 大多数开发板上的485芯片是MAX485 发送和接收状态的切换是通过IO给到这个两个引脚不同的电平进
  • win 11又更新,新功能简直绝了!

    很早之前 咱就知道微软下半年将会有一次大动作 没错 就是发布Win11 22H2正式版 之前有说过9月份发 现在也确实做到了 微软现在已经面向190多个国家 地区推送了Windows 11 22H2正式版更新 更新之后版本号为22621 5
  • linux中通过sed命令通过正则表达式过滤出中文[^[\u4E00-\u9FA5A-Za-z0-9_]+$]

    linux中通过sed命令通过正则表达式过滤出中文 sed r s u4E00 u9FA5A Za z0 9 lt gt 0 9 a z A Z g zz txt gt a txt
  • flutter listview 滚动到底部_(五) Flutter入门学习 之 Widget滚动

    列表是移动端经常使用的一种视图展示方式 在Flutter中提供了ListView和GridView 为了可能展示出更好的效果 我这里提供了一段Json数据 所以我们可以先学习一下Json解析 一 JSON读取和解析 在开发中 我们经常会使用
  • sql注入原理及解决方案

    sql注入原理就是用户输入动态的构造了意外sql语句 造成了意外结果 是攻击者有机可乘 SQL注入 SQL注入 就是通过把SQL命令插入到Web表单递交或输入域名或页面请求的查询字符串 最终达到欺骗服务器执行恶意的SQL命令 比如先前的很多
  • 随机练习题:浅浅固定思路

    1 牛牛的10类人 2 牛牛的四叶玫瑰数 3 牛牛的替换 4 牛牛的素数判断 笔者开头感想 如今大部分高校已经开学 当然笔者也不列外 但是由于疫情的原因 笔者被迫在家上网课学习 一脸忧愁 而这恰恰给了笔者自学的机会 相信笔者会加油滴 按照时
  • Acwing-4366. 上课睡觉

    假设最终答案为每堆石子均为cnt个 cnt一定可以整除sum 石子的总数 我们可以依次枚举答案 sum小于等于10 6 所以cnt的数量等于sum约数的个数 10 6范围内 约数最多的数为720720 它的约数个数有240个 int范围内
  • 单边带(SSB)调制技术

    文章目录 单边带 SSB 调制技术 1 双边带简述 2 单边带调制 单边带 SSB 调制技术 1 双边带简述 首先简述一下双边带调制 所谓双边带 DSB double sideband 调制 本质上就是调幅 时域上将基带信号x t 和高频载
  • pytorch FX模型静态量化

    文章目录 前言 一 pytorch静态量化 手动版 踩坑 二 使用FX量化 1 版本 2 代码如下 总结 前言 以前面文章写到的mobilenet图像分类为例 本文主要记录一下pytorchh训练后静态量化的过程 一 pytorch静态量化