##### Copyright 2018 The TensorFlow Authors.


In [1]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Distributed training with TensorFlow

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://www.tensorflow.org/guide/distributed_training"><img src="https://www.tensorflow.org/images/tf_logo_32px.png" />View on TensorFlow.org</a>
  </td>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/guide/distributed_training.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/tensorflow/docs/blob/master/site/en/guide/distributed_training.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
  <td>
    <a href="https://storage.googleapis.com/tensorflow_docs/docs/site/en/guide/distributed_training.ipynb"><img src="https://www.tensorflow.org/images/download_logo_32px.png" />Download notebook</a>
  </td>
</table>

## Overview

`tf.distribute.Strategy` is a TensorFlow API to distribute training across multiple GPUs, multiple machines, or TPUs. Using this API, you can distribute your existing models and training code with minimal code changes.

`tf.distribute.Strategy` has been designed with these key goals in mind:

* Easy to use and support multiple user segments, including researchers, machine learning engineers, etc.
* Provide good performance out of the box.
* Easy switching between strategies.

You can distribute training using `tf.distribute.Strategy` with a high-level API like Keras `Model.fit`, as well as [custom training loops](https://www.tensorflow.org/guide/keras/writing_a_training_loop_from_scratch) (and, in general, any computation using TensorFlow).

In TensorFlow 2.x, you can execute your programs eagerly, or in a graph using [`tf.function`](function.ipynb). `tf.distribute.Strategy` intends to support both these modes of execution, but works best with `tf.function`. Eager mode is only recommended for debugging purposes and not supported for `tf.distribute.TPUStrategy`. Although training is the focus of this guide, this API can also be used for distributing evaluation and prediction on different platforms.

You can use `tf.distribute.Strategy` with very few changes to your code, because the underlying components of TensorFlow have been changed to become strategy-aware. This includes variables, layers, models, optimizers, metrics, summaries, and checkpoints.

In this guide, you will learn about various types of strategies and how you can use them in different situations. To learn how to debug performance issues, check out the [Optimize TensorFlow GPU performance](gpu_performance_analysis.md) guide.

Note: For a deeper understanding of the concepts, watch the deep-dive presentation—[Inside TensorFlow: `tf.distribute.Strategy`](https://youtu.be/jKV53r9-H14). This is especially recommended if you plan to write your own training loop.


## Set up TensorFlow

In [2]:
import tensorflow as tf

2024-10-25 03:10:09.809713: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1729825809.832772  192915 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1729825809.839425  192915 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


## Types of strategies

`tf.distribute.Strategy` intends to cover a number of use cases along different axes. Some of these combinations are currently supported and others will be added in the future. Some of these axes are:

- *Synchronous vs asynchronous training:* These are two common ways of distributing training with data parallelism. In sync training, all workers train over different slices of input data in sync, and aggregating gradients at each step. In async training, all workers are independently training over the input data and updating variables asynchronously. Typically sync training is supported via all-reduce and async through parameter server architecture.
- *Hardware platform:* You may want to scale your training onto multiple GPUs on one machine, or multiple machines in a network (with 0 or more GPUs each), or on Cloud TPUs.

In order to support these use cases, TensorFlow has `MirroredStrategy`, `TPUStrategy`, `MultiWorkerMirroredStrategy`, `ParameterServerStrategy`, `CentralStorageStrategy`, as well as other strategies available. The next section explains which of these are supported in which scenarios in TensorFlow. Here is a quick overview:

| Training API             | `MirroredStrategy` | `TPUStrategy` | `MultiWorkerMirroredStrategy` | `CentralStorageStrategy` | `ParameterServerStrategy` |
| :----------------------- | :----------------- | :------------ | :---------------------------- | :----------------------- | :------------------------ |
| **Keras `Model.fit`**    | Supported          | Supported     | Supported                     | Experimental support     | Experimental support      |
| **Custom training loop** | Supported          | Supported     | Supported                     | Experimental support     | Experimental support      |
| **Estimator API**        | Limited Support    | Not supported | Limited Support               | Limited Support          | Limited Support           |

Note: [Experimental support](https://www.tensorflow.org/guide/versions#what_is_not_covered) means the APIs are not covered by any compatibility guarantees.

Warning: Estimator support is limited. Basic training and evaluation are experimental, and advanced features—such as scaffold—are not implemented. You should be using Keras or custom training loops if a use case is not covered. Estimators are not recommended for new code. Estimators run `v1.Session`-style code which is more difficult to write correctly, and can behave unexpectedly, especially when combined with TF 2 code. Estimators do fall under our [compatibility guarantees](https://tensorflow.org/guide/versions), but will receive no fixes other than security vulnerabilities. Go to the [migration guide](https://tensorflow.org/guide/migrate) for details.

### MirroredStrategy

`tf.distribute.MirroredStrategy` supports synchronous distributed training on multiple GPUs on one machine. It creates one replica per GPU device. Each variable in the model is mirrored across all the replicas. Together, these variables form a single conceptual variable called `MirroredVariable`. These variables are kept in sync with each other by applying identical updates.

Efficient all-reduce algorithms are used to communicate the variable updates across the devices. All-reduce aggregates tensors across all the devices by adding them up, and makes them available on each device. It’s a fused algorithm that is very efficient and can reduce the overhead of synchronization significantly. There are many all-reduce algorithms and implementations available, depending on the type of communication available between devices. By default, it uses the NVIDIA Collective Communication Library ([NCCL](https://developer.nvidia.com/nccl)) as the all-reduce implementation. You can choose from a few other options or write your own.

Here is the simplest way of creating `MirroredStrategy`:

In [3]:
mirrored_strategy = tf.distribute.MirroredStrategy()

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)


W0000 00:00:1729825812.490898  192915 gpu_device.cc:2344] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


This will create a `MirroredStrategy` instance, which will use all the GPUs that are visible to TensorFlow, and NCCL—as the cross-device communication.

If you wish to use only some of the GPUs on your machine, you can do so like this:

In [4]:
mirrored_strategy = tf.distribute.MirroredStrategy(devices=["/gpu:0", "/gpu:1"])

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1')


If you wish to override the cross device communication, you can do so using the `cross_device_ops` argument by supplying an instance of `tf.distribute.CrossDeviceOps`. Currently, `tf.distribute.HierarchicalCopyAllReduce` and `tf.distribute.ReductionToOneDevice` are two options other than `tf.distribute.NcclAllReduce`, which is the default.

In [5]:
mirrored_strategy = tf.distribute.MirroredStrategy(
    cross_device_ops=tf.distribute.HierarchicalCopyAllReduce())

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)


### TPUStrategy

`tf.distribute.TPUStrategy` lets you run your TensorFlow training on [Tensor Processing Units (TPUs)](tpu.ipynb). TPUs are Google's specialized ASICs designed to dramatically accelerate machine learning workloads. They are available on [Google Colab](https://colab.research.google.com/), the [TPU Research Cloud](https://sites.research.google/trc/), and [Cloud TPU](https://cloud.google.com/tpu).

In terms of distributed training architecture, `TPUStrategy` is the same `MirroredStrategy`—it implements synchronous distributed training. TPUs provide their own implementation of efficient all-reduce and other collective operations across multiple TPU cores, which are used in `TPUStrategy`.

Here is how you would instantiate `TPUStrategy`:

Note: To run any TPU code in Colab, you should select TPU as the Colab runtime. Refer to the [Use TPUs](tpu.ipynb) guide for a complete example.

```python
cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
    tpu=tpu_address)
tf.config.experimental_connect_to_cluster(cluster_resolver)
tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
tpu_strategy = tf.distribute.TPUStrategy(cluster_resolver)
```

The `TPUClusterResolver` instance helps locate the TPUs. In Colab, you don't need to specify any arguments to it.

If you want to use this for Cloud TPUs:

- You must specify the name of your TPU resource in the `tpu` argument.
- You must initialize the TPU system explicitly at the *start* of the program. This is required before TPUs can be used for computation. Initializing the TPU system also wipes out the TPU memory, so it's important to complete this step first in order to avoid losing state.

### MultiWorkerMirroredStrategy

`tf.distribute.MultiWorkerMirroredStrategy` is very similar to `MirroredStrategy`. It implements synchronous distributed training across multiple workers, each with potentially multiple GPUs. Similar to `tf.distribute.MirroredStrategy`, it creates copies of all variables in the model on each device across all workers.

Here is the simplest way of creating `MultiWorkerMirroredStrategy`:

In [6]:
strategy = tf.distribute.MultiWorkerMirroredStrategy()



INFO:tensorflow:Using MirroredStrategy with devices ('/device:CPU:0',)


INFO:tensorflow:Single-worker MultiWorkerMirroredStrategy with local_devices = ('/device:CPU:0',), communication = CommunicationImplementation.AUTO


`MultiWorkerMirroredStrategy` has two implementations for cross-device communications. `CommunicationImplementation.RING` is [RPC](https://en.wikipedia.org/wiki/Remote_procedure_call)-based and supports both CPUs and GPUs. `CommunicationImplementation.NCCL` uses NCCL and provides state-of-art performance on GPUs but it doesn't support CPUs. `CollectiveCommunication.AUTO` defers the choice to Tensorflow. You can specify them in the following way:


In [7]:
communication_options = tf.distribute.experimental.CommunicationOptions(
    implementation=tf.distribute.experimental.CommunicationImplementation.NCCL)
strategy = tf.distribute.MultiWorkerMirroredStrategy(
    communication_options=communication_options)



INFO:tensorflow:Using MirroredStrategy with devices ('/device:CPU:0',)




INFO:tensorflow:Single-worker MultiWorkerMirroredStrategy with local_devices = ('/device:CPU:0',), communication = CommunicationImplementation.NCCL


One of the key differences to get multi worker training going, as compared to multi-GPU training, is the multi-worker setup. The `'TF_CONFIG'` environment variable is the standard way in TensorFlow to specify the cluster configuration to each worker that is part of the cluster. Learn more in the [setting up TF_CONFIG section](#TF_CONFIG) of this document.

For more details about `MultiWorkerMirroredStrategy`, consider the following tutorials:

- [Multi-worker training with Keras Model.fit](../tutorials/distribute/multi_worker_with_keras.ipynb)
- [Multi-worker training with a custom training loop](../tutorials/distribute/multi_worker_with_ctl.ipynb)

### ParameterServerStrategy

Parameter server training is a common data-parallel method to scale up model training on multiple machines. A parameter server training cluster consists of workers and parameter servers. Variables are created on parameter servers and they are read and updated by workers in each step. Check out the [Parameter server training](../tutorials/distribute/parameter_server_training.ipynb) tutorial for details.

In TensorFlow 2, parameter server training uses a central coordinator-based architecture via the `tf.distribute.experimental.coordinator.ClusterCoordinator` class.

In this implementation, the `worker` and `parameter server` tasks run `tf.distribute.Server`s that listen for tasks from the coordinator. The coordinator creates resources, dispatches training tasks, writes checkpoints, and deals with task failures.

In the programming running on the coordinator, you will use a `ParameterServerStrategy` object to define a training step and use a `ClusterCoordinator` to dispatch training steps to remote workers. Here is the simplest way to create them:

```python
strategy = tf.distribute.experimental.ParameterServerStrategy(
    tf.distribute.cluster_resolver.TFConfigClusterResolver(),
    variable_partitioner=variable_partitioner)
coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
    strategy)
```

To learn more about `ParameterServerStrategy`, check out the [Parameter server training with Keras Model.fit and a custom training loop](../tutorials/distribute/parameter_server_training.ipynb) tutorial.

Note: You will need to configure the `'TF_CONFIG'` environment variable if you use `TFConfigClusterResolver`. It is similar to [`'TF_CONFIG'`](#TF_CONFIG) in `MultiWorkerMirroredStrategy` but has additional caveats.

In TensorFlow 1, `ParameterServerStrategy` is available only with an Estimator via `tf.compat.v1.distribute.experimental.ParameterServerStrategy` symbol.

Note: This strategy is [`experimental`](https://www.tensorflow.org/guide/versions#what_is_not_covered) as it is currently under active development.

### CentralStorageStrategy
`tf.distribute.experimental.CentralStorageStrategy` does synchronous training as well. Variables are not mirrored, instead they are placed on the CPU and operations are replicated across all local GPUs. If there is only one GPU, all variables and operations will be placed on that GPU.

Create an instance of `CentralStorageStrategy` by:


In [8]:
central_storage_strategy = tf.distribute.experimental.CentralStorageStrategy()

INFO:tensorflow:ParameterServerStrategy (CentralStorageStrategy if you are using a single machine) with compute_devices = ['/job:localhost/replica:0/task:0/device:CPU:0'], variable_device = '/job:localhost/replica:0/task:0/device:CPU:0'


This will create a `CentralStorageStrategy` instance which will use all visible GPUs and CPU. Update to variables on replicas will be aggregated before being applied to variables.

Note: This strategy is [`experimental`](https://www.tensorflow.org/guide/versions#what_is_not_covered), as it is currently a work in progress.

### Other strategies

In addition to the above strategies, there are two other strategies which might be useful for prototyping and debugging when using `tf.distribute` APIs.

#### Default Strategy

The Default Strategy is a distribution strategy which is present when no explicit distribution strategy is in scope. It implements the `tf.distribute.Strategy` interface but is a pass-through and provides no actual distribution. For instance, `Strategy.run(fn)` will simply call `fn`. Code written using this strategy should behave exactly as code written without any strategy. You can think of it as a "no-op" strategy.

The Default Strategy is a singleton—and one cannot create more instances of it. It can be obtained using `tf.distribute.get_strategy` outside any explicit strategy's scope (the same API that can be used to get the current strategy inside an explicit strategy's scope).

In [9]:
default_strategy = tf.distribute.get_strategy()

This strategy serves two main purposes:

* It allows writing distribution-aware library code unconditionally. For example, in `tf.keras.optimizers` you can use `tf.distribute.get_strategy` and use that strategy for reducing gradients—it will always return a strategy object on which you can call the `Strategy.reduce` API.


In [10]:
# In optimizer or other library code
# Get currently active strategy
strategy = tf.distribute.get_strategy()
strategy.reduce("SUM", 1., axis=None)  # reduce some values

1.0

* Similar to library code, it can be used to write end users' programs to work with and without distribution strategy, without requiring conditional logic. Here's a sample code snippet illustrating this:

In [11]:
if tf.config.list_physical_devices('GPU'):
  strategy = tf.distribute.MirroredStrategy()
else:  # Use the Default Strategy
  strategy = tf.distribute.get_strategy()

with strategy.scope():
  # Do something interesting
  print(tf.Variable(1.))

<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>


#### OneDeviceStrategy

`tf.distribute.OneDeviceStrategy` is a strategy to place all variables and computation on a single specified device.

```python
strategy = tf.distribute.OneDeviceStrategy(device="/gpu:0")
```

This strategy is distinct from the Default Strategy in a number of ways. In the Default Strategy, the variable placement logic remains unchanged when compared to running TensorFlow without any distribution strategy. But when using `OneDeviceStrategy`, all variables created in its scope are explicitly placed on the specified device. Moreover, any functions called via `OneDeviceStrategy.run` will also be placed on the specified device.

Input distributed through this strategy will be prefetched to the specified device. In the Default Strategy, there is no input distribution.

Similar to the Default Strategy, this strategy could also be used to test your code before switching to other strategies which actually distribute to multiple devices/machines. This will exercise the distribution strategy machinery somewhat more than the Default Strategy, but not to the full extent of using, for example, `MirroredStrategy` or `TPUStrategy`. If you want code that behaves as if there is no strategy, then use the Default Strategy.

So far you've learned about different strategies and how you can instantiate them. The next few sections show the different ways in which you can use them to distribute your training.

## Use tf.distribute.Strategy with Keras Model.fit

`tf.distribute.Strategy` is integrated into `tf.keras`, which is TensorFlow's implementation of the [Keras API specification](https://keras.io/api/). `tf.keras` is a high-level API to build and train models. By integrating into the `tf.keras` backend, it's seamless for you to distribute your training written in the Keras training framework [using Model.fit](https://www.tensorflow.org/guide/keras/customizing_what_happens_in_fit).

Here's what you need to change in your code:

1. Create an instance of the appropriate `tf.distribute.Strategy`.
2. Move the creation of Keras model, optimizer and metrics inside `strategy.scope`. Thus the code in the model's `call()`, `train_step()`, and `test_step()` methods will all be distributed and executed on the accelerator(s).

TensorFlow distribution strategies support all types of Keras models—[Sequential](https://www.tensorflow.org/guide/keras/sequential_model), [Functional](https://www.tensorflow.org/guide/keras/functional), and [subclassed](https://www.tensorflow.org/guide/keras/custom_layers_and_models)

Here is a snippet of code to do this for a very simple Keras model with one `Dense` layer:

In [12]:
mirrored_strategy = tf.distribute.MirroredStrategy()

with mirrored_strategy.scope():
  model = tf.keras.Sequential([
      tf.keras.layers.Dense(1, input_shape=(1,),
                            kernel_regularizer=tf.keras.regularizers.L2(1e-4))])
  model.compile(loss='mse', optimizer='sgd')

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


This example uses `MirroredStrategy`, so you can run this on a machine with multiple GPUs. `strategy.scope()` indicates to Keras which strategy to use to distribute the training. Creating models/optimizers/metrics inside this scope allows you to create distributed variables instead of regular variables. Once this is set up, you can fit your model like you would normally. `MirroredStrategy` takes care of replicating the model's training on the available GPUs, aggregating gradients, and more.

In [13]:
dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(100).batch(10)
model.fit(dataset, epochs=2)
model.evaluate(dataset)

2024-10-25 03:10:12.686928: W tensorflow/core/framework/dataset.cc:993] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.


Epoch 1/2


[1m 1/10[0m [32m━━[0m[37m━━━━━━━━━━━━━━━━━━[0m [1m2s[0m 287ms/step - loss: 0.6492

[1m 2/10[0m [32m━━━━[0m[37m━━━━━━━━━━━━━━━━[0m [1m0s[0m 3ms/step - loss: 0.6364[1m 3/10[0m [32m━━━━━━[0m[37m━━━━━━━━━━━━━━[0m [1m0s[0m 3ms/step - loss: 0.6242

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step - loss: 0.5412


Epoch 2/2


[1m 1/10[0m [32m━━[0m[37m━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 61ms/step - loss: 0.2869

[1m 2/10[0m [32m━━━━[0m[37m━━━━━━━━━━━━━━━━[0m [1m0s[0m 1ms/step - loss: 0.2813 

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1ms/step - loss: 0.2392


[1m 1/10[0m [32m━━[0m[37m━━━━━━━━━━━━━━━━━━[0m [1m1s[0m 168ms/step - loss: 0.1268

[1m 2/10[0m [32m━━━━[0m[37m━━━━━━━━━━━━━━━━[0m [1m0s[0m 2ms/step - loss: 0.1268[1m 3/10[0m [32m━━━━━━[0m[37m━━━━━━━━━━━━━━[0m [1m0s[0m 2ms/step - loss: 0.1268  

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step - loss: 0.1268


0.1268472820520401

Here a `tf.data.Dataset` provides the training and eval input. You can also use NumPy arrays:

In [14]:
import numpy as np

inputs, targets = np.ones((100, 1)), np.ones((100, 1))
model.fit(inputs, targets, epochs=2, batch_size=10)

Epoch 1/2


[1m 1/10[0m [32m━━[0m[37m━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 65ms/step - loss: 0.1268

[1m 2/10[0m [32m━━━━[0m[37m━━━━━━━━━━━━━━━━[0m [1m0s[0m 2ms/step - loss: 0.1244

[1m 3/10[0m [32m━━━━━━[0m[37m━━━━━━━━━━━━━━[0m [1m0s[0m 2ms/step - loss: 0.1220 

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step - loss: 0.1058


Epoch 2/2


[1m 1/10[0m [32m━━[0m[37m━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 69ms/step - loss: 0.0561[1m 2/10[0m [32m━━━━[0m[37m━━━━━━━━━━━━━━━━[0m [1m0s[0m 1ms/step - loss: 0.0550

[1m 3/10[0m [32m━━━━━━[0m[37m━━━━━━━━━━━━━━[0m [1m0s[0m 2ms/step - loss: 0.0539 

[1m 4/10[0m [32m━━━━━━━━[0m[37m━━━━━━━━━━━━[0m [1m0s[0m 2ms/step - loss: 0.0529 

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step - loss: 0.0468


<keras.src.callbacks.history.History at 0x7fab904bcfa0>

In both cases—with `Dataset` or NumPy—each batch of the given input is divided equally among the multiple replicas. For instance, if you are using the `MirroredStrategy` with 2 GPUs, each batch of size 10 will be divided among the 2 GPUs, with each receiving 5 input examples in each step. Each epoch will then train faster as you add more GPUs. Typically, you would want to increase your batch size as you add more accelerators, so as to make effective use of the extra computing power. You will also need to re-tune your learning rate, depending on the model. You can use `strategy.num_replicas_in_sync` to get the number of replicas.

In [15]:
mirrored_strategy.num_replicas_in_sync

1

In [16]:
# Compute a global batch size using a number of replicas.
BATCH_SIZE_PER_REPLICA = 5
global_batch_size = (BATCH_SIZE_PER_REPLICA *
                     mirrored_strategy.num_replicas_in_sync)
dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(100)
dataset = dataset.batch(global_batch_size)

LEARNING_RATES_BY_BATCH_SIZE = {5: 0.1, 10: 0.15, 20:0.175}
learning_rate = LEARNING_RATES_BY_BATCH_SIZE[global_batch_size]

### What's supported now?

| Training API      | `MirroredStrategy` | `TPUStrategy` | `MultiWorkerMirroredStrategy` | `ParameterServerStrategy` | `CentralStorageStrategy` |
| ----------------- | ------------------ | ------------- | ----------------------------- | ------------------------- | ------------------------ |
| Keras `Model.fit` | Supported          | Supported     | Supported                     | Experimental support      | Experimental support     |

### Examples and tutorials

Here is a list of tutorials and examples that illustrate the above integration end-to-end with Keras `Model.fit`:

1. [Tutorial](../tutorials/distribute/keras.ipynb): Training with `Model.fit` and `MirroredStrategy`.
2. [Tutorial](../tutorials/distribute/multi_worker_with_keras.ipynb): Training with `Model.fit` and `MultiWorkerMirroredStrategy`.
3. [Guide](tpu.ipynb): Contains an example of using `Model.fit` and `TPUStrategy`.
4. [Tutorial](../tutorials/distribute/parameter_server_training.ipynb): Parameter server training with `Model.fit` and `ParameterServerStrategy`.
5. [Tutorial](https://www.tensorflow.org/text/tutorials/bert_glue): Fine-tuning BERT for many tasks from the GLUE benchmark with `Model.fit` and `TPUStrategy`.
6. TensorFlow Model Garden [repository](https://github.com/tensorflow/models/tree/master/official) containing collections of state-of-the-art models implemented using various strategies.

## Use tf.distribute.Strategy with custom training loops

As demonstrated above, using `tf.distribute.Strategy` with Keras `Model.fit` requires changing only a couple lines of your code. With a little more effort, you can also use `tf.distribute.Strategy` [with custom training loops](https://www.tensorflow.org/guide/keras/writing_a_training_loop_from_scratch).

If you need more flexibility and control over your training loops than is possible with Estimator or Keras, you can write custom training loops. For instance, when using a GAN, you may want to take a different number of generator or discriminator steps each round. Similarly, the high level frameworks are not very suitable for Reinforcement Learning training.

The `tf.distribute.Strategy` classes provide a core set of methods to support custom training loops. Using these may require minor restructuring of the code initially, but once that is done, you should be able to switch between GPUs, TPUs, and multiple machines simply by changing the strategy instance.

Below is a brief snippet illustrating this use case for a simple training example using the same Keras model as before.


First, create the model and optimizer inside the strategy's scope. This ensures that any variables created with the model and optimizer are mirrored variables.

In [17]:
with mirrored_strategy.scope():
  model = tf.keras.Sequential([
      tf.keras.layers.Dense(1, input_shape=(1,),
                            kernel_regularizer=tf.keras.regularizers.L2(1e-4))])
  optimizer = tf.keras.optimizers.SGD()

Next, create the input dataset and call `tf.distribute.Strategy.experimental_distribute_dataset` to distribute the dataset based on the strategy.

In [18]:
dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(1000).batch(
    global_batch_size)
dist_dataset = mirrored_strategy.experimental_distribute_dataset(dataset)

Then, define one step of the training. Use `tf.GradientTape` to compute gradients and optimizer to apply those gradients to update your model's variables. To distribute this training step, put it in a function `train_step` and pass it to `tf.distribute.Strategy.run` along with the dataset inputs you got from the `dist_dataset` created before:

In [19]:
# Sets `reduction=NONE` to leave it to tf.nn.compute_average_loss() below.
loss_object = tf.keras.losses.BinaryCrossentropy(
  from_logits=True,
  reduction=tf.keras.losses.Reduction.NONE)

def train_step(inputs):
  features, labels = inputs

  with tf.GradientTape() as tape:
    predictions = model(features, training=True)
    per_example_loss = loss_object(labels, predictions)
    loss = tf.nn.compute_average_loss(per_example_loss)
    model_losses = model.losses
    if model_losses:
      loss += tf.nn.scale_regularization_loss(tf.add_n(model_losses))

  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))
  return loss

@tf.function
def distributed_train_step(dist_inputs):
  per_replica_losses = mirrored_strategy.run(train_step, args=(dist_inputs,))
  return mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
                         axis=None)

A few other things to note in the code above:

  1. You used `tf.nn.compute_average_loss` to reduce the per-example prediction losses to a scalar. `tf.nn.compute_average_loss` sums the per example loss and divides the sum by the global batch size. This is important because later after the gradients are calculated on each replica, they are aggregated across the replicas by **summing** them.

  By default, the global batch size is taken to be `tf.get_strategy().num_replicas_in_sync * tf.shape(per_example_loss)[0]`. It can also be specified explicitly as a keyword argument `global_batch_size=`. Without short batches, the default is equivalent to `tf.nn.compute_average_loss(..., global_batch_size=global_batch_size)` with the `global_batch_size` defined above. (For more on short batches and how to avoid or handle them, see the [Custom Training tutorial](../tutorials/distribute/custom_training.ipynb).)

  2. You used `tf.nn.scale_regularization_loss` to scale regularization losses registered with the `Model` object, if any, by `1/num_replicas_in_sync` as well. For those regularization losses that are input-dependent, it falls on the modeling code, not the custom training loop, to perform the averaging over the per-replica(!) batch size; that way the modeling code can remain agnostic of replication while the training loop remains agnostic of how regularization losses are computed.

  3. When you call `apply_gradients` within a distribution strategy scope, its behavior is modified. Specifically, before applying gradients on each parallel instance during synchronous training, it performs a sum-over-all-replicas of the gradients.

  4. You also used the `tf.distribute.Strategy.reduce` API to aggregate the results returned by `tf.distribute.Strategy.run` for reporting. `tf.distribute.Strategy.run` returns results from each local replica in the strategy, and there are multiple ways to consume this result. You can `reduce` them to get an aggregated value. You can also do `tf.distribute.Strategy.experimental_local_results` to get the list of values contained in the result, one per local replica.



Finally, once you have defined the training step, you can iterate over `dist_dataset` and run the training in a loop:

In [20]:
for dist_inputs in dist_dataset:
  print(distributed_train_step(dist_inputs))

tf.Tensor(0.9024367, shape=(), dtype=float32)
tf.Tensor(0.8953863, shape=(), dtype=float32)
tf.Tensor(0.8884038, shape=(), dtype=float32)
tf.Tensor(0.88148874, shape=(), dtype=float32)
tf.Tensor(0.87464076, shape=(), dtype=float32)
tf.Tensor(0.86785895, shape=(), dtype=float32)
tf.Tensor(0.86114323, shape=(), dtype=float32)
tf.Tensor(0.8544927, shape=(), dtype=float32)
tf.Tensor(0.84790725, shape=(), dtype=float32)
tf.Tensor(0.841386, shape=(), dtype=float32)
tf.Tensor(0.83492863, shape=(), dtype=float32)
tf.Tensor(0.8285344, shape=(), dtype=float32)
tf.Tensor(0.82220304, shape=(), dtype=float32)
tf.Tensor(0.8159339, shape=(), dtype=float32)
tf.Tensor(0.8097264, shape=(), dtype=float32)
tf.Tensor(0.8035801, shape=(), dtype=float32)
tf.Tensor(0.79749453, shape=(), dtype=float32)
tf.Tensor(0.79146886, shape=(), dtype=float32)
tf.Tensor(0.785503, shape=(), dtype=float32)
tf.Tensor(0.779596, shape=(), dtype=float32)
tf.Tensor(0.77374756, shape=(), dtype=float32)
tf.Tensor(0.7679571, shape=

tf.Tensor(0.51752573, shape=(), dtype=float32)
tf.Tensor(0.5142692, shape=(), dtype=float32)
tf.Tensor(0.51104385, shape=(), dtype=float32)
tf.Tensor(0.50784945, shape=(), dtype=float32)
tf.Tensor(0.50468564, shape=(), dtype=float32)
tf.Tensor(0.50155205, shape=(), dtype=float32)
tf.Tensor(0.49844825, shape=(), dtype=float32)
tf.Tensor(0.4953741, shape=(), dtype=float32)
tf.Tensor(0.49232918, shape=(), dtype=float32)
tf.Tensor(0.4893132, shape=(), dtype=float32)
tf.Tensor(0.48632562, shape=(), dtype=float32)
tf.Tensor(0.4833664, shape=(), dtype=float32)
tf.Tensor(0.4804351, shape=(), dtype=float32)
tf.Tensor(0.47753143, shape=(), dtype=float32)
tf.Tensor(0.47465506, shape=(), dtype=float32)
tf.Tensor(0.47180572, shape=(), dtype=float32)
tf.Tensor(0.46898302, shape=(), dtype=float32)
tf.Tensor(0.4661867, shape=(), dtype=float32)
tf.Tensor(0.46341658, shape=(), dtype=float32)
tf.Tensor(0.4606722, shape=(), dtype=float32)
tf.Tensor(0.4579534, shape=(), dtype=float32)
tf.Tensor(0.4552598, 

tf.Tensor(0.33074695, shape=(), dtype=float32)
tf.Tensor(0.32916442, shape=(), dtype=float32)
tf.Tensor(0.3275946, shape=(), dtype=float32)
tf.Tensor(0.3260375, shape=(), dtype=float32)
tf.Tensor(0.3244928, shape=(), dtype=float32)
tf.Tensor(0.3229605, shape=(), dtype=float32)
tf.Tensor(0.32144046, shape=(), dtype=float32)
tf.Tensor(0.31993246, shape=(), dtype=float32)
tf.Tensor(0.3184365, shape=(), dtype=float32)
tf.Tensor(0.31695238, shape=(), dtype=float32)
tf.Tensor(0.31548, shape=(), dtype=float32)
tf.Tensor(0.31401917, shape=(), dtype=float32)
tf.Tensor(0.3125699, shape=(), dtype=float32)
tf.Tensor(0.31113195, shape=(), dtype=float32)
tf.Tensor(0.30970532, shape=(), dtype=float32)
tf.Tensor(0.3082898, shape=(), dtype=float32)
tf.Tensor(0.30688527, shape=(), dtype=float32)
tf.Tensor(0.3054917, shape=(), dtype=float32)
tf.Tensor(0.30410892, shape=(), dtype=float32)
tf.Tensor(0.3027368, shape=(), dtype=float32)
tf.Tensor(0.30137527, shape=(), dtype=float32)
tf.Tensor(0.3000242, shap

In the example above, you iterated over the `dist_dataset` to provide input to your training. You are also provided with the `tf.distribute.Strategy.make_experimental_numpy_dataset` to support NumPy inputs. You can use this API to create a dataset before calling `tf.distribute.Strategy.experimental_distribute_dataset`.

Another way of iterating over your data is to explicitly use iterators. You may want to do this when you want to run for a given number of steps as opposed to iterating over the entire dataset. The above iteration would now be modified to first create an iterator and then explicitly call `next` on it to get the input data.

In [21]:
iterator = iter(dist_dataset)
for _ in range(10):
  print(distributed_train_step(next(iterator)))

tf.Tensor(0.27504745, shape=(), dtype=float32)
tf.Tensor(0.2738936, shape=(), dtype=float32)
tf.Tensor(0.2727481, shape=(), dtype=float32)
tf.Tensor(0.27161098, shape=(), dtype=float32)
tf.Tensor(0.27048206, shape=(), dtype=float32)
tf.Tensor(0.26936132, shape=(), dtype=float32)
tf.Tensor(0.26824862, shape=(), dtype=float32)
tf.Tensor(0.26714393, shape=(), dtype=float32)
tf.Tensor(0.26604718, shape=(), dtype=float32)
tf.Tensor(0.26495826, shape=(), dtype=float32)


This covers the simplest case of using `tf.distribute.Strategy` API to distribute custom training loops.

### What's supported now?

| Training API         | `MirroredStrategy` | `TPUStrategy` | `MultiWorkerMirroredStrategy` | `ParameterServerStrategy` | `CentralStorageStrategy` |
| :------------------- | :----------------- | :------------ | :---------------------------- | :------------------------ | :----------------------- |
| Custom training loop | Supported          | Supported     | Supported                     | Experimental support      | Experimental support     |

### Examples and tutorials

Here are some examples for using distribution strategies with custom training loops:

1. [Tutorial](../tutorials/distribute/custom_training.ipynb): Training with a custom training loop and `MirroredStrategy`.
2. [Tutorial](../tutorials/distribute/multi_worker_with_ctl.ipynb): Training with a custom training loop and `MultiWorkerMirroredStrategy`.
3. [Guide](tpu.ipynb): Contains an example of a custom training loop with `TPUStrategy`.
4. [Tutorial](../tutorials/distribute/parameter_server_training.ipynb): Parameter server training with a custom training loop and `ParameterServerStrategy`.
5. TensorFlow Model Garden [repository](https://github.com/tensorflow/models/tree/master/official) containing collections of state-of-the-art models implemented using various strategies.


## Other topics

This section covers some topics that are relevant to multiple use cases.

<a name="TF_CONFIG"></a>
### Setting up the TF\_CONFIG environment variable

For multi-worker training, as mentioned before, you need to set up the `'TF_CONFIG'` environment variable for each binary running in your cluster. The `'TF_CONFIG'` environment variable is a JSON string which specifies what tasks constitute a cluster, their addresses and each task's role in the cluster. The [`tensorflow/ecosystem`](https://github.com/tensorflow/ecosystem) repo provides a Kubernetes template, which sets up `'TF_CONFIG'` for your training tasks.

There are two components of `'TF_CONFIG'`: a cluster and a task.

- A cluster provides information about the training cluster, which is a dict consisting of different types of jobs such as workers. In multi-worker training, there is usually one worker that takes on a little more responsibility like saving checkpoint and writing summary file for TensorBoard in addition to what a regular worker does. Such worker is referred to as the "chief" worker, and it is customary that the worker with index `0` is appointed as the chief worker (in fact this is how `tf.distribute.Strategy` is implemented).
- A task on the other hand provides information about the current task. The first component cluster is the same for all workers, and the second component task is different on each worker and specifies the type and index of that worker.

One example of `'TF_CONFIG'` is:

```python
os.environ["TF_CONFIG"] = json.dumps({
    "cluster": {
        "worker": ["host1:port", "host2:port", "host3:port"],
        "ps": ["host4:port", "host5:port"]
    },
   "task": {"type": "worker", "index": 1}
})
```


This `'TF_CONFIG'` specifies that there are three workers and two `"ps"` tasks in the `"cluster"` along with their hosts and ports. The `"task"` part specifies the role of the current task in the `"cluster"`—worker `1` (the second worker). Valid roles in a cluster are `"chief"`, `"worker"`, `"ps"`, and `"evaluator"`. There should be no `"ps"` job except when using `tf.distribute.experimental.ParameterServerStrategy`.

## What's next?

`tf.distribute.Strategy` is actively under development. Try it out and provide your feedback using [GitHub issues](https://github.com/tensorflow/tensorflow/issues/new).