Distributed BatchNorm in TensorFlow
Introduction
When running distributed training on $B$ data points on $N$ machines in theory at least many layers will behave identically to running a batch of size $N * B$ on single machine because most layers process the elements of a batch in parallel. However when dealing with layers that particularly during training do things at a batchlevel the results might be different between the single machine and distributed settings. In this post we will see an example of how BatchNorm works when running distributed training with TensorFlow using TPUs and tf.keras.layers
. This post assumes you know BatchNormalisation works. Here is one tutorial but there is an endless number of them if you search on Google.
You can use TPUs for free on Colab and you can run this code as a notebook here. If using Colab make sure you select TPU
under Hardware accelerator
by going to Runtime>Change Runtime Type
in the top menu. The purpose of this blog is to demonstrate how distributed batch normalisation works so we won’t go into the details about using TPUs and the code here might not reflect best practices for training with TPUs. Please consult Tensorflow’s guide if you want to learn more about how to use TPUs.
Setting up
import tensorflow as tf
import numpy as np
First some setup code taken from the guide.
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
print("All devices: ", tf.config.list_logical_devices('TPU'))
INFO:tensorflow:Clearing out eager caches
INFO:tensorflow:Clearing out eager caches
WARNING:tensorflow:TPU system grpc://10.98.27.170:8470 has already been initialized. Reinitializing the TPU can cause previously created variables on TPU to be lost.
WARNING:tensorflow:TPU system grpc://10.98.27.170:8470 has already been initialized. Reinitializing the TPU can cause previously created variables on TPU to be lost.
INFO:tensorflow:Initializing the TPU system: grpc://10.98.27.170:8470
INFO:tensorflow:Initializing the TPU system: grpc://10.98.27.170:8470
INFO:tensorflow:Finished initializing TPU system.
INFO:tensorflow:Finished initializing TPU system.
All devices: [LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:0', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:1', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:2', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:3', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:4', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:5', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:6', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:7', device_type='TPU')]
tpu_strategy = tf.distribute.TPUStrategy(resolver)
INFO:tensorflow:Found TPU system:
INFO:tensorflow:Found TPU system:
INFO:tensorflow:*** Num TPU Cores: 8
INFO:tensorflow:*** Num TPU Cores: 8
INFO:tensorflow:*** Num TPU Workers: 1
INFO:tensorflow:*** Num TPU Workers: 1
INFO:tensorflow:*** Num TPU Cores Per Worker: 8
INFO:tensorflow:*** Num TPU Cores Per Worker: 8
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)
Now create a BatchNormExperiment
class that lets us get results easily. The experiment class creates a dataset and runs forward passes in the appropriate way depending on whether or not we are running in a distributed fashion.
class BatchNormExperiment(object):
def __init__(self, data, bn_layer, strategy=None):
self.bn_layer = bn_layer
self.strategy = strategy
if self.strategy is not None:
self.ds = strategy.experimental_distribute_datasets_from_function(
lambda _: tf.data.Dataset.from_tensor_slices(data).batch(32)
)
else:
self.ds = tf.data.Dataset.from_tensor_slices(data).batch(32 * 8)
@tf.function
def apply_bn_dist(self, iterator, training):
def _f(x, training):
return self.bn_layer(x, training=training)
result = self.strategy.run(_f, args=(next(iterator), training))
return tf.concat(result.values, axis=0)
@tf.function
def apply_bn(self, iterator, training):
return self.bn_layer(next(iterator), training=training)
def get_results(self):
bn = self.apply_bn if self.strategy is None else self.apply_bn_dist
iterator = iter(self.ds)
result = []
for _ in range(STEPS):
result.append(bn(iterator, training=True))
val_result = bn(iterator, training=False)
return tf.concat(result, axis=0), val_result
A simple function to calculate that max difference between two arrays / tensors since we will be comparing tensors which will be close but not identical
def max_diff(x, y):
if isinstance(x, tf.Tensor):
x = x.numpy()
if isinstance(y, tf.Tensor):
y = y.numpy()
return np.abs(np.array(x)  np.array(y)).max()
Now create a fake dataset consisting of 101 256 x 64 random matrices the first 100 of of which will be used to call batchnorm in the training mode and the last 1 in the inference mode. Note that we are not doing any training here. Training and inference refer to how the normalisation is done:
 In training mode, normalisation is done with the stats obtained from the input batch
 In inference mode means the normalisation is done with moving stats that are calculated using earlier batches that were run in the training mode
tf.keras.backend.clear_session()
STEPS = 100
data = tf.random.normal([32 * 8 * (STEPS + 1), 64])
bn_layer = tf.keras.layers.BatchNormalization()
bn_layer.build([None, 64])
init = bn_layer.get_weights()
Distributed BatchNorm
This layer normalises the per device batches separately on each device in training mode. However during inference all examples should be treated independently so the same moving stats will need to be used. The layer therefore averages the moving stats across devices. In the source code) the tf.Variable
instances for moving mean and variance are defined with these two settings:

synchronization=tf.VariableSynchronization.ON_READ
 according to the documentation here this means they will be aggregated when read which includes when applying batch normalisation 
aggregation=tf.VariableAggregation.MEAN
 indicates what aggregation should be use, here mean
Let us create a distributed batchnorm layer and initialise it with the same weights as bn_layer
with tpu_strategy.scope():
bn_dist = tf.keras.layers.BatchNormalization()
bn_dist.build([None, 64])
bn_dist.set_weights(init)
Setup the experiments and get results
exp_regular = BatchNormExperiment(data, bn_layer, strategy=None)
exp_dist = BatchNormExperiment(data, bn_dist, tpu_strategy)
trn_regular, val_regular = exp_regular.get_results()
trn_dist, val_dist = exp_dist.get_results()
Train results will be different when using regular distributed batchnorm since they are normalised using stats calculated separately in each device.
max_diff(trn_regular, trn_dist)
1.4629191
Val results will also be different since the batch norm stats are first calculated per sub batch and then averaged across the devices. The average of the per device batch mean is the same as the mean across all devices but the average of the per device variance is not the variance across all devices.
max_diff(val_regular, val_dist)
0.09319067
We observe that whilst the moving mean is close to the non distributed one the moving variance is different.
mov_mean_regular, mov_var_regular = bn_layer.get_weights()[2:]
mov_mean_dist, mov_var_dist = bn_dist.get_weights()[2:]
print('Moving mean diff', max_diff(mov_mean_regular, mov_mean_dist))
print('Moving var diff', max_diff(mov_var_regular, mov_var_dist))
Moving mean diff 3.1432137e09
Moving var diff 0.019875884
Note that get weights does the aggregation and returns a single value of the weight whereas bn_dist.moving_mean
will return the per device value
mov_var_dist_per_device = bn_dist.moving_variance
len(mov_var_dist_per_device.values)
8
Observe that they the mean of these and mov_var_dist
are identical
max_diff(tf.reduce_mean(mov_var_dist_per_device.values, axis=0), mov_var_dist)
0.0
# To avoid errors below delete everything corresponding to the distributed batch norm layer
del exp_dist, bn_dist, trn_dist, val_dist, mov_mean_dist, mov_var_dist, mov_var_dist_per_device
Synchronised BatchNorm
In synchronised BatchNorm as implemented in tf.keras.layers.experimental.SyncBatchNormalization
in the training mode batch stats are aggregated across the devices and the batches are normalised by the resulting value. From the source code we see that that first you find sum(x)
and sum(x^2)
and the batch_size
for each replica
local_sum = tf.reduce_sum(y, axis=axes, keepdims=True)
local_squared_sum = tf.reduce_sum(tf.square(y), axis=axes,
keepdims=True)
batch_size = tf.cast(tf.shape(y)[axes[0]],
tf.float32
Then you aggregate these across replicas using all_reduce
.
y_sum = replica_ctx.all_reduce(tf.distribute.ReduceOp.SUM, local_sum)
y_squared_sum = replica_ctx.all_reduce(tf.distribute.ReduceOp.SUM,
local_squared_sum)
global_batch_size = replica_ctx.all_reduce(tf.distribute.ReduceOp.SUM,
batch_size)
At this point each replica has a copy of the global sum, squared sum and global batch size which it can use to normalise its own subset of the data
axes_vals = [(tf.shape(y))[axes[i]]
for i in range(1, len(axes))]
multiplier = tf.cast(tf.reduce_prod(axes_vals),
tf.float32)
multiplier = multiplier * global_batch_size
mean = y_sum / multiplier
y_squared_mean = y_squared_sum / multiplier
# var = E(x^2)  E(x)^2
variance = y_squared_mean  tf.square(mean)
Let us initialise and run a synchronised BatchNorm layer
with tpu_strategy.scope():
sync_bn_dist = tf.keras.layers.experimental.SyncBatchNormalization()
sync_bn_dist.build([None, 64])
sync_bn_dist.set_weights(init)
exp_dist_sync = BatchNormExperiment(data, sync_bn_dist, tpu_strategy)
trn_dist_sync, val_dist_sync = exp_dist_sync.get_results()
Now the results should be the same up to a small epsilon in both training and inference modes assuming that the same data has been used in each case
max_diff(trn_regular, trn_dist_sync)
1.9073486e06
max_diff(val_regular, val_dist_sync)
7.1525574e07
Unsurprisingly both the moving stats are close this time
mov_mean_dist_sync, mov_var_dist_sync = sync_bn_dist.get_weights()[2:]
print('Moving mean diff', max_diff(mov_mean_regular, mov_mean_dist_sync))
print('Moving var diff', max_diff(mov_var_regular, mov_var_dist_sync))
Moving mean diff 4.4237822e09
Moving var diff 2.9802322e07
del exp_dist_sync, sync_bn_dist, trn_dist_sync, val_dist_sync, mov_mean_dist_sync, mov_var_dist_sync
Finally we can see that even in nonsynchronised distributed batchnorm during inference the same moving stats are applied in each replica since as noted before. We copy the weight from the regular batchnorm layer, bn_layer
, where the moving stats have been updated once and then run a nonsynchronised distributed batch norm step in inference mode.
with tpu_strategy.scope():
bn_copy = tf.keras.layers.BatchNormalization()
bn_copy.build([None, 64])
bn_copy.set_weights(bn_layer.get_weights())
exp_copy = BatchNormExperiment(data, bn_copy, tpu_strategy)
# Create iterator and advance to last batch which was used to get the "val" results
itr = iter(exp_copy.ds)
for i in range(STEPS):
_ = next(itr)
val_copy = exp_copy.apply_bn_dist(itr, training=False)
As expected this yields nearly the same values as val_regular
.
max_diff(val_regular, val_copy)
4.7683716e07