行业资讯 2025年08月6日
0 收藏 0 点赞 150 浏览 4670 个字
摘要 :

文章目录 一、安装ONNX Runtime 二、安装ONNX导出工具 三、快速上手示例 (一)PyTorch CV示例 (二)PyTorch NLP示例 (三)TensorFlow CV示例 (四)scikit-learn CV……




  • 一、安装ONNX Runtime
  • 二、安装ONNX导出工具
  • 三、快速上手示例
    • (一)PyTorch CV示例
    • (二)PyTorch NLP示例
    • (三)TensorFlow CV示例
    • (四)scikit-learn CV示例

    ONNX Runtime(ORT)作为一款高性能的推理引擎,能够支持像PyTorch、TensorFlow和scikit-learn等多种深度学习框架,为开发者提供了极大的便利。本文将详细介绍如何在Python环境中安装ONNX Runtime,并通过不同框架的示例展示其在模型推理中的具体应用。

    一、安装ONNX Runtime

    ONNX Runtime针对不同的硬件环境,提供了两个主要的Python安装包,大家可以根据自身的硬件条件来选择合适的版本。

    • CPU版本:这个版本适用于基于Arm架构的CPU以及macOS系统。在安装时,只需在命令行中输入:
    pip install onnxruntime
    
    • GPU版本(CUDA 12.x):若你的设备支持CUDA 12.x,那么可以选择这个版本,以充分利用GPU的计算能力提升推理速度。安装命令如下:
    pip install onnxruntime-gpu
    
    • GPU版本(CUDA 11.8):如果你的CUDA版本是11.8,就需要从Azure DevOps Feed进行安装,具体命令为:
    pip install onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-11/pypi/simple/
    

    二、安装ONNX导出工具

    要想使用ONNX Runtime进行模型推理,首先得把模型转换成ONNX格式。不同的深度学习框架,安装ONNX导出工具的方式也不一样。

    • PyTorch:PyTorch本身就自带了ONNX支持,安装PyTorch就相当于安装了导出工具。安装命令为:
    pip install torch
    
    • TensorFlow:TensorFlow需要额外安装tf2onnx来实现对ONNX导出的支持,在命令行输入以下命令即可完成安装:
    pip install tf2onnx
    
    • scikit-learn:对于scikit-learn框架,需要安装skl2onnx,安装命令如下:
    pip install skl2onnx
    

    三、快速上手示例

    接下来,我们通过不同框架在计算机视觉(CV)和自然语言处理(NLP)领域的示例,来看看如何导出模型并使用ONNX Runtime进行推理。

    (一)PyTorch CV示例

    1. 导出模型:在这个示例中,假设已经有一个训练好的PyTorch模型model,并且确定了模型运行的设备device(可以是GPU或CPU)。
    import torch
    import torch.onnx as onnx
    
    # 假设model是你的PyTorch模型,device是设备(如GPU或CPU)
    model = ...  # 初始化你的模型
    device = ...  # 设备
    
    # 生成一个随机的输入数据,形状为(1, 28, 28),并将其移动到指定设备上
    input_data = torch.randn(1, 28, 28).to(device)
    # 将模型导出为ONNX格式,指定输入和输出的名称
    onnx.export(model, input_data, \"fashion_mnist_model.onnx\", 
                input_names=[\'input\'], output_names=[\'output\'])
    
    1. 加载ONNX模型并进行推理:完成模型导出后,就可以加载ONNX模型并进行推理了。
    import onnx
    import onnxruntime as ort
    import numpy as np
    
    # 加载之前导出的ONNX模型
    onnx_model = onnx.load(\"fashion_mnist_model.onnx\")
    # 检查模型的完整性和正确性
    onnx.checker.check_model(onnx_model)
    
    # 创建推理会话,用于执行模型推理
    ort_sess = ort.InferenceSession(\'fashion_mnist_model.onnx\')
    
    # 假设x是输入数据
    x = ...  # 初始化输入数据
    
    # 运行推理,将输入数据转换为numpy数组后传入,得到输出结果
    outputs = ort_sess.run(None, {\'input\': x.numpy()})
    

    (二)PyTorch NLP示例

    1. 导出模型:这里假设已经有一个用于自然语言处理的PyTorch模型model
    import torch
    import torch.onnx as onnx
    
    # 假设model是你的PyTorch NLP模型
    model = ...  # 初始化你的模型
    
    text = \"示例文本\"
    # 假设text_pipeline是对输入文本进行预处理的函数,将文本转换为张量
    text_tensor = torch.tensor(text_pipeline(text)) 
    offsets = torch.tensor([0])
    
    # 将模型导出为ONNX格式,指定输入、输出名称以及动态轴信息
    onnx.export(model, (text_tensor, offsets), \"ag_news_model.onnx\", 
                input_names=[\'input\', \'offsets\'], output_names=[\'output\'],
                dynamic_axes={\'input\': {0: \'batch_size\'}, \'output\': {0: \'batch_size\'}})
    
    1. 加载ONNX模型并进行推理:同样,导出模型后就可以进行加载和推理操作。
    import onnx
    import onnxruntime as ort
    import numpy as np
    
    # 加载ONNX模型
    onnx_model = onnx.load(\"ag_news_model.onnx\")
    # 检查模型是否正确
    onnx.checker.check_model(onnx_model)
    
    # 创建推理会话
    ort_sess = ort.InferenceSession(\'ag_news_model.onnx\')
    
    # 假设text是输入文本,offsets是偏移量
    text = ...  # 初始化输入文本
    offsets = torch.tensor([0])
    
    # 运行推理,将输入数据转换为numpy数组后传入,获取输出结果
    outputs = ort_sess.run(None, {\'input\': text.numpy(), \'offsets\': offsets.numpy()})
    

    (三)TensorFlow CV示例

    1. 导出模型:以加载预训练的ResNet50模型为例,展示如何将其转换为ONNX格式。
    import tensorflow as tf
    from tensorflow.keras.applications import ResNet50
    import tf2onnx
    
    # 加载预训练的ResNet50模型,权重为imagenet上的预训练权重
    model = ResNet50(weights=\'imagenet\')
    
    # 定义输入的张量规格,这里假设输入形状为(None, 224, 224, 3),数据类型为float32
    spec = (tf.TensorSpec((None, 224, 224, 3), tf.float32, name=\"input\"),)
    # 指定导出的ONNX模型保存路径
    output_path = model.name + \".onnx\"
    # 将Keras模型转换为ONNX模型,指定操作集版本为13
    model_proto, _ = tf2onnx.convert.from_keras(model, input_signature=spec, opset=13, output_path=output_path)
    
    1. 加载ONNX模型并进行推理:完成模型转换后,就可以使用ONNX Runtime进行推理了。
    import onnxruntime as rt
    
    # 创建推理会话,指定使用CPU执行提供程序
    providers = [\'CPUExecutionProvider\']
    m = rt.InferenceSession(output_path, providers=providers)
    
    # 假设x是输入数据
    x = ...  # 初始化输入数据
    
    # 获取模型输出的名称列表
    output_names = [n.name for n in model_proto.graph.output]
    # 运行推理,传入输入数据,得到预测结果
    onnx_pred = m.run(output_names, {\"input\": x})
    

    (四)scikit-learn CV示例

    1. 导出模型:以鸢尾花数据集为例,训练一个逻辑回归模型并将其转换为ONNX格式。
    from sklearn.datasets import load_iris
    from sklearn.model_selection import train_test_split
    from sklearn.linear_model import LogisticRegression
    from skl2onnx import convert_sklearn
    from skl2onnx.common.data_types import FloatTensorType
    
    # 加载鸢尾花数据集
    iris = load_iris()
    # 提取特征和标签
    X, y = iris.data, iris.target
    # 将数据集划分为训练集和测试集
    X_train, X_test, y_train, y_test = train_test_split(X, y)
    
    # 创建逻辑回归模型并进行训练
    clr = LogisticRegression()
    clr.fit(X_train, y_train)
    
    # 定义输入数据的类型,这里假设输入是一个二维张量,形状为(None, 4),数据类型为float
    initial_type = [(\'float_input\', FloatTensorType([None, 4]))]
    # 将训练好的逻辑回归模型转换为ONNX格式
    onx = convert_sklearn(clr, initial_types=initial_type)
    # 将转换后的ONNX模型保存到文件中
    with open(\"logreg_iris.onnx\", \"wb\") as f:
        f.write(onx.SerializeToString())
    
    1. 加载ONNX模型并进行推理:保存好模型后,就可以加载并进行推理。
    import numpy
    import onnxruntime as rt
    
    # 创建推理会话
    sess = rt.InferenceSession(\"logreg_iris.onnx\")
    # 获取输入的名称
    input_name = sess.get_inputs()[0].name
    
    # 假设X_test是测试数据
    # 运行推理,将测试数据转换为float32类型后传入,得到预测结果
    pred_onx = sess.run(None, {input_name: X_test.astype(numpy.float32)})[0]
    

    通过上述内容,相信大家已经掌握了在Python中使用ONNX Runtime进行模型推理的方法,包括安装相关工具以及不同框架下的具体实践。希望这些示例能帮助大家在实际项目中更好地应用ONNX Runtime,提升模型推理的效率。

微信扫一扫

支付宝扫一扫

版权: 转载请注明出处:https://www.zuozi.net/10465.html

管理员

相关推荐
2025-08-06

文章目录 一、Reader 接口概述 1.1 什么是 Reader 接口? 1.2 Reader 与 InputStream 的区别 1.3 …

988
2025-08-06

文章目录 一、事件溯源 (一)核心概念 (二)Kafka与Golang的优势 (三)完整代码实现 二、命令…

465
2025-08-06

文章目录 一、证明GC期间执行native函数的线程仍在运行 二、native线程操作Java对象的影响及处理方…

348
2025-08-06

文章目录 一、事务基础概念 二、MyBatis事务管理机制 (一)JDBC原生事务管理(JdbcTransaction)…

456
2025-08-06

文章目录 一、SnowFlake算法核心原理 二、SnowFlake算法工作流程详解 三、SnowFlake算法的Java代码…

517
2025-08-06

文章目录 一、本地Jar包的加载操作 二、本地Class的加载方法 三、远程Jar包的加载方式 你知道Groo…

832
发表评论
暂无评论

还没有评论呢,快来抢沙发~

助力内容变现

将您的收入提升到一个新的水平

点击联系客服

在线时间:08:00-23:00

客服QQ

122325244

客服电话

400-888-8888

客服邮箱

122325244@qq.com

扫描二维码

关注微信客服号