1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for distributed_table.""" 16 17import copy 18import os 19 20from absl.testing import parameterized 21 22from tensorflow.python.compat import v2_compat 23from tensorflow.python.data.ops import dataset_ops 24from tensorflow.python.distribute import combinations 25from tensorflow.python.distribute import device_util 26from tensorflow.python.distribute import multi_process_runner 27from tensorflow.python.distribute import multi_worker_test_base 28from tensorflow.python.distribute import parameter_server_strategy_v2 29from tensorflow.python.distribute import ps_values 30from tensorflow.python.distribute.coordinator import cluster_coordinator as coordinator_lib 31from tensorflow.python.distribute.coordinator import coordinator_context 32from tensorflow.python.eager import def_function 33from tensorflow.python.eager import test 34from tensorflow.python.framework import constant_op 35from tensorflow.python.framework import dtypes 36from tensorflow.python.framework import tensor_spec 37from tensorflow.python.keras.saving import save as keras_save 38from tensorflow.python.module import module 39from tensorflow.python.ops import lookup_ops 40from tensorflow.python.ops import math_ops 41from tensorflow.python.saved_model import load as tf_load 42from tensorflow.python.saved_model import save as tf_save 43 44 45source_combination = combinations.combine(source=["textfile", "keyvaluetensor"]) 46 47source_and_load_combination = combinations.combine( 48 source=["textfile", "keyvaluetensor"], load=["tf_load", "keras_load"]) 49 50 51class DistributedTableTest(test.TestCase, parameterized.TestCase): 52 53 @classmethod 54 def setUpClass(cls): 55 super(DistributedTableTest, cls).setUpClass() 56 cls.cluster = multi_worker_test_base.create_multi_process_cluster( 57 num_workers=2, num_ps=3, rpc_layer="grpc") 58 cls.cluster_resolver = cls.cluster.cluster_resolver 59 60 @classmethod 61 def tearDownClass(cls): 62 super(DistributedTableTest, cls).tearDownClass() 63 cls.cluster.stop() 64 65 def make_initializer(self, init_source, vals): 66 if init_source == "textfile": 67 file = os.path.join(self.get_temp_dir(), "text_file_initializer") 68 with open(file, "w") as f: 69 f.write("\n".join(str(v) for v in vals) + "\n") 70 return lookup_ops.TextFileInitializer( 71 filename=file, 72 key_dtype=dtypes.int64, 73 key_index=lookup_ops.TextFileIndex.LINE_NUMBER, 74 value_dtype=dtypes.int64, 75 value_index=lookup_ops.TextFileIndex.WHOLE_LINE) 76 elif init_source == "keyvaluetensor": 77 keys_tensor = constant_op.constant( 78 list(range(len(vals))), dtype=dtypes.int64) 79 vals_tensor = constant_op.constant(vals) 80 return lookup_ops.KeyValueTensorInitializer(keys_tensor, vals_tensor) 81 else: 82 raise ValueError("Unrecognized init_source: " + init_source) 83 84 def createStaticHashTable(self, 85 init_source=None, 86 vals=None, 87 default_value=None, 88 initializer=None): 89 if not initializer: 90 initializer = self.make_initializer(init_source, vals) 91 return lookup_ops.StaticHashTable( 92 initializer=initializer, default_value=default_value) 93 94 def makeDatasetFromTensorWithoutUsingResource(self, input_context, tensor): 95 """Returns a dataset made from `tensor`. To be called in a dataset_fn.""" 96 global_batch_size = 24 97 batch_size = input_context.get_per_replica_batch_size(global_batch_size) 98 dataset = dataset_ops.DatasetV2.from_tensors(tensor).repeat().batch( 99 batch_size, drop_remainder=True) 100 dataset = dataset.shard(input_context.num_input_pipelines, 101 input_context.input_pipeline_id) 102 dataset = dataset.prefetch(2) # This prefetches 2 batches per device. 103 return dataset 104 105 @combinations.generate(source_combination) 106 def testCreateDistributedTableInScope(self, source): 107 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 108 self.cluster_resolver) 109 110 coordinator_lib.ClusterCoordinator(strategy=strategy) 111 112 with strategy.scope(): 113 lookuptable = self.createStaticHashTable( 114 init_source=source, vals=[0, 1, 2], default_value=-2) 115 116 self.assertIsInstance(lookuptable, ps_values.DistributedTable) 117 self.assertEqual(self.evaluate(lookuptable.size()), 3) 118 119 # Lookup on the coordinator. 120 output = lookuptable.lookup( 121 constant_op.constant([0, 1, -1], dtype=dtypes.int64)) 122 self.assertAllEqual([0, 1, -2], output) 123 self.assertEqual(lookuptable.size(), 3) 124 125 @combinations.generate(source_combination) 126 def testCopyDistributedTable(self, source): 127 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 128 self.cluster_resolver) 129 130 coordinator_lib.ClusterCoordinator(strategy=strategy) 131 132 with strategy.scope(): 133 lookuptable = self.createStaticHashTable( 134 init_source=source, vals=[0, 1, 2], default_value=-2) 135 136 new_table = copy.copy(lookuptable) 137 # No new coordinator instance or distributed tables are created. 138 self.assertDictEqual(lookuptable.__dict__, new_table.__dict__) 139 140 @combinations.generate(source_combination) 141 def testCreateLookupInDatasetFnUnderScope(self, source): 142 # TODO(wxinyi): Warn the user of the inefficiency of this workflow (i.e. 143 # creating `StaticHashTable` inside a `@tf.function`-wrapped `dataset_fn` to 144 # be distributed with `distribute_datasets_from_function` and 145 # `create_per_worker_dataset`. 146 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 147 self.cluster_resolver) 148 149 coordinator = coordinator_lib.ClusterCoordinator(strategy=strategy) 150 151 with strategy.scope(): 152 153 def dataset_fn(input_context): 154 some_out_of_range_tensor = constant_op.constant(10, dtype=dtypes.int64) 155 lookuptable = self.createStaticHashTable( 156 init_source=source, vals=[0, 1, 2], default_value=-2) 157 158 self.assertNotIsInstance(lookuptable, ps_values.DistributedTable) 159 160 generation_tensor = lookuptable.lookup(some_out_of_range_tensor) 161 dataset = self.makeDatasetFromTensorWithoutUsingResource( 162 input_context, generation_tensor) 163 return dataset 164 165 @def_function.function 166 def per_worker_dataset_fn(): 167 return strategy.distribute_datasets_from_function(dataset_fn) 168 169 per_worker_dataset = coordinator.create_per_worker_dataset( 170 per_worker_dataset_fn) 171 per_worker_iterator = iter(per_worker_dataset) 172 173 @def_function.function 174 def worker_fn(iterator): 175 return math_ops.reduce_sum(next(iterator)) 176 177 result = [] 178 for _ in range(10): 179 result.append( 180 coordinator.schedule(worker_fn, args=(per_worker_iterator,))) 181 182 for r in result: 183 returned_input = r.fetch() 184 self.assertAllClose(-48, returned_input) 185 186 @combinations.generate(source_combination) 187 def testAccessingResourceHandleInDatasetFnWithoutMap(self, source): 188 189 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 190 self.cluster_resolver) 191 192 coordinator = coordinator_lib.ClusterCoordinator(strategy=strategy) 193 194 with strategy.scope(): 195 lookuptable = self.createStaticHashTable( 196 init_source=source, vals=[0, 1, 2], default_value=-2) 197 198 def dataset_fn(input_context): 199 some_out_of_range_tensor = constant_op.constant(10, dtype=dtypes.int64) 200 201 self.assertIsInstance(lookuptable, ps_values.DistributedTable) 202 203 generation_tensor = lookuptable.lookup(some_out_of_range_tensor) 204 dataset = self.makeDatasetFromTensorWithoutUsingResource( 205 input_context, generation_tensor) 206 return dataset 207 208 @def_function.function 209 def per_worker_dataset_fn(): 210 return strategy.distribute_datasets_from_function(dataset_fn) 211 212 per_worker_dataset = coordinator.create_per_worker_dataset( 213 per_worker_dataset_fn) 214 per_worker_iterator = iter(per_worker_dataset) 215 216 @def_function.function 217 def worker_fn(iterator): 218 return math_ops.reduce_sum(next(iterator)) 219 220 result = [] 221 for _ in range(10): 222 result.append( 223 coordinator.schedule(worker_fn, args=(per_worker_iterator,))) 224 225 for r in result: 226 returned_input = r.fetch() 227 self.assertAllClose(-48, returned_input) 228 229 @combinations.generate(source_combination) 230 def testAccessingResourceHandleInDatasetFnWithMapFnDefinedInside( 231 self, source): 232 233 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 234 self.cluster_resolver) 235 236 coordinator = coordinator_lib.ClusterCoordinator(strategy=strategy) 237 238 with strategy.scope(): 239 lookuptable = self.createStaticHashTable( 240 init_source=source, vals=[0, 1, 2], default_value=-2) 241 242 def dataset_fn(input_context): 243 generation_tensor = constant_op.constant([0, 1, 3], dtype=dtypes.int64) 244 dataset = self.makeDatasetFromTensorWithoutUsingResource( 245 input_context, generation_tensor) 246 dataset = dataset.map(lookuptable.lookup) 247 return dataset 248 249 @def_function.function 250 def per_worker_dataset_fn(): 251 return strategy.distribute_datasets_from_function(dataset_fn) 252 253 per_worker_dataset = coordinator.create_per_worker_dataset( 254 per_worker_dataset_fn) 255 per_worker_iterator = iter(per_worker_dataset) 256 257 @def_function.function 258 def worker_fn(iterator): 259 return math_ops.reduce_sum(next(iterator)) 260 261 result = [] 262 for _ in range(10): 263 # batch_size == 24 and each input is [0, 1, -2] 264 result.append( 265 coordinator.schedule(worker_fn, args=(per_worker_iterator,))) 266 267 for r in result: 268 returned_input = r.fetch() 269 self.assertAllClose(-24, returned_input) 270 271 @combinations.generate(source_combination) 272 def testAccessingResourceHandleInDatasetFnWithMapFnDefinedOutside( 273 self, source): 274 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 275 self.cluster_resolver) 276 277 coordinator = coordinator_lib.ClusterCoordinator(strategy=strategy) 278 279 with strategy.scope(): 280 lookuptable = self.createStaticHashTable( 281 init_source=source, vals=[0, 1, 2], default_value=-2) 282 283 def map_fn(vals): 284 return lookuptable.lookup(vals) 285 286 def dataset_fn(input_context): 287 generation_tensor = constant_op.constant([0, 1, 3], dtype=dtypes.int64) 288 dataset = self.makeDatasetFromTensorWithoutUsingResource( 289 input_context, generation_tensor) 290 dataset = dataset.map(map_fn) 291 return dataset 292 293 @def_function.function 294 def per_worker_dataset_fn(): 295 return strategy.distribute_datasets_from_function(dataset_fn) 296 297 per_worker_dataset = coordinator.create_per_worker_dataset( 298 per_worker_dataset_fn) 299 per_worker_iterator = iter(per_worker_dataset) 300 301 @def_function.function 302 def worker_fn(iterator): 303 return math_ops.reduce_sum(next(iterator)) 304 305 result = [] 306 for _ in range(10): 307 # batch_size == 24 and each input is [0, 1, -2] 308 result.append( 309 coordinator.schedule(worker_fn, args=(per_worker_iterator,))) 310 311 for r in result: 312 returned_input = r.fetch() 313 self.assertAllClose(-24, returned_input) 314 315 class Model(module.Module): 316 317 def __init__(self, init_source, filepath): 318 vals = [0, 1, 2] 319 if init_source == "textfile": 320 321 with open(filepath, "w") as f: 322 f.write("\n".join(str(v) for v in vals) + "\n") 323 324 self.initializer = lookup_ops.TextFileInitializer( 325 filepath, dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER, 326 dtypes.int64, lookup_ops.TextFileIndex.WHOLE_LINE) 327 else: 328 keys_tensor = constant_op.constant( 329 list(range(len(vals))), dtype=dtypes.int64) 330 vals_tensor = constant_op.constant(vals) 331 self.initializer = lookup_ops.KeyValueTensorInitializer( 332 keys_tensor, vals_tensor) 333 334 self.table = lookup_ops.StaticHashTable( 335 self.initializer, default_value=-2) 336 337 @def_function.function( 338 input_signature=[tensor_spec.TensorSpec(None, dtypes.int64)]) 339 def use_table(self, x): 340 return self.table.lookup(x) 341 342 def verifyWorkerLocalInstance(self, coordinator, model): 343 # assert capturing a worker-local resource on each worker 344 for worker in coordinator._cluster.workers: 345 with coordinator_context.with_dispatch_context(worker): 346 captures = model.use_table.get_concrete_function().captured_inputs 347 resource_capture = [t for t in captures if t.dtype == dtypes.resource] 348 self.assertNotEmpty(resource_capture) 349 for capture in resource_capture: 350 self.assertEqual( 351 capture.device, 352 device_util.canonicalize("/CPU:0", default=worker.device_name)) 353 354 @combinations.generate(source_combination) 355 def testInModelAndCapture(self, source): 356 357 file_path = os.path.join(self.get_temp_dir(), "text_file_initializer") 358 359 model = self.Model(source, file_path) 360 func_captures = model.use_table.get_concrete_function( 361 ).graph.external_captures 362 self.assertLen(func_captures, 2) 363 self.assertTrue( 364 any(model.table.resource_handle is t for t in func_captures)) 365 deferred_captures = model.use_table.get_concrete_function( 366 ).graph.deferred_external_captures 367 self.assertEmpty(deferred_captures) 368 369 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 370 self.cluster_resolver) 371 coordinator = coordinator_lib.ClusterCoordinator(strategy) 372 with strategy.scope(): 373 distributed_model = self.Model("value", file_path) 374 func_captures = distributed_model.use_table.get_concrete_function( 375 ).graph.external_captures 376 # One less external_capture, since the table handle becomes a closure in the 377 # deferred_external_capture 378 self.assertLen(func_captures, 1) 379 self.assertFalse( 380 any(model.table.resource_handle is t for t in func_captures)) 381 deferred_captures = distributed_model.use_table.get_concrete_function( 382 ).graph.deferred_external_captures 383 self.assertNotEmpty(deferred_captures) 384 385 self.verifyWorkerLocalInstance(coordinator, distributed_model) 386 387 @combinations.generate(source_and_load_combination) 388 def testDistributeTableSaveAndServe(self, load, source): 389 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 390 self.cluster_resolver) 391 file_path = os.path.join(self.get_temp_dir(), "text_file_initializer") 392 with strategy.scope(): 393 model = self.Model(source, file_path) 394 395 model_dir = self.get_temp_dir() 396 tf_save.save(model, model_dir) 397 398 if load == "tf_load": 399 load_fn = tf_load.load 400 else: 401 load_fn = keras_save.load_model 402 403 loaded_without_strategy = load_fn(model_dir) 404 loaded_func_captures_without_strategy = ( 405 loaded_without_strategy.use_table.get_concrete_function().graph 406 .external_captures) 407 loaded_func_deferred_captures_without_strategy = ( 408 loaded_without_strategy.use_table.get_concrete_function().graph 409 .deferred_external_captures) 410 self.assertLen(loaded_func_captures_without_strategy, 2) 411 self.assertEmpty(loaded_func_deferred_captures_without_strategy) 412 413 self.assertAllEqual( 414 loaded_without_strategy.use_table( 415 constant_op.constant([0, 1, 3], dtype=dtypes.int64)), [0, 1, -2]) 416 417 @combinations.generate(source_and_load_combination) 418 def testDistributeTableSaveAndLoadUnderStrategy(self, load, source): 419 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 420 self.cluster_resolver) 421 coordinator = coordinator_lib.ClusterCoordinator(strategy) 422 file_path = os.path.join(self.get_temp_dir(), "text_file_initializer") 423 with strategy.scope(): 424 model = self.Model(source, file_path) 425 model_dir = self.get_temp_dir() 426 tf_save.save(model, model_dir) 427 428 if load == "tf_load": 429 load_fn = tf_load.load 430 else: 431 load_fn = keras_save.load_model 432 433 with strategy.scope(): 434 loaded = load_fn(model_dir) 435 436 loaded_func_captures = ( 437 loaded.use_table.get_concrete_function().graph.external_captures) 438 loaded_func_deferred_captures = ( 439 loaded.use_table.get_concrete_function().graph 440 .deferred_external_captures) 441 # Compared with loading without strategy, there is one less 442 # external_capture, since the captured table handle has been swapped to a 443 # closure in the deferred_external_capture 444 self.assertLen(loaded_func_captures, 1) 445 self.assertNotEmpty(loaded_func_deferred_captures) 446 447 self.assertIsInstance(loaded.table, ps_values.DistributedTable) 448 449 self.assertLen([ 450 t for t in loaded.use_table.get_concrete_function().captured_inputs 451 if t.dtype == dtypes.resource 452 ], 1) 453 454 self.verifyWorkerLocalInstance(coordinator, loaded) 455 456 457if __name__ == "__main__": 458 v2_compat.enable_v2_behavior() 459 multi_process_runner.test_main() 460