Tensorflow的妙用​

共 1929字,需浏览 4分钟

 ·

2020-11-07 07:44


向大家推荐一个 TensorFlow 工具———TensorFlow Hub,它包含各种预训练模型的综合代码库,这些模型稍作调整便可部署到任何设备上。只需几行代码即可重复使用经过训练的模型,例如 BERT 和 Faster R-CNN,实现这些些牛X的应用,简直和把大象装进冰箱一样简单。

第一步:安装 TensorFlow Hub

Tensorflow_hub 库可与  TensorFlow 一起安装(建议直接上TF2)

pip install "tensorflow>=2.0.0"
pip install --upgrade tensorflow-hub

使用时

import tensorflow as tf
import tensorflow_hub as hub

第二步:从 TF Hub 下载模型

TensorFlow Hub 在 hub.tensorflow.google.cn 中提供了一个开放的训练模型存储库。tensorflow_hub 库可以从这个存储库和其他基于 HTTP 的机器学习模型存储库中加载模型。

从 下载并解压缩模型后,tensorflow_hub 库会将这些模型缓存到文件系统上。下载位置默认为本地临时目录,但可以通过设置环境变量 TFHUB_CACHE_DIR(推荐)或传递命令行标记 --tfhub_cache_dir 进行自定义。

os.environ['TFHUB_CACHE_DIR'] = '/home/user/workspace/tf_cache'

值得注意的是,TensorFlow Hub Module仅为我们提供了包含模型体系结构的图形以及在某些数据集上训练的权重。大多数模块允许访问模型的内部层,可以根据不同的用例使用。但是,有些模块不能精细调整。在开始开发之前,建议在TensorFlow Hub网站中查看有关该模块的说明。

以目标检测为例:打开网站,动几下鼠标即可
https://hub.tensorflow.google.cn/

拿来直接用

module_handle = "https://hub.tensorflow.google.cn/google/faster_rcnn/openimages_v4/inception_resnet_v2/1" 
detector = hub.load(module_handle).signatures['default']
def load_img(path):
  img = tf.io.read_file(path)
  img = tf.image.decode_jpeg(img, channels=3)
  return img
def run_detector(detector, path):
  img = load_img(path)

  converted_img  = tf.image.convert_image_dtype(img, tf.float32)[tf.newaxis, ...]
  start_time = time.time()
  result = detector(converted_img)
  end_time = time.time()

  result = {key:value.numpy() for key,value in result.items()}

  print("Found %d objects." % len(result["detection_scores"]))
  print("Inference time: ", end_time-start_time)

  image_with_boxes = draw_boxes(
      img.numpy(), result["detection_boxes"],
      result["detection_class_entities"], result["detection_scores"])

  display_image(image_with_boxes)
run_detector(detector, downloaded_image_path)

无需重复训练,拿来即用!6不6?

代码参考:
https://tensorflow.google.cn/hub/tutorials/object_detection


浏览 42
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报