1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_FRAMEWORK_CANCELLATION_H_ 17 #define TENSORFLOW_CORE_FRAMEWORK_CANCELLATION_H_ 18 19 #include <atomic> 20 #include <functional> 21 22 #include "tensorflow/core/lib/core/notification.h" 23 #include "tensorflow/core/lib/core/status.h" 24 #include "tensorflow/core/lib/gtl/flatmap.h" 25 #include "tensorflow/core/lib/hash/hash.h" 26 #include "tensorflow/core/platform/mutex.h" 27 #include "tensorflow/core/platform/status.h" 28 #include "tensorflow/core/platform/stringpiece.h" 29 #include "tensorflow/core/platform/thread_annotations.h" 30 #include "tensorflow/core/platform/types.h" 31 32 namespace tensorflow { 33 34 // A token that can be used to register and deregister a 35 // CancelCallback with a CancellationManager. 36 // 37 // CancellationToken values must be created by a call to 38 // CancellationManager::get_cancellation_token. 39 typedef int64_t CancellationToken; 40 41 // A callback that is invoked when a step is canceled. 42 // 43 // NOTE(mrry): See caveats about CancelCallback implementations in the 44 // comment for CancellationManager::RegisterCallback. 45 typedef std::function<void()> CancelCallback; 46 47 // This class should never simultaneously be used as the cancellation manager 48 // for two separate sets of executions (i.e two separate steps, or two separate 49 // function executions). 50 class CancellationManager { 51 public: 52 // A value that won't be returned by get_cancellation_token(). 53 static const CancellationToken kInvalidToken; 54 55 CancellationManager(); 56 57 // Constructs a new CancellationManager that is a "child" of `*parent`. 58 // 59 // If `*parent` is cancelled, `*this` will be cancelled. `*parent` must 60 // outlive the created CancellationManager. 61 explicit CancellationManager(CancellationManager* parent); 62 63 ~CancellationManager(); 64 65 // Run all callbacks associated with this manager. 66 void StartCancel(); 67 68 // Run all callbacks associated with this manager with a status. 69 // Currently the status is for logging purpose only. See also 70 // CancellationManager::RegisterCallbackWithErrorLogging. 71 void StartCancelWithStatus(const Status& status); 72 73 // Returns true iff StartCancel() has been called. IsCancelled()74 bool IsCancelled() { return is_cancelled_.load(std::memory_order_acquire); } 75 76 // Returns a token that must be used in calls to RegisterCallback 77 // and DeregisterCallback. get_cancellation_token()78 CancellationToken get_cancellation_token() { 79 return next_cancellation_token_.fetch_add(1); 80 } 81 82 // Attempts to register the given callback to be invoked when this 83 // manager is cancelled. Returns true if the callback was 84 // registered; returns false if this manager was already cancelled, 85 // and the callback was not registered. 86 // 87 // If this method returns false, it is the caller's responsibility 88 // to perform any cancellation cleanup. 89 // 90 // This method is tricky to use correctly. The following usage pattern 91 // is recommended: 92 // 93 // class ObjectWithCancellableOperation { 94 // mutex mu_; 95 // void CancellableOperation(CancellationManager* cm, 96 // std::function<void(Status)> callback) { 97 // bool already_cancelled; 98 // CancellationToken token = cm->get_cancellation_token(); 99 // { 100 // mutex_lock(mu_); 101 // already_cancelled = !cm->RegisterCallback( 102 // [this, token]() { Cancel(token); }); 103 // if (!already_cancelled) { 104 // // Issue asynchronous operation. Associate the pending operation 105 // // with `token` in some object state, or provide another way for 106 // // the Cancel method to look up the operation for cancellation. 107 // // Ensure that `cm->DeregisterCallback(token)` is called without 108 // // holding `mu_`, before `callback` is invoked. 109 // // ... 110 // } 111 // } 112 // if (already_cancelled) { 113 // callback(errors::Cancelled("Operation was cancelled")); 114 // } 115 // } 116 // 117 // void Cancel(CancellationToken token) { 118 // mutex_lock(mu_); 119 // // Take action to cancel the operation with the given cancellation 120 // // token. 121 // } 122 // 123 // NOTE(mrry): The caller should take care that (i) the calling code 124 // is robust to `callback` being invoked asynchronously (e.g. from 125 // another thread), (ii) `callback` is deregistered by a call to 126 // this->DeregisterCallback(token) when the operation completes 127 // successfully, and (iii) `callback` does not invoke any method 128 // on this cancellation manager. Furthermore, it is important that 129 // the eventual caller of the complementary DeregisterCallback does not 130 // hold any mutexes that are required by `callback`. 131 bool RegisterCallback(CancellationToken token, CancelCallback callback); 132 133 // Similar to RegisterCallback, but if the cancellation manager starts a 134 // cancellation with an error status, it will log the error status before 135 // invoking the callback. `callback_name` is a human-readable name of the 136 // callback, which will be displayed on the log. 137 bool RegisterCallbackWithErrorLogging(CancellationToken token, 138 CancelCallback callback, 139 tensorflow::StringPiece callback_name); 140 141 // Deregister the callback that, when registered, was associated 142 // with the given cancellation token. Returns true iff the callback 143 // was deregistered and will not be invoked; otherwise returns false 144 // after the callback has been invoked, blocking if necessary. 145 // 146 // NOTE(mrry): This method may block if cancellation is in progress. 147 // The caller of this method must not hold any mutexes that are required 148 // to invoke any cancellation callback that has been registered with this 149 // cancellation manager. 150 bool DeregisterCallback(CancellationToken token); 151 152 // Deregister the callback that, when registered, was associated 153 // with the given cancellation token. Returns true iff the callback 154 // was deregistered and will not be invoked; otherwise returns false 155 // immediately, with no guarantee that the callback has completed. 156 // 157 // This method is guaranteed to return true if StartCancel has not been 158 // called. 159 bool TryDeregisterCallback(CancellationToken token); 160 161 // Returns true iff cancellation is in progress. 162 bool IsCancelling(); 163 164 private: 165 struct CallbackConfiguration { 166 CancelCallback callback; 167 std::string name; 168 bool log_error = false; 169 }; 170 171 struct State { 172 Notification cancelled_notification; 173 gtl::FlatMap<CancellationToken, CallbackConfiguration> callbacks; 174 175 // If this CancellationManager has any children, this member points to the 176 // head of a doubly-linked list of its children. 177 CancellationManager* first_child = nullptr; // Not owned. 178 }; 179 180 bool RegisterCallbackConfig(CancellationToken token, 181 CallbackConfiguration config); 182 183 bool RegisterChild(CancellationManager* child); 184 void DeregisterChild(CancellationManager* child); 185 186 bool is_cancelling_; 187 std::atomic_bool is_cancelled_; 188 std::atomic<CancellationToken> next_cancellation_token_; 189 190 CancellationManager* const parent_ = nullptr; // Not owned. 191 192 // If this CancellationManager is associated with a parent, this member will 193 // be set to `true` after this is removed from the parent's list of children. 194 bool is_removed_from_parent_ TF_GUARDED_BY(parent_->mu_) = false; 195 196 // If this CancellationManager is associated with a parent, these members form 197 // a doubly-linked list of that parent's children. 198 // 199 // These fields are valid only when `this->is_removed_from_parent_` is false. 200 CancellationManager* prev_sibling_ TF_GUARDED_BY(parent_->mu_) = 201 nullptr; // Not owned. 202 CancellationManager* next_sibling_ TF_GUARDED_BY(parent_->mu_) = 203 nullptr; // Not owned. 204 205 mutex mu_; 206 std::unique_ptr<State> state_ TF_GUARDED_BY(mu_); 207 }; 208 209 // Registers the given cancellation callback, returning a function that can be 210 // used to deregister the callback. If `cancellation_manager` is NULL, no 211 // registration occurs and `deregister_fn` will be a no-op. 212 Status RegisterCancellationCallback(CancellationManager* cancellation_manager, 213 std::function<void()> callback, 214 std::function<void()>* deregister_fn); 215 216 } // namespace tensorflow 217 218 #endif // TENSORFLOW_CORE_FRAMEWORK_CANCELLATION_H_ 219