1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_COLLECTIVE_EXECUTOR_MGR_H_ 16 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_COLLECTIVE_EXECUTOR_MGR_H_ 17 18 #include "tensorflow/core/common_runtime/collective_executor_mgr.h" 19 #include "tensorflow/core/framework/collective.h" 20 21 namespace tensorflow { 22 class CollectiveParamResolverDistributed; 23 class ConfigProto; 24 class DeviceMgr; 25 class DeviceResolverDistributed; 26 class WorkerCacheInterface; 27 class StepSequenceRequest; 28 class StepSequenceResponse; 29 30 // An implementation of CollectiveExecutorMgr for a distributed environment 31 // that uses WorkerInterface::RecvBufAsync to route data transfers over RPCs. 32 // 33 // In some execution environments it may be possible to implement a 34 // higher-performance solution and use it in place of this class. 35 class RpcCollectiveExecutorMgr : public CollectiveExecutorMgr { 36 public: 37 RpcCollectiveExecutorMgr( 38 const ConfigProto& config, const DeviceMgr* dev_mgr, 39 std::unique_ptr<DeviceResolverDistributed> dev_resolver, 40 std::unique_ptr<CollectiveParamResolverDistributed> param_resolver, 41 std::unique_ptr<NcclCommunicatorInterface> nccl_communicator, 42 WorkerCacheInterface* worker_cache, const string& task_name); 43 44 virtual ~RpcCollectiveExecutorMgr(); 45 46 // This function should only be called at the group_leader, by an RPC. 47 // Other needs for StepIds should be satisfied by NextStepId. 48 void GetStepSequenceAsync(const GetStepSequenceRequest* request, 49 GetStepSequenceResponse* response, 50 const StatusCallback& done) override; 51 52 void RefreshStepIdSequenceAsync(int64_t graph_key, 53 const StatusCallback& done) override; 54 55 int64_t NextStepId(int64_t graph_key) override; 56 57 void RetireStepId(int64_t graph_key, int64_t step_id) override; 58 59 protected: 60 virtual CollectiveExecutor* Create(int64_t step_id) override; 61 62 WorkerCacheInterface* const worker_cache_; // Not owned. 63 const string task_name_; 64 string group_leader_; 65 friend class RpcCollectiveExecutorMgrTest; 66 67 private: 68 Status UpdateStepSequences(const GetStepSequenceResponse& resp); 69 70 // This class maintains the step_id sequencing for a single 71 // collective_graph_key. 72 struct GraphKeySequence { GraphKeySequenceGraphKeySequence73 explicit GraphKeySequence(int64_t k) 74 : graph_key_(k), next_step_id_(CollectiveExecutor::kInvalidId) {} 75 76 const int64_t graph_key_; 77 int64_t next_step_id_; 78 }; 79 80 mutex sequence_mu_; 81 gtl::FlatMap<int64_t, GraphKeySequence*> sequence_table_ 82 TF_GUARDED_BY(sequence_mu_); 83 }; 84 85 // Creates a distributed CollectiveExecutorMgr with production implementations 86 // of each components. Cases that need to inject other implementations of these 87 // components should call CollectiveExecutorMgr constructor directly. 88 std::unique_ptr<RpcCollectiveExecutorMgr> CreateProdRpcCollectiveExecutorMgr( 89 const ConfigProto& config, const DeviceMgr* device_mgr, 90 std::unique_ptr<NcclCommunicatorInterface> nccl_communicator, 91 WorkerCacheInterface* worker_cache, const string& default_worker_name); 92 93 } // namespace tensorflow 94 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_COLLECTIVE_EXECUTOR_MGR_H_ 95