xref: /aosp_15_r20/external/tensorflow/tensorflow/core/framework/cancellation.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_CANCELLATION_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_CANCELLATION_H_
18 
19 #include <atomic>
20 #include <functional>
21 
22 #include "tensorflow/core/lib/core/notification.h"
23 #include "tensorflow/core/lib/core/status.h"
24 #include "tensorflow/core/lib/gtl/flatmap.h"
25 #include "tensorflow/core/lib/hash/hash.h"
26 #include "tensorflow/core/platform/mutex.h"
27 #include "tensorflow/core/platform/status.h"
28 #include "tensorflow/core/platform/stringpiece.h"
29 #include "tensorflow/core/platform/thread_annotations.h"
30 #include "tensorflow/core/platform/types.h"
31 
32 namespace tensorflow {
33 
34 // A token that can be used to register and deregister a
35 // CancelCallback with a CancellationManager.
36 //
37 // CancellationToken values must be created by a call to
38 // CancellationManager::get_cancellation_token.
39 typedef int64_t CancellationToken;
40 
41 // A callback that is invoked when a step is canceled.
42 //
43 // NOTE(mrry): See caveats about CancelCallback implementations in the
44 // comment for CancellationManager::RegisterCallback.
45 typedef std::function<void()> CancelCallback;
46 
47 // This class should never simultaneously be used as the cancellation manager
48 // for two separate sets of executions (i.e two separate steps, or two separate
49 // function executions).
50 class CancellationManager {
51  public:
52   // A value that won't be returned by get_cancellation_token().
53   static const CancellationToken kInvalidToken;
54 
55   CancellationManager();
56 
57   // Constructs a new CancellationManager that is a "child" of `*parent`.
58   //
59   // If `*parent` is cancelled, `*this` will be cancelled. `*parent` must
60   // outlive the created CancellationManager.
61   explicit CancellationManager(CancellationManager* parent);
62 
63   ~CancellationManager();
64 
65   // Run all callbacks associated with this manager.
66   void StartCancel();
67 
68   // Run all callbacks associated with this manager with a status.
69   // Currently the status is for logging purpose only. See also
70   // CancellationManager::RegisterCallbackWithErrorLogging.
71   void StartCancelWithStatus(const Status& status);
72 
73   // Returns true iff StartCancel() has been called.
IsCancelled()74   bool IsCancelled() { return is_cancelled_.load(std::memory_order_acquire); }
75 
76   // Returns a token that must be used in calls to RegisterCallback
77   // and DeregisterCallback.
get_cancellation_token()78   CancellationToken get_cancellation_token() {
79     return next_cancellation_token_.fetch_add(1);
80   }
81 
82   // Attempts to register the given callback to be invoked when this
83   // manager is cancelled. Returns true if the callback was
84   // registered; returns false if this manager was already cancelled,
85   // and the callback was not registered.
86   //
87   // If this method returns false, it is the caller's responsibility
88   // to perform any cancellation cleanup.
89   //
90   // This method is tricky to use correctly. The following usage pattern
91   // is recommended:
92   //
93   // class ObjectWithCancellableOperation {
94   //   mutex mu_;
95   //   void CancellableOperation(CancellationManager* cm,
96   //                             std::function<void(Status)> callback) {
97   //     bool already_cancelled;
98   //     CancellationToken token = cm->get_cancellation_token();
99   //     {
100   //       mutex_lock(mu_);
101   //       already_cancelled = !cm->RegisterCallback(
102   //           [this, token]() { Cancel(token); });
103   //       if (!already_cancelled) {
104   //         // Issue asynchronous operation. Associate the pending operation
105   //         // with `token` in some object state, or provide another way for
106   //         // the Cancel method to look up the operation for cancellation.
107   //         // Ensure that `cm->DeregisterCallback(token)` is called without
108   //         // holding `mu_`, before `callback` is invoked.
109   //         // ...
110   //       }
111   //     }
112   //     if (already_cancelled) {
113   //       callback(errors::Cancelled("Operation was cancelled"));
114   //     }
115   //   }
116   //
117   //   void Cancel(CancellationToken token) {
118   //     mutex_lock(mu_);
119   //     // Take action to cancel the operation with the given cancellation
120   //     // token.
121   //   }
122   //
123   // NOTE(mrry): The caller should take care that (i) the calling code
124   // is robust to `callback` being invoked asynchronously (e.g. from
125   // another thread), (ii) `callback` is deregistered by a call to
126   // this->DeregisterCallback(token) when the operation completes
127   // successfully, and (iii) `callback` does not invoke any method
128   // on this cancellation manager. Furthermore, it is important that
129   // the eventual caller of the complementary DeregisterCallback does not
130   // hold any mutexes that are required by `callback`.
131   bool RegisterCallback(CancellationToken token, CancelCallback callback);
132 
133   // Similar to RegisterCallback, but if the cancellation manager starts a
134   // cancellation with an error status, it will log the error status before
135   // invoking the callback. `callback_name` is a human-readable name of the
136   // callback, which will be displayed on the log.
137   bool RegisterCallbackWithErrorLogging(CancellationToken token,
138                                         CancelCallback callback,
139                                         tensorflow::StringPiece callback_name);
140 
141   // Deregister the callback that, when registered, was associated
142   // with the given cancellation token. Returns true iff the callback
143   // was deregistered and will not be invoked; otherwise returns false
144   // after the callback has been invoked, blocking if necessary.
145   //
146   // NOTE(mrry): This method may block if cancellation is in progress.
147   // The caller of this method must not hold any mutexes that are required
148   // to invoke any cancellation callback that has been registered with this
149   // cancellation manager.
150   bool DeregisterCallback(CancellationToken token);
151 
152   // Deregister the callback that, when registered, was associated
153   // with the given cancellation token. Returns true iff the callback
154   // was deregistered and will not be invoked; otherwise returns false
155   // immediately, with no guarantee that the callback has completed.
156   //
157   // This method is guaranteed to return true if StartCancel has not been
158   // called.
159   bool TryDeregisterCallback(CancellationToken token);
160 
161   // Returns true iff cancellation is in progress.
162   bool IsCancelling();
163 
164  private:
165   struct CallbackConfiguration {
166     CancelCallback callback;
167     std::string name;
168     bool log_error = false;
169   };
170 
171   struct State {
172     Notification cancelled_notification;
173     gtl::FlatMap<CancellationToken, CallbackConfiguration> callbacks;
174 
175     // If this CancellationManager has any children, this member points to the
176     // head of a doubly-linked list of its children.
177     CancellationManager* first_child = nullptr;  // Not owned.
178   };
179 
180   bool RegisterCallbackConfig(CancellationToken token,
181                               CallbackConfiguration config);
182 
183   bool RegisterChild(CancellationManager* child);
184   void DeregisterChild(CancellationManager* child);
185 
186   bool is_cancelling_;
187   std::atomic_bool is_cancelled_;
188   std::atomic<CancellationToken> next_cancellation_token_;
189 
190   CancellationManager* const parent_ = nullptr;  // Not owned.
191 
192   // If this CancellationManager is associated with a parent, this member will
193   // be set to `true` after this is removed from the parent's list of children.
194   bool is_removed_from_parent_ TF_GUARDED_BY(parent_->mu_) = false;
195 
196   // If this CancellationManager is associated with a parent, these members form
197   // a doubly-linked list of that parent's children.
198   //
199   // These fields are valid only when `this->is_removed_from_parent_` is false.
200   CancellationManager* prev_sibling_ TF_GUARDED_BY(parent_->mu_) =
201       nullptr;  // Not owned.
202   CancellationManager* next_sibling_ TF_GUARDED_BY(parent_->mu_) =
203       nullptr;  // Not owned.
204 
205   mutex mu_;
206   std::unique_ptr<State> state_ TF_GUARDED_BY(mu_);
207 };
208 
209 // Registers the given cancellation callback, returning a function that can be
210 // used to deregister the callback. If `cancellation_manager` is NULL, no
211 // registration occurs and `deregister_fn` will be a no-op.
212 Status RegisterCancellationCallback(CancellationManager* cancellation_manager,
213                                     std::function<void()> callback,
214                                     std::function<void()>* deregister_fn);
215 
216 }  // namespace tensorflow
217 
218 #endif  // TENSORFLOW_CORE_FRAMEWORK_CANCELLATION_H_
219