ONNX 模型中心¶
ONNX 模型中心是开始使用 ONNX 模型动物园 中最先进的预训练 ONNX 模型的简单快捷方式。此外,这使得研究人员和模型开发者有机会与更广泛的社区分享他们的预训练模型。
安装¶
ONNX 模型中心在 ONNX 1.11.0 之后可用。
基本用法¶
ONNX 模型中心能够从任何 Git 仓库下载、列出和查询训练好的模型,并默认为官方 ONNX 模型动物园。本节将演示一些基本功能。
首先请使用以下命令导入中心
from onnx import hub
按名称下载模型¶
load 函数将默认在模型动物园中搜索具有匹配名称的最新模型,将此模型下载到本地缓存,并将模型加载到 ModelProto 对象中,以便与 ONNX 运行时一起使用。
model = hub.load("resnet50")
从自定义仓库下载¶
任何具有正确结构的仓库都可以成为 ONNX 模型中心。要从其他中心下载,或在主模型中心指定特定分支或提交,可以提供 repo 参数
model = hub.load("resnet50", repo="onnx/models:771185265efbdc049fb223bd68ab1aeb1aecde76")
列出和检查模型¶
模型中心提供 API,用于查询模型动物园以了解更多可用模型。这不会下载模型,而只是返回有关匹配给定参数的模型信息
# List all models in the onnx/models:main repo
all_models = hub.list_models()
# List all versions/opsets of a specific model
mnist_models = hub.list_models(model="mnist")
# List all models matching a given "tag"
vision_models = hub.list_models(tags=["vision"])
还可以使用 get_model_info 函数在下载前检查模型的元数据
print(hub.get_model_info(model="mnist", opset=8))
这将打印类似以下内容
ModelInfo(
model=MNIST,
opset=8,
path=vision/classification/mnist/model/mnist-8.onnx,
metadata={
'model_sha': '2f06e72de813a8635c9bc0397ac447a601bdbfa7df4bebc278723b958831c9bf',
'model_bytes': 26454,
'tags': ['vision', 'classification', 'mnist'],
'io_ports': {
'inputs': [{'name': 'Input3', 'shape': [1, 1, 28, 28], 'type': 'tensor(float)'}],
'outputs': [{'name': 'Plus214_Output_0', 'shape': [1, 10], 'type': 'tensor(float)'}]},
'model_with_data_path': 'vision/classification/mnist/model/mnist-8.tar.gz',
'model_with_data_sha': '1dd098b0fe8bc750585eefc02013c37be1a1cae2bdba0191ccdb8e8518b3a882',
'model_with_data_bytes': 25962}
)
本地缓存¶
ONNX 模型中心将下载的模型本地缓存到可配置的位置,这样后续调用 hub.load 就不需要网络连接。
默认缓存位置¶
中心客户端按以下顺序查找默认缓存位置
如果定义了
ONNX_HOME环境变量,则为$ONNX_HOME/hub如果定义了
XDG_CACHE_HOME环境变量,则为$XDG_CACHE_HOME/hub~/.cache/onnx/hub,其中~是用户主目录
设置缓存位置¶
要手动设置缓存位置,请使用
hub.set_dir("my/cache/directory")
此外,还可以使用以下命令检查缓存位置
print(hub.get_dir())
更多缓存详情¶
要清除模型缓存,只需使用 shutil 或 os 等 Python 实用程序删除缓存目录。此外,还可以选择使用 force_reload 选项覆盖缓存的模型
model = hub.load("resnet50", force_reload=True)
我们包含此标志是为了完整性,但请注意,缓存中的模型通过 sha256 散列进行区分,因此在正常使用中不需要 force_reload 标志。最后,我们注意到模型缓存目录结构将镜像由清单的 model_path 字段指定的目录结构,但文件名通过模型 SHA256 散列进行区分。
这样,模型缓存是人类可读的,可以区分多个版本的模型,并且如果它们具有相同的名称和散列,可以在不同的中心之间重用缓存的模型。
架构¶
ONNX 中心由两个主要组件组成:客户端和服务器。客户端代码目前包含在 onnx 包中,可以指向一个服务器,该服务器以托管在 GitHub 仓库中的 ONNX_HUB_MANIFEST.json 的形式存在,例如 ONNX 模型动物园中的文件。此清单文件是一个 JSON 文档,列出了所有模型及其元数据,旨在实现编程语言无关。以下是一个格式良好的模型清单条目示例
{
"model": "BERT-Squad",
"model_path": "text/machine_comprehension/bert-squad/model/bertsquad-8.onnx",
"onnx_version": "1.3",
"opset_version": 8,
"metadata": {
"model_sha": "cad65b9807a5e0393e4f84331f9a0c5c844d9cc736e39781a80f9c48ca39447c",
"model_bytes": 435882893,
"tags": ["text", "machine comprehension", "bert-squad"],
"io_ports": {
"inputs": [
{
"name": "unique_ids_raw_output___9:0",
"shape": ["unk__475"],
"type": "tensor(int64)"
},
{
"name": "segment_ids:0",
"shape": ["unk__476", 256],
"type": "tensor(int64)"
},
{
"name": "input_mask:0",
"shape": ["unk__477", 256],
"type": "tensor(int64)"
},
{
"name": "input_ids:0",
"shape": ["unk__478", 256],
"type": "tensor(int64)"
}
],
"outputs": [
{
"name": "unstack:1",
"shape": ["unk__479", 256],
"type": "tensor(float)"
},
{
"name": "unstack:0",
"shape": ["unk__480", 256],
"type": "tensor(float)"
},
{
"name": "unique_ids:0",
"shape": ["unk__481"],
"type": "tensor(int64)"
}
]
},
"model_with_data_path": "text/machine_comprehension/bert-squad/model/bertsquad-8.tar.gz",
"model_with_data_sha": "c8c6c7e0ab9e1333b86e8415a9d990b2570f9374f80be1c1cb72f182d266f666",
"model_with_data_bytes": 403400046
}
}
这些重要字段是
model:用于查询的模型名称model_path:存储在 Git LFS 中的模型的相对路径。onnx_version:模型的 ONNX 版本opset_version:opset 的版本。如果未指定,客户端将下载最新的 opset。metadata/model_sha:可选的模型 sha 规范,用于提高下载安全性metadata/tags:可选的高级标签,帮助用户按给定类型查找模型
metadata 字段中的所有其他字段对于客户端来说都是可选的,但为用户提供了重要详细信息。
添加到 ONNX 模型中心¶
贡献官方模型¶
将模型添加到官方 onnx/models 版本模型中心的最简单方法是遵循 这些指南 贡献您的模型。一旦贡献完成,请确保您的模型在其 README.md 中有一个 Markdown 表 (示例)。模型中心清单生成器将从这些 Markdown 表中提取信息。要运行生成器
git clone https://github.com/onnx/models.git
git lfs pull --include="*" --exclude=""
cd models/workflow_scripts
python generate_onnx_hub_manifest.py
生成新清单后,将其添加到 onnx/models 的拉取请求中
托管您自己的 ONNX 模型中心¶
要托管您自己的模型中心,请在 GitHub 仓库的顶层添加一个 ONNX_HUB_MANIFEST.json (示例)。您的清单条目至少应包含本文档 架构部分 中提到的字段。提交后,检查您是否可以使用本文档的“从自定义仓库下载”部分下载模型。
如有任何问题请提出¶
有关 ONNX 模型问题或 SHA 不匹配问题,请在 [Model Zoo]/(https://github.com/onnx/models/issues) 中提出问题。
有关 ONNX 模型中心使用中的其他问题,请在此仓库 https://github.com/onnx/onnx/issues 中提出。