PyTorch 源码解读之 torch.serialization & torch.hub

视学算法

共 11722字,需浏览 24分钟

 ·

2021-10-29 16:07


作者 | 123456 
来源 | OpenMMLab 
编辑 | 极市平台

导读

 

本文解读基于PyTorch 1.7版本,对torch.serialization、torch.save和torch.hub展开介绍。

torch.serialization

torch.serialization 实现对 PyTorch 对象结构的二进制序列化和反序列化,其中序列化由 torch.save 实现,反序列化由 torch.load 实现。

torch.save

torch.save 主要使用 pickle 来进行二进制序列化:

def save(obj, # 待序列化的对象
         f: Union[str, os.PathLike, BinaryIO], # 带写入的文件
         pickle_module=pickle, # 默认使用 pickle 进行序列化
         pickle_protocol=DEFAULT_PROTOCOL, # 默认使用 pickle 第2版协议
         _use_new_zipfile_serialization=True)
 -> None:
 # pytorch 1.6 之后默认使用基于 zipfile 的存储文件格式, 如果想用旧的格式, 
                                                       # 可设为False. torch.load 同时支持新旧格式文件的读取.

    # 如果使用 dill 进行序列化操作, dill的版本需大于 0.3.1.
    _check_dill_version(pickle_module)

    with _open_file_like(f, 'wb'as opened_file:
        # 基于 zipfile 的存储格式
        if _use_new_zipfile_serialization:
            with _open_zipfile_writer(opened_file) as opened_zipfile:
                _save(obj, opened_zipfile, pickle_module, pickle_protocol)
                return
        # 以二进制方式写入文件
        _legacy_save(obj, opened_file, pickle_module, pickle_protocol)

可以看到核心函数是 _save()_legacy_save() ,接下来分别介绍,我们首先介绍_save()函数

def _save(obj, zip_file, pickle_module, pickle_protocol):
    serialized_storages = {} # 暂存具体数据内容以及其对应的key

    def persistent_id(obj):
        if torch.is_storage(obj): # 如果是需要存储的数据内容
            storage_type = normalize_storage_type(type(obj)) # 存储类型,int, float, ...
            obj_key = str(obj._cdata) # 数据内容对应的key. 在load时根据key读取数据
            location = location_tag(obj) # cpu 还是cuda
            serialized_storages[obj_key] = obj # 数据及其对应的key

            return ('storage', storage_type, obj_key, location, obj.size()) # 注意这里没有具体数据,只返回数据相关的信息
        return None

    data_buf = io.BytesIO() # 开辟 buffer
    pickler = pickle_module.Pickler(data_buf, protocol=pickle_protocol) # 对象的结构信息即将写入 data_buf 中
    pickler.persistent_id = persistent_id # 将对象的结构信息写入 data_buf 中,具体数据内容暂存在 serialized_storages 中
    pickler.dump(obj) # 对对象执行写入操作,写入过程会调 persistent_id 函数
    data_value = data_buf.getvalue() # 将写入的对象的结构信息取出来
    zip_file.write_record('data.pkl', data_value, len(data_value)) # 写入到存储文件 zip_file 中,注意这里写入的信息只是对象的结构
                                                                   # 信息(通过 data.pkl 来标识),具体数据内容还未写入

    for key in sorted(serialized_storages.keys()): # 写入数据内容
        name = f'data/{key}' # 数据的名字
        storage = serialized_storages[key] # 具体数据内容
        if storage.device.type == 'cpu'# 数据在 cpu 上
            num_bytes = storage.size() * storage.element_size() # 计算占用的字节数
            zip_file.write_record(name, storage.data_ptr(), num_bytes) # 写入数据
        else# 数据在 cuda 上
            buf = io.BytesIO() # 开辟 buffer
            storage._write_file(buf, _should_read_directly(buf), False# 将 cuda 上的数据复制到内存中
            buf_value = buf.getvalue() # 读取内存中的数据
            zip_file.write_record(name, buf_value, len(buf_value)) # 写入数据

总的来说 _save() 函数在将对象二进制序列化的过程中,首先写入对象的结构信息,之后再写入具体的数据内容。
接下来介绍_legacy_save()函数:

def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
    import torch.nn as nn
    serialized_container_types = {}
    serialized_storages = {}

    def persistent_id(obj: Any) -> Optional[Tuple]:
        if isinstance(obj, type) and issubclass(obj, nn.Module): # 记录 source code
            if obj in serialized_container_types: # 如果已经记录过一样的,不需要重复记录
                return None
            serialized_container_types[obj] = True
            source_file = source = None
            try:
                source_lines, _, source_file = get_source_lines_and_file(obj) # 读取 source code
                source = ''.join(source_lines) # 读取 source code
            except Exception: # 找不到的话,打印warning
                warnings.warn("Couldn't retrieve source code for container of "
                              "type " + obj.__name__ + ". It won't be checked "
                              "for correctness upon loading.")
            return ('module', obj, source_file, source)

        elif torch.is_storage(obj): # 与上面 `_save()` 中 `persistent_id()` 的对应内容类似
            view_metadata: Optional[Tuple[str, int, int]]
            obj = cast(Storage, obj)
            storage_type = normalize_storage_type(type(obj))
            offset = 0
            obj_key = str(obj._cdata)
            location = location_tag(obj)
            serialized_storages[obj_key] = obj
            is_view = obj._cdata != obj._cdata
            if is_view:
                view_metadata = (str(obj._cdata), offset, obj.size())
            else:
                view_metadata = None

            return ('storage', storage_type, obj_key, location, obj.size(),
                    view_metadata)
        return None

    # 记录一些系统信息
    sys_info = dict(
        protocol_version=PROTOCOL_VERSION,
        little_endian=sys.byteorder == 'little',
        type_sizes=dict(
            short=SHORT_SIZE,
            int=INT_SIZE,
            long=LONG_SIZE,
        ),
    )

    pickle_module.dump(MAGIC_NUMBER, f, protocol=pickle_protocol) # 记录 MAGIC_NUMBER,用于load时验证文件是否损坏
    pickle_module.dump(PROTOCOL_VERSION, f, protocol=pickle_protocol) # 记录 pickle 协议,用于load时验证pickle协议是否一致
    pickle_module.dump(sys_info, f, protocol=pickle_protocol) # 记录一些系统信息
    pickler = pickle_module.Pickler(f, protocol=pickle_protocol) # 对象的结构信息即将写入文件中
    pickler.persistent_id = persistent_id  # 将对象的结构信息写入 data_buf 中,具体数据内容暂存在 serialized_storages 中
    pickler.dump(obj) # 执行写入操作,期间会调用 persistent_id() 函数

    serialized_storage_keys = sorted(serialized_storages.keys())
    pickle_module.dump(serialized_storage_keys, f, protocol=pickle_protocol) # 写入具体数据对应的 key
    f.flush() # 刷新缓存区
    for key in serialized_storage_keys:
        serialized_storages[key]._write_file(f, _should_read_directly(f), True# 写入具体数据

可以看到_legacy_save()_save() 在序列化的过程中,整体的pipeline是类似的,只是写入的内容有轻微差别。

torch.load

torch.load 主要使用 pickle 来进行二进制反序列化。

def load(f, # 待反序列化的文件
         map_location=None, # 将对象放到cpu或cuda上,默认与文件里对象的location一致
         pickle_module=pickle, # 默认使用pickle来反序列化
         **pickle_load_args)
:

    _check_dill_version(pickle_module)

    if 'encoding' not in pickle_load_args.keys(): # 默认使用 utf-8 解码
        pickle_load_args['encoding'] = 'utf-8'

    with _open_file_like(f, 'rb'as opened_file:
        if _is_zipfile(opened_file): # 如果是基于 zipfile 的存储格式
            orig_position = opened_file.tell()
            with _open_zipfile_reader(opened_file) as opened_zipfile:
                if _is_torchscript_zip(opened_zipfile): # 如果存的torchscript文件,用torch.jit.load().否则用_load()反序列化
                    warnings.warn(
                        "'torch.load' received a zip file that looks like a TorchScript archive"
                        " dispatching to 'torch.jit.load' (call 'torch.jit.load' directly to"
                        " silence this warning)", UserWarning)
                    opened_file.seek(orig_position)
                    return torch.jit.load(opened_file)
                return _load(opened_zipfile, map_location, pickle_module,
                             **pickle_load_args)
        # 对二进制文件,用_legacy_load()反序列化
        return _legacy_load(opened_file, map_location, pickle_module,
                            **pickle_load_args)

可以看到核心函数是_load()_legacy_load(),接下来分别介绍,我们首先介绍_load()函数:

def _load(zip_file,
          map_location,
          pickle_module,
          pickle_file='data.pkl'# 注意这里的'data.pkl'与_save()中的一一对应
          **pickle_load_args)
:

    restore_location = _get_restore_location(map_location) # 根据map_location来生成restore_location函数,用于将数据放在cpu或cuda上

    loaded_storages = {}

    def load_tensor(data_type, size, key, location):
        name = f'data/{key}' # 数据的key,用于寻找数据
        dtype = data_type(0).dtype # 数据类型,比如 int, float, ...

        storage = zip_file.get_storage_from_record(name, size, dtype).storage() # 从文件中找到数据
        loaded_storages[key] = restore_location(storage, location) # 放到 cpu 或 cuda 上

    def persistent_load(saved_id):
        assert isinstance(saved_id, tuple) # save_id = ('storage', storage_type, obj_key, location, obj.size())
        typename = _maybe_decode_ascii(saved_id[0])
        data = saved_id[1:]

        assert typename == 'storage', \
            f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
        data_type, key, location, size = data # data_type, key, location, size = storage_type, obj_key, location, obj.size()
        if key not in loaded_storages:
            load_tensor(data_type, size, key, _maybe_decode_ascii(location))
        storage = loaded_storages[key]
        return storage

    data_file = io.BytesIO(zip_file.get_record(pickle_file)) # 读取对象的配置文件`data.pkl`,存储的对象的结构信息
    unpickler = pickle_module.Unpickler(data_file, **pickle_load_args)
    unpickler.persistent_load = persistent_load # 用于读取具体数据的persistent_load函数
    result = unpickler.load() # 执行读取操作

    torch._utils._validate_loaded_sparse_tensors()

    return result

总的来说 _load() 函数在将对象二进制反序列化的过程中,在构建对象结构信息的同时,就已经将具体的数据内容加载进来了。
_legacy_load()函数与它不同,_legacy_load()是先构建对象结构信息,再加载具体的数据。

def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
    deserialized_objects: Dict[int, Any] = {}

    restore_location = _get_restore_location(map_location) # 根据map_location来生成restore_location函数,用于将数据放在cpu或cuda上

    def legacy_load(f):
        deserialized_objects: Dict[int, Any] = {}

        # 由于不是基于 zipfile 的存储格式,报错退出,之后代码不会执行
        with closing(tarfile.open(fileobj=f, mode='r:', format=tarfile.PAX_FORMAT)) as tar, \
                mkdtemp() as tmpdir:
            ...

    deserialized_objects = {}

    def persistent_load(saved_id):
        assert isinstance(saved_id, tuple) # saved_id = ('storage', storage_type, obj_key, location, obj.size(), view_metadata)
                                           # or saved_id = ('module', obj, source_file, source)
        typename = _maybe_decode_ascii(saved_id[0])
        data = saved_id[1:]

        if typename == 'module':
            # Ignore containers that don't have any sources saved
            if all(data[1:]):
                _check_container_source(*data) # 检查source code是否一致
            return data[0]
        elif typename == 'storage'# 注意这里并没有载入具体数据,只是恢复了对象的结构信息
            data_type, root_key, location, size, view_metadata = data
            location = _maybe_decode_ascii(location)
            if root_key not in deserialized_objects:
                obj = data_type(size)
                obj._torch_load_uninitialized = True
                deserialized_objects[root_key] = restore_location(
                    obj, location)
            storage = deserialized_objects[root_key]
            if view_metadata is not None:
                view_key, offset, view_size = view_metadata
                if view_key not in deserialized_objects:
                    deserialized_objects[view_key] = storage[offset:offset +
                                                             view_size]
                return deserialized_objects[view_key]
            else:
                return storage
        else:
            raise RuntimeError("Unknown saved id type: %s" % saved_id[0])

    _check_seekable(f) # 检查文件是否支持seek(), tell()方法。seek()用于定位到文件任意位置,tell()返回指针在文件的当前位置
    f_should_read_directly = _should_read_directly(f) # 是否二进制可读,比如如果是zip文件,则为False。
                                                      # 但由于传进来的文件格式不是zip格式,这里一般为True

    if f_should_read_directly and f.tell() == 0:
        try:
            return legacy_load(f) # 因为不是 zip 格式,报错退出
        except tarfile.TarError:
            if _is_zipfile(f): # 一般不执行
                raise RuntimeError(
                    f"{f.name} is a zip archive (did you mean to use torch.jit.load()?)"
                ) from None
            f.seek(0# 定位到文件初始位置

    if not hasattr(f,
                   'readinto'and (380) <= sys.version_info < (382):
        raise RuntimeError(
            "torch.load does not work with file-like objects that do not implement readinto on Python 3.8.0 and 3.8.1. "
            f"Received object of type \"{type(f)}\". Please update to Python 3.8.2 or newer to restore this "
            "functionality.")

    magic_number = pickle_module.load(f, **pickle_load_args)
    if magic_number != MAGIC_NUMBER: # 检查MAGIC_NUMBER是否一致
        raise RuntimeError("Invalid magic number; corrupt file?")
    protocol_version = pickle_module.load(f, **pickle_load_args) 
    if protocol_version != PROTOCOL_VERSION: # 检查pickle协议是否一致
        raise RuntimeError("Invalid protocol version: %s" % protocol_version)

    _sys_info = pickle_module.load(f, **pickle_load_args) # 读取一些系统信息
    unpickler = pickle_module.Unpickler(f, **pickle_load_args)
    unpickler.persistent_load = persistent_load 
    result = unpickler.load() # 调用persistent_load()函数读取对象的结构信息,注意此时还未读取具体的数据

    deserialized_storage_keys = pickle_module.load(f, **pickle_load_args) # 读取数据对应的key,到这里可以发现pickle_module.load()
                                                                          # 出的结果和上面`_legacy_save()`函数中dump的内容一一对应

    offset = f.tell() if f_should_read_directly else None
    for key in deserialized_storage_keys: # 读取具体的数据
        assert key in deserialized_objects
        deserialized_objects[key]._set_from_file(f, offset,
                                                 f_should_read_directly)
        if offset is not None:
            offset = f.tell()

    torch._utils._validate_loaded_sparse_tensors()

    return result

load()_legacy_load()中都有_get_restore_location()函数生成restore_location(obj,location)函数,它决定将读取的对象(obj)放到 CPU or CUDA (location)上,接下来我们介绍_get_restore_location()

def _cpu_deserialize(obj, location): # 将对象放到cpu上,注意可能返回None
    if location == 'cpu':
        return obj


def _cuda_deserialize(obj, location): # 将对象放到指定的cuda device上,注意可能返回None
    if location.startswith('cuda'):
        device = validate_cuda_device(location) # 验证是否有显卡,以及给定的device id是否超过当前机器拥有的显卡数量
        if getattr(obj, "_torch_load_uninitialized"False):
            storage_type = getattr(torch.cuda, type(obj).__name__)
            with torch.cuda.device(device):
                return storage_type(obj.size())
        else:
            return obj.cuda(device)

_package_registry = []

def register_package(priority, tagger, deserializer):
    queue_elem = (priority, tagger, deserializer)
    _package_registry.append(queue_elem)
    _package_registry.sort()

register_package(10, _cpu_tag, _cpu_deserialize)
register_package(20, _cuda_tag, _cuda_deserialize)


def default_restore_location(storage, location): # 按先cpu后cuda的优先级将数据放入cpu或cuda上
    for _, _, fn in _package_registry:
        result = fn(storage, location)
        if result is not None:
            return result
    raise RuntimeError("don't know how to restore data location of " +
                       torch.typename(storage) + " (tagged with " + location +
                       ")")


def _get_restore_location(map_location):
    if map_location is None:
        restore_location = default_restore_location # map_location = None : 放到location记录的cpu or cuda上
    elif isinstance(map_location, dict):

        def restore_location(storage, location): # map_location = {'cpu': 'cuda:0'} : 如果location是'cpu',则放到'cuda:0'上;
                                                 # 否则仍放到'cpu'上
            location = map_location.get(location, location)
            return default_restore_location(storage, location)
    elif isinstance(map_location, _string_classes): # map_location = 'cuda:0' : 不管location是什么,都放到'cuda:0'上

        def restore_location(storage, location):
            return default_restore_location(storage, map_location)
    elif isinstance(map_location, torch.device):

        def restore_location(storage, location): # map_location = torch.device('cpu') : 不管location是什么,都放到'cpu'上
            return default_restore_location(storage, str(map_location))
    else:

        def restore_location(storage, location): # 可以替换default_restore_location函数,map_location是一个函数
                                                 # 比如 map_location = lambda storage, location: storage.cuda(1) 表示
                                                 # 不管location是什么,都放到'cuda:1'上
            result = map_location(storage, location)
            if result is None:
                result = default_restore_location(storage, location)
            return result

    return restore_location

以上是torch.serialization的源码分析,torch.serialization主要包含torch.save()torch.load()函数,其中torch.save()主要通过调用_save()or_legacy_save()实现,torch.load()主要通过调用_load()or_legacy_load().torch.load()中的map_location参数通过_get_restore_location()函数决定将对象反序列化到 CPU 还是 CUDA 上。

torch.hub

torch.hub 提供了一系列 pretrained models 来方便大家使用,我们以https:// github.com/ pytorch/ vision为例,介绍怎样使用 torch.hub 提供的接口来调用 torchvision 里的 model。

torch.hub主要提供了三个接口torch.hub.list(),torch.hub.help(),torch.hub.load(),我们依次介绍。

torch.hub.list() 会从给定的 GitHub repo 中寻找 hubconf.py(此文件导入 repo 里提供的所有 models),然后返回一个 list,里面包含了提供的 model 类名。https://github.com/pytorch/vision下的 hubconf.py 文件内容如下:

# Optional list of dependencies required by the package
dependencies = ['torch']

# classification
from torchvision.models.alexnet import alexnet
from torchvision.models.densenet import densenet121, densenet169, densenet201, densenet161
from torchvision.models.inception import inception_v3
from torchvision.models.resnet import resnet18, resnet34, resnet50, resnet101, resnet152,\
    resnext50_32x4d, resnext101_32x8d, wide_resnet50_2, wide_resnet101_2
from torchvision.models.squeezenet import squeezenet1_0, squeezenet1_1
from torchvision.models.vgg import vgg11, vgg13, vgg16, vgg19, vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn
from torchvision.models.googlenet import googlenet
from torchvision.models.shufflenetv2 import shufflenet_v2_x0_5, shufflenet_v2_x1_0
from torchvision.models.mobilenetv2 import mobilenet_v2
from torchvision.models.mobilenetv3 import mobilenet_v3_large, mobilenet_v3_small
from torchvision.models.mnasnet import mnasnet0_5, mnasnet0_75, mnasnet1_0, \
    mnasnet1_3

# segmentation
from torchvision.models.segmentation import fcn_resnet50, fcn_resnet101, \
    deeplabv3_resnet50, deeplabv3_resnet101, deeplabv3_mobilenet_v3_large, lraspp_mobilenet_v3_large

接下来我们分析torch.hub.list()代码:

def list(github, # repo的名字,比如`pytorch/vision`。注意没有前缀`https://github.com/`,代码里hard code进去了
         force_reload=False)
:
 # 是否要重新下载 repo
    repo_dir = _get_cache_or_reload(github, force_reload, True# 根据repo的地址下载到本地,然后返回下载到本地的repo的路径
                                                                # 可以通过torch.hub.get_dir()得到下载的根目录,提前通过
                                                                # torch.hub.set_dir(string)设置下载的根目录

    sys.path.insert(0, repo_dir) # 本地的repo路径加入到搜索路径中,优先级最高

    hub_module = import_module(MODULE_HUBCONF, repo_dir + '/' + MODULE_HUBCONF) # MODULE_HUBCONF = 'hubconf.py',
                                                                                # 从本地的repo中找到'hubconf.py',
                                                                                # 并解析得到'hubconf.py'里提供的所有module

    sys.path.remove(repo_dir) # 从搜索路径中删除本地repo路径

    # We take functions starts with '_' as internal helper functions
    entrypoints = [f for f in dir(hub_module) if callable(getattr(hub_module, f)) and not f.startswith('_')] # 注意这里类名的开头
                                                                                                             # 如果是'_'的话,
                                                                                                             # 将会被滤掉

    return entrypoints # list(string), repo提供的model类名

# An example:
print(torch.hub.list('pytorch/vision'True))

# print info:
'''
['alexnet', 'deeplabv3_mobilenet_v3_large', 'deeplabv3_resnet101', 'deeplabv3_resnet50', 'densenet121', 'densenet161',
'densenet169', 'densenet201', 'fcn_resnet101', 'fcn_resnet50', 'googlenet', 'inception_v3', 'lraspp_mobilenet_v3_large',
'mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', 'mnasnet1_3', 'mobilenet_v2', 'mobilenet_v3_large', 'mobilenet_v3_small',
'resnet101', 'resnet152', 'resnet18', 'resnet34', 'resnet50', 'resnext101_32x8d', 'resnext50_32x4d', 'shufflenet_v2_x0_5',
'shufflenet_v2_x1_0', 'squeezenet1_0', 'squeezenet1_1', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
'vgg19', 'vgg19_bn', 'wide_resnet101_2', 'wide_resnet50_2']
'''

torch.hub.help()会返回给定 repo 下给定 module 的文档:

def help(github,
         model, # module的名字,比如'resnet50'
         force_reload=False)
:

    repo_dir = _get_cache_or_reload(github, force_reload, True)

    sys.path.insert(0, repo_dir)

    hub_module = import_module(MODULE_HUBCONF, repo_dir + '/' + MODULE_HUBCONF)

    sys.path.remove(repo_dir)

    entry = _load_entry_from_hubconf(hub_module, model) # 从所有modules(hub_module)里找到给定module(model)

    return entry.__doc__ # 返回`model`的文档

# An example:
print(torch.hub.help('pytorch/vision''resnet18'True))

# print info:
'''
ResNet-18 model from
    `"Deep Residual Learning for Image Recognition" `_.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
'''

torch.hub.load()会返回实例化后的 module:

def load(repo_or_dir, # 本地的路径,或者github上的repo名
         model,
         *args, # 用于实例化 module
         **kwargs)
:
 # 用于实例化 module
    source = kwargs.pop('source''github').lower() # repo 默认从github上寻找
    force_reload = kwargs.pop('force_reload'False)
    verbose = kwargs.pop('verbose'True# 如果True,打印一些log

    if source not in ('github''local'): # 要么从github上找repo,要么从本地找repo
        raise ValueError(
            f'Unknown source: "{source}". Allowed values: "github" | "local".')

    if source == 'github':
        repo_or_dir = _get_cache_or_reload(repo_or_dir, force_reload, verbose)

    model = _load_local(repo_or_dir, model, *args, **kwargs)
    return model


def _load_local(hubconf_dir, model, *args, **kwargs):
    sys.path.insert(0, hubconf_dir)

    hubconf_path = os.path.join(hubconf_dir, MODULE_HUBCONF)
    hub_module = import_module(MODULE_HUBCONF, hubconf_path)

    entry = _load_entry_from_hubconf(hub_module, model) # 找到指定的module
    model = entry(*args, **kwargs) # 实例化 module

    sys.path.remove(hubconf_dir)

    return model


# An example:
resnet18 = torch.hub.load('pytorch/vision''resnet18', pretrained=True# 载入预训练权重

以上以 pytorch/vision 为例介绍了torch.hub的使用。实际上只要一个 GitHub repo 里有 hubconf.py 的文件,都可以使用 torch.hub 提供的接口,比如一个简单的例子 。

原文链接:https://zhuanlan.zhihu.com/p/364239544

如果觉得有用,就请分享到朋友圈吧!


浏览 96
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报