编写同时在PyTorch和Tensorflow上工作的代码

共 2687字,需浏览 6分钟

 ·

2021-08-10 18:43


点击上方小白学视觉”,选择加"星标"或“置顶

重磅干货,第一时间送达

“库开发人员不再需要在框架之间进行选择。”

来自德国图宾根人工智能中心的研究人员介绍了一种新的Python框架EagerPy,EagerPy允许开发人员编写独立于PyTorch和TensorFlow等流行框架的代码。

在最近发表的一篇关于EagerPy的文章中,研究人员写道,库开发人员不再关注框架依赖性。他们的新Python框架,急切地解决了它们的重新实现和代码复制障碍。

例如,Foolbox是一个构建在EagerPy之上的Python库。该库是用EagerPy而不是NumPy重写的,以实现在PyTorch和TensorFlow中开发,并且只有一个代码库,没有代码重复。Foolbox是一个对机器学习模型进行对抗性攻击的库。


框架无关的重要性


为了解决框架之间的差异,作者探索了句法偏差。在PyTorch的情况下,使用In-place的梯度需要使用**_grad_()「,而反向传播是使用」backward**()调用的。

然而,TensorFlow提供了一个高级管理器和像「tape.gradient」这样的函数来查询梯度。即使在句法层面,这两个框架也有很大的不同。例如,对于参数,dim vs axis;对于函数,sum vs reduce_sum。

这就是“EagerPy ”发挥作用的地方。它通过提供一个统一的API来解决PyTorch和TensorFlow之间的差异,该API透明地映射到各种底层框架,而无需计算开销。

EagerPy允许你编写自动使用PyTorch、TensorFlow、JAX和NumPy的代码。”

研究人员写道,EagerPy专注于Eager执行,此外,它的方法是透明的,用户可以将与框架无关的EagerPy代码与特定于框架的代码结合起来。

TensorFlow引入的eager执行模块和PyTorch的相似特性使eager执行成为主流,框架更加相似。然而,尽管PyTorch和TensorFlow2之间有这些相似之处,但编写框架无关的代码并不简单。在语法层面,这些框架中用于自动微分的api是不同的。

自动微分是指用算法求解微分方程。它的工作原理是链式规则,也就是说,求解函数的导数可以归结为基本的数学运算(加、减、乘、除)。这些算术运算可以用图形格式表示。EagerPy特别使用了一种函数式的方法来自动区分。

下面是一段来自文档的代码片段:

import eagerpy as ep

x = ep.astensor(x)

def loss_fn(x):
 #这个函数接受并返回一个eager张量
    return x.square().sum()

print(loss_fn(x))

# PyTorchTensor(tensor(14.))

print(ep.value_and_grad(loss_fn, x))

首先定义第一个函数,然后根据其输入进行微分。然后传递给「ep.value_and_grad」 来得到函数的值及其梯度。

此外,norm函数现在可以与PyTorch、TensorFlow、JAX和NumPy中的原生张量和数组一起使用,与本机代码相比几乎没有任何开销。它也适用于GPU张量。

import torch

norm(torch.tensor([1.2.3.]))

import tensorflow as tf

norm(tf.constant([1.2.3.]))

总之,EagerPy 旨在提供以下功能:

  • 为快速执行提供统一的API

  • 维护框架的本机性能

  • 完全可链接的API

  • 全面的类型检查支持

研究人员声称,这些属性使得使用这些属性比底层框架特定的api更容易、更安全。尽管有这些变化和改进,但EagerPy 背后的团队还是确保了eagerpy API遵循了NumPy、PyTorch和JAX设置的标准。


入门「EagerPy」


使用pip从PyPI安装最新版本:

python3 -m pip install eagerpy
import eagerpy as ep

def norm(x):

    x = ep.astensor(x)

    result = x.square().sum().sqrt()

    return result.raw

下载1:OpenCV-Contrib扩展模块中文版教程
在「小白学视觉」公众号后台回复:扩展模块中文教程即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。

下载2:Python视觉实战项目52讲
小白学视觉公众号后台回复:Python视觉实战项目即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。

下载3:OpenCV实战项目20讲
小白学视觉公众号后台回复:OpenCV实战项目20讲即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。

交流群


欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~


浏览 29
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报