xref: /aosp_15_r20/external/tensorflow/tensorflow/python/lib/core/pybind11_status.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_PYTHON_LIB_CORE_PYBIND11_STATUS_H_
17 #define TENSORFLOW_PYTHON_LIB_CORE_PYBIND11_STATUS_H_
18 
19 #include <Python.h>
20 
21 #include "pybind11/cast.h"
22 #include "pybind11/pybind11.h"
23 #include "tensorflow/c/tf_status_internal.h"
24 #include "tensorflow/core/lib/core/status.h"
25 #include "tensorflow/core/platform/errors.h"
26 #include "tensorflow/core/platform/statusor.h"
27 #include "tensorflow/core/protobuf/error_codes.pb.h"
28 #include "tensorflow/python/lib/core/py_exception_registry.h"
29 
30 namespace tensorflow {
31 
32 namespace internal {
33 
CodeToPyExc(const int code)34 inline PyObject* CodeToPyExc(const int code) {
35   switch (code) {
36     case error::Code::INVALID_ARGUMENT:
37       return PyExc_ValueError;
38     case error::Code::OUT_OF_RANGE:
39       return PyExc_IndexError;
40     case error::Code::UNIMPLEMENTED:
41       return PyExc_NotImplementedError;
42     default:
43       return PyExc_RuntimeError;
44   }
45 }
46 
StatusToPyExc(const Status & status)47 inline PyObject* StatusToPyExc(const Status& status) {
48   return CodeToPyExc(status.code());
49 }
50 
TFStatusToPyExc(const TF_Status * status)51 inline PyObject* TFStatusToPyExc(const TF_Status* status) {
52   return CodeToPyExc(TF_GetCode(status));
53 }
54 
StatusPayloadToDict(const Status & status)55 inline pybind11::dict StatusPayloadToDict(const Status& status) {
56   pybind11::dict dict;
57   const auto& payloads = errors::GetPayloads(status);
58   for (auto& pair : payloads) {
59     dict[PyBytes_FromString(pair.first.c_str())] =
60         PyBytes_FromString(pair.second.c_str());
61   }
62   return dict;
63 }
64 
TFStatusPayloadToDict(TF_Status * status)65 inline pybind11::dict TFStatusPayloadToDict(TF_Status* status) {
66   return StatusPayloadToDict(status->status);
67 }
68 
69 }  // namespace internal
70 
MaybeRaiseFromStatus(const Status & status)71 inline void MaybeRaiseFromStatus(const Status& status) {
72   if (!status.ok()) {
73     PyErr_SetString(internal::StatusToPyExc(status),
74                     status.error_message().c_str());
75     throw pybind11::error_already_set();
76   }
77 }
78 
SetRegisteredErrFromStatus(const tensorflow::Status & status)79 inline void SetRegisteredErrFromStatus(const tensorflow::Status& status) {
80   PyErr_SetObject(PyExceptionRegistry::Lookup(status.code()),
81                   pybind11::make_tuple(pybind11::none(), pybind11::none(),
82                                        status.error_message(),
83                                        internal::StatusPayloadToDict(status))
84                       .ptr());
85 }
86 
SetRegisteredErrFromTFStatus(TF_Status * status)87 inline void SetRegisteredErrFromTFStatus(TF_Status* status) {
88   PyErr_SetObject(PyExceptionRegistry::Lookup(TF_GetCode(status)),
89                   pybind11::make_tuple(pybind11::none(), pybind11::none(),
90                                        TF_Message(status),
91                                        internal::TFStatusPayloadToDict(status))
92                       .ptr());
93 }
94 
MaybeRaiseRegisteredFromStatus(const tensorflow::Status & status)95 inline void MaybeRaiseRegisteredFromStatus(const tensorflow::Status& status) {
96   if (!status.ok()) {
97     SetRegisteredErrFromStatus(status);
98     throw pybind11::error_already_set();
99   }
100 }
101 
MaybeRaiseRegisteredFromStatusWithGIL(const tensorflow::Status & status)102 inline void MaybeRaiseRegisteredFromStatusWithGIL(
103     const tensorflow::Status& status) {
104   if (!status.ok()) {
105     // Acquire GIL for throwing exception.
106     pybind11::gil_scoped_acquire acquire;
107     SetRegisteredErrFromStatus(status);
108     throw pybind11::error_already_set();
109   }
110 }
111 
MaybeRaiseFromTFStatus(TF_Status * status)112 inline void MaybeRaiseFromTFStatus(TF_Status* status) {
113   TF_Code code = TF_GetCode(status);
114   if (code != TF_OK) {
115     PyErr_SetString(internal::TFStatusToPyExc(status), TF_Message(status));
116     throw pybind11::error_already_set();
117   }
118 }
119 
MaybeRaiseRegisteredFromTFStatus(TF_Status * status)120 inline void MaybeRaiseRegisteredFromTFStatus(TF_Status* status) {
121   TF_Code code = TF_GetCode(status);
122   if (code != TF_OK) {
123     SetRegisteredErrFromTFStatus(status);
124     throw pybind11::error_already_set();
125   }
126 }
127 
MaybeRaiseRegisteredFromTFStatusWithGIL(TF_Status * status)128 inline void MaybeRaiseRegisteredFromTFStatusWithGIL(TF_Status* status) {
129   TF_Code code = TF_GetCode(status);
130   if (code != TF_OK) {
131     // Acquire GIL for throwing exception.
132     pybind11::gil_scoped_acquire acquire;
133     SetRegisteredErrFromTFStatus(status);
134     throw pybind11::error_already_set();
135   }
136 }
137 
138 }  // namespace tensorflow
139 
140 namespace pybind11 {
141 namespace detail {
142 
143 // Convert tensorflow::Status
144 //
145 // Raise an exception if a given status is not OK, otherwise return None.
146 //
147 // The correspondence between status codes and exception classes is given
148 // by PyExceptionRegistry. Note that the registry should be initialized
149 // in order to be used, see PyExceptionRegistry::Init.
150 template <>
151 struct type_caster<tensorflow::Status> {
152  public:
153   PYBIND11_TYPE_CASTER(tensorflow::Status, _("Status"));
154   static handle cast(tensorflow::Status status, return_value_policy, handle) {
155     tensorflow::MaybeRaiseFromStatus(status);
156     return none().inc_ref();
157   }
158 };
159 
160 // Convert tensorflow::StatusOr
161 //
162 // Uses the same logic as the Abseil implementation: raise an exception if the
163 // status is not OK, otherwise return its payload.
164 template <typename PayloadType>
165 struct type_caster<tensorflow::StatusOr<PayloadType>> {
166  public:
167   using PayloadCaster = make_caster<PayloadType>;
168   using StatusCaster = make_caster<tensorflow::Status>;
169   static constexpr auto name = PayloadCaster::name;
170 
171   static handle cast(const tensorflow::StatusOr<PayloadType>* src,
172                      return_value_policy policy, handle parent) {
173     if (!src) return none().release();
174     return cast_impl(*src, policy, parent);
175   }
176 
177   static handle cast(const tensorflow::StatusOr<PayloadType>& src,
178                      return_value_policy policy, handle parent) {
179     return cast_impl(src, policy, parent);
180   }
181 
182   static handle cast(tensorflow::StatusOr<PayloadType>&& src,
183                      return_value_policy policy, handle parent) {
184     return cast_impl(std::move(src), policy, parent);
185   }
186 
187  private:
188   template <typename CType>
189   static handle cast_impl(CType&& src, return_value_policy policy,
190                           handle parent) {
191     if (src.ok()) {
192       // Convert and return the payload.
193       return PayloadCaster::cast(std::forward<CType>(src).ValueOrDie(), policy,
194                                  parent);
195     } else {
196       // Convert and return the error.
197       return StatusCaster::cast(std::forward<CType>(src).status(),
198                                 return_value_policy::move, parent);
199     }
200   }
201 };
202 
203 }  // namespace detail
204 }  // namespace pybind11
205 
206 #endif  // TENSORFLOW_PYTHON_LIB_CORE_PYBIND11_STATUS_H_
207