HAN&DAI

  • 首页
  • 遥感应用
  • GIS应用
  • 机器学习
  • 实用工具
  • 文章链接
  • 遥感数据集
HAN&DAI
遥感与地理信息技术交流社区
  1. 首页
  2. pytorch使用
  3. 正文

Pytorch图像分割模型转ONNX

2024年11月8日 191点热度 1人点赞 0条评论

最近学习了一下图像分割模型部署——PyTorch转ONNX
参考了子豪兄的视频

一、主要学习内容如下:





即ONNX是个中间翻译器,帮助不同框架和不同设备更方便地迁移和部署。

二、具体操作

1. 安装配置环境

安装Pytorch
pip3 install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu113
安装 ONNX
pip install onnx -i https://pypi.tuna.tsinghua.edu.cn/simple
安装推理引擎 ONNX Runtime
pip install onnxruntime -i https://pypi.tuna.tsinghua.edu.cn/simple
验证安装配置成功

import torch
print('PyTorch 版本', torch.__version__)
#PyTorch 版本 1.10.0+cu113
import onnx
print('ONNX 版本', onnx.__version__)
#ONNX 版本 1.13.1
import onnxruntime as ort
print('ONNX Runtime 版本', ort.__version__)
#ONNX Runtime 版本 1.14.1

由于我已经安装过PyTorch和其他第三方包,因此只需要安装ONNX和ONNX Runtime即可。

2. 转ONNX

导入工具包

import torch
from torchvision import models

# 有 GPU 就用 GPU,没有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device', device)

导入训练好的模型

model = torch.load('checkpoint/fruit30_pytorch_20220814.pth')
model = model.eval().to(device)

构造一个输入图像Tensor

x = torch.randn(1, 3, 256, 256).to(device)

输入Pytorch模型推理预测,获得1000个类别的预测结果

output = model(x)
print(output.shape)

Pytorch模型转ONNX模型

x = torch.randn(1, 3, 256, 256).to(device)

with torch.no_grad():
    torch.onnx.export(
        model,                   # 要转换的模型
        x,                       # 模型的任意一组输入
        'resnet18_fruit30.onnx', # 导出的 ONNX 文件名
        opset_version=11,        # ONNX 算子集版本
        input_names=['input'],   # 输入 Tensor 的名称(自己起名字)
        output_names=['output']  # 输出 Tensor 的名称(自己起名字)
    ) 

验证onnx模型导出成功

import onnx

# 读取 ONNX 模型
onnx_model = onnx.load('resnet18_fruit30.onnx')

# 检查模型格式是否正确
onnx.checker.check_model(onnx_model)

print('无报错,onnx模型载入成功')

使用Netron可视化模型结构
Netron:https://netron.app
视频教程:https://www.bilibili.com/video/BV1TV4y1P7AP

适用到自己的网络模型:

# ------------------------------------------------------------------------------
# Modified based on https://github.com/HRNet/HRNet-Semantic-Segmentation
# ------------------------------------------------------------------------------

import argparse
import os
import pprint

import logging
import timeit

import numpy as np

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn

import _init_paths
import models
import datasets
from configs import config
from configs import update_config
from utils.function import testval, test, test_ood
from utils.utils import create_logger

# 有 GPU 就用 GPU,没有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device', device)

def parse_args():
    parser = argparse.ArgumentParser(description='Train segmentation network')

    parser.add_argument('--cfg',
                        help='experiment configure file name',
                        default="configs/bandon/pidnet_small_bandon.yaml",
                        type=str)
    parser.add_argument('opts',
                        help="Modify config options using the command-line",
                        default=None,
                        nargs=argparse.REMAINDER)

    args = parser.parse_args()
    update_config(config, args)

    return args

def main():
    args = parse_args()

    logger, final_output_dir, _ = create_logger(
        config, args.cfg, 'test')

    logger.info(pprint.pformat(args))
    logger.info(pprint.pformat(config))

    # cudnn related setting
    cudnn.benchmark = config.CUDNN.BENCHMARK
    cudnn.deterministic = config.CUDNN.DETERMINISTIC
    cudnn.enabled = config.CUDNN.ENABLED

    # build model
    model = model = models.pidnet.get_seg_model(config, imgnet_pretrained=True)

    if config.TEST.MODEL_FILE:
        model_state_file = config.TEST.MODEL_FILE
    else:
        model_state_file = os.path.join(final_output_dir, 'best.pt')      

    logger.info('=> loading model from {}'.format(model_state_file))

    pretrained_dict = torch.load(model_state_file)
    if 'state_dict' in pretrained_dict:
        pretrained_dict = pretrained_dict['state_dict']
    model_dict = model.state_dict()
    pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items()
                        if k[6:] in model_dict.keys()}
    for k, _ in pretrained_dict.items():
        logger.info(
            '=> loading {} from pretrained model'.format(k))
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)
    #model = model.cuda()
    model = model.eval().to(device)

    # 包装模型
    #export_model = CustomExportModel(model).to(device)

    x = torch.randn(1, 3, 512, 512).to(device)

    with torch.no_grad():
        torch.onnx.export(
            model,            # 要自定义的模型
            x,                       # 模型的任意一组输入
            'NeSF-Net.onnx',         # 导出的 ONNX 文件名
            opset_version=11,        # ONNX 算子集版本
            input_names=['input'],   # 输入 Tensor 的名称(自己起名字)
            output_names=['output']  # 输出 Tensor 的名称(自己起名字)
        )
    print('finish!')

if __name__ == '__main__':
    main()
    import onnx
    # 读取 ONNX 模型
    onnx_model = onnx.load('NeSF-Net.onnx')

    # 检查模型格式是否正确
    onnx.checker.check_model(onnx_model)

    print('无报错,onnx模型载入成功')

Post Views: 196

相关文章:

  1. BUG: torch.pairwise_distance()计算欧式距离报错,出现size不匹配的情况
  2. BUG:ImportError: /usr/lib/x86_64-linux-gnu/libstdc++.so.6: version `GLIBCXX_3.4.29' not found
  3. 在docker中使用tensorboard、jupyter notebook
  4. cuda 11.x如何配置tensorflow 1.x?(以3090为例)
本作品采用 知识共享署名 4.0 国际许可协议 进行许可
标签: BUG pytorch 实用技巧
最后更新:2024年11月8日

daidai

一个热爱RS和GIS技术的小姐姐!

打赏 点赞
< 上一篇
下一篇 >

文章评论

razz evil exclaim smile redface biggrin eek confused idea lol mad twisted rolleyes wink cool arrow neutral cry mrgreen drooling persevering
取消回复

文章目录
  • 一、主要学习内容如下:
  • 二、具体操作
浏览最多的文章
  • BUG:ImportError: /lib/x86_64-linux-gnu/libstdc++.so.6: version `GLIBCXX_3.4.29' not found (1,462)
  • BUG:“ModuleNotFoundError: No module named '_ext'”的解决方案 (1,229)
  • 利用GEE下载指定区域Landsat8影像 (1,175)
  • 利用arcgis制作深度学习标签数据(以二分类为例) (899)
  • 利用传统机器学习方法进行遥感影像分类-以随机森林(RF)为例 (807)

COPYRIGHT © 2025 HAN&DAI. ALL RIGHTS RESERVED. QQ交流群:821388027

Theme Kratos Made By Seaton Jiang