TorchVision重磅升级:支持多权重的API
点蓝色字关注“机器学习算法工程师”
设为星标,干货直达!
TorchVision库近期增加了一个向后兼容的API,以用于构建具有多权重支持的模型。新的 API 允许在同一模型上加载不同的预训练权重,并保存了重要的元数据,如分类标签和使用模型所需的预处理转换(eval模式)。这篇博文将介绍这个新 API的特性以及和现有 API 的主要区别。
现有API的局限
TorchVision前提供预训练模型,可以作为迁移学习的起点或在计算机视觉应用程序中原样使用。实例化预训练模型并进行预测的典型方法是:
import torch
from PIL import Image
from torchvision import models as M
from torchvision.transforms import transforms as T
img = Image.open("test/assets/encode_jpeg/grace_hopper_517x606.jpg")
# 步骤1:初始化模型
model = M.resnet50(pretrained=True)
model.eval()
# 步骤2:定义并初始化推理所用的数据变换
preprocess = T.Compose([
T.Resize([256, ]),
T.CenterCrop(224),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# 步骤3:对图像进行预处理并进行推理
batch = preprocess(img).unsqueeze(0)
prediction = model(batch).squeeze(0).softmax(0)
# 步骤4:后处理模型预测结果,并得到预测类别
class_id = prediction.argmax().item()
score = prediction[class_id].item()
with open("imagenet_classes.txt", "r") as f:
categories = [s.strip() for s in f.readlines()]
category_name = categories[class_id]
print(f"{category_name}: {100 * score}%")
上述方式存在以下几个局限:
无法支持多个预训练权重:由于参数pretrained是布尔值,我们只能提供一组权重。当我们显著提高现有模型的准确性并且我们希望将这些改进提供给社区时,这构成了严重的限制。而且也无法在不同数据集上提供相同模型的预训练权重。 缺少推理所需的预处理转换:用户被迫在使用模型之前要定义必要的转换。推理转换通常与用于估计权重的训练过程和数据集相关联。这些转换中的任何细微差异(例如插值、调整大小/裁剪大小等)都可能导致准确性大幅降低,甚至模型无法使用。 缺乏元数据:用户无法获得与权重相关的关键信息。例如,需要查看外部资源和文档以查找类别标签、训练策略、准确度指标等内容。
新的 API (prototype API,原型API)解决了上述限制,并减少了标准任务所需的样板代码量。
原型 API 概述
首先让我们看看如何使用新的 API 实现与上述完全相同的结果:
from PIL import Image
from torchvision.prototype import models as PM
img = Image.open("test/assets/encode_jpeg/grace_hopper_517x606.jpg")
# 步骤1:初始化模型
weights = PM.ResNet50_Weights.IMAGENET1K_V1
model = PM.resnet50(weights=weights)
model.eval()
# 步骤2:初始化推理所用的数据变换
preprocess = weights.transforms()
# 步骤3:对图像进行预处理并进行推理
batch = preprocess(img).unsqueeze(0)
prediction = model(batch).squeeze(0).softmax(0)
# 步骤4:后处理模型预测结果,并得到预测类别
class_id = prediction.argmax().item()
score = prediction[class_id].item()
category_name = weights.meta["categories"][class_id]
print(f"{category_name}: {100 * score}*%*")
可以看到,新的 API 消除了上述限制。下面让我们详细地说明这些新功能。
多权重支持
新 API 的核心就是我们能够为同一模型变体定义多个不同的权重。每个模型构建方法(例如 resnet50)都有一个关联的 Enum 类(例如 ResNet50_Weights),其条目数与可用的预训练权重的数量一样多。此外,每个 Enum 类都有一个 DEFAULT 别名,它指向特定模型的最佳可用权重。这允许希望始终使用最佳可用权重的用户在不修改其代码的情况下这样做。
下面是采用不同的权重来初始化ResNet50:
from torchvision.prototype.models import resnet50, ResNet50_Weights
# Legacy weights with accuracy 76.130%
model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
# New weights with accuracy 80.858%
model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
# Best available weights (currently alias for IMAGENET1K_V2)
model = resnet50(weights=ResNet50_Weights.DEFAULT)
# No weights - random initialization
model = resnet50(weights=None)
相关的元数据和预处理转换
每个模型的权重都与元数据相关联。我们存储的信息类型取决于模型的任务(分类、检测、分割等)。典型信息包括指向训练策略的PR、插值方式、类别和验证指标等信息。这些值可以通过 meta 属性以如下方式访问:
from torchvision.prototype.models import ResNet50_Weights
# Accessing a single record
size = ResNet50_Weights.IMAGENET1K_V2.meta["size"]
# Iterating the items of the meta-data dictionary
for k, v in ResNet50_Weights.IMAGENET1K_V2.meta.items():
print(k, v)
此外,每个权重都与必要的预处理转换相关联。所有当前的预处理转换都是 JIT 可编写脚本的,并且可以通过 transforms 属性访问。在将它们与数据一起使用之前,需要初始化/构造转换。完成这种惰性初始化方案是为了确保解决方案具有内存效率。转换的输入可以是 PIL.Image 或使用 torchvision.io 读取的张量。
from torchvision.prototype.models import ResNet50_Weights
# Initializing preprocessing at standard 224x224 resolution
preprocess = ResNet50_Weights.IMAGENET1K_V2.transforms()
# Initializing preprocessing at 400x400 resolution
preprocess = ResNet50_Weights.IMAGENET1K_V2.transforms(crop_size=400, resize_size=400)
# Once initialized the callable can accept the image data:
# img_preprocessed = preprocess(img)
将权重与其元数据和预处理相关联将提高透明度和可重复性,并更容易记录一组权重是如何产生的。
通过名字获取权重
将权重与其属性(元数据、预处理可调用对象等)直接联系在一起是我们的实现使用枚举而不是字符串的原因。然而,对于只有权重名称可用的情况,我们提供了一种能够将权重名称链接到其枚举的方法:
from torchvision.prototype.models import get_weight
# Weights can be retrieved by name:
assert get_weight("ResNet50_Weights.IMAGENET1K_V1") == ResNet50_Weights.IMAGENET1K_V1
assert get_weight("ResNet50_Weights.IMAGENET1K_V2") == ResNet50_Weights.IMAGENET1K_V2
# Including using the DEFAULT alias:
assert get_weight("ResNet50_Weights.DEFAULT") == ResNet50_Weights.IMAGENET1K_V2
弃用项
在新的 API 中,不推荐使用之前用于将权重加载到完整模型或其主干的 boolean pretrained 和 pretrained_backbone 参数。当前实现完全向后兼容,因为它将旧参数无缝映射到新参数。对新构建器使用旧参数会发出以下弃用警告:
>>> model = torchvision.prototype.models.resnet50(pretrained=True)
UserWarning: The parameter 'pretrained' is deprecated, please use 'weights' instead.
UserWarning:
Arguments other than a weight enum or `None` for 'weights' are deprecated.
The current behavior is equivalent to passing `weights=ResNet50_Weights.IMAGENET1K_V1`.
You can also use `weights=ResNet50_Weights.DEFAULT` to get the most up-to-date weights.
此外,构建器方法需要使用关键字参数。不推荐使用位置参数,使用它们会发出以下警告:
>>> model = torchvision.prototype.models.resnet50(None)
UserWarning:
Using 'weights' as positional parameter(s) is deprecated.
Please use keyword parameter(s) instead.
测试新的API
迁移到新 API 非常简单。两个 API 之间的以下方法调用都是等效的:
# Using pretrained weights:
torchvision.prototype.models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
torchvision.models.resnet50(pretrained=True)
torchvision.models.resnet50(True)
# Using no weights:
torchvision.prototype.models.resnet50(weights=None)
torchvision.models.resnet50(pretrained=False)
torchvision.models.resnet50(False)
请注意,原型功能仅适用于 TorchVision 的nightly版本,因此要使用它,您需要按如下方式安装它:
conda install torchvision -c pytorch-nightly
有关安装 nightly 的替代方法,请查看 PyTorch 下载页面(https://pytorch.org/get-started/locally/)。您还可以从最新的 main 源安装 TorchVision;有关更多信息,请查看TorchVision代码仓库(https://github.com/pytorch/vision/blob/main/CONTRIBUTING.md)。
使用新 API 访问SOTA模型权重
如果您仍然不想尝试新的 API,这里还有一个这样做的理由。我们最近更新了我们的训练策略(https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/),并在许多模型中实现了 SOTA 准确度。改进后的权重可以通过新的 API 轻松访问。以下是模型改进后的性能概述:
Model | Old Acc@1 | New Acc@1 |
---|---|---|
EfficientNet B1 | 78.642 | 79.838 |
MobileNetV3 Large | 74.042 | 75.274 |
Quantized ResNet50 | 75.92 | 80.282 |
Quantized ResNeXt101 32x8d | 78.986 | 82.574 |
RegNet X 400mf | 72.834 | 74.864 |
RegNet X 800mf | 75.212 | 77.522 |
RegNet X 1 6gf | 77.04 | 79.668 |
RegNet X 3 2gf | 78.364 | 81.198 |
RegNet X 8gf | 79.344 | 81.682 |
RegNet X 16gf | 80.058 | 82.72 |
RegNet X 32gf | 80.622 | 83.018 |
RegNet Y 400mf | 74.046 | 75.806 |
RegNet Y 800mf | 76.42 | 78.838 |
RegNet Y 1 6gf | 77.95 | 80.882 |
RegNet Y 3 2gf | 78.948 | 81.984 |
RegNet Y 8gf | 80.032 | 82.828 |
RegNet Y 16gf | 80.424 | 82.89 |
RegNet Y 32gf | 80.878 | 83.366 |
ResNet50 | 76.13 | 80.858 |
ResNet101 | 77.374 | 81.886 |
ResNet152 | 78.312 | 82.284 |
ResNeXt50 32x4d | 77.618 | 81.198 |
ResNeXt101 32x8d | 79.312 | 82.834 |
Wide ResNet50 2 | 78.468 | 81.602 |
Wide ResNet101 2 | 78.848 | 82.51 |
本文翻译自Introducing TorchVision’s New Multi-Weight Support API:https://pytorch.org/blog/introducing-torchvision-new-multi-weight-support-api/
推荐阅读
辅助模块加速收敛,精度大幅提升!移动端实时的NanoDet-Plus来了!
机器学习算法工程师
一个用心的公众号