1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_PYTHON_LIB_CORE_PYBIND11_STATUS_H_
17 #define TENSORFLOW_PYTHON_LIB_CORE_PYBIND11_STATUS_H_
18
19 #include <Python.h>
20
21 #include "pybind11/cast.h"
22 #include "pybind11/pybind11.h"
23 #include "tensorflow/c/tf_status_internal.h"
24 #include "tensorflow/core/lib/core/status.h"
25 #include "tensorflow/core/platform/errors.h"
26 #include "tensorflow/core/platform/statusor.h"
27 #include "tensorflow/core/protobuf/error_codes.pb.h"
28 #include "tensorflow/python/lib/core/py_exception_registry.h"
29
30 namespace tensorflow {
31
32 namespace internal {
33
CodeToPyExc(const int code)34 inline PyObject* CodeToPyExc(const int code) {
35 switch (code) {
36 case error::Code::INVALID_ARGUMENT:
37 return PyExc_ValueError;
38 case error::Code::OUT_OF_RANGE:
39 return PyExc_IndexError;
40 case error::Code::UNIMPLEMENTED:
41 return PyExc_NotImplementedError;
42 default:
43 return PyExc_RuntimeError;
44 }
45 }
46
StatusToPyExc(const Status & status)47 inline PyObject* StatusToPyExc(const Status& status) {
48 return CodeToPyExc(status.code());
49 }
50
TFStatusToPyExc(const TF_Status * status)51 inline PyObject* TFStatusToPyExc(const TF_Status* status) {
52 return CodeToPyExc(TF_GetCode(status));
53 }
54
StatusPayloadToDict(const Status & status)55 inline pybind11::dict StatusPayloadToDict(const Status& status) {
56 pybind11::dict dict;
57 const auto& payloads = errors::GetPayloads(status);
58 for (auto& pair : payloads) {
59 dict[PyBytes_FromString(pair.first.c_str())] =
60 PyBytes_FromString(pair.second.c_str());
61 }
62 return dict;
63 }
64
TFStatusPayloadToDict(TF_Status * status)65 inline pybind11::dict TFStatusPayloadToDict(TF_Status* status) {
66 return StatusPayloadToDict(status->status);
67 }
68
69 } // namespace internal
70
MaybeRaiseFromStatus(const Status & status)71 inline void MaybeRaiseFromStatus(const Status& status) {
72 if (!status.ok()) {
73 PyErr_SetString(internal::StatusToPyExc(status),
74 status.error_message().c_str());
75 throw pybind11::error_already_set();
76 }
77 }
78
SetRegisteredErrFromStatus(const tensorflow::Status & status)79 inline void SetRegisteredErrFromStatus(const tensorflow::Status& status) {
80 PyErr_SetObject(PyExceptionRegistry::Lookup(status.code()),
81 pybind11::make_tuple(pybind11::none(), pybind11::none(),
82 status.error_message(),
83 internal::StatusPayloadToDict(status))
84 .ptr());
85 }
86
SetRegisteredErrFromTFStatus(TF_Status * status)87 inline void SetRegisteredErrFromTFStatus(TF_Status* status) {
88 PyErr_SetObject(PyExceptionRegistry::Lookup(TF_GetCode(status)),
89 pybind11::make_tuple(pybind11::none(), pybind11::none(),
90 TF_Message(status),
91 internal::TFStatusPayloadToDict(status))
92 .ptr());
93 }
94
MaybeRaiseRegisteredFromStatus(const tensorflow::Status & status)95 inline void MaybeRaiseRegisteredFromStatus(const tensorflow::Status& status) {
96 if (!status.ok()) {
97 SetRegisteredErrFromStatus(status);
98 throw pybind11::error_already_set();
99 }
100 }
101
MaybeRaiseRegisteredFromStatusWithGIL(const tensorflow::Status & status)102 inline void MaybeRaiseRegisteredFromStatusWithGIL(
103 const tensorflow::Status& status) {
104 if (!status.ok()) {
105 // Acquire GIL for throwing exception.
106 pybind11::gil_scoped_acquire acquire;
107 SetRegisteredErrFromStatus(status);
108 throw pybind11::error_already_set();
109 }
110 }
111
MaybeRaiseFromTFStatus(TF_Status * status)112 inline void MaybeRaiseFromTFStatus(TF_Status* status) {
113 TF_Code code = TF_GetCode(status);
114 if (code != TF_OK) {
115 PyErr_SetString(internal::TFStatusToPyExc(status), TF_Message(status));
116 throw pybind11::error_already_set();
117 }
118 }
119
MaybeRaiseRegisteredFromTFStatus(TF_Status * status)120 inline void MaybeRaiseRegisteredFromTFStatus(TF_Status* status) {
121 TF_Code code = TF_GetCode(status);
122 if (code != TF_OK) {
123 SetRegisteredErrFromTFStatus(status);
124 throw pybind11::error_already_set();
125 }
126 }
127
MaybeRaiseRegisteredFromTFStatusWithGIL(TF_Status * status)128 inline void MaybeRaiseRegisteredFromTFStatusWithGIL(TF_Status* status) {
129 TF_Code code = TF_GetCode(status);
130 if (code != TF_OK) {
131 // Acquire GIL for throwing exception.
132 pybind11::gil_scoped_acquire acquire;
133 SetRegisteredErrFromTFStatus(status);
134 throw pybind11::error_already_set();
135 }
136 }
137
138 } // namespace tensorflow
139
140 namespace pybind11 {
141 namespace detail {
142
143 // Convert tensorflow::Status
144 //
145 // Raise an exception if a given status is not OK, otherwise return None.
146 //
147 // The correspondence between status codes and exception classes is given
148 // by PyExceptionRegistry. Note that the registry should be initialized
149 // in order to be used, see PyExceptionRegistry::Init.
150 template <>
151 struct type_caster<tensorflow::Status> {
152 public:
153 PYBIND11_TYPE_CASTER(tensorflow::Status, _("Status"));
154 static handle cast(tensorflow::Status status, return_value_policy, handle) {
155 tensorflow::MaybeRaiseFromStatus(status);
156 return none().inc_ref();
157 }
158 };
159
160 // Convert tensorflow::StatusOr
161 //
162 // Uses the same logic as the Abseil implementation: raise an exception if the
163 // status is not OK, otherwise return its payload.
164 template <typename PayloadType>
165 struct type_caster<tensorflow::StatusOr<PayloadType>> {
166 public:
167 using PayloadCaster = make_caster<PayloadType>;
168 using StatusCaster = make_caster<tensorflow::Status>;
169 static constexpr auto name = PayloadCaster::name;
170
171 static handle cast(const tensorflow::StatusOr<PayloadType>* src,
172 return_value_policy policy, handle parent) {
173 if (!src) return none().release();
174 return cast_impl(*src, policy, parent);
175 }
176
177 static handle cast(const tensorflow::StatusOr<PayloadType>& src,
178 return_value_policy policy, handle parent) {
179 return cast_impl(src, policy, parent);
180 }
181
182 static handle cast(tensorflow::StatusOr<PayloadType>&& src,
183 return_value_policy policy, handle parent) {
184 return cast_impl(std::move(src), policy, parent);
185 }
186
187 private:
188 template <typename CType>
189 static handle cast_impl(CType&& src, return_value_policy policy,
190 handle parent) {
191 if (src.ok()) {
192 // Convert and return the payload.
193 return PayloadCaster::cast(std::forward<CType>(src).ValueOrDie(), policy,
194 parent);
195 } else {
196 // Convert and return the error.
197 return StatusCaster::cast(std::forward<CType>(src).status(),
198 return_value_policy::move, parent);
199 }
200 }
201 };
202
203 } // namespace detail
204 } // namespace pybind11
205
206 #endif // TENSORFLOW_PYTHON_LIB_CORE_PYBIND11_STATUS_H_
207