1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Training utilities for Estimator to use Distribute Coordinator.""" 16 17import copy 18 19import six 20 21from tensorflow.python.distribute import distribute_coordinator as dc 22from tensorflow.python.distribute import distribute_coordinator_context as dc_context 23from tensorflow.python.distribute import multi_worker_util 24from tensorflow.python.platform import tf_logging as logging 25from tensorflow.python.training import server_lib 26 27# pylint: disable=protected-access 28CHIEF = dc._TaskType.CHIEF 29EVALUATOR = dc._TaskType.EVALUATOR 30PS = dc._TaskType.PS 31WORKER = dc._TaskType.WORKER 32 33# pylint: enable=protected-access 34 35 36def _count_ps(cluster_spec): 37 """Counts the number of parameter servers in cluster_spec.""" 38 if not cluster_spec: 39 raise RuntimeError( 40 'Internal error: `_count_ps` does not expect empty cluster_spec.') 41 42 return len(cluster_spec.as_dict().get(PS, [])) 43 44 45def _count_worker(cluster_spec, chief_task_type): 46 """Counts the number of workers (including chief) in cluster_spec.""" 47 if not cluster_spec: 48 raise RuntimeError( 49 'Internal error: `_count_worker` does not expect empty cluster_spec.') 50 51 return (len(cluster_spec.as_dict().get(WORKER, [])) + len( 52 cluster_spec.as_dict().get(chief_task_type, []))) 53 54 55def _get_global_id(cluster_spec, task_type, task_id, chief_task_type): 56 """Returns the global id of the given task type in a cluster.""" 57 if not task_type: 58 return 0 59 60 # Sort task names in cluster by "chief"/"master", "evaluator", "worker" 61 # and "ps". More details can be found at the documentation of 62 # `tf.estimator.RunConfig.global_id_in_cluster`. 63 task_type_ordered_list = [] 64 if chief_task_type in cluster_spec.jobs: 65 task_type_ordered_list = [chief_task_type] 66 task_type_ordered_list.extend([ 67 t for t in sorted(cluster_spec.jobs) if t != chief_task_type and t != PS 68 ]) 69 if PS in cluster_spec.jobs: 70 task_type_ordered_list.append(PS) 71 72 # Find the right global_id for current task. 73 next_global_id = 0 74 for t in task_type_ordered_list: 75 if t == task_type: 76 return next_global_id + task_id 77 # `cluster_spec.job_tasks` returns all task addresses of type `t`. 78 next_global_id += len(cluster_spec.job_tasks(t)) 79 80 # It is unexpected that it passes through all task_types in 81 # `task_type_ordered_list`. 82 raise RuntimeError('Internal Error: `task_type` ({}) is not in ' 83 'cluster_spec ({}).'.format(task_type, cluster_spec)) 84 85 86def _init_run_config_from_worker_context(config, worker_context): 87 """Initializes run config from distribute coordinator's worker context.""" 88 89 # pylint: disable=protected-access 90 config._service = None 91 config._cluster_spec = worker_context.cluster_spec 92 config._task_type = worker_context.task_type 93 config._task_id = worker_context.task_id 94 config._evaluation_master = worker_context.master_target 95 config._master = worker_context.master_target 96 config._is_chief = worker_context.is_chief 97 98 if config._cluster_spec: 99 # Distributed mode. 100 if config._task_type != EVALUATOR: 101 102 config._num_ps_replicas = _count_ps(config._cluster_spec) 103 config._num_worker_replicas = _count_worker( 104 config._cluster_spec, chief_task_type=CHIEF) 105 config._global_id_in_cluster = _get_global_id( 106 config._cluster_spec, 107 config._task_type, 108 config._task_id, 109 chief_task_type=CHIEF) 110 else: 111 # Evaluator task should not be aware of the other tasks. 112 config._cluster_spec = server_lib.ClusterSpec({}) 113 config._num_ps_replicas = 0 114 config._num_worker_replicas = 0 115 config._global_id_in_cluster = None # undefined 116 else: 117 # Local mode. 118 config._global_id_in_cluster = 0 119 config._num_ps_replicas = 0 120 config._num_worker_replicas = 1 121 122 123def init_run_config(config, tf_config): 124 """Initializes RunConfig for distribution strategies.""" 125 # pylint: disable=protected-access 126 if (config._experimental_distribute and 127 config._experimental_distribute.train_distribute): 128 if config._train_distribute: 129 raise ValueError('Either `train_distribute` or' 130 '`experimental_distribute.train_distribute` can be set.') 131 config._train_distribute = config._experimental_distribute.train_distribute 132 133 if (config._experimental_distribute and 134 config._experimental_distribute.eval_distribute): 135 if config._eval_distribute: 136 raise ValueError('Either `eval_distribute` or' 137 '`experimental_distribute.eval_distribute` can be set.') 138 config._eval_distribute = config._experimental_distribute.eval_distribute 139 140 cluster_spec = server_lib.ClusterSpec(tf_config.get('cluster', {})) 141 config._init_distributed_setting_from_environment_var({}) 142 143 # Use distribute coordinator with STANDALONE_CLIENT mode if 144 # `experimental_distribute.remote_cluster` is set. 145 if (config._train_distribute and config._experimental_distribute and 146 config._experimental_distribute.remote_cluster): 147 if cluster_spec: 148 raise ValueError('Cannot set both "cluster_spec" of TF_CONFIG and ' 149 '`experimental_distribute.remote_cluster`') 150 config._distribute_coordinator_mode = dc.CoordinatorMode.STANDALONE_CLIENT 151 config._cluster_spec = config._experimental_distribute.remote_cluster 152 logging.info('RunConfig initialized for Distribute Coordinator with ' 153 'STANDALONE_CLIENT mode') 154 return 155 156 # Don't use distribute coordinator if it is local training or cluster has a 157 # MASTER job or `train_distribute` is not specified. 158 if (not cluster_spec or 'master' in cluster_spec.jobs or 159 not config._train_distribute): 160 config._distribute_coordinator_mode = None 161 config._init_distributed_setting_from_environment_var(tf_config) 162 config._maybe_overwrite_session_config_for_distributed_training() 163 logging.info('Not using Distribute Coordinator.') 164 return 165 166 # Use distribute coordinator with INDEPENDENT_WORKER mode otherwise. 167 assert tf_config 168 169 # Set the cluster_spec only since the distributed setting will come from 170 # distribute coordinator. 171 config._cluster_spec = cluster_spec 172 config._distribute_coordinator_mode = dc.CoordinatorMode.INDEPENDENT_WORKER 173 logging.info('RunConfig initialized for Distribute Coordinator with ' 174 'INDEPENDENT_WORKER mode') 175 176 177def should_run_distribute_coordinator(config): 178 """Checks the config to see whether to run distribute coordinator.""" 179 # pylint: disable=protected-access 180 if (not hasattr(config, '_distribute_coordinator_mode') or 181 config._distribute_coordinator_mode is None): 182 logging.info('Not using Distribute Coordinator.') 183 return False 184 if (not isinstance(config._distribute_coordinator_mode, six.string_types) or 185 config._distribute_coordinator_mode not in [ 186 dc.CoordinatorMode.STANDALONE_CLIENT, 187 dc.CoordinatorMode.INDEPENDENT_WORKER 188 ]): 189 logging.warning('Unexpected distribute_coordinator_mode: %r', 190 config._distribute_coordinator_mode) 191 return False 192 if not config.cluster_spec: 193 logging.warning('Running `train_and_evaluate` locally, ignoring ' 194 '`experimental_distribute_coordinator_mode`.') 195 return False 196 return True 197 198 199def train_and_evaluate(estimator, train_spec, eval_spec, executor_cls): 200 """Run distribute coordinator for Estimator's `train_and_evaluate`. 201 202 Args: 203 estimator: An `Estimator` instance to train and evaluate. 204 train_spec: A `TrainSpec` instance to specify the training specification. 205 eval_spec: A `EvalSpec` instance to specify the evaluation and export 206 specification. 207 executor_cls: the evaluation executor class of Estimator. 208 209 Raises: 210 ValueError: if `distribute_coordinator_mode` is None in RunConfig. 211 """ 212 run_config = estimator.config 213 if not run_config._distribute_coordinator_mode: # pylint: disable=protected-access 214 raise ValueError( 215 'Distribute coordinator mode is not specified in `RunConfig`.') 216 217 def _worker_fn(strategy): 218 """Function for worker task.""" 219 local_estimator = copy.deepcopy(estimator) 220 # pylint: disable=protected-access 221 local_estimator._config._train_distribute = strategy 222 context = dc_context.get_current_worker_context() 223 _init_run_config_from_worker_context(local_estimator._config, context) 224 logging.info('Updated config: %s', str(vars(local_estimator._config))) 225 local_estimator._train_distribution = strategy 226 # pylint: enable=protected-access 227 228 # In the standalone client, we don't need to run hooks on all threads 229 # because logging hooks on all threads may be too much on the screen; also 230 # tensor passed to one hook can only be fetched with the graph where the 231 # tensor is defined. Other hooks such as checkpointing hooks will added by 232 # MonitoredTrainingSession. 233 # TODO(yuefengz): Is there a hook that does need to run on all threads in 234 # standalone client mode? 235 if (run_config._distribute_coordinator_mode == # pylint: disable=protected-access 236 dc.CoordinatorMode.INDEPENDENT_WORKER or context.is_chief): 237 hooks = list(train_spec.hooks) 238 else: 239 hooks = [] 240 241 # Prevent estimator.train from calling distribute coordinator again. This 242 # function calls estimator.train which will use distribute coordinator path 243 # again if `_distribute_coordinator_mode` is set. 244 local_estimator._config._distribute_coordinator_mode = None # pylint: disable=protected-access 245 local_estimator.train( 246 input_fn=train_spec.input_fn, 247 max_steps=train_spec.max_steps, 248 hooks=hooks) 249 250 def _eval_fn(strategy): 251 """Function for evaluator task.""" 252 local_estimator = copy.deepcopy(estimator) 253 # pylint: disable=protected-access 254 local_estimator._config._eval_distribute = strategy 255 _init_run_config_from_worker_context( 256 local_estimator._config, dc_context.get_current_worker_context()) 257 logging.info('Updated config: %s', str(vars(local_estimator._config))) 258 local_estimator._eval_distribution = strategy 259 260 # Prevent estimator.evaluate from calling distribute coordinator again. This 261 # function calls estimator.evaluate which will use distribute coordinator 262 # path again if `_distribute_coordinator_mode` is set. 263 local_estimator._config._distribute_coordinator_mode = None # pylint: disable=protected-access 264 265 executor = executor_cls(local_estimator, train_spec, eval_spec) 266 executor._start_continuous_evaluation() 267 # pylint: enable=protected-access 268 269 # pylint: disable=protected-access 270 if (run_config._distribute_coordinator_mode == 271 dc.CoordinatorMode.STANDALONE_CLIENT): 272 cluster_spec = run_config.cluster_spec 273 assert cluster_spec 274 else: 275 # The cluster_spec comes from TF_CONFIG environment variable if it is 276 # INDEPENDENT_WORKER mode. 277 cluster_spec = None 278 279 dc.run_distribute_coordinator( 280 _worker_fn, 281 run_config.train_distribute, 282 _eval_fn, 283 run_config.eval_distribute, 284 mode=run_config._distribute_coordinator_mode, 285 cluster_spec=cluster_spec, 286 session_config=run_config.session_config) 287 288 289# TODO(yuefengz): maybe merge the following two functions? 290# pylint: disable=protected-access 291def estimator_train(estimator, train_distributed_fn, hooks): 292 """Run distribute coordinator for Estimator's `train` method.""" 293 assert estimator._config._distribute_coordinator_mode 294 run_config = estimator._config 295 assert estimator._config.cluster_spec 296 cluster_spec = multi_worker_util.normalize_cluster_spec( 297 estimator._config.cluster_spec) 298 assert estimator._config._train_distribute 299 300 if 'evaluator' in cluster_spec.jobs: 301 raise ValueError("'evaluator' job is not supported if you don't use " 302 '`train_and_evaluate`') 303 304 if (estimator._config._distribute_coordinator_mode != # pylint: disable=protected-access 305 dc.CoordinatorMode.STANDALONE_CLIENT): 306 raise ValueError('Only `STANDALONE_CLIENT` mode is supported when you call ' 307 '`estimator.train`') 308 309 if estimator._config._train_distribute.extended.experimental_between_graph: 310 # TODO(yuefengz): remove this limitation once we figure out how to merge 311 # return values from `_worker_fn`s. 312 raise ValueError('`Estimator.train` API is not supported for %s with ' 313 '`STANDALONE_CLIENT` mode.' % 314 estimator._config._train_distribute.__class__.__name__) 315 316 def _worker_fn(strategy): 317 """Function for worker task.""" 318 local_estimator = copy.deepcopy(estimator) 319 local_estimator._config._train_distribute = strategy 320 context = dc_context.get_current_worker_context() 321 _init_run_config_from_worker_context(local_estimator._config, context) 322 logging.info('Updated config: %s', str(vars(local_estimator._config))) 323 local_estimator._train_distribution = strategy 324 325 if context.is_chief: 326 chief_hooks = hooks 327 else: 328 chief_hooks = [] 329 train_distributed_fn(local_estimator, strategy, chief_hooks) 330 return local_estimator 331 332 return dc.run_distribute_coordinator( 333 _worker_fn, 334 estimator._config.train_distribute, 335 mode=run_config._distribute_coordinator_mode, 336 cluster_spec=cluster_spec, 337 session_config=run_config.session_config) 338 339 340def estimator_evaluate(estimator, evaluate_distributed_fn, hooks): 341 """Run distribute coordinator for Estimator's `evaluate` method.""" 342 assert estimator._config._distribute_coordinator_mode 343 run_config = estimator._config 344 assert estimator._config.cluster_spec 345 cluster_spec = multi_worker_util.normalize_cluster_spec( 346 estimator._config.cluster_spec) 347 assert estimator._config._eval_distribute 348 349 if 'evaluator' in cluster_spec.jobs: 350 raise ValueError("'evaluator' job is not supported if you don't use " 351 '`train_and_evaluate`') 352 353 if (estimator._config._distribute_coordinator_mode != 354 dc.CoordinatorMode.STANDALONE_CLIENT): 355 raise ValueError('Only `STANDALONE_CLIENT` mode is supported when you call ' 356 '`Estimator.evaluate`') 357 358 if estimator._config._eval_distribute.extended.experimental_between_graph: 359 # TODO(yuefengz): remove this limitation once we figure out how to merge 360 # return values from `_worker_fn`s. 361 raise ValueError('`Estimator.evaluate` API is not supported for %s with ' 362 '`STANDALONE_CLIENT` mode.' % 363 estimator._config._eval_distribute.__class__.__name__) 364 365 def _worker_fn(strategy): 366 """Function for evaluation.""" 367 local_estimator = copy.deepcopy(estimator) 368 local_estimator._config._eval_distribute = strategy 369 context = dc_context.get_current_worker_context() 370 _init_run_config_from_worker_context(local_estimator._config, context) 371 logging.info('Updated config: %s', str(vars(local_estimator._config))) 372 local_estimator._eval_distribution = strategy 373 374 if context.is_chief: 375 chief_hooks = hooks 376 else: 377 chief_hooks = [] 378 return evaluate_distributed_fn(local_estimator, strategy, chief_hooks) 379 380 return dc.run_distribute_coordinator( 381 _worker_fn, 382 estimator._config.eval_distribute, 383 mode=run_config._distribute_coordinator_mode, 384 cluster_spec=cluster_spec, 385 session_config=run_config.session_config) 386 387# pylint: enable=protected-access 388