官方 | Keras分布式训练教程
共 8407字,需浏览 17分钟
·
2021-08-20 22:58
点击上方“小白学视觉”,选择加"星标"或“置顶”
重磅干货,第一时间送达
总览
tf.distribute.Strategy API提供了一种抽象,用于在多个处理单元之间分布您的训练。目的是允许用户以最小的更改使用现有模型和培训代码来进行分布式培训。
本教程使用tf.distribute.MirroredStrategy,它在一台机器上的多个GPU上进行同步训练的图内复制。本质上,它将所有模型变量复制到每个处理器。然后,它使用all-reduce组合所有处理器的梯度,并将组合后的值应用于模型的所有副本。
MirroredStategy是TensorFlow核心中可用的几种分发策略之一。您可以在分发策略指南中了解更多策略。
Keras API
本示例使用tf.keras API构建模型和训练循环。有关自定义训练循环,请参阅带有训练循环的tf.distribute.Strategy教程。
This example uses the tf.keras
API to build the model and training loop. For custom training loops, see the tf.distribute.Strategy with training loops tutorial.
Import dependencies
from __future__ import absolute_import, division, print_function, unicode_literals
# Import TensorFlow and TensorFlow Datasetstry:
!pip install -q tf-nightly
exceptException:
passimport tensorflow_datasets as tfds
import tensorflow as tf
tfds.disable_progress_bar()import os
print(tf.__version__)
2.1.0-dev20191004
Download the dataset
Download the MNIST dataset and load it from TensorFlow Datasets. This returns a dataset in tf.data
format.
Setting with_info
to True
includes the metadata for the entire dataset, which is being saved here to info
. Among other things, this metadata object includes the number of train and test examples.
datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']
Downloading and preparing dataset mnist (11.06 MiB) to /home/kbuilder/tensorflow_datasets/mnist/1.0.0...
/usr/lib/python3/dist-packages/urllib3/connectionpool.py:860: InsecureRequestWarning: Unverified HTTPS request is being made. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings
InsecureRequestWarning)
/usr/lib/python3/dist-packages/urllib3/connectionpool.py:860: InsecureRequestWarning: Unverified HTTPS request is being made. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings
InsecureRequestWarning)
/usr/lib/python3/dist-packages/urllib3/connectionpool.py:860: InsecureRequestWarning: Unverified HTTPS request is being made. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings
InsecureRequestWarning)
/usr/lib/python3/dist-packages/urllib3/connectionpool.py:860: InsecureRequestWarning: Unverified HTTPS request is being made. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings
InsecureRequestWarning)
WARNING:tensorflow:From /home/kbuilder/.local/lib/python3.6/site-packages/tensorflow_datasets/core/file_format_adapter.py:209: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version.
Instructions for updating:
Use eager execution and:
`tf.data.TFRecordDataset(path)`
WARNING:tensorflow:From /home/kbuilder/.local/lib/python3.6/site-packages/tensorflow_datasets/core/file_format_adapter.py:209: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version.
Instructions for updating:
Use eager execution and:
`tf.data.TFRecordDataset(path)`
Dataset mnist downloaded and prepared to /home/kbuilder/tensorflow_datasets/mnist/1.0.0. Subsequent calls will reuse this data.
Define distribution strategy
Create a MirroredStrategy
object. This will handle distribution, and provides a context manager (tf.distribute.MirroredStrategy.scope
) to build your model inside.
strategy = tf.distribute.MirroredStrategy()
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
Number of devices: 1
Setup input pipeline
When training a model with multiple GPUs, you can use the extra computing power effectively by increasing the batch size. In general, use the largest batch size that fits the GPU memory, and tune the learning rate accordingly.
# You can also do info.splits.total_num_examples to get the total# number of examples in the dataset.
num_train_examples = info.splits['train'].num_examples
num_test_examples = info.splits['test'].num_examples
BUFFER_SIZE = 10000
BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
Pixel values, which are 0-255, have to be normalized to the 0-1 range. Define this scale in a function.
def scale(image, label):
image = tf.cast(image, tf.float32)
image /= 255
return image, label
Apply this function to the training and test data, shuffle the training data, and batch it for training. Notice we are also keeping an in-memory cache of the training data to improve performance.
train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)
Create the model
Create and compile the Keras model in the context of strategy.scope
.
with strategy.scope():
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(loss='sparse_categorical_crossentropy',
optimizer=tf.keras.optimizers.Adam(),
metrics=['accuracy'])
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
Define the callbacks
The callbacks used here are:
TensorBoard: This callback writes a log for TensorBoard which allows you to visualize the graphs.
Model Checkpoint: This callback saves the model after every epoch.
Learning Rate Scheduler: Using this callback, you can schedule the learning rate to change after every epoch/batch.
For illustrative purposes, add a print callback to display the learning rate in the notebook.
# Define the checkpoint directory to store the checkpoints
checkpoint_dir = './training_checkpoints'# Name of the checkpoint files
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")
# Function for decaying the learning rate.# You can define any decay function you need.def decay(epoch):
if epoch < 3:
return1e-3
elif epoch >= 3and epoch < 7:
return1e-4
else:
return1e-5
# Callback for printing the LR at the end of each epoch.classPrintLR(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
print('\nLearning rate for epoch {} is {}'.format(epoch + 1,
callbacks = [
tf.keras.callbacks.TensorBoard(log_dir='./logs'),
tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix,
save_weights_only=True),
tf.keras.callbacks.LearningRateScheduler(decay),
PrintLR()]
Train and evaluate
Now, train the model in the usual way, calling fit
on the model and passing in the dataset created at the beginning of the tutorial. This step is the same whether you are distributing the training or not.
交流群
欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器、自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN、算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~