xref: /aosp_15_r20/external/pytorch/c10/util/complex_math.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #if !defined(C10_INTERNAL_INCLUDE_COMPLEX_REMAINING_H)
2*da0073e9SAndroid Build Coastguard Worker #error \
3*da0073e9SAndroid Build Coastguard Worker     "c10/util/complex_math.h is not meant to be individually included. Include c10/util/complex.h instead."
4*da0073e9SAndroid Build Coastguard Worker #endif
5*da0073e9SAndroid Build Coastguard Worker 
6*da0073e9SAndroid Build Coastguard Worker namespace c10_complex_math {
7*da0073e9SAndroid Build Coastguard Worker 
8*da0073e9SAndroid Build Coastguard Worker // Exponential functions
9*da0073e9SAndroid Build Coastguard Worker 
10*da0073e9SAndroid Build Coastguard Worker template <typename T>
exp(const c10::complex<T> & x)11*da0073e9SAndroid Build Coastguard Worker C10_HOST_DEVICE inline c10::complex<T> exp(const c10::complex<T>& x) {
12*da0073e9SAndroid Build Coastguard Worker #if defined(__CUDACC__) || defined(__HIPCC__)
13*da0073e9SAndroid Build Coastguard Worker   return static_cast<c10::complex<T>>(
14*da0073e9SAndroid Build Coastguard Worker       thrust::exp(static_cast<thrust::complex<T>>(x)));
15*da0073e9SAndroid Build Coastguard Worker #else
16*da0073e9SAndroid Build Coastguard Worker   return static_cast<c10::complex<T>>(
17*da0073e9SAndroid Build Coastguard Worker       std::exp(static_cast<std::complex<T>>(x)));
18*da0073e9SAndroid Build Coastguard Worker #endif
19*da0073e9SAndroid Build Coastguard Worker }
20*da0073e9SAndroid Build Coastguard Worker 
21*da0073e9SAndroid Build Coastguard Worker template <typename T>
log(const c10::complex<T> & x)22*da0073e9SAndroid Build Coastguard Worker C10_HOST_DEVICE inline c10::complex<T> log(const c10::complex<T>& x) {
23*da0073e9SAndroid Build Coastguard Worker #if defined(__CUDACC__) || defined(__HIPCC__)
24*da0073e9SAndroid Build Coastguard Worker   return static_cast<c10::complex<T>>(
25*da0073e9SAndroid Build Coastguard Worker       thrust::log(static_cast<thrust::complex<T>>(x)));
26*da0073e9SAndroid Build Coastguard Worker #else
27*da0073e9SAndroid Build Coastguard Worker   return static_cast<c10::complex<T>>(
28*da0073e9SAndroid Build Coastguard Worker       std::log(static_cast<std::complex<T>>(x)));
29*da0073e9SAndroid Build Coastguard Worker #endif
30*da0073e9SAndroid Build Coastguard Worker }
31*da0073e9SAndroid Build Coastguard Worker 
32*da0073e9SAndroid Build Coastguard Worker template <typename T>
log10(const c10::complex<T> & x)33*da0073e9SAndroid Build Coastguard Worker C10_HOST_DEVICE inline c10::complex<T> log10(const c10::complex<T>& x) {
34*da0073e9SAndroid Build Coastguard Worker #if defined(__CUDACC__) || defined(__HIPCC__)
35*da0073e9SAndroid Build Coastguard Worker   return static_cast<c10::complex<T>>(
36*da0073e9SAndroid Build Coastguard Worker       thrust::log10(static_cast<thrust::complex<T>>(x)));
37*da0073e9SAndroid Build Coastguard Worker #else
38*da0073e9SAndroid Build Coastguard Worker   return static_cast<c10::complex<T>>(
39*da0073e9SAndroid Build Coastguard Worker       std::log10(static_cast<std::complex<T>>(x)));
40*da0073e9SAndroid Build Coastguard Worker #endif
41*da0073e9SAndroid Build Coastguard Worker }
42*da0073e9SAndroid Build Coastguard Worker 
43*da0073e9SAndroid Build Coastguard Worker template <typename T>
log2(const c10::complex<T> & x)44*da0073e9SAndroid Build Coastguard Worker C10_HOST_DEVICE inline c10::complex<T> log2(const c10::complex<T>& x) {
45*da0073e9SAndroid Build Coastguard Worker   const c10::complex<T> log2 = c10::complex<T>(::log(2.0), 0.0);
46*da0073e9SAndroid Build Coastguard Worker   return c10_complex_math::log(x) / log2;
47*da0073e9SAndroid Build Coastguard Worker }
48*da0073e9SAndroid Build Coastguard Worker 
49*da0073e9SAndroid Build Coastguard Worker // Power functions
50*da0073e9SAndroid Build Coastguard Worker //
51*da0073e9SAndroid Build Coastguard Worker #if defined(_LIBCPP_VERSION) || \
52*da0073e9SAndroid Build Coastguard Worker     (defined(__GLIBCXX__) && !defined(_GLIBCXX11_USE_C99_COMPLEX))
53*da0073e9SAndroid Build Coastguard Worker namespace _detail {
54*da0073e9SAndroid Build Coastguard Worker C10_API c10::complex<float> sqrt(const c10::complex<float>& in);
55*da0073e9SAndroid Build Coastguard Worker C10_API c10::complex<double> sqrt(const c10::complex<double>& in);
56*da0073e9SAndroid Build Coastguard Worker C10_API c10::complex<float> acos(const c10::complex<float>& in);
57*da0073e9SAndroid Build Coastguard Worker C10_API c10::complex<double> acos(const c10::complex<double>& in);
58*da0073e9SAndroid Build Coastguard Worker } // namespace _detail
59*da0073e9SAndroid Build Coastguard Worker #endif
60*da0073e9SAndroid Build Coastguard Worker 
61*da0073e9SAndroid Build Coastguard Worker template <typename T>
sqrt(const c10::complex<T> & x)62*da0073e9SAndroid Build Coastguard Worker C10_HOST_DEVICE inline c10::complex<T> sqrt(const c10::complex<T>& x) {
63*da0073e9SAndroid Build Coastguard Worker #if defined(__CUDACC__) || defined(__HIPCC__)
64*da0073e9SAndroid Build Coastguard Worker   return static_cast<c10::complex<T>>(
65*da0073e9SAndroid Build Coastguard Worker       thrust::sqrt(static_cast<thrust::complex<T>>(x)));
66*da0073e9SAndroid Build Coastguard Worker #elif !(                        \
67*da0073e9SAndroid Build Coastguard Worker     defined(_LIBCPP_VERSION) || \
68*da0073e9SAndroid Build Coastguard Worker     (defined(__GLIBCXX__) && !defined(_GLIBCXX11_USE_C99_COMPLEX)))
69*da0073e9SAndroid Build Coastguard Worker   return static_cast<c10::complex<T>>(
70*da0073e9SAndroid Build Coastguard Worker       std::sqrt(static_cast<std::complex<T>>(x)));
71*da0073e9SAndroid Build Coastguard Worker #else
72*da0073e9SAndroid Build Coastguard Worker   return _detail::sqrt(x);
73*da0073e9SAndroid Build Coastguard Worker #endif
74*da0073e9SAndroid Build Coastguard Worker }
75*da0073e9SAndroid Build Coastguard Worker 
76*da0073e9SAndroid Build Coastguard Worker template <typename T>
pow(const c10::complex<T> & x,const c10::complex<T> & y)77*da0073e9SAndroid Build Coastguard Worker C10_HOST_DEVICE inline c10::complex<T> pow(
78*da0073e9SAndroid Build Coastguard Worker     const c10::complex<T>& x,
79*da0073e9SAndroid Build Coastguard Worker     const c10::complex<T>& y) {
80*da0073e9SAndroid Build Coastguard Worker #if defined(__CUDACC__) || defined(__HIPCC__)
81*da0073e9SAndroid Build Coastguard Worker   return static_cast<c10::complex<T>>(thrust::pow(
82*da0073e9SAndroid Build Coastguard Worker       static_cast<thrust::complex<T>>(x), static_cast<thrust::complex<T>>(y)));
83*da0073e9SAndroid Build Coastguard Worker #else
84*da0073e9SAndroid Build Coastguard Worker   return static_cast<c10::complex<T>>(std::pow(
85*da0073e9SAndroid Build Coastguard Worker       static_cast<std::complex<T>>(x), static_cast<std::complex<T>>(y)));
86*da0073e9SAndroid Build Coastguard Worker #endif
87*da0073e9SAndroid Build Coastguard Worker }
88*da0073e9SAndroid Build Coastguard Worker 
89*da0073e9SAndroid Build Coastguard Worker template <typename T>
pow(const c10::complex<T> & x,const T & y)90*da0073e9SAndroid Build Coastguard Worker C10_HOST_DEVICE inline c10::complex<T> pow(
91*da0073e9SAndroid Build Coastguard Worker     const c10::complex<T>& x,
92*da0073e9SAndroid Build Coastguard Worker     const T& y) {
93*da0073e9SAndroid Build Coastguard Worker #if defined(__CUDACC__) || defined(__HIPCC__)
94*da0073e9SAndroid Build Coastguard Worker   return static_cast<c10::complex<T>>(
95*da0073e9SAndroid Build Coastguard Worker       thrust::pow(static_cast<thrust::complex<T>>(x), y));
96*da0073e9SAndroid Build Coastguard Worker #else
97*da0073e9SAndroid Build Coastguard Worker   return static_cast<c10::complex<T>>(
98*da0073e9SAndroid Build Coastguard Worker       std::pow(static_cast<std::complex<T>>(x), y));
99*da0073e9SAndroid Build Coastguard Worker #endif
100*da0073e9SAndroid Build Coastguard Worker }
101*da0073e9SAndroid Build Coastguard Worker 
102*da0073e9SAndroid Build Coastguard Worker template <typename T>
pow(const T & x,const c10::complex<T> & y)103*da0073e9SAndroid Build Coastguard Worker C10_HOST_DEVICE inline c10::complex<T> pow(
104*da0073e9SAndroid Build Coastguard Worker     const T& x,
105*da0073e9SAndroid Build Coastguard Worker     const c10::complex<T>& y) {
106*da0073e9SAndroid Build Coastguard Worker #if defined(__CUDACC__) || defined(__HIPCC__)
107*da0073e9SAndroid Build Coastguard Worker   return static_cast<c10::complex<T>>(
108*da0073e9SAndroid Build Coastguard Worker       thrust::pow(x, static_cast<thrust::complex<T>>(y)));
109*da0073e9SAndroid Build Coastguard Worker #else
110*da0073e9SAndroid Build Coastguard Worker   return static_cast<c10::complex<T>>(
111*da0073e9SAndroid Build Coastguard Worker       std::pow(x, static_cast<std::complex<T>>(y)));
112*da0073e9SAndroid Build Coastguard Worker #endif
113*da0073e9SAndroid Build Coastguard Worker }
114*da0073e9SAndroid Build Coastguard Worker 
115*da0073e9SAndroid Build Coastguard Worker template <typename T, typename U>
pow(const c10::complex<T> & x,const c10::complex<U> & y)116*da0073e9SAndroid Build Coastguard Worker C10_HOST_DEVICE inline c10::complex<decltype(T() * U())> pow(
117*da0073e9SAndroid Build Coastguard Worker     const c10::complex<T>& x,
118*da0073e9SAndroid Build Coastguard Worker     const c10::complex<U>& y) {
119*da0073e9SAndroid Build Coastguard Worker #if defined(__CUDACC__) || defined(__HIPCC__)
120*da0073e9SAndroid Build Coastguard Worker   return static_cast<c10::complex<T>>(thrust::pow(
121*da0073e9SAndroid Build Coastguard Worker       static_cast<thrust::complex<T>>(x), static_cast<thrust::complex<T>>(y)));
122*da0073e9SAndroid Build Coastguard Worker #else
123*da0073e9SAndroid Build Coastguard Worker   return static_cast<c10::complex<T>>(std::pow(
124*da0073e9SAndroid Build Coastguard Worker       static_cast<std::complex<T>>(x), static_cast<std::complex<T>>(y)));
125*da0073e9SAndroid Build Coastguard Worker #endif
126*da0073e9SAndroid Build Coastguard Worker }
127*da0073e9SAndroid Build Coastguard Worker 
128*da0073e9SAndroid Build Coastguard Worker template <typename T, typename U>
pow(const c10::complex<T> & x,const U & y)129*da0073e9SAndroid Build Coastguard Worker C10_HOST_DEVICE inline c10::complex<decltype(T() * U())> pow(
130*da0073e9SAndroid Build Coastguard Worker     const c10::complex<T>& x,
131*da0073e9SAndroid Build Coastguard Worker     const U& y) {
132*da0073e9SAndroid Build Coastguard Worker #if defined(__CUDACC__) || defined(__HIPCC__)
133*da0073e9SAndroid Build Coastguard Worker   return static_cast<c10::complex<T>>(
134*da0073e9SAndroid Build Coastguard Worker       thrust::pow(static_cast<thrust::complex<T>>(x), y));
135*da0073e9SAndroid Build Coastguard Worker #else
136*da0073e9SAndroid Build Coastguard Worker   return static_cast<c10::complex<T>>(
137*da0073e9SAndroid Build Coastguard Worker       std::pow(static_cast<std::complex<T>>(x), y));
138*da0073e9SAndroid Build Coastguard Worker #endif
139*da0073e9SAndroid Build Coastguard Worker }
140*da0073e9SAndroid Build Coastguard Worker 
141*da0073e9SAndroid Build Coastguard Worker template <typename T, typename U>
pow(const T & x,const c10::complex<U> & y)142*da0073e9SAndroid Build Coastguard Worker C10_HOST_DEVICE inline c10::complex<decltype(T() * U())> pow(
143*da0073e9SAndroid Build Coastguard Worker     const T& x,
144*da0073e9SAndroid Build Coastguard Worker     const c10::complex<U>& y) {
145*da0073e9SAndroid Build Coastguard Worker #if defined(__CUDACC__) || defined(__HIPCC__)
146*da0073e9SAndroid Build Coastguard Worker   return static_cast<c10::complex<T>>(
147*da0073e9SAndroid Build Coastguard Worker       thrust::pow(x, static_cast<thrust::complex<T>>(y)));
148*da0073e9SAndroid Build Coastguard Worker #else
149*da0073e9SAndroid Build Coastguard Worker   return static_cast<c10::complex<T>>(
150*da0073e9SAndroid Build Coastguard Worker       std::pow(x, static_cast<std::complex<T>>(y)));
151*da0073e9SAndroid Build Coastguard Worker #endif
152*da0073e9SAndroid Build Coastguard Worker }
153*da0073e9SAndroid Build Coastguard Worker 
154*da0073e9SAndroid Build Coastguard Worker // Trigonometric functions
155*da0073e9SAndroid Build Coastguard Worker 
156*da0073e9SAndroid Build Coastguard Worker template <typename T>
sin(const c10::complex<T> & x)157*da0073e9SAndroid Build Coastguard Worker C10_HOST_DEVICE inline c10::complex<T> sin(const c10::complex<T>& x) {
158*da0073e9SAndroid Build Coastguard Worker #if defined(__CUDACC__) || defined(__HIPCC__)
159*da0073e9SAndroid Build Coastguard Worker   return static_cast<c10::complex<T>>(
160*da0073e9SAndroid Build Coastguard Worker       thrust::sin(static_cast<thrust::complex<T>>(x)));
161*da0073e9SAndroid Build Coastguard Worker #else
162*da0073e9SAndroid Build Coastguard Worker   return static_cast<c10::complex<T>>(
163*da0073e9SAndroid Build Coastguard Worker       std::sin(static_cast<std::complex<T>>(x)));
164*da0073e9SAndroid Build Coastguard Worker #endif
165*da0073e9SAndroid Build Coastguard Worker }
166*da0073e9SAndroid Build Coastguard Worker 
167*da0073e9SAndroid Build Coastguard Worker template <typename T>
cos(const c10::complex<T> & x)168*da0073e9SAndroid Build Coastguard Worker C10_HOST_DEVICE inline c10::complex<T> cos(const c10::complex<T>& x) {
169*da0073e9SAndroid Build Coastguard Worker #if defined(__CUDACC__) || defined(__HIPCC__)
170*da0073e9SAndroid Build Coastguard Worker   return static_cast<c10::complex<T>>(
171*da0073e9SAndroid Build Coastguard Worker       thrust::cos(static_cast<thrust::complex<T>>(x)));
172*da0073e9SAndroid Build Coastguard Worker #else
173*da0073e9SAndroid Build Coastguard Worker   return static_cast<c10::complex<T>>(
174*da0073e9SAndroid Build Coastguard Worker       std::cos(static_cast<std::complex<T>>(x)));
175*da0073e9SAndroid Build Coastguard Worker #endif
176*da0073e9SAndroid Build Coastguard Worker }
177*da0073e9SAndroid Build Coastguard Worker 
178*da0073e9SAndroid Build Coastguard Worker template <typename T>
tan(const c10::complex<T> & x)179*da0073e9SAndroid Build Coastguard Worker C10_HOST_DEVICE inline c10::complex<T> tan(const c10::complex<T>& x) {
180*da0073e9SAndroid Build Coastguard Worker #if defined(__CUDACC__) || defined(__HIPCC__)
181*da0073e9SAndroid Build Coastguard Worker   return static_cast<c10::complex<T>>(
182*da0073e9SAndroid Build Coastguard Worker       thrust::tan(static_cast<thrust::complex<T>>(x)));
183*da0073e9SAndroid Build Coastguard Worker #else
184*da0073e9SAndroid Build Coastguard Worker   return static_cast<c10::complex<T>>(
185*da0073e9SAndroid Build Coastguard Worker       std::tan(static_cast<std::complex<T>>(x)));
186*da0073e9SAndroid Build Coastguard Worker #endif
187*da0073e9SAndroid Build Coastguard Worker }
188*da0073e9SAndroid Build Coastguard Worker 
189*da0073e9SAndroid Build Coastguard Worker template <typename T>
asin(const c10::complex<T> & x)190*da0073e9SAndroid Build Coastguard Worker C10_HOST_DEVICE inline c10::complex<T> asin(const c10::complex<T>& x) {
191*da0073e9SAndroid Build Coastguard Worker #if defined(__CUDACC__) || defined(__HIPCC__)
192*da0073e9SAndroid Build Coastguard Worker   return static_cast<c10::complex<T>>(
193*da0073e9SAndroid Build Coastguard Worker       thrust::asin(static_cast<thrust::complex<T>>(x)));
194*da0073e9SAndroid Build Coastguard Worker #else
195*da0073e9SAndroid Build Coastguard Worker   return static_cast<c10::complex<T>>(
196*da0073e9SAndroid Build Coastguard Worker       std::asin(static_cast<std::complex<T>>(x)));
197*da0073e9SAndroid Build Coastguard Worker #endif
198*da0073e9SAndroid Build Coastguard Worker }
199*da0073e9SAndroid Build Coastguard Worker 
200*da0073e9SAndroid Build Coastguard Worker template <typename T>
acos(const c10::complex<T> & x)201*da0073e9SAndroid Build Coastguard Worker C10_HOST_DEVICE inline c10::complex<T> acos(const c10::complex<T>& x) {
202*da0073e9SAndroid Build Coastguard Worker #if defined(__CUDACC__) || defined(__HIPCC__)
203*da0073e9SAndroid Build Coastguard Worker   return static_cast<c10::complex<T>>(
204*da0073e9SAndroid Build Coastguard Worker       thrust::acos(static_cast<thrust::complex<T>>(x)));
205*da0073e9SAndroid Build Coastguard Worker #elif !defined(_LIBCPP_VERSION)
206*da0073e9SAndroid Build Coastguard Worker   return static_cast<c10::complex<T>>(
207*da0073e9SAndroid Build Coastguard Worker       std::acos(static_cast<std::complex<T>>(x)));
208*da0073e9SAndroid Build Coastguard Worker #else
209*da0073e9SAndroid Build Coastguard Worker   return _detail::acos(x);
210*da0073e9SAndroid Build Coastguard Worker #endif
211*da0073e9SAndroid Build Coastguard Worker }
212*da0073e9SAndroid Build Coastguard Worker 
213*da0073e9SAndroid Build Coastguard Worker template <typename T>
atan(const c10::complex<T> & x)214*da0073e9SAndroid Build Coastguard Worker C10_HOST_DEVICE inline c10::complex<T> atan(const c10::complex<T>& x) {
215*da0073e9SAndroid Build Coastguard Worker #if defined(__CUDACC__) || defined(__HIPCC__)
216*da0073e9SAndroid Build Coastguard Worker   return static_cast<c10::complex<T>>(
217*da0073e9SAndroid Build Coastguard Worker       thrust::atan(static_cast<thrust::complex<T>>(x)));
218*da0073e9SAndroid Build Coastguard Worker #else
219*da0073e9SAndroid Build Coastguard Worker   return static_cast<c10::complex<T>>(
220*da0073e9SAndroid Build Coastguard Worker       std::atan(static_cast<std::complex<T>>(x)));
221*da0073e9SAndroid Build Coastguard Worker #endif
222*da0073e9SAndroid Build Coastguard Worker }
223*da0073e9SAndroid Build Coastguard Worker 
224*da0073e9SAndroid Build Coastguard Worker // Hyperbolic functions
225*da0073e9SAndroid Build Coastguard Worker 
226*da0073e9SAndroid Build Coastguard Worker template <typename T>
sinh(const c10::complex<T> & x)227*da0073e9SAndroid Build Coastguard Worker C10_HOST_DEVICE inline c10::complex<T> sinh(const c10::complex<T>& x) {
228*da0073e9SAndroid Build Coastguard Worker #if defined(__CUDACC__) || defined(__HIPCC__)
229*da0073e9SAndroid Build Coastguard Worker   return static_cast<c10::complex<T>>(
230*da0073e9SAndroid Build Coastguard Worker       thrust::sinh(static_cast<thrust::complex<T>>(x)));
231*da0073e9SAndroid Build Coastguard Worker #else
232*da0073e9SAndroid Build Coastguard Worker   return static_cast<c10::complex<T>>(
233*da0073e9SAndroid Build Coastguard Worker       std::sinh(static_cast<std::complex<T>>(x)));
234*da0073e9SAndroid Build Coastguard Worker #endif
235*da0073e9SAndroid Build Coastguard Worker }
236*da0073e9SAndroid Build Coastguard Worker 
237*da0073e9SAndroid Build Coastguard Worker template <typename T>
cosh(const c10::complex<T> & x)238*da0073e9SAndroid Build Coastguard Worker C10_HOST_DEVICE inline c10::complex<T> cosh(const c10::complex<T>& x) {
239*da0073e9SAndroid Build Coastguard Worker #if defined(__CUDACC__) || defined(__HIPCC__)
240*da0073e9SAndroid Build Coastguard Worker   return static_cast<c10::complex<T>>(
241*da0073e9SAndroid Build Coastguard Worker       thrust::cosh(static_cast<thrust::complex<T>>(x)));
242*da0073e9SAndroid Build Coastguard Worker #else
243*da0073e9SAndroid Build Coastguard Worker   return static_cast<c10::complex<T>>(
244*da0073e9SAndroid Build Coastguard Worker       std::cosh(static_cast<std::complex<T>>(x)));
245*da0073e9SAndroid Build Coastguard Worker #endif
246*da0073e9SAndroid Build Coastguard Worker }
247*da0073e9SAndroid Build Coastguard Worker 
248*da0073e9SAndroid Build Coastguard Worker template <typename T>
tanh(const c10::complex<T> & x)249*da0073e9SAndroid Build Coastguard Worker C10_HOST_DEVICE inline c10::complex<T> tanh(const c10::complex<T>& x) {
250*da0073e9SAndroid Build Coastguard Worker #if defined(__CUDACC__) || defined(__HIPCC__)
251*da0073e9SAndroid Build Coastguard Worker   return static_cast<c10::complex<T>>(
252*da0073e9SAndroid Build Coastguard Worker       thrust::tanh(static_cast<thrust::complex<T>>(x)));
253*da0073e9SAndroid Build Coastguard Worker #else
254*da0073e9SAndroid Build Coastguard Worker   return static_cast<c10::complex<T>>(
255*da0073e9SAndroid Build Coastguard Worker       std::tanh(static_cast<std::complex<T>>(x)));
256*da0073e9SAndroid Build Coastguard Worker #endif
257*da0073e9SAndroid Build Coastguard Worker }
258*da0073e9SAndroid Build Coastguard Worker 
259*da0073e9SAndroid Build Coastguard Worker template <typename T>
asinh(const c10::complex<T> & x)260*da0073e9SAndroid Build Coastguard Worker C10_HOST_DEVICE inline c10::complex<T> asinh(const c10::complex<T>& x) {
261*da0073e9SAndroid Build Coastguard Worker #if defined(__CUDACC__) || defined(__HIPCC__)
262*da0073e9SAndroid Build Coastguard Worker   return static_cast<c10::complex<T>>(
263*da0073e9SAndroid Build Coastguard Worker       thrust::asinh(static_cast<thrust::complex<T>>(x)));
264*da0073e9SAndroid Build Coastguard Worker #else
265*da0073e9SAndroid Build Coastguard Worker   return static_cast<c10::complex<T>>(
266*da0073e9SAndroid Build Coastguard Worker       std::asinh(static_cast<std::complex<T>>(x)));
267*da0073e9SAndroid Build Coastguard Worker #endif
268*da0073e9SAndroid Build Coastguard Worker }
269*da0073e9SAndroid Build Coastguard Worker 
270*da0073e9SAndroid Build Coastguard Worker template <typename T>
acosh(const c10::complex<T> & x)271*da0073e9SAndroid Build Coastguard Worker C10_HOST_DEVICE inline c10::complex<T> acosh(const c10::complex<T>& x) {
272*da0073e9SAndroid Build Coastguard Worker #if defined(__CUDACC__) || defined(__HIPCC__)
273*da0073e9SAndroid Build Coastguard Worker   return static_cast<c10::complex<T>>(
274*da0073e9SAndroid Build Coastguard Worker       thrust::acosh(static_cast<thrust::complex<T>>(x)));
275*da0073e9SAndroid Build Coastguard Worker #else
276*da0073e9SAndroid Build Coastguard Worker   return static_cast<c10::complex<T>>(
277*da0073e9SAndroid Build Coastguard Worker       std::acosh(static_cast<std::complex<T>>(x)));
278*da0073e9SAndroid Build Coastguard Worker #endif
279*da0073e9SAndroid Build Coastguard Worker }
280*da0073e9SAndroid Build Coastguard Worker 
281*da0073e9SAndroid Build Coastguard Worker template <typename T>
atanh(const c10::complex<T> & x)282*da0073e9SAndroid Build Coastguard Worker C10_HOST_DEVICE inline c10::complex<T> atanh(const c10::complex<T>& x) {
283*da0073e9SAndroid Build Coastguard Worker #if defined(__CUDACC__) || defined(__HIPCC__)
284*da0073e9SAndroid Build Coastguard Worker   return static_cast<c10::complex<T>>(
285*da0073e9SAndroid Build Coastguard Worker       thrust::atanh(static_cast<thrust::complex<T>>(x)));
286*da0073e9SAndroid Build Coastguard Worker #else
287*da0073e9SAndroid Build Coastguard Worker   return static_cast<c10::complex<T>>(
288*da0073e9SAndroid Build Coastguard Worker       std::atanh(static_cast<std::complex<T>>(x)));
289*da0073e9SAndroid Build Coastguard Worker #endif
290*da0073e9SAndroid Build Coastguard Worker }
291*da0073e9SAndroid Build Coastguard Worker 
292*da0073e9SAndroid Build Coastguard Worker template <typename T>
log1p(const c10::complex<T> & z)293*da0073e9SAndroid Build Coastguard Worker C10_HOST_DEVICE inline c10::complex<T> log1p(const c10::complex<T>& z) {
294*da0073e9SAndroid Build Coastguard Worker #if defined(__APPLE__) || defined(__MACOSX) || defined(__CUDACC__) || \
295*da0073e9SAndroid Build Coastguard Worker     defined(__HIPCC__)
296*da0073e9SAndroid Build Coastguard Worker   // For Mac, the new implementation yielded a high relative error. Falling back
297*da0073e9SAndroid Build Coastguard Worker   // to the old version for now.
298*da0073e9SAndroid Build Coastguard Worker   // See https://github.com/numpy/numpy/pull/22611#issuecomment-1667945354
299*da0073e9SAndroid Build Coastguard Worker   // For CUDA we also use this one, as thrust::log(thrust::complex) takes
300*da0073e9SAndroid Build Coastguard Worker   // *forever* to compile
301*da0073e9SAndroid Build Coastguard Worker 
302*da0073e9SAndroid Build Coastguard Worker   // log1p(z) = log(1 + z)
303*da0073e9SAndroid Build Coastguard Worker   // Let's define 1 + z = r * e ^ (i * a), then we have
304*da0073e9SAndroid Build Coastguard Worker   // log(r * e ^ (i * a)) = log(r) + i * a
305*da0073e9SAndroid Build Coastguard Worker   // With z = x + iy, the term r can be written as
306*da0073e9SAndroid Build Coastguard Worker   // r = ((1 + x) ^ 2 + y ^ 2) ^ 0.5
307*da0073e9SAndroid Build Coastguard Worker   //   = (1 + x ^ 2 + 2 * x + y ^ 2) ^ 0.5
308*da0073e9SAndroid Build Coastguard Worker   // So, log(r) is
309*da0073e9SAndroid Build Coastguard Worker   // log(r) = 0.5 * log(1 + x ^ 2 + 2 * x + y ^ 2)
310*da0073e9SAndroid Build Coastguard Worker   //        = 0.5 * log1p(x * (x + 2) + y ^ 2)
311*da0073e9SAndroid Build Coastguard Worker   // we need to use the expression only on certain condition to avoid overflow
312*da0073e9SAndroid Build Coastguard Worker   // and underflow from `(x * (x + 2) + y ^ 2)`
313*da0073e9SAndroid Build Coastguard Worker   T x = z.real();
314*da0073e9SAndroid Build Coastguard Worker   T y = z.imag();
315*da0073e9SAndroid Build Coastguard Worker   T zabs = std::abs(z);
316*da0073e9SAndroid Build Coastguard Worker   T theta = std::atan2(y, x + T(1));
317*da0073e9SAndroid Build Coastguard Worker   if (zabs < 0.5) {
318*da0073e9SAndroid Build Coastguard Worker     T r = x * (T(2) + x) + y * y;
319*da0073e9SAndroid Build Coastguard Worker     if (r == 0) { // handle underflow
320*da0073e9SAndroid Build Coastguard Worker       return {x, theta};
321*da0073e9SAndroid Build Coastguard Worker     }
322*da0073e9SAndroid Build Coastguard Worker     return {T(0.5) * std::log1p(r), theta};
323*da0073e9SAndroid Build Coastguard Worker   } else {
324*da0073e9SAndroid Build Coastguard Worker     T z0 = std::hypot(x + 1, y);
325*da0073e9SAndroid Build Coastguard Worker     return {std::log(z0), theta};
326*da0073e9SAndroid Build Coastguard Worker   }
327*da0073e9SAndroid Build Coastguard Worker #else
328*da0073e9SAndroid Build Coastguard Worker   // CPU path
329*da0073e9SAndroid Build Coastguard Worker   // Based on https://github.com/numpy/numpy/pull/22611#issuecomment-1667945354
330*da0073e9SAndroid Build Coastguard Worker   c10::complex<T> u = z + T(1);
331*da0073e9SAndroid Build Coastguard Worker   if (u == T(1)) {
332*da0073e9SAndroid Build Coastguard Worker     return z;
333*da0073e9SAndroid Build Coastguard Worker   } else {
334*da0073e9SAndroid Build Coastguard Worker     auto log_u = log(u);
335*da0073e9SAndroid Build Coastguard Worker     if (u - T(1) == z) {
336*da0073e9SAndroid Build Coastguard Worker       return log_u;
337*da0073e9SAndroid Build Coastguard Worker     }
338*da0073e9SAndroid Build Coastguard Worker     return log_u * (z / (u - T(1)));
339*da0073e9SAndroid Build Coastguard Worker   }
340*da0073e9SAndroid Build Coastguard Worker #endif
341*da0073e9SAndroid Build Coastguard Worker }
342*da0073e9SAndroid Build Coastguard Worker 
343*da0073e9SAndroid Build Coastguard Worker template <typename T>
expm1(const c10::complex<T> & z)344*da0073e9SAndroid Build Coastguard Worker C10_HOST_DEVICE inline c10::complex<T> expm1(const c10::complex<T>& z) {
345*da0073e9SAndroid Build Coastguard Worker   // expm1(z) = exp(z) - 1
346*da0073e9SAndroid Build Coastguard Worker   // Define z = x + i * y
347*da0073e9SAndroid Build Coastguard Worker   // f = e ^ (x + i * y) - 1
348*da0073e9SAndroid Build Coastguard Worker   //   = e ^ x * e ^ (i * y) - 1
349*da0073e9SAndroid Build Coastguard Worker   //   = (e ^ x * cos(y) - 1) + i * (e ^ x * sin(y))
350*da0073e9SAndroid Build Coastguard Worker   //   = (e ^ x - 1) * cos(y) - (1 - cos(y)) + i * e ^ x * sin(y)
351*da0073e9SAndroid Build Coastguard Worker   //   = expm1(x) * cos(y) - 2 * sin(y / 2) ^ 2 + i * e ^ x * sin(y)
352*da0073e9SAndroid Build Coastguard Worker   T x = z.real();
353*da0073e9SAndroid Build Coastguard Worker   T y = z.imag();
354*da0073e9SAndroid Build Coastguard Worker   T a = std::sin(y / 2);
355*da0073e9SAndroid Build Coastguard Worker   T er = std::expm1(x) * std::cos(y) - T(2) * a * a;
356*da0073e9SAndroid Build Coastguard Worker   T ei = std::exp(x) * std::sin(y);
357*da0073e9SAndroid Build Coastguard Worker   return {er, ei};
358*da0073e9SAndroid Build Coastguard Worker }
359*da0073e9SAndroid Build Coastguard Worker 
360*da0073e9SAndroid Build Coastguard Worker } // namespace c10_complex_math
361*da0073e9SAndroid Build Coastguard Worker 
362*da0073e9SAndroid Build Coastguard Worker using c10_complex_math::acos;
363*da0073e9SAndroid Build Coastguard Worker using c10_complex_math::acosh;
364*da0073e9SAndroid Build Coastguard Worker using c10_complex_math::asin;
365*da0073e9SAndroid Build Coastguard Worker using c10_complex_math::asinh;
366*da0073e9SAndroid Build Coastguard Worker using c10_complex_math::atan;
367*da0073e9SAndroid Build Coastguard Worker using c10_complex_math::atanh;
368*da0073e9SAndroid Build Coastguard Worker using c10_complex_math::cos;
369*da0073e9SAndroid Build Coastguard Worker using c10_complex_math::cosh;
370*da0073e9SAndroid Build Coastguard Worker using c10_complex_math::exp;
371*da0073e9SAndroid Build Coastguard Worker using c10_complex_math::expm1;
372*da0073e9SAndroid Build Coastguard Worker using c10_complex_math::log;
373*da0073e9SAndroid Build Coastguard Worker using c10_complex_math::log10;
374*da0073e9SAndroid Build Coastguard Worker using c10_complex_math::log1p;
375*da0073e9SAndroid Build Coastguard Worker using c10_complex_math::log2;
376*da0073e9SAndroid Build Coastguard Worker using c10_complex_math::pow;
377*da0073e9SAndroid Build Coastguard Worker using c10_complex_math::sin;
378*da0073e9SAndroid Build Coastguard Worker using c10_complex_math::sinh;
379*da0073e9SAndroid Build Coastguard Worker using c10_complex_math::sqrt;
380*da0073e9SAndroid Build Coastguard Worker using c10_complex_math::tan;
381*da0073e9SAndroid Build Coastguard Worker using c10_complex_math::tanh;
382*da0073e9SAndroid Build Coastguard Worker 
383*da0073e9SAndroid Build Coastguard Worker namespace std {
384*da0073e9SAndroid Build Coastguard Worker 
385*da0073e9SAndroid Build Coastguard Worker using c10_complex_math::acos;
386*da0073e9SAndroid Build Coastguard Worker using c10_complex_math::acosh;
387*da0073e9SAndroid Build Coastguard Worker using c10_complex_math::asin;
388*da0073e9SAndroid Build Coastguard Worker using c10_complex_math::asinh;
389*da0073e9SAndroid Build Coastguard Worker using c10_complex_math::atan;
390*da0073e9SAndroid Build Coastguard Worker using c10_complex_math::atanh;
391*da0073e9SAndroid Build Coastguard Worker using c10_complex_math::cos;
392*da0073e9SAndroid Build Coastguard Worker using c10_complex_math::cosh;
393*da0073e9SAndroid Build Coastguard Worker using c10_complex_math::exp;
394*da0073e9SAndroid Build Coastguard Worker using c10_complex_math::expm1;
395*da0073e9SAndroid Build Coastguard Worker using c10_complex_math::log;
396*da0073e9SAndroid Build Coastguard Worker using c10_complex_math::log10;
397*da0073e9SAndroid Build Coastguard Worker using c10_complex_math::log1p;
398*da0073e9SAndroid Build Coastguard Worker using c10_complex_math::log2;
399*da0073e9SAndroid Build Coastguard Worker using c10_complex_math::pow;
400*da0073e9SAndroid Build Coastguard Worker using c10_complex_math::sin;
401*da0073e9SAndroid Build Coastguard Worker using c10_complex_math::sinh;
402*da0073e9SAndroid Build Coastguard Worker using c10_complex_math::sqrt;
403*da0073e9SAndroid Build Coastguard Worker using c10_complex_math::tan;
404*da0073e9SAndroid Build Coastguard Worker using c10_complex_math::tanh;
405*da0073e9SAndroid Build Coastguard Worker 
406*da0073e9SAndroid Build Coastguard Worker } // namespace std
407