xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/estimator_training.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Training utilities for Estimator to use Distribute Coordinator."""
16
17import copy
18
19import six
20
21from tensorflow.python.distribute import distribute_coordinator as dc
22from tensorflow.python.distribute import distribute_coordinator_context as dc_context
23from tensorflow.python.distribute import multi_worker_util
24from tensorflow.python.platform import tf_logging as logging
25from tensorflow.python.training import server_lib
26
27# pylint: disable=protected-access
28CHIEF = dc._TaskType.CHIEF
29EVALUATOR = dc._TaskType.EVALUATOR
30PS = dc._TaskType.PS
31WORKER = dc._TaskType.WORKER
32
33# pylint: enable=protected-access
34
35
36def _count_ps(cluster_spec):
37  """Counts the number of parameter servers in cluster_spec."""
38  if not cluster_spec:
39    raise RuntimeError(
40        'Internal error: `_count_ps` does not expect empty cluster_spec.')
41
42  return len(cluster_spec.as_dict().get(PS, []))
43
44
45def _count_worker(cluster_spec, chief_task_type):
46  """Counts the number of workers (including chief) in cluster_spec."""
47  if not cluster_spec:
48    raise RuntimeError(
49        'Internal error: `_count_worker` does not expect empty cluster_spec.')
50
51  return (len(cluster_spec.as_dict().get(WORKER, [])) + len(
52      cluster_spec.as_dict().get(chief_task_type, [])))
53
54
55def _get_global_id(cluster_spec, task_type, task_id, chief_task_type):
56  """Returns the global id of the given task type in a cluster."""
57  if not task_type:
58    return 0
59
60  # Sort task names in cluster by "chief"/"master", "evaluator", "worker"
61  # and "ps". More details can be found at the documentation of
62  # `tf.estimator.RunConfig.global_id_in_cluster`.
63  task_type_ordered_list = []
64  if chief_task_type in cluster_spec.jobs:
65    task_type_ordered_list = [chief_task_type]
66  task_type_ordered_list.extend([
67      t for t in sorted(cluster_spec.jobs) if t != chief_task_type and t != PS
68  ])
69  if PS in cluster_spec.jobs:
70    task_type_ordered_list.append(PS)
71
72  # Find the right global_id for current task.
73  next_global_id = 0
74  for t in task_type_ordered_list:
75    if t == task_type:
76      return next_global_id + task_id
77    # `cluster_spec.job_tasks` returns all task addresses of type `t`.
78    next_global_id += len(cluster_spec.job_tasks(t))
79
80  # It is unexpected that it passes through all task_types in
81  # `task_type_ordered_list`.
82  raise RuntimeError('Internal Error: `task_type` ({}) is not in '
83                     'cluster_spec ({}).'.format(task_type, cluster_spec))
84
85
86def _init_run_config_from_worker_context(config, worker_context):
87  """Initializes run config from distribute coordinator's worker context."""
88
89  # pylint: disable=protected-access
90  config._service = None
91  config._cluster_spec = worker_context.cluster_spec
92  config._task_type = worker_context.task_type
93  config._task_id = worker_context.task_id
94  config._evaluation_master = worker_context.master_target
95  config._master = worker_context.master_target
96  config._is_chief = worker_context.is_chief
97
98  if config._cluster_spec:
99    # Distributed mode.
100    if config._task_type != EVALUATOR:
101
102      config._num_ps_replicas = _count_ps(config._cluster_spec)
103      config._num_worker_replicas = _count_worker(
104          config._cluster_spec, chief_task_type=CHIEF)
105      config._global_id_in_cluster = _get_global_id(
106          config._cluster_spec,
107          config._task_type,
108          config._task_id,
109          chief_task_type=CHIEF)
110    else:
111      # Evaluator task should not be aware of the other tasks.
112      config._cluster_spec = server_lib.ClusterSpec({})
113      config._num_ps_replicas = 0
114      config._num_worker_replicas = 0
115      config._global_id_in_cluster = None  # undefined
116  else:
117    # Local mode.
118    config._global_id_in_cluster = 0
119    config._num_ps_replicas = 0
120    config._num_worker_replicas = 1
121
122
123def init_run_config(config, tf_config):
124  """Initializes RunConfig for distribution strategies."""
125  # pylint: disable=protected-access
126  if (config._experimental_distribute and
127      config._experimental_distribute.train_distribute):
128    if config._train_distribute:
129      raise ValueError('Either `train_distribute` or'
130                       '`experimental_distribute.train_distribute` can be set.')
131    config._train_distribute = config._experimental_distribute.train_distribute
132
133  if (config._experimental_distribute and
134      config._experimental_distribute.eval_distribute):
135    if config._eval_distribute:
136      raise ValueError('Either `eval_distribute` or'
137                       '`experimental_distribute.eval_distribute` can be set.')
138    config._eval_distribute = config._experimental_distribute.eval_distribute
139
140  cluster_spec = server_lib.ClusterSpec(tf_config.get('cluster', {}))
141  config._init_distributed_setting_from_environment_var({})
142
143  # Use distribute coordinator with STANDALONE_CLIENT mode if
144  # `experimental_distribute.remote_cluster` is set.
145  if (config._train_distribute and config._experimental_distribute and
146      config._experimental_distribute.remote_cluster):
147    if cluster_spec:
148      raise ValueError('Cannot set both "cluster_spec" of TF_CONFIG and '
149                       '`experimental_distribute.remote_cluster`')
150    config._distribute_coordinator_mode = dc.CoordinatorMode.STANDALONE_CLIENT
151    config._cluster_spec = config._experimental_distribute.remote_cluster
152    logging.info('RunConfig initialized for Distribute Coordinator with '
153                 'STANDALONE_CLIENT mode')
154    return
155
156  # Don't use distribute coordinator if it is local training or cluster has a
157  # MASTER job or `train_distribute` is not specified.
158  if (not cluster_spec or 'master' in cluster_spec.jobs or
159      not config._train_distribute):
160    config._distribute_coordinator_mode = None
161    config._init_distributed_setting_from_environment_var(tf_config)
162    config._maybe_overwrite_session_config_for_distributed_training()
163    logging.info('Not using Distribute Coordinator.')
164    return
165
166  # Use distribute coordinator with INDEPENDENT_WORKER mode otherwise.
167  assert tf_config
168
169  # Set the cluster_spec only since the distributed setting will come from
170  # distribute coordinator.
171  config._cluster_spec = cluster_spec
172  config._distribute_coordinator_mode = dc.CoordinatorMode.INDEPENDENT_WORKER
173  logging.info('RunConfig initialized for Distribute Coordinator with '
174               'INDEPENDENT_WORKER mode')
175
176
177def should_run_distribute_coordinator(config):
178  """Checks the config to see whether to run distribute coordinator."""
179  # pylint: disable=protected-access
180  if (not hasattr(config, '_distribute_coordinator_mode') or
181      config._distribute_coordinator_mode is None):
182    logging.info('Not using Distribute Coordinator.')
183    return False
184  if (not isinstance(config._distribute_coordinator_mode, six.string_types) or
185      config._distribute_coordinator_mode not in [
186          dc.CoordinatorMode.STANDALONE_CLIENT,
187          dc.CoordinatorMode.INDEPENDENT_WORKER
188      ]):
189    logging.warning('Unexpected distribute_coordinator_mode: %r',
190                    config._distribute_coordinator_mode)
191    return False
192  if not config.cluster_spec:
193    logging.warning('Running `train_and_evaluate` locally, ignoring '
194                    '`experimental_distribute_coordinator_mode`.')
195    return False
196  return True
197
198
199def train_and_evaluate(estimator, train_spec, eval_spec, executor_cls):
200  """Run distribute coordinator for Estimator's `train_and_evaluate`.
201
202  Args:
203    estimator: An `Estimator` instance to train and evaluate.
204    train_spec: A `TrainSpec` instance to specify the training specification.
205    eval_spec: A `EvalSpec` instance to specify the evaluation and export
206      specification.
207    executor_cls: the evaluation executor class of Estimator.
208
209  Raises:
210    ValueError: if `distribute_coordinator_mode` is None in RunConfig.
211  """
212  run_config = estimator.config
213  if not run_config._distribute_coordinator_mode:  # pylint: disable=protected-access
214    raise ValueError(
215        'Distribute coordinator mode is not specified in `RunConfig`.')
216
217  def _worker_fn(strategy):
218    """Function for worker task."""
219    local_estimator = copy.deepcopy(estimator)
220    # pylint: disable=protected-access
221    local_estimator._config._train_distribute = strategy
222    context = dc_context.get_current_worker_context()
223    _init_run_config_from_worker_context(local_estimator._config, context)
224    logging.info('Updated config: %s', str(vars(local_estimator._config)))
225    local_estimator._train_distribution = strategy
226    # pylint: enable=protected-access
227
228    # In the standalone client, we don't need to run hooks on all threads
229    # because logging hooks on all threads may be too much on the screen; also
230    # tensor passed to one hook can only be fetched with the graph where the
231    # tensor is defined. Other hooks such as checkpointing hooks will added by
232    # MonitoredTrainingSession.
233    # TODO(yuefengz): Is there a hook that does need to run on all threads in
234    # standalone client mode?
235    if (run_config._distribute_coordinator_mode ==  # pylint: disable=protected-access
236        dc.CoordinatorMode.INDEPENDENT_WORKER or context.is_chief):
237      hooks = list(train_spec.hooks)
238    else:
239      hooks = []
240
241    # Prevent estimator.train from calling distribute coordinator again. This
242    # function calls estimator.train which will use distribute coordinator path
243    # again if `_distribute_coordinator_mode` is set.
244    local_estimator._config._distribute_coordinator_mode = None  # pylint: disable=protected-access
245    local_estimator.train(
246        input_fn=train_spec.input_fn,
247        max_steps=train_spec.max_steps,
248        hooks=hooks)
249
250  def _eval_fn(strategy):
251    """Function for evaluator task."""
252    local_estimator = copy.deepcopy(estimator)
253    # pylint: disable=protected-access
254    local_estimator._config._eval_distribute = strategy
255    _init_run_config_from_worker_context(
256        local_estimator._config, dc_context.get_current_worker_context())
257    logging.info('Updated config: %s', str(vars(local_estimator._config)))
258    local_estimator._eval_distribution = strategy
259
260    # Prevent estimator.evaluate from calling distribute coordinator again. This
261    # function calls estimator.evaluate which will use distribute coordinator
262    # path again if `_distribute_coordinator_mode` is set.
263    local_estimator._config._distribute_coordinator_mode = None  # pylint: disable=protected-access
264
265    executor = executor_cls(local_estimator, train_spec, eval_spec)
266    executor._start_continuous_evaluation()
267    # pylint: enable=protected-access
268
269  # pylint: disable=protected-access
270  if (run_config._distribute_coordinator_mode ==
271      dc.CoordinatorMode.STANDALONE_CLIENT):
272    cluster_spec = run_config.cluster_spec
273    assert cluster_spec
274  else:
275    # The cluster_spec comes from TF_CONFIG environment variable if it is
276    # INDEPENDENT_WORKER mode.
277    cluster_spec = None
278
279  dc.run_distribute_coordinator(
280      _worker_fn,
281      run_config.train_distribute,
282      _eval_fn,
283      run_config.eval_distribute,
284      mode=run_config._distribute_coordinator_mode,
285      cluster_spec=cluster_spec,
286      session_config=run_config.session_config)
287
288
289# TODO(yuefengz): maybe merge the following two functions?
290# pylint: disable=protected-access
291def estimator_train(estimator, train_distributed_fn, hooks):
292  """Run distribute coordinator for Estimator's `train` method."""
293  assert estimator._config._distribute_coordinator_mode
294  run_config = estimator._config
295  assert estimator._config.cluster_spec
296  cluster_spec = multi_worker_util.normalize_cluster_spec(
297      estimator._config.cluster_spec)
298  assert estimator._config._train_distribute
299
300  if 'evaluator' in cluster_spec.jobs:
301    raise ValueError("'evaluator' job is not supported if you don't use "
302                     '`train_and_evaluate`')
303
304  if (estimator._config._distribute_coordinator_mode !=  # pylint: disable=protected-access
305      dc.CoordinatorMode.STANDALONE_CLIENT):
306    raise ValueError('Only `STANDALONE_CLIENT` mode is supported when you call '
307                     '`estimator.train`')
308
309  if estimator._config._train_distribute.extended.experimental_between_graph:
310    # TODO(yuefengz): remove this limitation once we figure out how to merge
311    # return values from `_worker_fn`s.
312    raise ValueError('`Estimator.train` API is not supported for %s with '
313                     '`STANDALONE_CLIENT` mode.' %
314                     estimator._config._train_distribute.__class__.__name__)
315
316  def _worker_fn(strategy):
317    """Function for worker task."""
318    local_estimator = copy.deepcopy(estimator)
319    local_estimator._config._train_distribute = strategy
320    context = dc_context.get_current_worker_context()
321    _init_run_config_from_worker_context(local_estimator._config, context)
322    logging.info('Updated config: %s', str(vars(local_estimator._config)))
323    local_estimator._train_distribution = strategy
324
325    if context.is_chief:
326      chief_hooks = hooks
327    else:
328      chief_hooks = []
329    train_distributed_fn(local_estimator, strategy, chief_hooks)
330    return local_estimator
331
332  return dc.run_distribute_coordinator(
333      _worker_fn,
334      estimator._config.train_distribute,
335      mode=run_config._distribute_coordinator_mode,
336      cluster_spec=cluster_spec,
337      session_config=run_config.session_config)
338
339
340def estimator_evaluate(estimator, evaluate_distributed_fn, hooks):
341  """Run distribute coordinator for Estimator's `evaluate` method."""
342  assert estimator._config._distribute_coordinator_mode
343  run_config = estimator._config
344  assert estimator._config.cluster_spec
345  cluster_spec = multi_worker_util.normalize_cluster_spec(
346      estimator._config.cluster_spec)
347  assert estimator._config._eval_distribute
348
349  if 'evaluator' in cluster_spec.jobs:
350    raise ValueError("'evaluator' job is not supported if you don't use "
351                     '`train_and_evaluate`')
352
353  if (estimator._config._distribute_coordinator_mode !=
354      dc.CoordinatorMode.STANDALONE_CLIENT):
355    raise ValueError('Only `STANDALONE_CLIENT` mode is supported when you call '
356                     '`Estimator.evaluate`')
357
358  if estimator._config._eval_distribute.extended.experimental_between_graph:
359    # TODO(yuefengz): remove this limitation once we figure out how to merge
360    # return values from `_worker_fn`s.
361    raise ValueError('`Estimator.evaluate` API is not supported for %s with '
362                     '`STANDALONE_CLIENT` mode.' %
363                     estimator._config._eval_distribute.__class__.__name__)
364
365  def _worker_fn(strategy):
366    """Function for evaluation."""
367    local_estimator = copy.deepcopy(estimator)
368    local_estimator._config._eval_distribute = strategy
369    context = dc_context.get_current_worker_context()
370    _init_run_config_from_worker_context(local_estimator._config, context)
371    logging.info('Updated config: %s', str(vars(local_estimator._config)))
372    local_estimator._eval_distribution = strategy
373
374    if context.is_chief:
375      chief_hooks = hooks
376    else:
377      chief_hooks = []
378    return evaluate_distributed_fn(local_estimator, strategy, chief_hooks)
379
380  return dc.run_distribute_coordinator(
381      _worker_fn,
382      estimator._config.eval_distribute,
383      mode=run_config._distribute_coordinator_mode,
384      cluster_spec=cluster_spec,
385      session_config=run_config.session_config)
386
387# pylint: enable=protected-access
388