使用tensorrt加速深度学习模型推断

2023-12-05

此博客介绍如何将resnet101模型在CIFAR100数据集的分类任务,使用tensorrt部署。

完整代码如下

1.import以及数据加载、构建engine函数

import argparse
import os

import torch
from torch.utils.data import DataLoader
import torchvision
import torchvision.models as models

import time


import numpy as np
import tensorrt as trt
import common
import torchvision.transforms as transforms

TRT_LOGGER = trt.Logger()
os.environ["CUDA_VISIBLE_DEVICES"] = '0'  # 指定0号GPU可用


# mean and std of cifar100 dataset
CIFAR100_TRAIN_MEAN = (
    0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
CIFAR100_TRAIN_STD = (0.2673342858792401,
                      0.2564384629170883, 0.27615047132568404)

def get_test_dataloader(mean, std, batch_size=16, num_workers=2, shuffle=True):
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
    cifar100_test = torchvision.datasets.CIFAR100(
        root='./data', train=False, download=True, transform=transform_test)
    cifar100_test_loader = DataLoader(
        cifar100_test, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size)

    return cifar100_test_loader


def ONNX_build_engine(onnx_file_path, trt_file):
    G_LOGGER = trt.Logger(trt.Logger.WARNING)
    explicit_batch = 1 << (int)(
        trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
    batch_size = 64  
    with trt.Builder(G_LOGGER) as builder, builder.create_network(explicit_batch) as network, \
            trt.OnnxParser(network, G_LOGGER) as parser:
        builder.max_batch_size = batch_size
        config = builder.create_builder_config()
        config.set_memory_pool_limit(
            trt.MemoryPoolType.WORKSPACE, common.GiB(1))
        config.set_flag(trt.BuilderFlag.FP16)
        print('Loading ONNX file from path {}...'.format(onnx_file_path))
        with open(onnx_file_path, 'rb') as model:
            print('Beginning ONNX file parsing')
            parser.parse(model.read())
        print('Completed parsing of ONNX file')
        print('Building an engine from file {}; this may take a while...'.format(
            onnx_file_path))

        profile = builder.create_optimization_profile()
        profile.set_shape("input", (1, 3, 32, 32),
                          (1, 3, 32, 32), (batch_size, 3, 32, 32))
        config.add_optimization_profile(profile)
        engine = builder.build_serialized_network(network, config)
        print("Completed creating Engine")
        with open(trt_file, "wb") as f:
            f.write(engine)
        return engine

2.导入官方模型及CIFAR100数据集


if __name__ == '__main__':

    parser = argparse.ArgumentParser()

    parser.add_argument('-gpu', action='store_true',
                        default=True, help='use gpu or not')
    parser.add_argument('-b', type=int, default=32,
                        help='batch size for dataloader')
    args = parser.parse_args()
    print(args)

    cifar100_test_loader = get_test_dataloader(
        CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD,
        num_workers=1,
        batch_size=args.b)


    device = "cuda" if args.gpu else "cpu"
    net = models.resnet101(pretrained=True)
    net = net.to(device)
    # # print(net)
    net.eval()

3.不采用tensort的推断时间

#%%
    t1 = time.time()
    for n_iter, (image, label) in enumerate(cifar100_test_loader):
        pred = net(image.to(device))
        # print(pred.shape)
    t2 = time.time()
    print(t2-t1)

耗时约为8~9s。

4.采用tensort加速—使用tensorrt 库

4.1 导出onnx模型

#%% save onnx 
    input = torch.rand([1, 3, 32, 32]).to(device)
    onnx_file = "resnet101.onnx"

    if  os.path.exists(onnx_file):
        os.remove(onnx_file)
    torch.onnx.export(net, input, onnx_file,
                      input_names=['input'],  # the model's input names
                      output_names=['output'],
                      dynamic_axes={'input': {0: 'batch_size'},
                                    'output': {0: 'batch_size'}},
                      # opset_version=12,
                      )
    print("onnx file generated!")

4.2 生成tensorrt engine 文件

# %%generate tensorrt engine file
    trt_file = "resnet101.trt"

    ONNX_build_engine(onnx_file, trt_file)
    print("trt file generated!")

4.3 deserialize

    trt_file = "resnet101.trt"
    runtime = trt.Runtime(TRT_LOGGER)
    with open(trt_file, 'rb') as f:
        engine = runtime.deserialize_cuda_engine(f.read())
        print("Completed creating Engine")
    context = engine.create_execution_context()
    context.set_binding_shape(0, (16, 3, 32, 32))

    inputs, outputs, bindings, stream = common.allocate_buffers(engine, 32)

4.4 推断

    t1 = time.time()
    label_ls = []
    pred_ls = []
    for n_iter, (image, label) in enumerate(cifar100_test_loader):
        # print("iteration: {}\ttotal {} iterations".format(n_iter + 1, len(cifar100_test_loader)))
        # print(image)
        inputs[0].host = image.numpy()

        trt_outputs = common.do_inference(
            context, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream, batch_size=32)
        label_ls.extend(label.numpy())
        pred_ls.extend(np.array(trt_outputs[0]).reshape(
            [-1, 100]).argmax(1).tolist())
        # print((np.array(pred_ls)[:10000]==np.array(label_ls)[:10000]).sum())
    t2 = time.time()
    print(t2-t1)

耗时约为4.3s,是用我的笔记本 上的GPU RTX 3050可以实现两倍左右的加速。

5.采用tensort加速—使用torch2trt库

nvidia还有torch2trt Python包,可用于一键tensorrt加速。

其安装可参考 https://github.com/NVIDIA-AI-IOT/torch2trt .

git clone https://github.com/NVIDIA-AI-IOT/torch2trt
cd torch2trt
python setup.py install

torch2trt的使用可参考 github torch2trt

    from torch2trt import torch2trt
    inputs = torch.rand([1, 3, 32, 32]).to(device)
    model_trt = torch2trt(net, [inputs], fp16_mode=True)

    t1 = time.time()
    label_ls = []
    pred_ls = []
    for n_iter, (image, label) in enumerate(cifar100_test_loader):

        output_trt = model_trt(image.to(device))

    t2 = time.time()
    print(t2-t1)

使用起来不要太easy!

完整代码可参考 https://github.com/L0-zhang/tentorrt_demo/tree/main

参考文献

[1] csdn pytorch TensorRT 官方例子
[2] https://github.com/NVIDIA-AI-IOT/torch2trt
[3] https://github.com/L0-zhang/tentorrt_demo/tree/main

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

使用tensorrt加速深度学习模型推断 的相关文章

随机推荐

  • Latex正文引用图片编号,防止某张图片删除或调整导致正文序号对应错误

    一 背景 Latex真的是非常好用的论文排版工具 虽然不像word一样是 所见即所得 的可视化方式 但完全不用管格式 包括图片的排版 文字的缩进等等 这在word里调整起来真的是非常麻烦 特别是某个段落 图片修改后 又要重新调整格式 非常的
  • Ubuntu20.04安装向日葵、开机自启、解决windows系统远程黑屏(笔记)

    这里写目录标题 动机 1 Ubuntu20 04 安装向日葵 2 设置开机自启 3 解决windows不可远程的问题 4 大公告成 动机 办公室有个工作站 要比我的笔记本的CPU稍微好一点 用来跑陆面过程 我信心满满的装了个Ubuntu20
  • 什么是离岸公司?有什么作用?

    离岸公司是泛指在离岸法区内依据其离岸公司法规范成立的有限责任公司或股份有限公司 这些公司不能在注册地经营 而主要是在离岸法区以外的地方开展业务活动 离岸公司的主要特点包括高度保密性 无外汇管制和减免税务负担 离岸公司的作用主要有以下几个方面
  • 销售人员一定要知道的6种获取电话号码的方法

    对于销售来说 电话销售是必须要知道的销售方法 也是销售生涯中的必经之路 最开始我们并不清楚这么电话是从哪里来的 也不清楚是通过哪些方法渠道获取 那么今天就来分享给各位销售人员获取客户电话号码的方法 1 打印自己的名片 在工作当中少不了接触其
  • 5.【自动驾驶与机器人中的SLAM技术】2D点云的scan matching算法 和 检测退化场景的思路

    目录 1 基于优化的点到点 线的配准 2 对似然场图像进行插值 提高匹配精度 3 对二维激光点云中会对SLAM功能产生退化场景的检测 4 在诸如扫地机器人等这样基于2D激光雷达导航的机器人 如何处理悬空 低矮物体 5 也欢迎大家来我的读书号
  • 大Ⅲ周记11

    1 本周学习了mysql数据库操作的相关知识 根据课设要求完成了压降系统数据库表的设计 2 计算机网络完成了所有章节的作业 开始进入复习阶段 预计下周完成一至二章的复习作业
  • leetcode:93. 复原 IP 地址

    复原 IP 地址 中等 1 4K 相关企业 有效 IP 地址 正好由四个整数 每个整数位于 0 到 255 之间组成 且不能含有前导 0 整数之间用 分隔 例如 0 1 2 201 和 192 168 1 1 是 有效 IP 地址 但是 0
  • 最近在对接电商供应链,说说开放平台API接口

    B2B电商开放平台的设计需要从以下几面去思考 开放平台API接口 的接入 主要是从功能需求的角度 设计满足业务需求的接口及对应的字段 平台与商家之间信息的对接 对接的方法有哪些 对接过程中需要可能会遇到什么问题 同步开关及权限的设计 处理信
  • 鸿蒙4.0开发笔记之ArkTS装饰器语法基础@Prop@Link@State状态装饰器(十二)

    文章目录 一 哪些是状态装饰器 二 State Prop Link状态传递的核心规则 三 状态装饰器练习 一 哪些是状态装饰器 1 State 被装饰拥有其所属组件的状态 可以作为其子组件单向和双向同步的数据源 当其数值改变时 会引起相关组
  • Nessus简单介绍与安装

    1 Nessus简介 Nessus号称是世界上最流行的漏洞扫描程序 全世界有超过75000个组织在使用它 该工具提供完整的电脑漏洞扫描服务 并随时更新其漏洞数据库 Nessus不同于传统的漏洞扫描软件 Nessus可同时在本机或远端上遥控
  • WebGL笔记:矩阵平移的数学原理和实现

    矩阵平移的数学原理 让向量OA位移 x方向 tx y方向 ty z方向 tz 最终得到向量OB 矩阵平移的应用 再比如我要让顶点的x移动0 1 y移动0 2 z移动0 3 1 顶点着色器核心代码
  • 有效表达观点的艺术

    有效表达观点的艺术 在人际交往中 有效地表达自己的观点是建立良好关系和实现有效沟通的关键 然而 这并不总是易如反掌 有时候 我们可能会遇到表达困难 或者我们的观点可能被误解 本文将探讨如何有效地表达观点 以及掌握说话的艺术的重要性 首先 清
  • 人工智能:开启未来商业新篇章

    人工智能 开启未来商业新篇章 随着科技的快速发展 人工智能 AI 在商业领域的应用越来越广泛 成为企业把握未来商业机遇的重要方向 本文将探讨人工智能如何重塑商业格局 为企业提供新的增长点 以及企业如何抓住AI的商业契机 一 AI重塑商业格局
  • 机器人学英语

    我的prompt i want to you act as an english language teacher asistant to help me study english you could teach me in such a
  • 详解Hotspot的经典7种垃圾收集器原理特点与组合搭配

    详解Hotspot的经典7种垃圾收集器原理特点与组合搭配 HotSpot共有7种垃圾收集器 3个新生代垃圾收集器 3个老年代垃圾收集器 以及G1 一共构成7种可供选择的垃圾收集器组合 新生代与老年代垃圾收集器之间形成6种组合 每个新生代垃圾
  • 在深圳月入一万的很丢人吗

    在深圳 月入一万的收入是否丢人 这是一个很主观的问题 因为每个人的生活需求和价值观不同 从经济学的角度来看 深圳作为中国的经济特区和一线城市 其生活成本相对较高 从这个角度看 月入一万的收入在某种程度上可能不足以满足一些人的生活需求 根据最
  • 给自己泡了一壶茶

    清晨 当第一缕阳光透过窗户照亮了房间 我慵懒地爬起床 开始享受新的一天 我泡了一壶早茶 浅浅的茶香立刻弥漫在空气中 让我感到宁静而放松 我坐在窗边 静静地看着窗外的世界 清晨的街道上 行人和车辆都还不多 显得格外的宁静 微风吹过树叶 带来阵
  • 拍图识字软件哪个好用?这些好用的软件推荐给你们

    在快节奏的现代生活中 你可能会遇到需要从图片中获取文字信息的情况 无论是读书 工作还是生活中 有时候会需要从图片中提取文字 当你收到了一份手写的便签或菜单 上面的字迹可能很模糊 或者你需要在没有文字的地方快速获取信息 这时 你可能会想 如果
  • 详解十大经典排序算法(四):希尔排序(Shell Sort)

    算法原理 希尔排序是一种基于插入排序的排序算法 也被称为缩小增量排序 它通过将待排序的序列分割成若干个子序列 对每个子序列进行插入排序 然后逐步缩小增量 最终使整个序列有序 算法描述 希尔排序 Shell Sort 是一种基于插入排序的算法
  • 使用tensorrt加速深度学习模型推断

    使用tensorrt加速深度学习模型推断 1 import以及数据加载 构建engine函数 2 导入官方模型及CIFAR100数据集 3 不采用tensort的推断时间 4 采用tensort加速 使用tensorrt 库 4 1 导出o