xref: /aosp_15_r20/external/tensorflow/tensorflow/python/client/session_ref.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/python/client/session_ref.h"
16 
17 #include <stdlib.h>
18 #include <memory>
19 #include <utility>
20 
21 #include "tensorflow/core/lib/io/path.h"
22 #include "tensorflow/core/lib/io/record_writer.h"
23 #include "tensorflow/core/lib/strings/stringprintf.h"
24 #include "tensorflow/core/protobuf/master.pb.h"
25 #include "tensorflow/core/protobuf/named_tensor.pb.h"
26 #include "tensorflow/core/protobuf/replay_log.pb.h"
27 
28 namespace tensorflow {
29 
30 namespace {
31 
32 // Scope helper to track active calls and manage session lifetime.
33 // SessionRef blocks closing until all active calls complete or are cancelled.
34 struct RunCounter {
35   std::shared_ptr<Session> session;
36   uint64* value;
37   mutex* m;
38   condition_variable* cv;
39 
RunCountertensorflow::__anon6781973e0111::RunCounter40   explicit RunCounter(std::shared_ptr<Session> s, uint64* v, mutex* m,
41                       condition_variable* cv)
42       : session(std::move(s)), value(v), m(m), cv(cv) {
43     mutex_lock l(*m);
44     ++*value;
45   }
46 
~RunCountertensorflow::__anon6781973e0111::RunCounter47   ~RunCounter() {
48     mutex_lock l(*m);
49     if (--*value == 0) {
50       cv->notify_all();
51     }
52   }
53 };
54 
SessionToHandle(Session * session)55 std::string SessionToHandle(Session* session) {
56   return strings::Printf("%llu", static_cast<unsigned long long>(
57                                      reinterpret_cast<uintptr_t>(session)));
58 }
59 
60 // The Session interface has many methods of the form:
61 //
62 // X(a, b);
63 // X(RunOptions, a, b);
64 //
65 // Not all sessions support the second case (with an empty RunOptions()).
66 // We use this variable as a sentinel to dispatch to the correct call.
kEmptyRunOptions()67 RunOptions* kEmptyRunOptions() {
68   static RunOptions* options = new RunOptions();
69   return options;
70 }
71 
72 }  // namespace
73 
74 // Run the given session operation, recording start and end timestamps.
75 // If the operation returns a bad status, return after flushing the current
76 // log request.  This should be run _after_ all request information has been
77 // added to the current op.
78 #define RUN_WITH_TIMESTAMP(OpName, ...)              \
79   op.set_start_time_us(Env::Default()->NowMicros()); \
80   Status status = session->OpName(__VA_ARGS__);      \
81   op.set_end_time_us(Env::Default()->NowMicros());   \
82   if (!status.ok()) {                                \
83     Flush(op).IgnoreError();                         \
84     return status;                                   \
85   }
86 
87 // Records requests (and optionally responses) performed against a session.
88 // The resulting replay log can be used with the `tf_replay` tool to replicate
89 // the operations against a simulated environment, without requiring the
90 // original code or cluster setup.
91 //
92 // Session logging by setting the TF_REPLAY_LOG_FILE environment variable.
93 class SessionLogger {
94  public:
SessionLogger()95   SessionLogger() {
96     const char* log_file_env = getenv("TF_REPLAY_LOG_FILE");
97     std::string log_name = log_file_env ? std::string(log_file_env) : ".";
98     LOG(INFO) << "Constructing new session logger for " << log_name;
99     TF_CHECK_OK(
100         Env::Default()->RecursivelyCreateDir(string(io::Dirname(log_name))));
101     Env::Default()->DeleteFile(log_name).IgnoreError();
102 
103     TF_CHECK_OK(Env::Default()->NewWritableFile(log_name, &log_file_));
104     log_writer_ = absl::make_unique<io::RecordWriter>(log_file_.get());
105   }
106 
~SessionLogger()107   ~SessionLogger() {
108     log_writer_->Close().IgnoreError();
109     log_writer_.release();
110     log_file_->Close().IgnoreError();
111   }
112 
RecordNewSession(Session * session)113   Status RecordNewSession(Session* session) {
114     ReplayOp op;
115     NewReplaySession* req = op.mutable_new_replay_session();
116     req->set_session_handle(SessionToHandle(session));
117     return Flush(op);
118   }
119 
RecordRun(Session * session,const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_tensor_names,const std::vector<string> & target_node_names,std::vector<Tensor> * outputs)120   Status RecordRun(Session* session,
121                    const std::vector<std::pair<string, Tensor> >& inputs,
122                    const std::vector<string>& output_tensor_names,
123                    const std::vector<string>& target_node_names,
124                    std::vector<Tensor>* outputs) {
125     return RecordRun(session, *kEmptyRunOptions(), inputs, output_tensor_names,
126                      target_node_names, outputs, nullptr);
127   }
128 
RecordRun(Session * session,const RunOptions & run_options,const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_tensor_names,const std::vector<string> & target_node_names,std::vector<Tensor> * outputs,RunMetadata * run_metadata)129   Status RecordRun(Session* session, const RunOptions& run_options,
130                    const std::vector<std::pair<string, Tensor> >& inputs,
131                    const std::vector<string>& output_tensor_names,
132                    const std::vector<string>& target_node_names,
133                    std::vector<Tensor>* outputs, RunMetadata* run_metadata) {
134     ReplayOp op;
135     RunStepRequest* req = op.mutable_run_step();
136     RunStepResponse* resp = op.mutable_run_step_response();
137 
138     req->set_session_handle(SessionToHandle(session));
139     *req->mutable_options() = run_options;
140 
141     for (const auto& it : inputs) {
142       NamedTensorProto* feed = req->add_feed();
143       feed->set_name(it.first);
144       it.second.AsProtoField(feed->mutable_tensor());
145     }
146 
147     // Build an index from fetch tensor name to first index in
148     // output_tensor_names.
149     std::unordered_map<string, int> output_name_to_offset;
150     for (int i = 0, end = output_tensor_names.size(); i < end; ++i) {
151       const string& name = output_tensor_names[i];
152       if (output_name_to_offset.insert(std::make_pair(name, i)).second) {
153         req->add_fetch(name);
154       }
155     }
156     for (const string& target : target_node_names) {
157       req->add_target(target);
158     }
159 
160     if (&run_options == kEmptyRunOptions()) {
161       RUN_WITH_TIMESTAMP(Run, inputs, output_tensor_names, target_node_names,
162                          outputs);
163     } else {
164       RUN_WITH_TIMESTAMP(Run, run_options, inputs, output_tensor_names,
165                          target_node_names, outputs, run_metadata);
166     }
167 
168     for (size_t i = 0; i < outputs->size(); ++i) {
169       const Tensor& tensor = (*outputs)[i];
170       NamedTensorProto* tproto = resp->add_tensor();
171       tensor.AsProtoField(tproto->mutable_tensor());
172       tproto->set_name(output_tensor_names[i]);
173     }
174 
175     if (run_metadata) {
176       *resp->mutable_metadata() = *run_metadata;
177     }
178 
179     return Flush(op);
180   }
181 
RecordCreate(Session * session,const GraphDef & graph)182   Status RecordCreate(Session* session, const GraphDef& graph) {
183     return RecordCreate(session, *kEmptyRunOptions(), graph);
184   }
185 
186   // N.B. RunOptions is not stored (it has no entry in CreateRequest)
RecordCreate(Session * session,const RunOptions & run_options,const GraphDef & graph)187   Status RecordCreate(Session* session, const RunOptions& run_options,
188                       const GraphDef& graph) {
189     ReplayOp op;
190     CreateSessionRequest* req = op.mutable_create_session();
191     *req->mutable_graph_def() = graph;
192 
193     CreateSessionResponse* resp = op.mutable_create_session_response();
194     if (&run_options == kEmptyRunOptions()) {
195       RUN_WITH_TIMESTAMP(Create, graph);
196     } else {
197       RUN_WITH_TIMESTAMP(Create, run_options, graph);
198     }
199     resp->set_session_handle(SessionToHandle(session));
200     return Flush(op);
201   }
202 
RecordExtend(Session * session,const GraphDef & graph)203   Status RecordExtend(Session* session, const GraphDef& graph) {
204     return RecordExtend(session, *kEmptyRunOptions(), graph);
205   }
206 
207   // N.B. RunOptions is not stored (it has no entry in ExtendRequest)
RecordExtend(Session * session,const RunOptions & run_options,const GraphDef & graph)208   Status RecordExtend(Session* session, const RunOptions& run_options,
209                       const GraphDef& graph) {
210     ReplayOp op;
211     ExtendSessionRequest* req = op.mutable_extend_session();
212     op.mutable_extend_session_response();
213     req->set_session_handle(SessionToHandle(session));
214     *req->mutable_graph_def() = graph;
215     if (&run_options == kEmptyRunOptions()) {
216       RUN_WITH_TIMESTAMP(Extend, graph);
217     } else {
218       RUN_WITH_TIMESTAMP(Extend, run_options, graph);
219     }
220 
221     return Flush(op);
222   }
223 
RecordClose(Session * session)224   Status RecordClose(Session* session) {
225     return RecordClose(session, *kEmptyRunOptions());
226   }
227 
228   // N.B. RunOptions is not stored (it has no entry in CloseRequest)
RecordClose(Session * session,const RunOptions & run_options)229   Status RecordClose(Session* session, const RunOptions& run_options) {
230     ReplayOp op;
231     CloseSessionRequest* req = op.mutable_close_session();
232     req->set_session_handle(SessionToHandle(session));
233     op.mutable_close_session_response();
234     if (&run_options == kEmptyRunOptions()) {
235       RUN_WITH_TIMESTAMP(Close);
236     } else {
237       RUN_WITH_TIMESTAMP(Close, run_options);
238     }
239     return Flush(op);
240   }
241 
RecordListDevices(Session * session,std::vector<DeviceAttributes> * response)242   Status RecordListDevices(Session* session,
243                            std::vector<DeviceAttributes>* response) {
244     ReplayOp op;
245     ListDevicesRequest* req = op.mutable_list_devices();
246     ListDevicesResponse* resp = op.mutable_list_devices_response();
247     req->set_session_handle(SessionToHandle(session));
248     RUN_WITH_TIMESTAMP(ListDevices, response);
249 
250     // TODO(power) -- local vs remote device distinction is lost here!
251     *resp->mutable_local_device() = {response->begin(), response->end()};
252     return Flush(op);
253   }
254 
RecordPRunSetup(Session * session,const std::vector<string> & input_names,const std::vector<string> & output_names,const std::vector<string> & target_nodes,string * handle)255   Status RecordPRunSetup(Session* session,
256                          const std::vector<string>& input_names,
257                          const std::vector<string>& output_names,
258                          const std::vector<string>& target_nodes,
259                          string* handle) {
260     ReplayOp op;
261     PartialRunSetupRequest* req = op.mutable_partial_run_setup();
262     req->set_session_handle(SessionToHandle(session));
263     for (auto& input : input_names) {
264       req->add_feed(input);
265     }
266     for (auto& output : output_names) {
267       req->add_fetch(output);
268     }
269     for (auto& target : target_nodes) {
270       req->add_target(target);
271     }
272     RUN_WITH_TIMESTAMP(PRunSetup, input_names, output_names, target_nodes,
273                        handle);
274     op.mutable_partial_run_setup_response()->set_partial_run_handle(*handle);
275     return Flush(op);
276   }
277 
RecordPRun(Session * session,const string & handle,const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_names,std::vector<Tensor> * outputs)278   Status RecordPRun(Session* session, const string& handle,
279                     const std::vector<std::pair<string, Tensor> >& inputs,
280                     const std::vector<string>& output_names,
281                     std::vector<Tensor>* outputs) {
282     ReplayOp op;
283     RunStepRequest* req = op.mutable_run_step();
284     RunStepResponse* resp = op.mutable_run_step_response();
285     req->set_session_handle(SessionToHandle(session));
286 
287     // Mark this step as a partial run for replay.
288     req->set_partial_run_handle(handle);
289     for (auto& input : inputs) {
290       auto* feed = req->add_feed();
291       feed->set_name(input.first);
292       input.second.AsProtoField(feed->mutable_tensor());
293     }
294 
295     for (auto& output : output_names) {
296       req->add_fetch(output);
297     }
298 
299     RUN_WITH_TIMESTAMP(PRun, handle, inputs, output_names, outputs);
300 
301     for (size_t i = 0; i < outputs->size(); ++i) {
302       const Tensor& tensor = (*outputs)[i];
303       NamedTensorProto* tproto = resp->add_tensor();
304       tensor.AsProtoField(tproto->mutable_tensor());
305       tproto->set_name(output_names[i]);
306     }
307 
308     return Flush(op);
309   }
310 
RecordMakeCallable(Session * session,const CallableOptions & callable_options,Session::CallableHandle * handle)311   Status RecordMakeCallable(Session* session,
312                             const CallableOptions& callable_options,
313                             Session::CallableHandle* handle) {
314     ReplayOp op;
315     MakeCallableRequest* req = op.mutable_make_callable();
316     req->set_session_handle(SessionToHandle(session));
317     *req->mutable_options() = callable_options;
318 
319     RUN_WITH_TIMESTAMP(MakeCallable, callable_options, handle);
320 
321     MakeCallableResponse* resp = op.mutable_make_callable_response();
322     resp->set_handle(*handle);
323 
324     return Flush(op);
325   }
326 
RecordRunCallable(Session * session,Session::CallableHandle handle,const std::vector<Tensor> & feed_tensors,std::vector<Tensor> * fetch_tensors,RunMetadata * run_metadata)327   Status RecordRunCallable(Session* session, Session::CallableHandle handle,
328                            const std::vector<Tensor>& feed_tensors,
329                            std::vector<Tensor>* fetch_tensors,
330                            RunMetadata* run_metadata) {
331     ReplayOp op;
332     RunCallableRequest* req = op.mutable_run_callable();
333     req->set_session_handle(SessionToHandle(session));
334     req->set_handle(handle);
335     for (auto& tensor : feed_tensors) {
336       tensor.AsProtoField(req->add_feed());
337     }
338     RUN_WITH_TIMESTAMP(RunCallable, handle, feed_tensors, fetch_tensors,
339                        run_metadata);
340 
341     RunCallableResponse* resp = op.mutable_run_callable_response();
342     if (run_metadata) {
343       *resp->mutable_metadata() = *run_metadata;
344     }
345     for (const Tensor& tensor : *fetch_tensors) {
346       tensor.AsProtoTensorContent(resp->add_fetch());
347     }
348     return Flush(op);
349   }
350 
RecordReleaseCallable(Session * session,Session::CallableHandle handle)351   Status RecordReleaseCallable(Session* session,
352                                Session::CallableHandle handle) {
353     ReplayOp op;
354     ReleaseCallableRequest* req = op.mutable_release_callable();
355     req->set_session_handle(SessionToHandle(session));
356     req->set_handle(handle);
357     RUN_WITH_TIMESTAMP(ReleaseCallable, handle);
358     return Flush(op);
359   }
360 
361  private:
Flush(const ReplayOp & op)362   Status Flush(const ReplayOp& op) {
363     mutex_lock l(log_mutex_);
364 
365     string buf;
366     op.SerializeToString(&buf);
367     TF_RETURN_IF_ERROR(log_writer_->WriteRecord(buf));
368 
369     // TODO(b/116624106): Not all file-systems respect calls to `Sync()`
370     return log_file_->Sync();
371   }
372 
373   std::unique_ptr<WritableFile> log_file_;
374   std::unique_ptr<io::RecordWriter> log_writer_;
375   mutex log_mutex_;
376 };
377 
global_session_logger()378 static SessionLogger* global_session_logger() {
379   static SessionLogger* logger = new SessionLogger();
380   return logger;
381 }
382 
SessionRef(Session * session)383 SessionRef::SessionRef(Session* session) : session_(session) {
384   if (getenv("TF_REPLAY_LOG_FILE") != nullptr) {
385     logger_ = global_session_logger();
386     logger_->RecordNewSession(this->session_.get()).IgnoreError();
387   } else {
388     logger_ = nullptr;
389   }
390 }
391 
392 SessionRef::~SessionRef() = default;
393 
CheckNotClosed()394 Status SessionRef::CheckNotClosed() {
395   mutex_lock l(run_lock_);
396   if (session_ == nullptr) return errors::Cancelled("Session has been closed.");
397   return OkStatus();
398 }
399 
400 // If logging is active, log the start and end time of the operation along with
401 // the request and response.
402 #define LOG_AND_RUN_OPERATION(OpName, ...)                          \
403   TF_RETURN_IF_ERROR(CheckNotClosed());                             \
404   RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); \
405   if (!logger_) {                                                   \
406     return rc.session->OpName(__VA_ARGS__);                         \
407   }                                                                 \
408   return logger_->Record##OpName(rc.session.get(), __VA_ARGS__);
409 
Run(const RunOptions & run_options,const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_tensor_names,const std::vector<string> & target_node_names,std::vector<Tensor> * outputs,RunMetadata * run_metadata)410 Status SessionRef::Run(const RunOptions& run_options,
411                        const std::vector<std::pair<string, Tensor> >& inputs,
412                        const std::vector<string>& output_tensor_names,
413                        const std::vector<string>& target_node_names,
414                        std::vector<Tensor>* outputs,
415                        RunMetadata* run_metadata) {
416   LOG_AND_RUN_OPERATION(Run, run_options, inputs, output_tensor_names,
417                         target_node_names, outputs, run_metadata);
418 }
419 
Run(const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_tensor_names,const std::vector<string> & target_node_names,std::vector<Tensor> * outputs)420 Status SessionRef::Run(const std::vector<std::pair<string, Tensor> >& inputs,
421                        const std::vector<string>& output_tensor_names,
422                        const std::vector<string>& target_node_names,
423                        std::vector<Tensor>* outputs) {
424   LOG_AND_RUN_OPERATION(Run, inputs, output_tensor_names, target_node_names,
425                         outputs);
426 }
427 
Create(const GraphDef & graph)428 Status SessionRef::Create(const GraphDef& graph) {
429   LOG_AND_RUN_OPERATION(Create, graph);
430 }
431 
Create(const RunOptions & run_options,const GraphDef & graph)432 Status SessionRef::Create(const RunOptions& run_options,
433                           const GraphDef& graph) {
434   LOG_AND_RUN_OPERATION(Create, run_options, graph);
435 }
436 
Extend(const RunOptions & run_options,const GraphDef & graph)437 Status SessionRef::Extend(const RunOptions& run_options,
438                           const GraphDef& graph) {
439   LOG_AND_RUN_OPERATION(Extend, run_options, graph);
440 }
441 
Extend(const GraphDef & graph)442 Status SessionRef::Extend(const GraphDef& graph) {
443   LOG_AND_RUN_OPERATION(Extend, graph);
444 }
445 
ListDevices(std::vector<DeviceAttributes> * response)446 Status SessionRef::ListDevices(std::vector<DeviceAttributes>* response) {
447   LOG_AND_RUN_OPERATION(ListDevices, response);
448 }
449 
PRunSetup(const std::vector<string> & input_names,const std::vector<string> & output_names,const std::vector<string> & target_nodes,string * handle)450 Status SessionRef::PRunSetup(const std::vector<string>& input_names,
451                              const std::vector<string>& output_names,
452                              const std::vector<string>& target_nodes,
453                              string* handle) {
454   LOG_AND_RUN_OPERATION(PRunSetup, input_names, output_names, target_nodes,
455                         handle);
456 }
457 
PRun(const string & handle,const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_names,std::vector<Tensor> * outputs)458 Status SessionRef::PRun(const string& handle,
459                         const std::vector<std::pair<string, Tensor> >& inputs,
460                         const std::vector<string>& output_names,
461                         std::vector<Tensor>* outputs) {
462   LOG_AND_RUN_OPERATION(PRun, handle, inputs, output_names, outputs);
463 }
464 
MakeCallable(const CallableOptions & callable_options,CallableHandle * out_handle)465 Status SessionRef::MakeCallable(const CallableOptions& callable_options,
466                                 CallableHandle* out_handle) {
467   LOG_AND_RUN_OPERATION(MakeCallable, callable_options, out_handle);
468 }
469 
RunCallable(CallableHandle handle,const std::vector<Tensor> & feed_tensors,std::vector<Tensor> * fetch_tensors,RunMetadata * run_metadata)470 Status SessionRef::RunCallable(CallableHandle handle,
471                                const std::vector<Tensor>& feed_tensors,
472                                std::vector<Tensor>* fetch_tensors,
473                                RunMetadata* run_metadata) {
474   LOG_AND_RUN_OPERATION(RunCallable, handle, feed_tensors, fetch_tensors,
475                         run_metadata);
476 }
477 
ReleaseCallable(CallableHandle handle)478 Status SessionRef::ReleaseCallable(CallableHandle handle) {
479   {
480     mutex_lock l(run_lock_);
481     if (session_ == nullptr) {
482       // Session already closed. Do nothing.
483       return OkStatus();
484     }
485   }
486   LOG_AND_RUN_OPERATION(ReleaseCallable, handle);
487 }
488 
Close(const RunOptions & run_options)489 Status SessionRef::Close(const RunOptions& run_options) {
490   TF_RETURN_IF_ERROR(CheckNotClosed());
491   mutex_lock l(run_lock_);
492   Status status;
493   if (logger_) {
494     status = logger_->RecordClose(session_.get(), run_options);
495   } else {
496     status = session_->Close(run_options);
497   }
498   session_.reset();
499   while (run_count_ > 0) {
500     run_finished_.wait(l);
501   }
502   return status;
503 }
504 
Close()505 Status SessionRef::Close() {
506   TF_RETURN_IF_ERROR(CheckNotClosed());
507   mutex_lock l(run_lock_);
508   Status status;
509   if (logger_) {
510     status = logger_->RecordClose(session_.get());
511   } else {
512     status = session_->Close();
513   }
514   session_.reset();
515   while (run_count_ > 0) {
516     run_finished_.wait(l);
517   }
518   return status;
519 }
520 
521 }  // namespace tensorflow
522